import Lock from "@common/lock"; import SSE from "@common/sse"; import { createContext } from "preact"; import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks"; import { MessageTools, type IMessage } from "../messages"; import { Instruct, StateContext } from "./state"; import { useBool } from "@common/hooks/useBool"; import { Template } from "@huggingface/jinja"; import { Huggingface } from "../huggingface"; import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection"; interface ICompileArgs { keepUsers?: number; raw?: boolean; } interface ICompiledPrompt { prompt: string; isContinue: boolean; isRegen: boolean; } interface IContext { generating: boolean; blockConnection: ReturnType; modelName: string; modelTemplate: string; hasToolCalls: boolean; promptTokens: number; contextLength: number; } const MESSAGES_TO_KEEP = 10; interface IActions { compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator; summarize: (content: string) => Promise; countTokens: (prompt: string) => Promise; } export type ILLMContext = IContext & IActions; export const LLMContext = createContext({} as ILLMContext); const processing = { tokenizing: false, summarizing: false, } export const LLMContextProvider = ({ children }: { children?: any }) => { const { connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, setTriggerNext, addMessage, editMessage, editSummary, setInstruct, } = useContext(StateContext); const generating = useBool(false); const blockConnection = useBool(false); const [promptTokens, setPromptTokens] = useState(0); const [contextLength, setContextLength] = useState(0); const [modelName, setModelName] = useState(''); const [modelTemplate, setModelTemplate] = useState(''); const [hasToolCalls, setHasToolCalls] = useState(false); const userPromptTemplate = useMemo(() => { try { return new Template(userPrompt) } catch { return { render: () => userPrompt, } } }, [userPrompt]); const getContextLength = useCallback(async () => { if (!connection || blockConnection.value) { return 0; } return Connection.getContextLength(connection); }, [connection, blockConnection.value]); const getModelName = useCallback(async () => { if (!connection || blockConnection.value) { return ''; } return Connection.getModelName(connection); }, [connection, blockConnection.value]); const actions: IActions = useMemo(() => ({ compilePrompt: async (messages, { keepUsers } = {}) => { const promptMessages = messages.slice(); const lastMessage = promptMessages.at(-1); const isAssistantLast = lastMessage?.role === 'assistant'; const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content; const isContinue = isAssistantLast && !isRegen; if (isContinue) { promptMessages.push(MessageTools.create(userPromptTemplate.render({}))); } const userMessages = promptMessages.filter(m => m.role === 'user'); const lastUserMessage = userMessages.at(-1); const firstUserMessage = userMessages.at(0); const templateMessages: Huggingface.ITemplateMessage[] = [ { role: 'system', content: systemPrompt.trim() }, ]; if (keepUsers) { let usersRemaining = messages.filter(m => m.role === 'user').length; let wasStory = false; for (const message of messages) { const { role } = message; const swipe = MessageTools.getSwipe(message); let content = swipe?.content ?? ''; if (role === 'user' && usersRemaining > keepUsers) { usersRemaining--; } else if (role === 'assistant' && templateMessages.at(-1).role === 'assistant') { wasStory = true; templateMessages.at(-1).content += '\n\n' + content; } else if (role === 'user' && !message.technical) { templateMessages.push({ role: message.role, content: userPromptTemplate.render({ prompt: content, isStart: !wasStory }), }); } else { if (role === 'assistant') { wasStory = true; } templateMessages.push({ role, content }); } } } else { const story = promptMessages.filter(m => m.role === 'assistant') .map((m, i, msgs) => { const swipe = MessageTools.getSwipe(m); if (!swipe) return ''; let { content, summary } = swipe; if (summary && i < msgs.length - MESSAGES_TO_KEEP) { content = summary; } return content; }).join('\n\n'); if (story.length > 0) { const prompt = MessageTools.getSwipe(firstUserMessage)?.content; templateMessages.push({ role: 'user', content: userPromptTemplate.render({ prompt, isStart: true }) }); templateMessages.push({ role: 'assistant', content: story }); } let userPrompt = MessageTools.getSwipe(lastUserMessage)?.content; if (!lastUserMessage?.technical && !isContinue && userPrompt) { userPrompt = userPromptTemplate.render({ prompt: userPrompt, isStart: story.length === 0 }); } if (userPrompt) { templateMessages.push({ role: 'user', content: userPrompt }); } } if (templateMessages[1]?.role !== 'user') { const prompt = MessageTools.getSwipe(firstUserMessage)?.content; templateMessages.splice(1, 0, { role: 'user', content: userPromptTemplate.render({ prompt, isStart: true }), }); } templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages); return { prompt, isContinue, isRegen, }; }, generate: async function* (prompt, extraSettings = {}) { try { generating.setTrue(); console.log('[LLM.generate]', prompt); yield* Connection.generate(connection, prompt, { ...extraSettings, banned_tokens: bannedWords.filter(w => w.trim()), }); } finally { generating.setFalse(); } }, summarize: async (message) => { const content = Huggingface.applyTemplate(summarizePrompt, { message }); const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); const tokens = await Array.fromAsync(actions.generate(prompt)); return MessageTools.trimSentence(tokens.join('')); }, countTokens: async (prompt) => { return await Connection.countTokens(connection, prompt); }, }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]); useEffect(() => void (async () => { if (triggerNext && !generating.value) { setTriggerNext(false); let messageId = messages.length - 1; let text: string = ''; const { prompt, isRegen } = await actions.compilePrompt(messages); if (!isRegen) { addMessage('', 'assistant'); messageId++; } editSummary(messageId, 'Generating...'); for await (const chunk of actions.generate(prompt)) { text += chunk; setPromptTokens(promptTokens + approximateTokens(text)); editMessage(messageId, text.trim()); } text = MessageTools.trimSentence(text); editMessage(messageId, text); editSummary(messageId, ''); MessageTools.playReady(); } })(), [triggerNext]); useEffect(() => void (async () => { if (summaryEnabled && !generating.value && !processing.summarizing) { try { processing.summarizing = true; for (let id = 0; id < messages.length; id++) { const message = messages[id]; const swipe = MessageTools.getSwipe(message); if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) { const summary = await actions.summarize(swipe.content); editSummary(id, summary); } } } catch (e) { console.error(`Could not summarize`, e) } finally { processing.summarizing = false; } } })(), [messages]); useEffect(() => { if (!blockConnection.value) { setPromptTokens(0); setContextLength(0); setModelName(''); getContextLength().then(setContextLength); getModelName().then(normalizeModel).then(setModelName); } }, [connection, blockConnection.value]); useEffect(() => { setModelTemplate(''); if (modelName) { Huggingface.findModelTemplate(modelName) .then((template) => { if (template) { setModelTemplate(template); setInstruct(template); } else { setInstruct(Instruct.CHATML); } }); } }, [modelName]); const calculateTokens = useCallback(async () => { if (!processing.tokenizing && !blockConnection.value && !generating.value) { try { processing.tokenizing = true; const { prompt } = await actions.compilePrompt(messages); const tokens = await actions.countTokens(prompt); setPromptTokens(tokens); } catch (e) { console.error(`Could not count tokens`, e) } finally { processing.tokenizing = false; } } }, [actions, messages, blockConnection.value]); useEffect(() => { calculateTokens(); }, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]); useEffect(() => { try { const hasTools = Huggingface.testToolCalls(connection.instruct); setHasToolCalls(hasTools); } catch { setHasToolCalls(false); } }, [connection.instruct]); const rawContext: IContext = { generating: generating.value, blockConnection, modelName, modelTemplate, hasToolCalls, promptTokens, contextLength, }; const context = useMemo(() => rawContext, Object.values(rawContext)); const value = useMemo(() => ({ ...context, ...actions }), [context, actions]) return ( {children} ); }