From 017ef7aaa59f7f53029e7981aab22c37d1d8d36a Mon Sep 17 00:00:00 2001 From: Pabloader Date: Tue, 12 Nov 2024 13:32:35 +0000 Subject: [PATCH] AIStory: add basic horde support --- src/common/hooks/useInputCallback.ts | 16 + src/common/utils.ts | 21 +- src/games/ai/assets/style.css | 4 + src/games/ai/components/chat.tsx | 2 +- .../ai/components/header/connectionEditor.tsx | 146 ++++++++ .../ai/components/header/header.module.css | 7 + src/games/ai/components/header/header.tsx | 56 +-- src/games/ai/connection.ts | 352 ++++++++++++++++++ src/games/ai/contexts/llm.tsx | 176 +-------- src/games/ai/contexts/state.tsx | 104 ++++-- src/games/ai/huggingface.ts | 4 +- 11 files changed, 655 insertions(+), 233 deletions(-) create mode 100644 src/common/hooks/useInputCallback.ts create mode 100644 src/games/ai/components/header/connectionEditor.tsx create mode 100644 src/games/ai/connection.ts diff --git a/src/common/hooks/useInputCallback.ts b/src/common/hooks/useInputCallback.ts new file mode 100644 index 0000000..eff731e --- /dev/null +++ b/src/common/hooks/useInputCallback.ts @@ -0,0 +1,16 @@ +import { useCallback } from "preact/hooks"; + +export function useInputCallback(callback: (value: string) => T, deps: any[]): ((value: string | Event) => T) { + return useCallback((e: Event | string) => { + if (typeof e === 'string') { + return callback(e); + } else { + const { target } = e; + if (target && 'value' in target && typeof target.value === 'string') { + return callback(target.value); + } + } + + return callback(''); + }, deps); +} \ No newline at end of file diff --git a/src/common/utils.ts b/src/common/utils.ts index 6301880..3d00c8b 100644 --- a/src/common/utils.ts +++ b/src/common/utils.ts @@ -49,4 +49,23 @@ export const intHash = (seed: number, ...parts: number[]) => { h1 ^= Math.imul(h2 ^ (h2 >>> 13), 3266489909); return h1; }; -export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; \ No newline at end of file +export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; +export const throttle = function R>(func: F, ms: number): F { + let isThrottled = false; + let savedResult: R; + + const wrapper: F = function (...args: A) { + if (!isThrottled) { + savedResult = func.apply(this, args); + + isThrottled = true; + + setTimeout(function () { + isThrottled = false; + }, ms); + } + return savedResult; + } as F; + + return wrapper; +} \ No newline at end of file diff --git a/src/games/ai/assets/style.css b/src/games/ai/assets/style.css index c3b4611..bdcc6e5 100644 --- a/src/games/ai/assets/style.css +++ b/src/games/ai/assets/style.css @@ -32,6 +32,10 @@ select { outline: none; } +option, optgroup { + background-color: var(--backgroundColor); +} + textarea { resize: vertical; width: 100%; diff --git a/src/games/ai/components/chat.tsx b/src/games/ai/components/chat.tsx index 75ec7bb..f16859b 100644 --- a/src/games/ai/components/chat.tsx +++ b/src/games/ai/components/chat.tsx @@ -15,7 +15,7 @@ export const Chat = () => { const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant'); useEffect(() => { - DOMTools.scrollDown(chatRef.current); + setTimeout(() => DOMTools.scrollDown(chatRef.current, false), 100); }, [messages.length, lastMessageContent]); return ( diff --git a/src/games/ai/components/header/connectionEditor.tsx b/src/games/ai/components/header/connectionEditor.tsx new file mode 100644 index 0000000..9fb6be4 --- /dev/null +++ b/src/games/ai/components/header/connectionEditor.tsx @@ -0,0 +1,146 @@ +import { useCallback, useContext, useEffect, useMemo, useState } from 'preact/hooks'; + +import styles from './header.module.css'; +import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection'; +import { Instruct, StateContext } from '../../contexts/state'; +import { useInputState } from '@common/hooks/useInputState'; +import { useInputCallback } from '@common/hooks/useInputCallback'; +import { Huggingface } from '../../huggingface'; + +interface IProps { + connection: IConnection; + setConnection: (c: IConnection) => void; +} + +export const ConnectionEditor = ({ connection, setConnection }: IProps) => { + const [connectionUrl, setConnectionUrl] = useInputState(''); + const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY); + const [modelName, setModelName] = useInputState(''); + + const [modelTemplate, setModelTemplate] = useInputState(Instruct.CHATML); + const [hordeModels, setHordeModels] = useState([]); + const [contextLength, setContextLength] = useState(0); + + const backendType = useMemo(() => { + if (isKoboldConnection(connection)) return 'kobold'; + if (isHordeConnection(connection)) return 'horde'; + return 'unknown'; + }, [connection]); + + const urlValid = useMemo(() => contextLength > 0, [contextLength]); + + useEffect(() => { + if (isKoboldConnection(connection)) { + setConnectionUrl(connection.url); + } else if (isHordeConnection(connection)) { + setModelName(connection.model); + setApiKey(connection.apiKey || HORDE_ANON_KEY); + + Connection.getHordeModels() + .then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name)))); + } + + Connection.getContextLength(connection).then(setContextLength); + Connection.getModelName(connection).then(setModelName); + }, [connection]); + + useEffect(() => { + if (modelName) { + Huggingface.findModelTemplate(modelName) + .then(template => { + if (template) { + setModelTemplate(template); + } + }); + } + }, [modelName]); + + const setInstruct = useInputCallback((instruct) => { + setConnection({ ...connection, instruct }); + }, [connection, setConnection]); + + const setBackendType = useInputCallback((type) => { + if (type === 'kobold') { + setConnection({ + instruct: connection.instruct, + url: connectionUrl, + }); + } else if (type === 'horde') { + setConnection({ + instruct: connection.instruct, + apiKey, + model: modelName, + }); + } + }, [connection, setConnection, connectionUrl, apiKey, modelName]); + + const handleBlurUrl = useCallback(() => { + const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i; + const url = connectionUrl.replace(regex, 'http$1://$2'); + + setConnection({ + instruct: connection.instruct, + url, + }); + }, [connection, connectionUrl, setConnection]); + + const handleBlurHorde = useCallback(() => { + setConnection({ + instruct: connection.instruct, + apiKey, + model: modelName, + }); + }, [connection, apiKey, modelName, setConnection]); + + return ( +
+ + + {isKoboldConnection(connection) && } + {isHordeConnection(connection) && <> + + + + } +
+ ); +}; diff --git a/src/games/ai/components/header/header.module.css b/src/games/ai/components/header/header.module.css index b05142b..e28a5bc 100644 --- a/src/games/ai/components/header/header.module.css +++ b/src/games/ai/components/header/header.module.css @@ -44,4 +44,11 @@ textarea { overflow: hidden; } +} + +.connectionEditor { + display: flex; + flex-direction: row; + gap: 8px; + flex-wrap: wrap; } \ No newline at end of file diff --git a/src/games/ai/components/header/header.tsx b/src/games/ai/components/header/header.tsx index 2f2bfae..b4a9486 100644 --- a/src/games/ai/components/header/header.tsx +++ b/src/games/ai/components/header/header.tsx @@ -2,35 +2,29 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho import { useBool } from "@common/hooks/useBool"; import { Modal } from "@common/components/modal/modal"; -import { Instruct, StateContext } from "../../contexts/state"; +import { StateContext } from "../../contexts/state"; import { LLMContext } from "../../contexts/llm"; import { MiniChat } from "../minichat/minichat"; import { AutoTextarea } from "../autoTextarea"; +import { Ace } from "../ace"; +import { ConnectionEditor } from "./connectionEditor"; import styles from './header.module.css'; -import { Ace } from "../ace"; export const Header = () => { - const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext); + const { contextLength, promptTokens, modelName } = useContext(LLMContext); const { - messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, - setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, + messages, connection, systemPrompt, lore, userPrompt, bannedWords, summarizePrompt, summaryEnabled, + setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setConnection, } = useContext(StateContext); + const connectionsOpen = useBool(); const loreOpen = useBool(); const promptsOpen = useBool(); const genparamsOpen = useBool(); const assistantOpen = useBool(); const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]); - const urlValid = useMemo(() => contextLength > 0, [contextLength]); - - const handleBlurUrl = useCallback(() => { - const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i - const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2'); - setConnectionUrl(normalizedConnectionUrl); - blockConnection.setFalse(); - }, [connectionUrl, setConnectionUrl, blockConnection]); const handleAssistantAddSwipe = useCallback((answer: string) => { const index = messages.findLastIndex(m => m.role === 'assistant'); @@ -61,29 +55,13 @@ export const Header = () => { return (
- - +
+ +
- {promptTokens} / {contextLength} + {modelName} - {promptTokens} / {contextLength}
@@ -102,6 +80,10 @@ export const Header = () => { ❓
+ +

Connection settings

+ +

Lore Editor

{

Summary template


Instruct template

- +
( + obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string' +); + +export const isHordeConnection = (obj: unknown): obj is IHordeConnection => ( + obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string' +); + +export type IConnection = IKoboldConnection | IHordeConnection; + +interface IHordeWorker { + id: string; + models: string[]; + flagged: boolean; + online: boolean; + maintenance_mode: boolean; + max_context_length: number; + max_length: number; + performance: string; +} + +export interface IHordeModel { + name: string; + hordeNames: string[]; + maxLength: number; + maxContext: number; + workers: string[]; +} + +interface IHordeResult { + faulted: boolean; + done: boolean; + finished: number; + generations?: { + text: string; + }[]; +} + +const DEFAULT_GENERATION_SETTINGS = { + temperature: 0.8, + min_p: 0.1, + rep_pen: 1.08, + rep_pen_range: -1, + rep_pen_slope: 0.7, + top_k: 100, + top_p: 0.92, + banned_tokens: ['anticipat'], + max_length: 300, + trim_stop: true, + stop_sequence: ['[INST]', '[/INST]', '', '<|'], + dry_allowed_length: 5, + dry_multiplier: 0.8, + dry_base: 1, + dry_sequence_breakers: ["\n", ":", "\"", "*"], + dry_penalty_last_n: 0 +} + +const MIN_PERFORMANCE = 5.0; +const MIN_WORKER_CONTEXT = 8192; +const MAX_HORDE_LENGTH = 512; +const MAX_HORDE_CONTEXT = 32000; +export const HORDE_ANON_KEY = '0000000000'; + +export const normalizeModel = (model: string) => { + let currentModel = model.split(/[\\\/]/).at(-1); + currentModel = currentModel.split('::').at(0); + let normalizedModel: string; + + do { + normalizedModel = currentModel; + + currentModel = currentModel + .replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k + .replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name + .replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc + .replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size + .replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw + .replace(/[ ._-]f(p|loat)?(8|16|32)/i, '') + .replace(/^(debug-?)+/i, '') + .trim(); + } while (normalizedModel !== currentModel); + + return normalizedModel + .replace(/[ _-]+/ig, '-') + .replace(/\.{2,}/, '-') + .replace(/[ ._-]+$/ig, '') + .trim(); +} + +export const approximateTokens = (prompt: string): number => + Math.round(prompt.split(/\s+/).length * 0.75); + +export type IGenerationSettings = Partial; + +export namespace Connection { + const AIHORDE = 'https://aihorde.net'; + + async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator { + const sse = new SSE(`${url}/api/extra/generate/stream`, { + payload: JSON.stringify({ + ...DEFAULT_GENERATION_SETTINGS, + ...extraSettings, + prompt, + }), + }); + + const messages: string[] = []; + const messageLock = new Lock(); + let end = false; + + sse.addEventListener('message', (e) => { + if (e.data) { + { + const { token, finish_reason } = JSON.parse(e.data); + messages.push(token); + + if (finish_reason && finish_reason !== 'null') { + end = true; + } + } + } + messageLock.release(); + }); + + const handleEnd = () => { + end = true; + messageLock.release(); + }; + + sse.addEventListener('error', handleEnd); + sse.addEventListener('abort', handleEnd); + sse.addEventListener('readystatechange', (e) => { + if (e.readyState === SSE.CLOSED) handleEnd(); + }); + + while (!end || messages.length) { + while (messages.length > 0) { + const message = messages.shift(); + if (message != null) { + try { + yield message; + } catch { } + } + } + if (!end) { + await messageLock.wait(); + } + } + + sse.close(); + } + + async function generateHorde(connection: Omit, prompt: string, extraSettings: IGenerationSettings = {}): Promise { + const models = await getHordeModels(); + const model = models.get(connection.model); + if (model) { + let maxLength = Math.min(model.maxLength, DEFAULT_GENERATION_SETTINGS.max_length); + if (extraSettings.max_length && extraSettings.max_length < maxLength) { + maxLength = extraSettings.max_length; + } + const requestData = { + prompt, + params: { + ...DEFAULT_GENERATION_SETTINGS, + ...extraSettings, + n: 1, + max_context_length: model.maxContext, + max_length: maxLength, + rep_pen_range: Math.min(model.maxContext, 4096), + }, + models: model.hordeNames, + workers: model.workers, + }; + + const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, { + method: 'POST', + body: JSON.stringify(requestData), + headers: { + 'Content-Type': 'application/json', + apikey: connection.apiKey || HORDE_ANON_KEY, + }, + }); + + if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) { + throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`); + } + + const { id } = await generateResponse.json() as { id: string }; + const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' }) + .catch(e => console.error('Error deleting request', e)); + + while (true) { + await delay(2500); + + const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`); + if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) { + deleteRequest(); + throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`); + } + + const result: IHordeResult = await retrieveResponse.json(); + + if (result.done && result.generations?.length === 1) { + const { text } = result.generations[0]; + + return text; + } + } + } + + throw new Error(`Model ${connection.model} is offline`); + } + + export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) { + if (isKoboldConnection(connection)) { + yield* generateKobold(connection.url, prompt, extraSettings); + } else if (isHordeConnection(connection)) { + yield await generateHorde(connection, prompt, extraSettings); + } + } + + async function requestHordeModels(): Promise> { + try { + const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`); + if (response.ok) { + const workers: IHordeWorker[] = await response.json(); + const goodWorkers = workers.filter(w => + w.online + && !w.maintenance_mode + && !w.flagged + && w.max_context_length >= MIN_WORKER_CONTEXT + && parseFloat(w.performance) >= MIN_PERFORMANCE + ); + + const models = new Map(); + + for (const worker of goodWorkers) { + for (const modelName of worker.models) { + const normName = normalizeModel(modelName.toLowerCase()); + let model = models.get(normName); + if (!model) { + model = { + hordeNames: [], + maxContext: MAX_HORDE_CONTEXT, + maxLength: MAX_HORDE_LENGTH, + name: normName, + workers: [] + } + } + + if (!model.hordeNames.includes(modelName)) { + model.hordeNames.push(modelName); + } + if (!model.workers.includes(worker.id)) { + model.workers.push(worker.id); + } + + model.maxContext = Math.min(model.maxContext, worker.max_context_length); + model.maxLength = Math.min(model.maxLength, worker.max_length); + + models.set(normName, model); + } + } + + return models; + } + } catch (e) { + console.error(e); + } + + return new Map(); + }; + + export const getHordeModels = throttle(requestHordeModels, 10000); + + export async function getModelName(connection: IConnection): Promise { + if (isKoboldConnection(connection)) { + try { + const response = await fetch(`${connection.url}/api/v1/model`); + if (response.ok) { + const { result } = await response.json(); + return result; + } + } catch (e) { + console.log('Error getting max tokens', e); + } + } else if (isHordeConnection(connection)) { + return connection.model; + } + + return ''; + } + + export async function getContextLength(connection: IConnection): Promise { + if (isKoboldConnection(connection)) { + try { + const response = await fetch(`${connection.url}/api/extra/true_max_context_length`); + if (response.ok) { + const { value } = await response.json(); + return value; + } + } catch (e) { + console.log('Error getting max tokens', e); + } + } else if (isHordeConnection(connection)) { + const models = await getHordeModels(); + const model = models.get(connection.model); + if (model) { + return model.maxContext; + } + } + + return 0; + } + + export async function countTokens(connection: IConnection, prompt: string) { + if (isKoboldConnection(connection)) { + try { + const response = await fetch(`${connection.url}/api/extra/tokencount`, { + body: JSON.stringify({ prompt }), + headers: { 'Content-Type': 'applicarion/json' }, + method: 'POST', + }); + if (response.ok) { + const { value } = await response.json(); + return value; + } + } catch (e) { + console.log('Error counting tokens', e); + } + } + + return approximateTokens(prompt); + } +} \ No newline at end of file diff --git a/src/games/ai/contexts/llm.tsx b/src/games/ai/contexts/llm.tsx index 9eb7294..4a7eaae 100644 --- a/src/games/ai/contexts/llm.tsx +++ b/src/games/ai/contexts/llm.tsx @@ -7,6 +7,7 @@ import { Instruct, StateContext } from "./state"; import { useBool } from "@common/hooks/useBool"; import { Template } from "@huggingface/jinja"; import { Huggingface } from "../huggingface"; +import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection"; interface ICompileArgs { keepUsers?: number; @@ -29,29 +30,8 @@ interface IContext { contextLength: number; } -const DEFAULT_GENERATION_SETTINGS = { - temperature: 0.8, - min_p: 0.1, - rep_pen: 1.08, - rep_pen_range: -1, - rep_pen_slope: 0.7, - top_k: 100, - top_p: 0.92, - banned_tokens: [], - max_length: 300, - trim_stop: true, - stop_sequence: ['[INST]', '[/INST]', '', '<|'], - dry_allowed_length: 5, - dry_multiplier: 0.8, - dry_base: 1, - dry_sequence_breakers: ["\n", ":", "\"", "*"], - dry_penalty_last_n: 0 -} - const MESSAGES_TO_KEEP = 10; -type IGenerationSettings = Partial; - interface IActions { compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator; @@ -60,32 +40,6 @@ interface IActions { } export type ILLMContext = IContext & IActions; -export const normalizeModel = (model: string) => { - let currentModel = model.split(/[\\\/]/).at(-1); - currentModel = currentModel.split('::').at(0); - let normalizedModel: string; - - do { - normalizedModel = currentModel; - - currentModel = currentModel - .replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k - .replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name - .replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc - .replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size - .replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw - .replace(/[ ._-]f(p|loat)?(8|16|32)/i, '') - .replace(/^(debug-?)+/i, '') - .trim(); - } while (normalizedModel !== currentModel); - - return normalizedModel - .replace(/[ _-]+/ig, '-') - .replace(/\.{2,}/, '-') - .replace(/[ ._-]+$/ig, '') - .trim(); -} - export const LLMContext = createContext({} as ILLMContext); const processing = { @@ -95,7 +49,7 @@ const processing = { export const LLMContextProvider = ({ children }: { children?: any }) => { const { - connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, + connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, setTriggerNext, addMessage, editMessage, editSummary, setInstruct, } = useContext(StateContext); @@ -118,38 +72,18 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }, [userPrompt]); const getContextLength = useCallback(async () => { - if (!connectionUrl || blockConnection.value) { + if (!connection || blockConnection.value) { return 0; } - try { - const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`); - if (response.ok) { - const { value } = await response.json(); - return value; - } - } catch (e) { - console.log('Error getting max tokens', e); - } - - return 0; - }, [connectionUrl, blockConnection.value]); + return Connection.getContextLength(connection); + }, [connection, blockConnection.value]); const getModelName = useCallback(async () => { - if (!connectionUrl || blockConnection.value) { + if (!connection || blockConnection.value) { return ''; } - try { - const response = await fetch(`${connectionUrl}/api/v1/model`); - if (response.ok) { - const { result } = await response.json(); - return result; - } - } catch (e) { - console.log('Error getting max tokens', e); - } - - return ''; - }, [connectionUrl, blockConnection.value]); + return Connection.getModelName(connection); + }, [connection, blockConnection.value]); const actions: IActions = useMemo(() => ({ compilePrompt: async (messages, { keepUsers } = {}) => { @@ -236,7 +170,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; - const prompt = Huggingface.applyChatTemplate(instruct, templateMessages); + const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages); return { prompt, isContinue, @@ -244,100 +178,30 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }; }, generate: async function* (prompt, extraSettings = {}) { - if (!connectionUrl) { - return; - } - try { generating.setTrue(); console.log('[LLM.generate]', prompt); - const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, { - payload: JSON.stringify({ - ...DEFAULT_GENERATION_SETTINGS, - banned_tokens: bannedWords.filter(w => w.trim()), - ...extraSettings, - prompt, - }), + yield* Connection.generate(connection, prompt, { + ...extraSettings, + banned_tokens: bannedWords.filter(w => w.trim()), }); - - const messages: string[] = []; - const messageLock = new Lock(); - let end = false; - - sse.addEventListener('message', (e) => { - if (e.data) { - { - const { token, finish_reason } = JSON.parse(e.data); - messages.push(token); - - if (finish_reason && finish_reason !== 'null') { - end = true; - } - } - } - messageLock.release(); - }); - - const handleEnd = () => { - end = true; - messageLock.release(); - }; - - sse.addEventListener('error', handleEnd); - sse.addEventListener('abort', handleEnd); - sse.addEventListener('readystatechange', (e) => { - if (e.readyState === SSE.CLOSED) handleEnd(); - }); - - while (!end || messages.length) { - while (messages.length > 0) { - const message = messages.shift(); - if (message != null) { - try { - yield message; - } catch { } - } - } - if (!end) { - await messageLock.wait(); - } - } - - sse.close(); } finally { generating.setFalse(); } }, summarize: async (message) => { const content = Huggingface.applyTemplate(summarizePrompt, { message }); - const prompt = Huggingface.applyChatTemplate(instruct, [{ role: 'user', content }]); + const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); const tokens = await Array.fromAsync(actions.generate(prompt)); return MessageTools.trimSentence(tokens.join('')); }, countTokens: async (prompt) => { - if (!connectionUrl) { - return 0; - } - try { - const response = await fetch(`${connectionUrl}/api/extra/tokencount`, { - body: JSON.stringify({ prompt }), - headers: { 'Content-Type': 'applicarion/json' }, - method: 'POST', - }); - if (response.ok) { - const { value } = await response.json(); - return value; - } - } catch (e) { - console.log('Error counting tokens', e); - } - - return 0; + return await Connection.countTokens(connection, prompt); }, - }), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct, summarizePrompt]); + }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]); useEffect(() => void (async () => { if (triggerNext && !generating.value) { @@ -356,7 +220,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { editSummary(messageId, 'Generating...'); for await (const chunk of actions.generate(prompt)) { text += chunk; - setPromptTokens(promptTokens + Math.round(text.length * 0.25)); + setPromptTokens(promptTokens + approximateTokens(text)); editMessage(messageId, text.trim()); } @@ -397,7 +261,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { getContextLength().then(setContextLength); getModelName().then(normalizeModel).then(setModelName); } - }, [connectionUrl, blockConnection.value]); + }, [connection, blockConnection.value]); useEffect(() => { setModelTemplate(''); @@ -431,16 +295,16 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { useEffect(() => { calculateTokens(); - }, [messages, connectionUrl, blockConnection.value, instruct, /* systemPrompt, lore, userPrompt TODO debounce*/]); + }, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]); useEffect(() => { try { - const hasTools = Huggingface.testToolCalls(instruct); + const hasTools = Huggingface.testToolCalls(connection.instruct); setHasToolCalls(hasTools); } catch { setHasToolCalls(false); } - }, [instruct]); + }, [connection.instruct]); const rawContext: IContext = { generating: generating.value, diff --git a/src/games/ai/contexts/state.tsx b/src/games/ai/contexts/state.tsx index 958a8b7..d7a076b 100644 --- a/src/games/ai/contexts/state.tsx +++ b/src/games/ai/contexts/state.tsx @@ -1,12 +1,13 @@ import { createContext } from "preact"; -import { useEffect, useMemo, useState } from "preact/hooks"; +import { useCallback, useEffect, useMemo, useState } from "preact/hooks"; import { MessageTools, type IMessage } from "../messages"; import { useInputState } from "@common/hooks/useInputState"; +import { type IConnection } from "../connection"; interface IContext { - connectionUrl: string; + currentConnection: number; + availableConnections: IConnection[]; input: string; - instruct: string; systemPrompt: string; lore: string; userPrompt: string; @@ -17,8 +18,14 @@ interface IContext { triggerNext: boolean; } +interface IComputableContext { + connection: IConnection; +} + interface IActions { - setConnectionUrl: (url: string | Event) => void; + setConnection: (connection: IConnection) => void; + setAvailableConnections: (connections: IConnection[]) => void; + setCurrentConnection: (connection: number) => void; setInput: (url: string | Event) => void; setInstruct: (template: string | Event) => void; setLore: (lore: string | Event) => void; @@ -49,11 +56,40 @@ export enum Instruct { MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + ''}}{%- endif %}{%- endfor %}`, + METHARME = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>' + message['content'] }}{% elif message['role'] == 'user' %}{{'<|user|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{'<|model|>' + message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% endif %}`, + GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}`, ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`, }; +const DEFAULT_CONTEXT: IContext = { + currentConnection: 0, + availableConnections: [{ + url: 'http://localhost:5001', + instruct: Instruct.CHATML, + }], + input: '', + systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.', + lore: '', + userPrompt: `{% if isStart -%} +Write a novel using information above as a reference. +{%- else -%} +Continue the story forward. +{%- endif %} + +{% if prompt -%} +What should happen next in your answer: {{ prompt | trim }} +{% endif %} +Remember that this story should be infinite and go forever. +Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, + summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', + summaryEnabled: false, + bannedWords: [], + messages: [], + triggerNext: false, +}; + export const saveContext = (context: IContext) => { const contextToSave: Partial = { ...context }; delete contextToSave.triggerNext; @@ -62,30 +98,6 @@ export const saveContext = (context: IContext) => { } export const loadContext = (): IContext => { - const defaultContext: IContext = { - connectionUrl: 'http://localhost:5001', - input: '', - instruct: Instruct.CHATML, - systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.', - lore: '', - userPrompt: `{% if isStart -%} - Write a novel using information above as a reference. -{%- else -%} - Continue the story forward. -{%- endif %} - -{% if prompt -%} - This is the description of what I want to happen next: {{ prompt | trim }} -{% endif %} -Remember that this story should be infinite and go forever. -Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, - summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', - summaryEnabled: false, - bannedWords: [], - messages: [], - triggerNext: false, - }; - let loadedContext: Partial = {}; try { @@ -95,18 +107,18 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers } } catch { } - return { ...defaultContext, ...loadedContext }; + return { ...DEFAULT_CONTEXT, ...loadedContext }; } -export type IStateContext = IContext & IActions; +export type IStateContext = IContext & IActions & IComputableContext; export const StateContext = createContext({} as IStateContext); export const StateContextProvider = ({ children }: { children?: any }) => { const loadedContext = useMemo(() => loadContext(), []); - const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl); + const [currentConnection, setCurrentConnection] = useState(loadedContext.currentConnection); + const [availableConnections, setAvailableConnections] = useState(loadedContext.availableConnections); const [input, setInput] = useInputState(loadedContext.input); - const [instruct, setInstruct] = useInputState(loadedContext.instruct); const [lore, setLore] = useInputState(loadedContext.lore); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); @@ -115,10 +127,26 @@ export const StateContextProvider = ({ children }: { children?: any }) => { const [messages, setMessages] = useState(loadedContext.messages); const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled); + const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0]; + const [triggerNext, setTriggerNext] = useState(false); + const [instruct, setInstruct] = useInputState(connection.instruct); + + const setConnection = useCallback((c: IConnection) => { + setAvailableConnections(availableConnections.map((ac, ai) => { + if (ai === currentConnection) { + return c; + } else { + return ac; + } + })); + }, [availableConnections, currentConnection]); + + useEffect(() => setConnection({ ...connection, instruct }), [instruct]); const actions: IActions = useMemo(() => ({ - setConnectionUrl, + setConnection, + setCurrentConnection, setInput, setInstruct, setSystemPrompt, @@ -127,7 +155,8 @@ export const StateContextProvider = ({ children }: { children?: any }) => { setLore, setTriggerNext, setSummaryEnabled, - setBannedWords: (words) => setBannedWords([...words]), + setBannedWords: (words) => setBannedWords(words.slice()), + setAvailableConnections: (connections) => setAvailableConnections(connections.slice()), setMessages: (newMessages) => setMessages(newMessages.slice()), addMessage: (content, role, triggerNext = false) => { @@ -198,10 +227,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => { continueMessage: () => setTriggerNext(true), }), []); - const rawContext: IContext = { - connectionUrl, + const rawContext: IContext & IComputableContext = { + connection, + currentConnection, + availableConnections, input, - instruct, systemPrompt, lore, userPrompt, diff --git a/src/games/ai/huggingface.ts b/src/games/ai/huggingface.ts index 7f4ccf3..3fd6ddb 100644 --- a/src/games/ai/huggingface.ts +++ b/src/games/ai/huggingface.ts @@ -230,7 +230,9 @@ export namespace Huggingface { } export const findModelTemplate = async (modelName: string): Promise => { - const modelKey = modelName.toLowerCase(); + const modelKey = modelName.toLowerCase().trim(); + if (!modelKey) return ''; + let template = templateCache[modelKey] ?? null; if (template) {