1
0
Fork 0

AIStory: add basic horde support

This commit is contained in:
Pabloader 2024-11-12 13:32:35 +00:00
parent ece1621e73
commit 017ef7aaa5
11 changed files with 655 additions and 233 deletions

View File

@ -0,0 +1,16 @@
import { useCallback } from "preact/hooks";
export function useInputCallback<T>(callback: (value: string) => T, deps: any[]): ((value: string | Event) => T) {
return useCallback((e: Event | string) => {
if (typeof e === 'string') {
return callback(e);
} else {
const { target } = e;
if (target && 'value' in target && typeof target.value === 'string') {
return callback(target.value);
}
}
return callback('');
}, deps);
}

View File

@ -50,3 +50,22 @@ export const intHash = (seed: number, ...parts: number[]) => {
return h1; return h1;
}; };
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number): F {
let isThrottled = false;
let savedResult: R;
const wrapper: F = function (...args: A) {
if (!isThrottled) {
savedResult = func.apply(this, args);
isThrottled = true;
setTimeout(function () {
isThrottled = false;
}, ms);
}
return savedResult;
} as F;
return wrapper;
}

View File

@ -32,6 +32,10 @@ select {
outline: none; outline: none;
} }
option, optgroup {
background-color: var(--backgroundColor);
}
textarea { textarea {
resize: vertical; resize: vertical;
width: 100%; width: 100%;

View File

@ -15,7 +15,7 @@ export const Chat = () => {
const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant'); const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant');
useEffect(() => { useEffect(() => {
DOMTools.scrollDown(chatRef.current); setTimeout(() => DOMTools.scrollDown(chatRef.current, false), 100);
}, [messages.length, lastMessageContent]); }, [messages.length, lastMessageContent]);
return ( return (

View File

@ -0,0 +1,146 @@
import { useCallback, useContext, useEffect, useMemo, useState } from 'preact/hooks';
import styles from './header.module.css';
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection';
import { Instruct, StateContext } from '../../contexts/state';
import { useInputState } from '@common/hooks/useInputState';
import { useInputCallback } from '@common/hooks/useInputCallback';
import { Huggingface } from '../../huggingface';
interface IProps {
connection: IConnection;
setConnection: (c: IConnection) => void;
}
export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
const [connectionUrl, setConnectionUrl] = useInputState('');
const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY);
const [modelName, setModelName] = useInputState('');
const [modelTemplate, setModelTemplate] = useInputState(Instruct.CHATML);
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
const [contextLength, setContextLength] = useState<number>(0);
const backendType = useMemo(() => {
if (isKoboldConnection(connection)) return 'kobold';
if (isHordeConnection(connection)) return 'horde';
return 'unknown';
}, [connection]);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
useEffect(() => {
if (isKoboldConnection(connection)) {
setConnectionUrl(connection.url);
} else if (isHordeConnection(connection)) {
setModelName(connection.model);
setApiKey(connection.apiKey || HORDE_ANON_KEY);
Connection.getHordeModels()
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
}
Connection.getContextLength(connection).then(setContextLength);
Connection.getModelName(connection).then(setModelName);
}, [connection]);
useEffect(() => {
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then(template => {
if (template) {
setModelTemplate(template);
}
});
}
}, [modelName]);
const setInstruct = useInputCallback((instruct) => {
setConnection({ ...connection, instruct });
}, [connection, setConnection]);
const setBackendType = useInputCallback((type) => {
if (type === 'kobold') {
setConnection({
instruct: connection.instruct,
url: connectionUrl,
});
} else if (type === 'horde') {
setConnection({
instruct: connection.instruct,
apiKey,
model: modelName,
});
}
}, [connection, setConnection, connectionUrl, apiKey, modelName]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i;
const url = connectionUrl.replace(regex, 'http$1://$2');
setConnection({
instruct: connection.instruct,
url,
});
}, [connection, connectionUrl, setConnection]);
const handleBlurHorde = useCallback(() => {
setConnection({
instruct: connection.instruct,
apiKey,
model: modelName,
});
}, [connection, apiKey, modelName, setConnection]);
return (
<div class={styles.connectionEditor}>
<select value={backendType} onChange={setBackendType}>
<option value='kobold'>Kobold CPP</option>
<option value='horde'>Horde</option>
</select>
<select value={connection.instruct} onChange={setInstruct} title='Instruct template'>
{modelName && modelTemplate && <optgroup label='Native model template'>
<option value={modelTemplate} title='Native for model'>{modelName}</option>
</optgroup>}
<optgroup label='Manual templates'>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</optgroup>
<optgroup label='Custom'>
<option value={connection.instruct}>Custom</option>
</optgroup>
</select>
{isKoboldConnection(connection) && <input
value={connectionUrl}
onInput={setConnectionUrl}
onBlur={handleBlurUrl}
class={urlValid ? styles.valid : styles.invalid}
/>}
{isHordeConnection(connection) && <>
<input
placeholder='Horde API key'
title='Horde API key'
value={apiKey}
onInput={setApiKey}
onBlur={handleBlurHorde}
/>
<select
value={modelName}
onChange={setModelName}
onBlur={handleBlurHorde}
title='Horde model'
>
{hordeModels.map((m) => (
<option value={m.name} key={m.name}>
{m.name} ({m.maxLength}/{m.maxContext})
</option>
))}
</select>
</>}
</div>
);
};

