import { createContext } from "preact"; import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks"; import { MessageTools, type IMessage } from "../tools/messages"; import { StateContext } from "./state"; import { useBool } from "@common/hooks/useBool"; import { Huggingface } from "../tools/huggingface"; import { Connection, type IGenerationSettings } from "../tools/connection"; import { throttle } from "@common/utils"; import { useAsyncEffect } from "@common/hooks/useAsyncEffect"; import { approximateTokens, normalizeModel } from "../tools/model"; interface ICompileArgs { keepUsers?: number; continueLast?: boolean; } interface ICompiledPrompt { prompt: string; isContinue: boolean; isRegen: boolean; } interface IContext { generating: boolean; modelName: string; hasToolCalls: boolean; promptTokens: number; contextLength: number; spentKudos: number; } const MESSAGES_TO_KEEP = 10; interface IActions { compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator; stopGeneration: () => void; summarize: (content: string) => Promise; countTokens: (prompt: string) => Promise; } export type ILLMContext = IContext & IActions; export const LLMContext = createContext({} as ILLMContext); const processing = { tokenizing: false, summarizing: false, } export const LLMContextProvider = ({ children }: { children?: any }) => { const { connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, setTriggerNext, setContinueLast, addMessage, editMessage, editSummary, setTotalSpentKudos, } = useContext(StateContext); const generating = useBool(false); const [promptTokens, setPromptTokens] = useState(0); const [contextLength, setContextLength] = useState(0); const [modelName, setModelName] = useState(''); const [hasToolCalls, setHasToolCalls] = useState(false); const [spentKudos, setSpentKudos] = useState(0); const isOnline = useMemo(() => contextLength > 0, [contextLength]); const actions: IActions = useMemo(() => ({ compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => { const lastMessage = messages.at(-1); const lastMessageContent = MessageTools.getSwipe(lastMessage)?.content; const isAssistantLast = lastMessage?.role === 'assistant'; let isRegen = continueLast; if (!isAssistantLast) { isRegen = false; } else if (!lastMessageContent) { isRegen = true; } const isContinue = isAssistantLast && !isRegen; const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice(); if (isContinue) { promptMessages.push(MessageTools.create(Huggingface.applyTemplate(userPrompt, {}))); } const userMessages = promptMessages.filter(m => m.role === 'user'); const lastUserMessage = userMessages.at(-1); const firstUserMessage = userMessages.at(0); const templateMessages: Huggingface.ITemplateMessage[] = [ { role: 'system', content: systemPrompt.trim() }, ]; if (keepUsers) { let usersRemaining = messages.filter(m => m.role === 'user').length; let wasStory = false; for (const message of messages) { const { role } = message; const swipe = MessageTools.getSwipe(message); let content = swipe?.content ?? ''; if (role === 'user' && usersRemaining > keepUsers) { usersRemaining--; } else if (role === 'assistant' && templateMessages.at(-1).role === 'assistant') { wasStory = true; templateMessages.at(-1).content += '\n\n' + content; } else if (role === 'user' && !message.technical) { templateMessages.push({ role: message.role, content: Huggingface.applyTemplate(userPrompt, { prompt: content, isStart: !wasStory }), }); } else { if (role === 'assistant') { wasStory = true; } templateMessages.push({ role, content }); } } } else { const story = promptMessages.filter(m => m.role === 'assistant') .map((m, i, msgs) => { const swipe = MessageTools.getSwipe(m); if (!swipe) return ''; let { content, summary } = swipe; if (summary && i < msgs.length - MESSAGES_TO_KEEP) { content = summary; } return content; }).join('\n\n'); if (story.length > 0) { const prompt = MessageTools.getSwipe(firstUserMessage)?.content; templateMessages.push({ role: 'user', content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }) }); templateMessages.push({ role: 'assistant', content: story }); } let userMessage = MessageTools.getSwipe(lastUserMessage)?.content; if (!lastUserMessage?.technical && !isContinue && userMessage) { userMessage = Huggingface.applyTemplate(userPrompt, { prompt: userMessage, isStart: story.length === 0 }); } if (userMessage) { templateMessages.push({ role: 'user', content: userMessage }); } } if (templateMessages[1]?.role !== 'user') { const prompt = MessageTools.getSwipe(firstUserMessage)?.content; templateMessages.splice(1, 0, { role: 'user', content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }), }); } templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; let prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages); if (isRegen) { prompt += lastMessageContent; } return { prompt, isContinue, isRegen, }; }, generate: async function* (prompt, extraSettings = {}) { try { console.log('[LLM.generate]', prompt); setSpentKudos(0); for await (const { text, cost } of Connection.generate(connection, prompt, { ...extraSettings, banned_tokens: bannedWords.filter(w => w.trim()), })) { setSpentKudos(sk => sk + cost); setTotalSpentKudos(sk => sk + cost); yield text; } } catch (e) { if (e instanceof Error && e.name !== 'AbortError') { alert(e.message); } else { console.error('[LLM.generate]', e); } } }, summarize: async (message) => { try { const content = Huggingface.applyTemplate(summarizePrompt, { message }); const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); console.log('[LLM.summarize]', prompt); const tokens = await Array.fromAsync(Connection.generate(connection, prompt)); const summary = tokens.reduce((sum, token) => ({ text: sum.text + token.text, cost: sum.cost + token.cost, }), { text: '', cost: 0 }); setSpentKudos(sk => sk + summary.cost); setTotalSpentKudos(sk => sk + summary.cost); return MessageTools.trimSentence(summary.text); } catch (e) { console.error('Error summarizing:', e); return ''; } }, countTokens: async (prompt) => { return await Connection.countTokens(connection, prompt); }, stopGeneration: () => { Connection.stopGeneration(); }, }), [connection, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt]); useAsyncEffect(async () => { if (isOnline && triggerNext && !generating.value) { setTriggerNext(false); setContinueLast(false); let messageId = messages.length - 1; let text = ''; const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast }); if (isRegen) { text = MessageTools.getSwipe(messages.at(-1))?.content ?? ''; } else { addMessage('', 'assistant'); messageId++; } generating.setTrue(); editSummary(messageId, 'Generating...'); for await (const chunk of actions.generate(prompt)) { text += chunk; setPromptTokens(promptTokens + approximateTokens(text)); editMessage(messageId, text.trim()); } generating.setFalse(); text = MessageTools.trimSentence(text); editMessage(messageId, text); editSummary(messageId, ''); MessageTools.playReady(); } }, [triggerNext, isOnline]); useAsyncEffect(async () => { if (isOnline && summaryEnabled && !processing.summarizing) { try { processing.summarizing = true; for (let id = 0; id < messages.length; id++) { const message = messages[id]; const swipe = MessageTools.getSwipe(message); if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) { const summary = await actions.summarize(swipe.content); editSummary(id, summary); } } } catch (e) { console.error(`Could not summarize`, e) } finally { processing.summarizing = false; } } }, [messages, summaryEnabled, isOnline]); useEffect(throttle(() => { Connection.getContextLength(connection).then(setContextLength); Connection.getModelName(connection).then(normalizeModel).then(setModelName); }, 1000, true), [connection]); const calculateTokens = useCallback(throttle(async () => { if (isOnline && !processing.tokenizing && !generating.value) { try { processing.tokenizing = true; const { prompt } = await actions.compilePrompt(messages); const tokens = await actions.countTokens(prompt); setPromptTokens(tokens); } catch (e) { console.error(`Could not count tokens`, e) } finally { processing.tokenizing = false; } } }, 1000, true), [actions, messages, isOnline]); useEffect(() => { calculateTokens(); }, [messages, connection, systemPrompt, lore, userPrompt, isOnline]); useEffect(() => { try { const hasTools = Huggingface.testToolCalls(connection.instruct); setHasToolCalls(hasTools); } catch { setHasToolCalls(false); } }, [connection.instruct]); const rawContext: IContext = { generating: generating.value, modelName, hasToolCalls, promptTokens, contextLength, spentKudos, }; const context = useMemo(() => rawContext, Object.values(rawContext)); const value = useMemo(() => ({ ...context, ...actions }), [context, actions]) return ( {children} ); }