318 lines
12 KiB
TypeScript
318 lines
12 KiB
TypeScript
import { createContext } from "preact";
|
|
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
|
import { MessageTools, type IMessage } from "../tools/messages";
|
|
import { StateContext } from "./state";
|
|
import { useBool } from "@common/hooks/useBool";
|
|
import { Huggingface } from "../tools/huggingface";
|
|
import { Connection, type IGenerationSettings } from "../tools/connection";
|
|
import { throttle } from "@common/utils";
|
|
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
|
|
import { approximateTokens, normalizeModel } from "../tools/model";
|
|
|
|
interface ICompileArgs {
|
|
keepUsers?: number;
|
|
continueLast?: boolean;
|
|
}
|
|
|
|
interface ICompiledPrompt {
|
|
prompt: string;
|
|
isContinue: boolean;
|
|
isRegen: boolean;
|
|
}
|
|
|
|
interface IContext {
|
|
generating: boolean;
|
|
modelName: string;
|
|
hasToolCalls: boolean;
|
|
promptTokens: number;
|
|
contextLength: number;
|
|
spentKudos: number;
|
|
}
|
|
|
|
const MESSAGES_TO_KEEP = 10;
|
|
|
|
interface IActions {
|
|
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
|
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
|
stopGeneration: () => void;
|
|
summarize: (content: string) => Promise<string>;
|
|
countTokens: (prompt: string) => Promise<number>;
|
|
}
|
|
export type ILLMContext = IContext & IActions;
|
|
|
|
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
|
|
|
const processing = {
|
|
tokenizing: false,
|
|
summarizing: false,
|
|
}
|
|
|
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|
const {
|
|
connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
|
setTriggerNext, setContinueLast, addMessage, editMessage, editSummary, setTotalSpentKudos,
|
|
} = useContext(StateContext);
|
|
|
|
const generating = useBool(false);
|
|
const [promptTokens, setPromptTokens] = useState(0);
|
|
const [contextLength, setContextLength] = useState(0);
|
|
const [modelName, setModelName] = useState('');
|
|
const [hasToolCalls, setHasToolCalls] = useState(false);
|
|
const [spentKudos, setSpentKudos] = useState(0);
|
|
|
|
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
|
|
|
|
const actions: IActions = useMemo(() => ({
|
|
compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => {
|
|
const lastMessage = messages.at(-1);
|
|
const lastMessageContent = MessageTools.getSwipe(lastMessage)?.content;
|
|
const isAssistantLast = lastMessage?.role === 'assistant';
|
|
let isRegen = continueLast;
|
|
|
|
if (!isAssistantLast) {
|
|
isRegen = false;
|
|
} else if (!lastMessageContent) {
|
|
isRegen = true;
|
|
}
|
|
|
|
const isContinue = isAssistantLast && !isRegen;
|
|
|
|
const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice();
|
|
|
|
if (isContinue) {
|
|
promptMessages.push(MessageTools.create(Huggingface.applyTemplate(userPrompt, {})));
|
|
}
|
|
|
|
const userMessages = promptMessages.filter(m => m.role === 'user');
|
|
const lastUserMessage = userMessages.at(-1);
|
|
const firstUserMessage = userMessages.at(0);
|
|
|
|
const templateMessages: Huggingface.ITemplateMessage[] = [
|
|
{ role: 'system', content: systemPrompt.trim() },
|
|
];
|
|
|
|
if (keepUsers) {
|
|
let usersRemaining = messages.filter(m => m.role === 'user').length;
|
|
let wasStory = false;
|
|
|
|
for (const message of messages) {
|
|
const { role } = message;
|
|
const swipe = MessageTools.getSwipe(message);
|
|
let content = swipe?.content ?? '';
|
|
if (role === 'user' && usersRemaining > keepUsers) {
|
|
usersRemaining--;
|
|
} else if (role === 'assistant' && templateMessages.at(-1).role === 'assistant') {
|
|
wasStory = true;
|
|
templateMessages.at(-1).content += '\n\n' + content;
|
|
} else if (role === 'user' && !message.technical) {
|
|
templateMessages.push({
|
|
role: message.role,
|
|
content: Huggingface.applyTemplate(userPrompt, { prompt: content, isStart: !wasStory }),
|
|
});
|
|
} else {
|
|
if (role === 'assistant') {
|
|
wasStory = true;
|
|
}
|
|
templateMessages.push({ role, content });
|
|
}
|
|
}
|
|
} else {
|
|
const story = promptMessages.filter(m => m.role === 'assistant')
|
|
.map((m, i, msgs) => {
|
|
const swipe = MessageTools.getSwipe(m);
|
|
if (!swipe) return '';
|
|
|
|
let { content, summary } = swipe;
|
|
if (summary && i < msgs.length - MESSAGES_TO_KEEP) {
|
|
content = summary;
|
|
}
|
|
return content;
|
|
}).join('\n\n');
|
|
|
|
if (story.length > 0) {
|
|
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
|
templateMessages.push({ role: 'user', content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }) });
|
|
templateMessages.push({ role: 'assistant', content: story });
|
|
}
|
|
|
|
let userMessage = MessageTools.getSwipe(lastUserMessage)?.content;
|
|
if (!lastUserMessage?.technical && !isContinue && userMessage) {
|
|
userMessage = Huggingface.applyTemplate(userPrompt, { prompt: userMessage, isStart: story.length === 0 });
|
|
}
|
|
|
|
if (userMessage) {
|
|
templateMessages.push({ role: 'user', content: userMessage });
|
|
}
|
|
}
|
|
|
|
if (templateMessages[1]?.role !== 'user') {
|
|
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
|
|
|
templateMessages.splice(1, 0, {
|
|
role: 'user',
|
|
content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }),
|
|
});
|
|
}
|
|
|
|
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
|
|
|
let prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
|
|
|
|
if (isRegen) {
|
|
prompt += lastMessageContent;
|
|
}
|
|
|
|
return {
|
|
prompt,
|
|
isContinue,
|
|
isRegen,
|
|
};
|
|
},
|
|
generate: async function* (prompt, extraSettings = {}) {
|
|
try {
|
|
console.log('[LLM.generate]', prompt);
|
|
|
|
setSpentKudos(0);
|
|
for await (const { text, cost } of Connection.generate(connection, prompt, {
|
|
...extraSettings,
|
|
banned_tokens: bannedWords.filter(w => w.trim()),
|
|
})) {
|
|
setSpentKudos(sk => sk + cost);
|
|
setTotalSpentKudos(sk => sk + cost);
|
|
yield text;
|
|
}
|
|
} catch (e) {
|
|
if (e instanceof Error && e.name !== 'AbortError') {
|
|
alert(e.message);
|
|
} else {
|
|
console.error('[LLM.generate]', e);
|
|
}
|
|
}
|
|
},
|
|
summarize: async (message) => {
|
|
try {
|
|
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
|
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
|
|
console.log('[LLM.summarize]', prompt);
|
|
|
|
const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
|
|
|
|
return MessageTools.trimSentence(tokens.join(''));
|
|
} catch (e) {
|
|
console.error('Error summarizing:', e);
|
|
return '';
|
|
}
|
|
},
|
|
countTokens: async (prompt) => {
|
|
return await Connection.countTokens(connection, prompt);
|
|
},
|
|
stopGeneration: () => {
|
|
Connection.stopGeneration();
|
|
},
|
|
}), [connection, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt]);
|
|
|
|
useAsyncEffect(async () => {
|
|
if (isOnline && triggerNext && !generating.value) {
|
|
setTriggerNext(false);
|
|
setContinueLast(false);
|
|
|
|
let messageId = messages.length - 1;
|
|
let text = '';
|
|
|
|
const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast });
|
|
|
|
if (isRegen) {
|
|
text = MessageTools.getSwipe(messages.at(-1))?.content ?? '';
|
|
} else {
|
|
addMessage('', 'assistant');
|
|
messageId++;
|
|
}
|
|
|
|
generating.setTrue();
|
|
editSummary(messageId, 'Generating...');
|
|
for await (const chunk of actions.generate(prompt)) {
|
|
text += chunk;
|
|
setPromptTokens(promptTokens + approximateTokens(text));
|
|
editMessage(messageId, text.trim());
|
|
}
|
|
generating.setFalse();
|
|
|
|
text = MessageTools.trimSentence(text);
|
|
editMessage(messageId, text);
|
|
editSummary(messageId, '');
|
|
|
|
MessageTools.playReady();
|
|
}
|
|
}, [triggerNext, isOnline]);
|
|
|
|
useAsyncEffect(async () => {
|
|
if (isOnline && summaryEnabled && !processing.summarizing) {
|
|
try {
|
|
processing.summarizing = true;
|
|
for (let id = 0; id < messages.length; id++) {
|
|
const message = messages[id];
|
|
const swipe = MessageTools.getSwipe(message);
|
|
if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) {
|
|
const summary = await actions.summarize(swipe.content);
|
|
editSummary(id, summary);
|
|
}
|
|
}
|
|
} catch (e) {
|
|
console.error(`Could not summarize`, e)
|
|
} finally {
|
|
processing.summarizing = false;
|
|
}
|
|
}
|
|
}, [messages, summaryEnabled, isOnline]);
|
|
|
|
useEffect(throttle(() => {
|
|
Connection.getContextLength(connection).then(setContextLength);
|
|
Connection.getModelName(connection).then(normalizeModel).then(setModelName);
|
|
}, 1000, true), [connection]);
|
|
|
|
const calculateTokens = useCallback(throttle(async () => {
|
|
if (isOnline && !processing.tokenizing && !generating.value) {
|
|
try {
|
|
processing.tokenizing = true;
|
|
const { prompt } = await actions.compilePrompt(messages);
|
|
const tokens = await actions.countTokens(prompt);
|
|
setPromptTokens(tokens);
|
|
} catch (e) {
|
|
console.error(`Could not count tokens`, e)
|
|
} finally {
|
|
processing.tokenizing = false;
|
|
}
|
|
}
|
|
}, 1000, true), [actions, messages, isOnline]);
|
|
|
|
useEffect(() => {
|
|
calculateTokens();
|
|
}, [messages, connection, systemPrompt, lore, userPrompt, isOnline]);
|
|
|
|
useEffect(() => {
|
|
try {
|
|
const hasTools = Huggingface.testToolCalls(connection.instruct);
|
|
setHasToolCalls(hasTools);
|
|
} catch {
|
|
setHasToolCalls(false);
|
|
}
|
|
}, [connection.instruct]);
|
|
|
|
const rawContext: IContext = {
|
|
generating: generating.value,
|
|
modelName,
|
|
hasToolCalls,
|
|
promptTokens,
|
|
contextLength,
|
|
spentKudos,
|
|
};
|
|
|
|
const context = useMemo(() => rawContext, Object.values(rawContext));
|
|
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
|
|
|
|
return (
|
|
<LLMContext.Provider value={value}>
|
|
{children}
|
|
</LLMContext.Provider>
|
|
);
|
|
} |