View File

@ -45,3 +45,10 @@
overflow: hidden; overflow: hidden;
} }
} }
.connectionEditor {
display: flex;
flex-direction: row;
gap: 8px;
flex-wrap: wrap;
}

View File

@ -2,35 +2,29 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Modal } from "@common/components/modal/modal"; import { Modal } from "@common/components/modal/modal";
import { Instruct, StateContext } from "../../contexts/state"; import { StateContext } from "../../contexts/state";
import { LLMContext } from "../../contexts/llm"; import { LLMContext } from "../../contexts/llm";
import { MiniChat } from "../minichat/minichat"; import { MiniChat } from "../minichat/minichat";
import { AutoTextarea } from "../autoTextarea"; import { AutoTextarea } from "../autoTextarea";
import { Ace } from "../ace";
import { ConnectionEditor } from "./connectionEditor";
import styles from './header.module.css'; import styles from './header.module.css';
import { Ace } from "../ace";
export const Header = () => { export const Header = () => {
const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext); const { contextLength, promptTokens, modelName } = useContext(LLMContext);
const { const {
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, messages, connection, systemPrompt, lore, userPrompt, bannedWords, summarizePrompt, summaryEnabled,
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setConnection,
} = useContext(StateContext); } = useContext(StateContext);
const connectionsOpen = useBool();
const loreOpen = useBool(); const loreOpen = useBool();
const promptsOpen = useBool(); const promptsOpen = useBool();
const genparamsOpen = useBool(); const genparamsOpen = useBool();
const assistantOpen = useBool(); const assistantOpen = useBool();
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]); const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i
const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2');
setConnectionUrl(normalizedConnectionUrl);
blockConnection.setFalse();
}, [connectionUrl, setConnectionUrl, blockConnection]);
const handleAssistantAddSwipe = useCallback((answer: string) => { const handleAssistantAddSwipe = useCallback((answer: string) => {
const index = messages.findLastIndex(m => m.role === 'assistant'); const index = messages.findLastIndex(m => m.role === 'assistant');
@ -61,29 +55,13 @@ export const Header = () => {
return ( return (
<div class={styles.header}> <div class={styles.header}>
<div class={styles.inputs}> <div class={styles.inputs}>
<input value={connectionUrl} <div class={styles.buttons}>
onInput={setConnectionUrl} <button class='icon' onClick={connectionsOpen.setTrue} title='Connection settings'>
onFocus={blockConnection.setTrue} 🔌
onBlur={handleBlurUrl} </button>
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid} </div>
/>
<select value={instruct} onChange={setInstruct} title='Instruct template'>
{modelName && modelTemplate && <optgroup label='Native model template'>
<option value={modelTemplate} title='Native for model'>{modelName}</option>
</optgroup>}
<optgroup label='Manual templates'>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</optgroup>
<optgroup label='Custom'>
<option value={instruct}>Custom</option>
</optgroup>
</select>
<div class={styles.info}> <div class={styles.info}>
{promptTokens} / {contextLength} {modelName} - {promptTokens} / {contextLength}
</div> </div>
</div> </div>
<div class={styles.buttons}> <div class={styles.buttons}>
@ -102,6 +80,10 @@ export const Header = () => {
</button> </button>
</div> </div>
<Modal open={connectionsOpen.value} onClose={connectionsOpen.setFalse}>
<h3 class={styles.modalTitle}>Connection settings</h3>
<ConnectionEditor connection={connection} setConnection={setConnection} />
</Modal>
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}> <Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
<h3 class={styles.modalTitle}>Lore Editor</h3> <h3 class={styles.modalTitle}>Lore Editor</h3>
<AutoTextarea <AutoTextarea
@ -135,12 +117,12 @@ export const Header = () => {
<h4 class={styles.modalTitle}>Summary template</h4> <h4 class={styles.modalTitle}>Summary template</h4>
<Ace value={summarizePrompt} onInput={setSummarizePrompt} /> <Ace value={summarizePrompt} onInput={setSummarizePrompt} />
<label> <label>
<input type='checkbox' checked={summaryEnabled} onChange={handleSetSummaryEnabled}/> <input type='checkbox' checked={summaryEnabled} onChange={handleSetSummaryEnabled} />
&nbsp;Enable summarization &nbsp;Enable summarization
</label> </label>
<hr /> <hr />
<h4 class={styles.modalTitle}>Instruct template</h4> <h4 class={styles.modalTitle}>Instruct template</h4>
<Ace value={instruct} onInput={setInstruct} /> <Ace value={connection.instruct} onInput={setInstruct} />
</div> </div>
</Modal> </Modal>
<MiniChat <MiniChat

352
src/games/ai/connection.ts Normal file
View File

@ -0,0 +1,352 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { delay, throttle } from "@common/utils";
interface IBaseConnection {
instruct: string;
}
interface IKoboldConnection extends IBaseConnection {
url: string;
}
interface IHordeConnection extends IBaseConnection {
apiKey?: string;
model: string;
}
export const isKoboldConnection = (obj: unknown): obj is IKoboldConnection => (
obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string'
);
export const isHordeConnection = (obj: unknown): obj is IHordeConnection => (
obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string'
);
export type IConnection = IKoboldConnection | IHordeConnection;
interface IHordeWorker {
id: string;
models: string[];
flagged: boolean;
online: boolean;
maintenance_mode: boolean;
max_context_length: number;
max_length: number;
performance: string;
}
export interface IHordeModel {
name: string;
hordeNames: string[];
maxLength: number;
maxContext: number;
workers: string[];
}
interface IHordeResult {
faulted: boolean;
done: boolean;
finished: number;
generations?: {
text: string;
}[];
}
const DEFAULT_GENERATION_SETTINGS = {
temperature: 0.8,
min_p: 0.1,
rep_pen: 1.08,
rep_pen_range: -1,
rep_pen_slope: 0.7,
top_k: 100,
top_p: 0.92,
banned_tokens: ['anticipat'],
max_length: 300,
trim_stop: true,
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
dry_allowed_length: 5,
dry_multiplier: 0.8,
dry_base: 1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
dry_penalty_last_n: 0
}
const MIN_PERFORMANCE = 5.0;
const MIN_WORKER_CONTEXT = 8192;
const MAX_HORDE_LENGTH = 512;
const MAX_HORDE_CONTEXT = 32000;
export const HORDE_ANON_KEY = '0000000000';
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const approximateTokens = (prompt: string): number =>
Math.round(prompt.split(/\s+/).length * 0.75);
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
export namespace Connection {
const AIHORDE = 'https://aihorde.net';
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
const sse = new SSE(`${url}/api/extra/generate/stream`, {
payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
prompt,
}),
});
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
}
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
let maxLength = Math.min(model.maxLength, DEFAULT_GENERATION_SETTINGS.max_length);
if (extraSettings.max_length && extraSettings.max_length < maxLength) {
maxLength = extraSettings.max_length;
}
const requestData = {
prompt,
params: {
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
n: 1,
max_context_length: model.maxContext,
max_length: maxLength,
rep_pen_range: Math.min(model.maxContext, 4096),
},
models: model.hordeNames,
workers: model.workers,
};
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
method: 'POST',
body: JSON.stringify(requestData),
headers: {
'Content-Type': 'application/json',
apikey: connection.apiKey || HORDE_ANON_KEY,
},
});
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) {
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
}
const { id } = await generateResponse.json() as { id: string };
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' })
.catch(e => console.error('Error deleting request', e));
while (true) {
await delay(2500);
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`);
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
deleteRequest();
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
}
const result: IHordeResult = await retrieveResponse.json();
if (result.done && result.generations?.length === 1) {
const { text } = result.generations[0];
return text;
}
}
}
throw new Error(`Model ${connection.model} is offline`);
}
export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) {
if (isKoboldConnection(connection)) {
yield* generateKobold(connection.url, prompt, extraSettings);
} else if (isHordeConnection(connection)) {
yield await generateHorde(connection, prompt, extraSettings);
}
}
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
try {
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
if (response.ok) {
const workers: IHordeWorker[] = await response.json();
const goodWorkers = workers.filter(w =>
w.online
&& !w.maintenance_mode
&& !w.flagged
&& w.max_context_length >= MIN_WORKER_CONTEXT
&& parseFloat(w.performance) >= MIN_PERFORMANCE
);
const models = new Map<string, IHordeModel>();
for (const worker of goodWorkers) {
for (const modelName of worker.models) {
const normName = normalizeModel(modelName.toLowerCase());
let model = models.get(normName);
if (!model) {
model = {
hordeNames: [],
maxContext: MAX_HORDE_CONTEXT,
maxLength: MAX_HORDE_LENGTH,
name: normName,
workers: []
}
}
if (!model.hordeNames.includes(modelName)) {
model.hordeNames.push(modelName);
}
if (!model.workers.includes(worker.id)) {
model.workers.push(worker.id);
}
model.maxContext = Math.min(model.maxContext, worker.max_context_length);
model.maxLength = Math.min(model.maxLength, worker.max_length);
models.set(normName, model);
}
}
return models;
}
} catch (e) {
console.error(e);
}
return new Map();
};
export const getHordeModels = throttle(requestHordeModels, 10000);
export async function getModelName(connection: IConnection): Promise<string> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/v1/model`);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
return connection.model;
}
return '';
}
export async function getContextLength(connection: IConnection): Promise<number> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
return model.maxContext;
}
}
return 0;
}
export async function countTokens(connection: IConnection, prompt: string) {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
}
}
return approximateTokens(prompt);
}
}

