diff --git a/bun.lockb b/bun.lockb index d9cef44..40e1e24 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index b2190b0..404c108 100644 --- a/package.json +++ b/package.json @@ -14,6 +14,7 @@ "@inquirer/select": "2.3.10", "ace-builds": "1.36.3", "classnames": "2.5.1", + "delay": "6.0.0", "preact": "10.22.0" }, "devDependencies": { diff --git a/src/common/hooks/useAsyncEffect.ts b/src/common/hooks/useAsyncEffect.ts new file mode 100644 index 0000000..ef80bc5 --- /dev/null +++ b/src/common/hooks/useAsyncEffect.ts @@ -0,0 +1,4 @@ +import { useEffect } from "preact/hooks"; + +export const useAsyncEffect = (fx: () => any, deps: any[]) => + useEffect(() => void fx(), deps); diff --git a/src/common/utils.ts b/src/common/utils.ts index 3d00c8b..1deed16 100644 --- a/src/common/utils.ts +++ b/src/common/utils.ts @@ -1,4 +1,3 @@ -export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); export const nextFrame = async (): Promise => new Promise((resolve) => requestAnimationFrame(resolve)); export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random()); @@ -50,20 +49,30 @@ export const intHash = (seed: number, ...parts: number[]) => { return h1; }; export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; -export const throttle = function R>(func: F, ms: number): F { +export const throttle = function R>(func: F, ms: number, trailing = false): F { let isThrottled = false; let savedResult: R; + let savedThis: T; + let savedArgs: A | undefined; const wrapper: F = function (...args: A) { - if (!isThrottled) { + if (isThrottled) { + savedThis = this; + savedArgs = args; + } else { savedResult = func.apply(this, args); + savedArgs = undefined; isThrottled = true; setTimeout(function () { isThrottled = false; + if (trailing && savedArgs) { + savedResult = wrapper.apply(savedThis, savedArgs); + } }, ms); } + return savedResult; } as F; diff --git a/src/games/ai/components/input.tsx b/src/games/ai/components/input.tsx index df02a3c..70beb82 100644 --- a/src/games/ai/components/input.tsx +++ b/src/games/ai/components/input.tsx @@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea"; export const Input = () => { const { input, setInput, addMessage, continueMessage } = useContext(StateContext); - const { generating } = useContext(LLMContext); + const { generating, stopGeneration } = useContext(LLMContext); const handleSend = useCallback(async () => { if (!generating) { @@ -29,7 +29,10 @@ export const Input = () => { return (
- + {generating + ? + : + }
); } \ No newline at end of file diff --git a/src/games/ai/components/minichat/minichat.tsx b/src/games/ai/components/minichat/minichat.tsx index d096057..7e5b59f 100644 --- a/src/games/ai/components/minichat/minichat.tsx +++ b/src/games/ai/components/minichat/minichat.tsx @@ -16,7 +16,7 @@ interface IProps { } export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => { - const { generating, generate, compilePrompt } = useContext(LLMContext); + const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext); const [messages, setMessages] = useState([]); const ref = useRef(null); @@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
- + {generating + ? + : + } diff --git a/src/games/ai/connection.ts b/src/games/ai/connection.ts index 0f6daa5..dd111c8 100644 --- a/src/games/ai/connection.ts +++ b/src/games/ai/connection.ts @@ -1,6 +1,7 @@ import Lock from "@common/lock"; import SSE from "@common/sse"; -import { delay, throttle } from "@common/utils"; +import { throttle } from "@common/utils"; +import delay, { clearDelay } from "delay"; interface IBaseConnection { instruct: string; @@ -72,7 +73,7 @@ const DEFAULT_GENERATION_SETTINGS = { dry_penalty_last_n: 0 } -const MIN_PERFORMANCE = 5.0; +const MIN_PERFORMANCE = 2.0; const MIN_WORKER_CONTEXT = 8192; const MAX_HORDE_LENGTH = 512; const MAX_HORDE_CONTEXT = 32000; @@ -88,7 +89,7 @@ export const normalizeModel = (model: string) => { currentModel = currentModel .replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k - .replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name + .replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name .replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc .replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size .replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw @@ -104,14 +105,15 @@ export const normalizeModel = (model: string) => { .trim(); } -export const approximateTokens = (prompt: string): number => - Math.round(prompt.split(/\s+/).length * 0.75); +export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length; export type IGenerationSettings = Partial; export namespace Connection { const AIHORDE = 'https://aihorde.net'; + let abortController = new AbortController(); + async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator { const sse = new SSE(`${url}/api/extra/generate/stream`, { payload: JSON.stringify({ @@ -144,12 +146,14 @@ export namespace Connection { messageLock.release(); }; + abortController.signal.addEventListener('abort', handleEnd); sse.addEventListener('error', handleEnd); sse.addEventListener('abort', handleEnd); sse.addEventListener('readystatechange', (e) => { if (e.readyState === SSE.CLOSED) handleEnd(); }); + while (!end || messages.length) { while (messages.length > 0) { const message = messages.shift(); @@ -189,6 +193,8 @@ export namespace Connection { workers: model.workers, }; + const { signal } = abortController; + const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, { method: 'POST', body: JSON.stringify(requestData), @@ -196,31 +202,44 @@ export namespace Connection { 'Content-Type': 'application/json', apikey: connection.apiKey || HORDE_ANON_KEY, }, + signal, }); - if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) { + if (!generateResponse.ok || generateResponse.status >= 400) { throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`); } const { id } = await generateResponse.json() as { id: string }; - const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' }) - .catch(e => console.error('Error deleting request', e)); + const request = async (method = 'GET'): Promise => { + const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method }); + if (response.ok && response.status < 400) { + const result: IHordeResult = await response.json(); + if (result.generations?.length === 1) { + const { text } = result.generations[0]; - while (true) { - await delay(2500); - - const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`); - if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) { - deleteRequest(); - throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`); + return text; + } + } else { + throw new Error(await response.text()); } - const result: IHordeResult = await retrieveResponse.json(); + return null; + }; - if (result.done && result.generations?.length === 1) { - const { text } = result.generations[0]; + const deleteRequest = async () => (await request('DELETE')) ?? ''; - return text; + while (true) { + try { + await delay(2500, { signal }); + + const text = await request(); + + if (text) { + return text; + } + } catch (e) { + console.error('Error in horde generation:', e); + return deleteRequest(); } } } @@ -236,15 +255,20 @@ export namespace Connection { } } + export function stopGeneration() { + abortController.abort(); + abortController = new AbortController(); // refresh + } + async function requestHordeModels(): Promise> { try { const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`); if (response.ok) { const workers: IHordeWorker[] = await response.json(); - const goodWorkers = workers.filter(w => - w.online - && !w.maintenance_mode - && !w.flagged + const goodWorkers = workers.filter(w => + w.online + && !w.maintenance_mode + && !w.flagged && w.max_context_length >= MIN_WORKER_CONTEXT && parseFloat(w.performance) >= MIN_PERFORMANCE ); @@ -299,7 +323,7 @@ export namespace Connection { return result; } } catch (e) { - console.log('Error getting max tokens', e); + console.error('Error getting max tokens', e); } } else if (isHordeConnection(connection)) { return connection.model; @@ -317,7 +341,7 @@ export namespace Connection { return value; } } catch (e) { - console.log('Error getting max tokens', e); + console.error('Error getting max tokens', e); } } else if (isHordeConnection(connection)) { const models = await getHordeModels(); @@ -343,7 +367,7 @@ export namespace Connection { return value; } } catch (e) { - console.log('Error counting tokens', e); + console.error('Error counting tokens', e); } } diff --git a/src/games/ai/contexts/llm.tsx b/src/games/ai/contexts/llm.tsx index 4a7eaae..e9f85a9 100644 --- a/src/games/ai/contexts/llm.tsx +++ b/src/games/ai/contexts/llm.tsx @@ -1,13 +1,13 @@ -import Lock from "@common/lock"; -import SSE from "@common/sse"; import { createContext } from "preact"; import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks"; import { MessageTools, type IMessage } from "../messages"; -import { Instruct, StateContext } from "./state"; +import { StateContext } from "./state"; import { useBool } from "@common/hooks/useBool"; import { Template } from "@huggingface/jinja"; import { Huggingface } from "../huggingface"; import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection"; +import { throttle } from "@common/utils"; +import { useAsyncEffect } from "@common/hooks/useAsyncEffect"; interface ICompileArgs { keepUsers?: number; @@ -22,9 +22,7 @@ interface ICompiledPrompt { interface IContext { generating: boolean; - blockConnection: ReturnType; modelName: string; - modelTemplate: string; hasToolCalls: boolean; promptTokens: number; contextLength: number; @@ -35,6 +33,7 @@ const MESSAGES_TO_KEEP = 10; interface IActions { compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator; + stopGeneration: () => void; summarize: (content: string) => Promise; countTokens: (prompt: string) => Promise; } @@ -50,15 +49,13 @@ const processing = { export const LLMContextProvider = ({ children }: { children?: any }) => { const { connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, - setTriggerNext, addMessage, editMessage, editSummary, setInstruct, + setTriggerNext, addMessage, editMessage, editSummary, } = useContext(StateContext); const generating = useBool(false); - const blockConnection = useBool(false); const [promptTokens, setPromptTokens] = useState(0); const [contextLength, setContextLength] = useState(0); const [modelName, setModelName] = useState(''); - const [modelTemplate, setModelTemplate] = useState(''); const [hasToolCalls, setHasToolCalls] = useState(false); const userPromptTemplate = useMemo(() => { @@ -71,20 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { } }, [userPrompt]); - const getContextLength = useCallback(async () => { - if (!connection || blockConnection.value) { - return 0; - } - return Connection.getContextLength(connection); - }, [connection, blockConnection.value]); - - const getModelName = useCallback(async () => { - if (!connection || blockConnection.value) { - return ''; - } - return Connection.getModelName(connection); - }, [connection, blockConnection.value]); - const actions: IActions = useMemo(() => ({ compilePrompt: async (messages, { keepUsers } = {}) => { const promptMessages = messages.slice(); @@ -179,31 +162,43 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }, generate: async function* (prompt, extraSettings = {}) { try { - generating.setTrue(); console.log('[LLM.generate]', prompt); yield* Connection.generate(connection, prompt, { - ...extraSettings, + ...extraSettings, banned_tokens: bannedWords.filter(w => w.trim()), }); - } finally { - generating.setFalse(); + } catch (e) { + if (e instanceof Error && e.name !== 'AbortError') { + alert(e.message); + } else { + console.error('[LLM.generate]', e); + } } }, summarize: async (message) => { - const content = Huggingface.applyTemplate(summarizePrompt, { message }); - const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); + try { + const content = Huggingface.applyTemplate(summarizePrompt, { message }); + const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); + console.log('[LLM.summarize]', prompt); - const tokens = await Array.fromAsync(actions.generate(prompt)); + const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {})); - return MessageTools.trimSentence(tokens.join('')); + return MessageTools.trimSentence(tokens.join('')); + } catch (e) { + console.error('Error summarizing:', e); + return ''; + } }, countTokens: async (prompt) => { return await Connection.countTokens(connection, prompt); }, + stopGeneration: () => { + Connection.stopGeneration(); + }, }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]); - useEffect(() => void (async () => { + useAsyncEffect(async () => { if (triggerNext && !generating.value) { setTriggerNext(false); @@ -217,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { messageId++; } + generating.setTrue(); editSummary(messageId, 'Generating...'); for await (const chunk of actions.generate(prompt)) { text += chunk; setPromptTokens(promptTokens + approximateTokens(text)); editMessage(messageId, text.trim()); } + generating.setFalse(); text = MessageTools.trimSentence(text); editMessage(messageId, text); @@ -230,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { MessageTools.playReady(); } - })(), [triggerNext]); + }, [triggerNext]); - useEffect(() => void (async () => { - if (summaryEnabled && !generating.value && !processing.summarizing) { + useAsyncEffect(async () => { + if (summaryEnabled && !processing.summarizing) { try { processing.summarizing = true; for (let id = 0; id < messages.length; id++) { @@ -250,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { processing.summarizing = false; } } - })(), [messages]); + }, [messages, summaryEnabled]); - useEffect(() => { - if (!blockConnection.value) { - setPromptTokens(0); - setContextLength(0); - setModelName(''); + useEffect(throttle(() => { + Connection.getContextLength(connection).then(setContextLength); + Connection.getModelName(connection).then(normalizeModel).then(setModelName); + }, 1000, true), [connection]); - getContextLength().then(setContextLength); - getModelName().then(normalizeModel).then(setModelName); - } - }, [connection, blockConnection.value]); - - useEffect(() => { - setModelTemplate(''); - if (modelName) { - Huggingface.findModelTemplate(modelName) - .then((template) => { - if (template) { - setModelTemplate(template); - setInstruct(template); - } else { - setInstruct(Instruct.CHATML); - } - }); - } - }, [modelName]); - - const calculateTokens = useCallback(async () => { - if (!processing.tokenizing && !blockConnection.value && !generating.value) { + const calculateTokens = useCallback(throttle(async () => { + if (!processing.tokenizing && !generating.value) { try { processing.tokenizing = true; const { prompt } = await actions.compilePrompt(messages); @@ -291,11 +267,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { processing.tokenizing = false; } } - }, [actions, messages, blockConnection.value]); + }, 1000, true), [actions, messages]); useEffect(() => { calculateTokens(); - }, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]); + }, [messages, connection, systemPrompt, lore, userPrompt]); useEffect(() => { try { @@ -308,9 +284,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { const rawContext: IContext = { generating: generating.value, - blockConnection, modelName, - modelTemplate, hasToolCalls, promptTokens, contextLength, diff --git a/src/games/ai/contexts/state.tsx b/src/games/ai/contexts/state.tsx index d7a076b..aaa8334 100644 --- a/src/games/ai/contexts/state.tsx +++ b/src/games/ai/contexts/state.tsx @@ -83,7 +83,7 @@ What should happen next in your answer: {{ prompt | trim }} {% endif %} Remember that this story should be infinite and go forever. Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, - summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', + summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', summaryEnabled: false, bannedWords: [], messages: [], diff --git a/src/games/ai/huggingface.ts b/src/games/ai/huggingface.ts index 3fd6ddb..630504c 100644 --- a/src/games/ai/huggingface.ts +++ b/src/games/ai/huggingface.ts @@ -1,6 +1,7 @@ import { gguf } from '@huggingface/gguf'; import * as hub from '@huggingface/hub'; import { Template } from '@huggingface/jinja'; +import { normalizeModel } from './connection'; export namespace Huggingface { export interface ITemplateMessage { @@ -92,11 +93,12 @@ export namespace Huggingface { const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise => { console.log(`[huggingface] searching config for '${modelName}'`); + const searchModel = normalizeModel(modelName); - const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] })); + const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] })); const models = hubModels.filter(m => { if (m.gated) return false; - if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false; + if (!normalizeModel(m.name).includes(searchModel)) return false; return true; }).sort((a, b) => b.downloads - a.downloads);