1
0
Fork 0
tsgames/src/games/ai/contexts/llm.tsx

327 lines
12 KiB
TypeScript

import Lock from "@common/lock";
import SSE from "@common/sse";
import { createContext } from "preact";
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages";
import { Instruct, StateContext } from "./state";
import { useBool } from "@common/hooks/useBool";
import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface";
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
interface ICompileArgs {
keepUsers?: number;
raw?: boolean;
}
interface ICompiledPrompt {
prompt: string;
isContinue: boolean;
isRegen: boolean;
}
interface IContext {
generating: boolean;
blockConnection: ReturnType<typeof useBool>;
modelName: string;
modelTemplate: string;
hasToolCalls: boolean;
promptTokens: number;
contextLength: number;
}
const MESSAGES_TO_KEEP = 10;
interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
summarize: (content: string) => Promise<string>;
countTokens: (prompt: string) => Promise<number>;
}
export type ILLMContext = IContext & IActions;
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
const processing = {
tokenizing: false,
summarizing: false,
}
export const LLMContextProvider = ({ children }: { children?: any }) => {
const {
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
} = useContext(StateContext);
const generating = useBool(false);
const blockConnection = useBool(false);
const [promptTokens, setPromptTokens] = useState(0);
const [contextLength, setContextLength] = useState(0);
const [modelName, setModelName] = useState('');
const [modelTemplate, setModelTemplate] = useState('');
const [hasToolCalls, setHasToolCalls] = useState(false);
const userPromptTemplate = useMemo(() => {
try {
return new Template(userPrompt)
} catch {
return {
render: () => userPrompt,
}
}
}, [userPrompt]);
const getContextLength = useCallback(async () => {
if (!connection || blockConnection.value) {
return 0;
}
return Connection.getContextLength(connection);
}, [connection, blockConnection.value]);
const getModelName = useCallback(async () => {
if (!connection || blockConnection.value) {
return '';
}
return Connection.getModelName(connection);
}, [connection, blockConnection.value]);
const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers } = {}) => {
const promptMessages = messages.slice();
const lastMessage = promptMessages.at(-1);
const isAssistantLast = lastMessage?.role === 'assistant';
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
const isContinue = isAssistantLast && !isRegen;
if (isContinue) {
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
}
const userMessages = promptMessages.filter(m => m.role === 'user');
const lastUserMessage = userMessages.at(-1);
const firstUserMessage = userMessages.at(0);
const templateMessages: Huggingface.ITemplateMessage[] = [
{ role: 'system', content: systemPrompt.trim() },
];
if (keepUsers) {
let usersRemaining = messages.filter(m => m.role === 'user').length;
let wasStory = false;
for (const message of messages) {
const { role } = message;
const swipe = MessageTools.getSwipe(message);
let content = swipe?.content ?? '';
if (role === 'user' && usersRemaining > keepUsers) {
usersRemaining--;
} else if (role === 'assistant' && templateMessages.at(-1).role === 'assistant') {
wasStory = true;
templateMessages.at(-1).content += '\n\n' + content;
} else if (role === 'user' && !message.technical) {
templateMessages.push({
role: message.role,
content: userPromptTemplate.render({ prompt: content, isStart: !wasStory }),
});
} else {
if (role === 'assistant') {
wasStory = true;
}
templateMessages.push({ role, content });
}
}
} else {
const story = promptMessages.filter(m => m.role === 'assistant')
.map((m, i, msgs) => {
const swipe = MessageTools.getSwipe(m);
if (!swipe) return '';
let { content, summary } = swipe;
if (summary && i < msgs.length - MESSAGES_TO_KEEP) {
content = summary;
}
return content;
}).join('\n\n');
if (story.length > 0) {
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
templateMessages.push({ role: 'user', content: userPromptTemplate.render({ prompt, isStart: true }) });
templateMessages.push({ role: 'assistant', content: story });
}
let userPrompt = MessageTools.getSwipe(lastUserMessage)?.content;
if (!lastUserMessage?.technical && !isContinue && userPrompt) {
userPrompt = userPromptTemplate.render({ prompt: userPrompt, isStart: story.length === 0 });
}
if (userPrompt) {
templateMessages.push({ role: 'user', content: userPrompt });
}
}
if (templateMessages[1]?.role !== 'user') {
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
templateMessages.splice(1, 0, {
role: 'user',
content: userPromptTemplate.render({ prompt, isStart: true }),
});
}
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
return {
prompt,
isContinue,
isRegen,
};
},
generate: async function* (prompt, extraSettings = {}) {
try {
generating.setTrue();
console.log('[LLM.generate]', prompt);
yield* Connection.generate(connection, prompt, {
...extraSettings,
banned_tokens: bannedWords.filter(w => w.trim()),
});
} finally {
generating.setFalse();
}
},
summarize: async (message) => {
const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
const tokens = await Array.fromAsync(actions.generate(prompt));
return MessageTools.trimSentence(tokens.join(''));
},
countTokens: async (prompt) => {
return await Connection.countTokens(connection, prompt);
},
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
useEffect(() => void (async () => {
if (triggerNext && !generating.value) {
setTriggerNext(false);
let messageId = messages.length - 1;
let text: string = '';
const { prompt, isRegen } = await actions.compilePrompt(messages);
if (!isRegen) {
addMessage('', 'assistant');
messageId++;
}
editSummary(messageId, 'Generating...');
for await (const chunk of actions.generate(prompt)) {
text += chunk;
setPromptTokens(promptTokens + approximateTokens(text));
editMessage(messageId, text.trim());
}
text = MessageTools.trimSentence(text);
editMessage(messageId, text);
editSummary(messageId, '');
MessageTools.playReady();
}
})(), [triggerNext]);
useEffect(() => void (async () => {
if (summaryEnabled && !generating.value && !processing.summarizing) {
try {
processing.summarizing = true;
for (let id = 0; id < messages.length; id++) {
const message = messages[id];
const swipe = MessageTools.getSwipe(message);
if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) {
const summary = await actions.summarize(swipe.content);
editSummary(id, summary);
}
}
} catch (e) {
console.error(`Could not summarize`, e)
} finally {
processing.summarizing = false;
}
}
})(), [messages]);
useEffect(() => {
if (!blockConnection.value) {
setPromptTokens(0);
setContextLength(0);
setModelName('');
getContextLength().then(setContextLength);
getModelName().then(normalizeModel).then(setModelName);
}
}, [connection, blockConnection.value]);
useEffect(() => {
setModelTemplate('');
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then((template) => {
if (template) {
setModelTemplate(template);
setInstruct(template);
} else {
setInstruct(Instruct.CHATML);
}
});
}
}, [modelName]);
const calculateTokens = useCallback(async () => {
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
try {
processing.tokenizing = true;
const { prompt } = await actions.compilePrompt(messages);
const tokens = await actions.countTokens(prompt);
setPromptTokens(tokens);
} catch (e) {
console.error(`Could not count tokens`, e)
} finally {
processing.tokenizing = false;
}
}
}, [actions, messages, blockConnection.value]);
useEffect(() => {
calculateTokens();
}, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]);
useEffect(() => {
try {
const hasTools = Huggingface.testToolCalls(connection.instruct);
setHasToolCalls(hasTools);
} catch {
setHasToolCalls(false);
}
}, [connection.instruct]);
const rawContext: IContext = {
generating: generating.value,
blockConnection,
modelName,
modelTemplate,
hasToolCalls,
promptTokens,
contextLength,
};
const context = useMemo(() => rawContext, Object.values(rawContext));
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
return (
<LLMContext.Provider value={value}>
{children}
</LLMContext.Provider>
);
}