View File

@ -7,6 +7,7 @@ import { Instruct, StateContext } from "./state";
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Template } from "@huggingface/jinja"; import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface"; import { Huggingface } from "../huggingface";
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
interface ICompileArgs { interface ICompileArgs {
keepUsers?: number; keepUsers?: number;
@ -29,29 +30,8 @@ interface IContext {
contextLength: number; contextLength: number;
} }
const DEFAULT_GENERATION_SETTINGS = {
temperature: 0.8,
min_p: 0.1,
rep_pen: 1.08,
rep_pen_range: -1,
rep_pen_slope: 0.7,
top_k: 100,
top_p: 0.92,
banned_tokens: [],
max_length: 300,
trim_stop: true,
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
dry_allowed_length: 5,
dry_multiplier: 0.8,
dry_base: 1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
dry_penalty_last_n: 0
}
const MESSAGES_TO_KEEP = 10; const MESSAGES_TO_KEEP = 10;
type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
interface IActions { interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>; compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
@ -60,32 +40,6 @@ interface IActions {
} }
export type ILLMContext = IContext & IActions; export type ILLMContext = IContext & IActions;
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const LLMContext = createContext<ILLMContext>({} as ILLMContext); export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
const processing = { const processing = {
@ -95,7 +49,7 @@ const processing = {
export const LLMContextProvider = ({ children }: { children?: any }) => { export const LLMContextProvider = ({ children }: { children?: any }) => {
const { const {
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, addMessage, editMessage, editSummary, setInstruct, setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
} = useContext(StateContext); } = useContext(StateContext);
@ -118,38 +72,18 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}, [userPrompt]); }, [userPrompt]);
const getContextLength = useCallback(async () => { const getContextLength = useCallback(async () => {
if (!connectionUrl || blockConnection.value) { if (!connection || blockConnection.value) {
return 0; return 0;
} }
try { return Connection.getContextLength(connection);
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`); }, [connection, blockConnection.value]);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return 0;
}, [connectionUrl, blockConnection.value]);
const getModelName = useCallback(async () => { const getModelName = useCallback(async () => {
if (!connectionUrl || blockConnection.value) { if (!connection || blockConnection.value) {
return ''; return '';
} }
try { return Connection.getModelName(connection);
const response = await fetch(`${connectionUrl}/api/v1/model`); }, [connection, blockConnection.value]);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return '';
}, [connectionUrl, blockConnection.value]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers } = {}) => { compilePrompt: async (messages, { keepUsers } = {}) => {
@ -236,7 +170,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
const prompt = Huggingface.applyChatTemplate(instruct, templateMessages); const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
return { return {
prompt, prompt,
isContinue, isContinue,
@ -244,100 +178,30 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}; };
}, },
generate: async function* (prompt, extraSettings = {}) { generate: async function* (prompt, extraSettings = {}) {
if (!connectionUrl) {
return;
}
try { try {
generating.setTrue(); generating.setTrue();
console.log('[LLM.generate]', prompt); console.log('[LLM.generate]', prompt);
const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, { yield* Connection.generate(connection, prompt, {
payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS,
banned_tokens: bannedWords.filter(w => w.trim()),
...extraSettings, ...extraSettings,
prompt, banned_tokens: bannedWords.filter(w => w.trim()),
}),
}); });
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
} finally { } finally {
generating.setFalse(); generating.setFalse();
} }
}, },
summarize: async (message) => { summarize: async (message) => {
const content = Huggingface.applyTemplate(summarizePrompt, { message }); const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(instruct, [{ role: 'user', content }]); const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
const tokens = await Array.fromAsync(actions.generate(prompt)); const tokens = await Array.fromAsync(actions.generate(prompt));
return MessageTools.trimSentence(tokens.join('')); return MessageTools.trimSentence(tokens.join(''));
}, },
countTokens: async (prompt) => { countTokens: async (prompt) => {
if (!connectionUrl) { return await Connection.countTokens(connection, prompt);
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
}
return 0;
}, },
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct, summarizePrompt]); }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
useEffect(() => void (async () => { useEffect(() => void (async () => {
if (triggerNext && !generating.value) { if (triggerNext && !generating.value) {
@ -356,7 +220,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
editSummary(messageId, 'Generating...'); editSummary(messageId, 'Generating...');
for await (const chunk of actions.generate(prompt)) { for await (const chunk of actions.generate(prompt)) {
text += chunk; text += chunk;
setPromptTokens(promptTokens + Math.round(text.length * 0.25)); setPromptTokens(promptTokens + approximateTokens(text));
editMessage(messageId, text.trim()); editMessage(messageId, text.trim());
} }
@ -397,7 +261,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
getContextLength().then(setContextLength); getContextLength().then(setContextLength);
getModelName().then(normalizeModel).then(setModelName); getModelName().then(normalizeModel).then(setModelName);
} }
}, [connectionUrl, blockConnection.value]); }, [connection, blockConnection.value]);
useEffect(() => { useEffect(() => {
setModelTemplate(''); setModelTemplate('');
@ -431,16 +295,16 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
useEffect(() => { useEffect(() => {
calculateTokens(); calculateTokens();
}, [messages, connectionUrl, blockConnection.value, instruct, /* systemPrompt, lore, userPrompt TODO debounce*/]); }, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]);
useEffect(() => { useEffect(() => {
try { try {
const hasTools = Huggingface.testToolCalls(instruct); const hasTools = Huggingface.testToolCalls(connection.instruct);
setHasToolCalls(hasTools); setHasToolCalls(hasTools);
} catch { } catch {
setHasToolCalls(false); setHasToolCalls(false);
} }
}, [instruct]); }, [connection.instruct]);
const rawContext: IContext = { const rawContext: IContext = {
generating: generating.value, generating: generating.value,

View File

@ -1,12 +1,13 @@
import { createContext } from "preact"; import { createContext } from "preact";
import { useEffect, useMemo, useState } from "preact/hooks"; import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages"; import { MessageTools, type IMessage } from "../messages";
import { useInputState } from "@common/hooks/useInputState"; import { useInputState } from "@common/hooks/useInputState";
import { type IConnection } from "../connection";
interface IContext { interface IContext {
connectionUrl: string; currentConnection: number;
availableConnections: IConnection[];
input: string; input: string;
instruct: string;
systemPrompt: string; systemPrompt: string;
lore: string; lore: string;
userPrompt: string; userPrompt: string;
@ -17,8 +18,14 @@ interface IContext {
triggerNext: boolean; triggerNext: boolean;
} }
interface IComputableContext {
connection: IConnection;
}
interface IActions { interface IActions {
setConnectionUrl: (url: string | Event) => void; setConnection: (connection: IConnection) => void;
setAvailableConnections: (connections: IConnection[]) => void;
setCurrentConnection: (connection: number) => void;
setInput: (url: string | Event) => void; setInput: (url: string | Event) => void;
setInstruct: (template: string | Event) => void; setInstruct: (template: string | Event) => void;
setLore: (lore: string | Event) => void; setLore: (lore: string | Event) => void;
@ -49,11 +56,40 @@ export enum Instruct {
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`, MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
METHARME = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>' + message['content'] }}{% elif message['role'] == 'user' %}{{'<|user|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{'<|model|>' + message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% endif %}`,
GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`, GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`,
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`, ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`,
}; };
const DEFAULT_CONTEXT: IContext = {
currentConnection: 0,
availableConnections: [{
url: 'http://localhost:5001',
instruct: Instruct.CHATML,
}],
input: '',
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
lore: '',
userPrompt: `{% if isStart -%}
Write a novel using information above as a reference.
{%- else -%}
Continue the story forward.
{%- endif %}
{% if prompt -%}
What should happen next in your answer: {{ prompt | trim }}
{% endif %}
Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
summaryEnabled: false,
bannedWords: [],
messages: [],
triggerNext: false,
};
export const saveContext = (context: IContext) => { export const saveContext = (context: IContext) => {
const contextToSave: Partial<IContext> = { ...context }; const contextToSave: Partial<IContext> = { ...context };
delete contextToSave.triggerNext; delete contextToSave.triggerNext;
@ -62,30 +98,6 @@ export const saveContext = (context: IContext) => {
} }
export const loadContext = (): IContext => { export const loadContext = (): IContext => {
const defaultContext: IContext = {
connectionUrl: 'http://localhost:5001',
input: '',
instruct: Instruct.CHATML,
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
lore: '',
userPrompt: `{% if isStart -%}
Write a novel using information above as a reference.
{%- else -%}
Continue the story forward.
{%- endif %}
{% if prompt -%}
This is the description of what I want to happen next: {{ prompt | trim }}
{% endif %}
Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
summaryEnabled: false,
bannedWords: [],
messages: [],
triggerNext: false,
};
let loadedContext: Partial<IContext> = {}; let loadedContext: Partial<IContext> = {};
try { try {
@ -95,18 +107,18 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
} }
} catch { } } catch { }
return { ...defaultContext, ...loadedContext }; return { ...DEFAULT_CONTEXT, ...loadedContext };
} }
export type IStateContext = IContext & IActions; export type IStateContext = IContext & IActions & IComputableContext;
export const StateContext = createContext<IStateContext>({} as IStateContext); export const StateContext = createContext<IStateContext>({} as IStateContext);
export const StateContextProvider = ({ children }: { children?: any }) => { export const StateContextProvider = ({ children }: { children?: any }) => {
const loadedContext = useMemo(() => loadContext(), []); const loadedContext = useMemo(() => loadContext(), []);
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl); const [currentConnection, setCurrentConnection] = useState<number>(loadedContext.currentConnection);
const [availableConnections, setAvailableConnections] = useState<IConnection[]>(loadedContext.availableConnections);
const [input, setInput] = useInputState(loadedContext.input); const [input, setInput] = useInputState(loadedContext.input);
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
const [lore, setLore] = useInputState(loadedContext.lore); const [lore, setLore] = useInputState(loadedContext.lore);
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
@ -115,10 +127,26 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const [messages, setMessages] = useState(loadedContext.messages); const [messages, setMessages] = useState(loadedContext.messages);
const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled); const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled);
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
const [triggerNext, setTriggerNext] = useState(false); const [triggerNext, setTriggerNext] = useState(false);
const [instruct, setInstruct] = useInputState(connection.instruct);
const setConnection = useCallback((c: IConnection) => {
setAvailableConnections(availableConnections.map((ac, ai) => {
if (ai === currentConnection) {
return c;
} else {
return ac;
}
}));
}, [availableConnections, currentConnection]);
useEffect(() => setConnection({ ...connection, instruct }), [instruct]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
setConnectionUrl, setConnection,
setCurrentConnection,
setInput, setInput,
setInstruct, setInstruct,
setSystemPrompt, setSystemPrompt,
@ -127,7 +155,8 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
setLore, setLore,
setTriggerNext, setTriggerNext,
setSummaryEnabled, setSummaryEnabled,
setBannedWords: (words) => setBannedWords([...words]), setBannedWords: (words) => setBannedWords(words.slice()),
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
setMessages: (newMessages) => setMessages(newMessages.slice()), setMessages: (newMessages) => setMessages(newMessages.slice()),
addMessage: (content, role, triggerNext = false) => { addMessage: (content, role, triggerNext = false) => {
@ -198,10 +227,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
continueMessage: () => setTriggerNext(true), continueMessage: () => setTriggerNext(true),
}), []); }), []);
const rawContext: IContext = { const rawContext: IContext & IComputableContext = {
connectionUrl, connection,
currentConnection,
availableConnections,
input, input,
instruct,
systemPrompt, systemPrompt,
lore, lore,
userPrompt, userPrompt,

View File

@ -230,7 +230,9 @@ export namespace Huggingface {
} }
export const findModelTemplate = async (modelName: string): Promise<string | null> => { export const findModelTemplate = async (modelName: string): Promise<string | null> => {
const modelKey = modelName.toLowerCase(); const modelKey = modelName.toLowerCase().trim();
if (!modelKey) return '';
let template = templateCache[modelKey] ?? null; let template = templateCache[modelKey] ?? null;
if (template) { if (template) {