1
0
Fork 0
tsgames/src/games/ai-story/contexts/llm.tsx

325 lines
12 KiB
TypeScript

import { createContext } from "preact";
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../tools/messages";
import { StateContext } from "./state";
import { useBool } from "@common/hooks/useBool";
import { Huggingface } from "../tools/huggingface";
import { Connection, type IGenerationSettings } from "../tools/connection";
import { throttle } from "@common/utils";
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
import { approximateTokens, normalizeModel } from "../tools/model";
interface ICompileArgs {
keepUsers?: number;
continueLast?: boolean;
}
interface ICompiledPrompt {
prompt: string;
isContinue: boolean;
isRegen: boolean;
}
interface IContext {
generating: boolean;
modelName: string;
hasToolCalls: boolean;
promptTokens: number;
contextLength: number;
spentKudos: number;
}
const MESSAGES_TO_KEEP = 10;
interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
stopGeneration: () => void;
summarize: (content: string) => Promise<string>;
countTokens: (prompt: string) => Promise<number>;
}
export type ILLMContext = IContext & IActions;
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
const processing = {
tokenizing: false,
summarizing: false,
}
export const LLMContextProvider = ({ children }: { children?: any }) => {
const {
connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, setContinueLast, addMessage, editMessage, editSummary, setTotalSpentKudos,
} = useContext(StateContext);
const generating = useBool(false);
const [promptTokens, setPromptTokens] = useState(0);
const [contextLength, setContextLength] = useState(0);
const [modelName, setModelName] = useState('');
const [hasToolCalls, setHasToolCalls] = useState(false);
const [spentKudos, setSpentKudos] = useState(0);
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => {
const lastMessage = messages.at(-1);
const lastMessageContent = MessageTools.getSwipe(lastMessage)?.content;
const isAssistantLast = lastMessage?.role === 'assistant';
let isRegen = continueLast;
if (!isAssistantLast) {
isRegen = false;
} else if (!lastMessageContent) {
isRegen = true;
}
const isContinue = isAssistantLast && !isRegen;
const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice();
if (isContinue) {
promptMessages.push(MessageTools.create(Huggingface.applyTemplate(userPrompt, {})));
}
const userMessages = promptMessages.filter(m => m.role === 'user');
const lastUserMessage = userMessages.at(-1);
const firstUserMessage = userMessages.at(0);
const templateMessages: Huggingface.ITemplateMessage[] = [
{ role: 'system', content: systemPrompt.trim() },
];
if (keepUsers) {
let usersRemaining = messages.filter(m => m.role === 'user').length;
let wasStory = false;
for (const message of messages) {
const { role } = message;
const swipe = MessageTools.getSwipe(message);
let content = swipe?.content ?? '';
if (role === 'user' && usersRemaining > keepUsers) {
usersRemaining--;
} else if (role === 'assistant' && templateMessages.at(-1).role === 'assistant') {
wasStory = true;
templateMessages.at(-1).content += '\n\n' + content;
} else if (role === 'user' && !message.technical) {
templateMessages.push({
role: message.role,
content: Huggingface.applyTemplate(userPrompt, { prompt: content, isStart: !wasStory }),
});
} else {
if (role === 'assistant') {
wasStory = true;
}
templateMessages.push({ role, content });
}
}
} else {
const story = promptMessages.filter(m => m.role === 'assistant')
.map((m, i, msgs) => {
const swipe = MessageTools.getSwipe(m);
if (!swipe) return '';
let { content, summary } = swipe;
if (summary && i < msgs.length - MESSAGES_TO_KEEP) {
content = summary;
}
return content;
}).join('\n\n');
if (story.length > 0) {
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
templateMessages.push({ role: 'user', content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }) });
templateMessages.push({ role: 'assistant', content: story });
}
let userMessage = MessageTools.getSwipe(lastUserMessage)?.content;
if (!lastUserMessage?.technical && !isContinue && userMessage) {
userMessage = Huggingface.applyTemplate(userPrompt, { prompt: userMessage, isStart: story.length === 0 });
}
if (userMessage) {
templateMessages.push({ role: 'user', content: userMessage });
}
}
if (templateMessages[1]?.role !== 'user') {
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
templateMessages.splice(1, 0, {
role: 'user',
content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }),
});
}
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
let prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
if (isRegen) {
prompt += lastMessageContent;
}
return {
prompt,
isContinue,
isRegen,
};
},
generate: async function* (prompt, extraSettings = {}) {
try {
console.log('[LLM.generate]', prompt);
setSpentKudos(0);
for await (const { text, cost } of Connection.generate(connection, prompt, {
...extraSettings,
banned_tokens: bannedWords.filter(w => w.trim()),
})) {
setSpentKudos(sk => sk + cost);
setTotalSpentKudos(sk => sk + cost);
yield text;
}
} catch (e) {
if (e instanceof Error && e.name !== 'AbortError') {
alert(e.message);
} else {
console.error('[LLM.generate]', e);
}
}
},
summarize: async (message) => {
try {
const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
console.log('[LLM.summarize]', prompt);
const tokens = await Array.fromAsync(Connection.generate(connection, prompt));
const summary = tokens.reduce((sum, token) => ({
text: sum.text + token.text,
cost: sum.cost + token.cost,
}), { text: '', cost: 0 });
setSpentKudos(sk => sk + summary.cost);
setTotalSpentKudos(sk => sk + summary.cost);
return MessageTools.trimSentence(summary.text);
} catch (e) {
console.error('Error summarizing:', e);
return '';
}
},
countTokens: async (prompt) => {
return await Connection.countTokens(connection, prompt);
},
stopGeneration: () => {
Connection.stopGeneration();
},
}), [connection, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt]);
useAsyncEffect(async () => {
if (isOnline && triggerNext && !generating.value) {
setTriggerNext(false);
setContinueLast(false);
let messageId = messages.length - 1;
let text = '';
const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast });
if (isRegen) {
text = MessageTools.getSwipe(messages.at(-1))?.content ?? '';
} else {
addMessage('', 'assistant');
messageId++;
}
generating.setTrue();
editSummary(messageId, 'Generating...');
for await (const chunk of actions.generate(prompt)) {
text += chunk;
setPromptTokens(promptTokens + approximateTokens(text));
editMessage(messageId, text.trim());
}
generating.setFalse();
text = MessageTools.trimSentence(text);
editMessage(messageId, text);
editSummary(messageId, '');
MessageTools.playReady();
}
}, [triggerNext, isOnline]);
useAsyncEffect(async () => {
if (isOnline && summaryEnabled && !processing.summarizing) {
try {
processing.summarizing = true;
for (let id = 0; id < messages.length; id++) {
const message = messages[id];
const swipe = MessageTools.getSwipe(message);
if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) {
const summary = await actions.summarize(swipe.content);
editSummary(id, summary);
}
}
} catch (e) {
console.error(`Could not summarize`, e)
} finally {
processing.summarizing = false;
}
}
}, [messages, summaryEnabled, isOnline]);
useEffect(throttle(() => {
Connection.getContextLength(connection).then(setContextLength);
Connection.getModelName(connection).then(normalizeModel).then(setModelName);
}, 1000, true), [connection]);
const calculateTokens = useCallback(throttle(async () => {
if (isOnline && !processing.tokenizing && !generating.value) {
try {
processing.tokenizing = true;
const { prompt } = await actions.compilePrompt(messages);
const tokens = await actions.countTokens(prompt);
setPromptTokens(tokens);
} catch (e) {
console.error(`Could not count tokens`, e)
} finally {
processing.tokenizing = false;
}
}
}, 1000, true), [actions, messages, isOnline]);
useEffect(() => {
calculateTokens();
}, [messages, connection, systemPrompt, lore, userPrompt, isOnline]);
useEffect(() => {
try {
const hasTools = Huggingface.testToolCalls(connection.instruct);
setHasToolCalls(hasTools);
} catch {
setHasToolCalls(false);
}
}, [connection.instruct]);
const rawContext: IContext = {
generating: generating.value,
modelName,
hasToolCalls,
promptTokens,
contextLength,
spentKudos,
};
const context = useMemo(() => rawContext, Object.values(rawContext));
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
return (
<LLMContext.Provider value={value}>
{children}
</LLMContext.Provider>
);
}