import Lock from "@common/lock"; import SSE from "@common/sse"; import { GENERATION_SETTINGS } from "../const"; import { createContext } from "preact"; import { useContext, useEffect, useMemo } from "preact/hooks"; import { MessageTools } from "../messages"; import { StateContext } from "./state"; import { useBool } from "@common/hooks/useBool"; interface IContext { generating: boolean; } interface IActions { generate: (prompt: string, extraSettings?: Partial) => AsyncGenerator; countTokens(prompt: string): Promise; getContextLength(): Promise; } export type ILLMContext = IContext & IActions; export const LLMContext = createContext({} as ILLMContext); export const LLMContextProvider = ({ children }: { children?: any }) => { const { connectionUrl, messages, triggerNext, setTriggerNext, addMessage, editMessage } = useContext(StateContext); const generating = useBool(false); const actions: IActions = useMemo(() => ({ generate: async function* (prompt, extraSettings = {}) { if (!connectionUrl) { return; } try { generating.setTrue(); console.log('[LLM.generate]', prompt); const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, { payload: JSON.stringify({ ...GENERATION_SETTINGS, ...extraSettings, prompt, }), }); const messages: string[] = []; const messageLock = new Lock(); let end = false; sse.addEventListener('message', (e) => { if (e.data) { { const { token, finish_reason } = JSON.parse(e.data); messages.push(token); if (finish_reason && finish_reason !== 'null') { end = true; } } } messageLock.release(); }); const handleEnd = () => { end = true; messageLock.release(); }; sse.addEventListener('error', handleEnd); sse.addEventListener('abort', handleEnd); sse.addEventListener('readystatechange', (e) => { if (e.readyState === SSE.CLOSED) handleEnd(); }); while (!end || messages.length) { while (messages.length > 0) { const message = messages.shift(); if (message != null) { try { yield message; } catch { } } } if (!end) { await messageLock.wait(); } } sse.close(); } finally { generating.setFalse(); } }, countTokens: async (prompt) => { if (!connectionUrl) { return 0; } try { const response = await fetch(`${connectionUrl}/api/extra/tokencount`, { body: JSON.stringify({ prompt }), headers: { 'Content-Type': 'applicarion/json' }, method: 'POST', }); if (response.ok) { const { value } = await response.json(); return value; } } catch (e) { console.log('Error counting tokens', e); } return 0; }, getContextLength: async() => { if (!connectionUrl) { return 0; } try { const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`); if (response.ok) { const { value } = await response.json(); return value; } } catch (e) { console.log('Error getting max tokens', e); } return 0; }, }), [connectionUrl]); useEffect(() => void (async () => { if (triggerNext && !generating) { setTriggerNext(false); let messageId = messages.length - 1; let text: string = ''; const { prompt, isRegen } = await MessageTools.compilePrompt(messages); if (!isRegen) { addMessage('', 'assistant'); messageId++; } for await (const chunk of actions.generate(prompt)) { text += chunk; editMessage(messageId, text); } text = MessageTools.trimSentence(text); editMessage(messageId, text); MessageTools.playReady(); } })(), [triggerNext, messages, generating]); const rawContext: IContext = { generating: generating.value, }; const context = useMemo(() => rawContext, Object.values(rawContext)); const value = useMemo(() => ({ ...context, ...actions }), [context, actions]) return ( {children} ); }