327 lines
12 KiB
TypeScript
327 lines
12 KiB
TypeScript
import Lock from "@common/lock";
|
|
import SSE from "@common/sse";
|
|
import { createContext } from "preact";
|
|
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
|
import { MessageTools, type IMessage } from "../messages";
|
|
import { Instruct, StateContext } from "./state";
|
|
import { useBool } from "@common/hooks/useBool";
|
|
import { Template } from "@huggingface/jinja";
|
|
import { Huggingface } from "../huggingface";
|
|
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
|
|
|
|
interface ICompileArgs {
|
|
keepUsers?: number;
|
|
raw?: boolean;
|
|
}
|
|
|
|
interface ICompiledPrompt {
|
|
prompt: string;
|
|
isContinue: boolean;
|
|
isRegen: boolean;
|
|
}
|
|
|
|
interface IContext {
|
|
generating: boolean;
|
|
blockConnection: ReturnType<typeof useBool>;
|
|
modelName: string;
|
|
modelTemplate: string;
|
|
hasToolCalls: boolean;
|
|
promptTokens: number;
|
|
contextLength: number;
|
|
}
|
|
|
|
const MESSAGES_TO_KEEP = 10;
|
|
|
|
interface IActions {
|
|
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
|
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
|
summarize: (content: string) => Promise<string>;
|
|
countTokens: (prompt: string) => Promise<number>;
|
|
}
|
|
export type ILLMContext = IContext & IActions;
|
|
|
|
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
|
|
|
const processing = {
|
|
tokenizing: false,
|
|
summarizing: false,
|
|
}
|
|
|
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|
const {
|
|
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
|
setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
|
|
} = useContext(StateContext);
|
|
|
|
const generating = useBool(false);
|
|
const blockConnection = useBool(false);
|
|
const [promptTokens, setPromptTokens] = useState(0);
|
|
const [contextLength, setContextLength] = useState(0);
|
|
const [modelName, setModelName] = useState('');
|
|
const [modelTemplate, setModelTemplate] = useState('');
|
|
const [hasToolCalls, setHasToolCalls] = useState(false);
|
|
|
|
const userPromptTemplate = useMemo(() => {
|
|
try {
|
|
return new Template(userPrompt)
|
|
} catch {
|
|
return {
|
|
render: () => userPrompt,
|
|
}
|
|
}
|
|
}, [userPrompt]);
|
|
|
|
const getContextLength = useCallback(async () => {
|
|
if (!connection || blockConnection.value) {
|
|
return 0;
|
|
}
|
|
return Connection.getContextLength(connection);
|
|
}, [connection, blockConnection.value]);
|
|
|
|
const getModelName = useCallback(async () => {
|
|
if (!connection || blockConnection.value) {
|
|
return '';
|
|
}
|
|
return Connection.getModelName(connection);
|
|
}, [connection, blockConnection.value]);
|
|
|
|
const actions: IActions = useMemo(() => ({
|
|
compilePrompt: async (messages, { keepUsers } = {}) => {
|
|
const promptMessages = messages.slice();
|
|
const lastMessage = promptMessages.at(-1);
|
|
const isAssistantLast = lastMessage?.role === 'assistant';
|
|
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
|
|
const isContinue = isAssistantLast && !isRegen;
|
|
|
|
if (isContinue) {
|
|
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
|
|
}
|
|
|
|
const userMessages = promptMessages.filter(m => m.role === 'user');
|
|
const lastUserMessage = userMessages.at(-1);
|
|
const firstUserMessage = userMessages.at(0);
|
|
|
|
const templateMessages: Huggingface.ITemplateMessage[] = [
|
|
{ role: 'system', content: systemPrompt.trim() },
|
|
];
|
|
|
|
if (keepUsers) {
|
|
let usersRemaining = messages.filter(m => m.role === 'user').length;
|
|
let wasStory = false;
|
|
|
|
for (const message of messages) {
|
|
const { role } = message;
|
|
const swipe = MessageTools.getSwipe(message);
|
|
let content = swipe?.content ?? '';
|
|
if (role === 'user' && usersRemaining > keepUsers) {
|
|
usersRemaining--;
|
|
} else if (role === 'assistant' && templateMessages.at(-1).role === 'assistant') {
|
|
wasStory = true;
|
|
templateMessages.at(-1).content += '\n\n' + content;
|
|
} else if (role === 'user' && !message.technical) {
|
|
templateMessages.push({
|
|
role: message.role,
|
|
content: userPromptTemplate.render({ prompt: content, isStart: !wasStory }),
|
|
});
|
|
} else {
|
|
if (role === 'assistant') {
|
|
wasStory = true;
|
|
}
|
|
templateMessages.push({ role, content });
|
|
}
|
|
}
|
|
} else {
|
|
const story = promptMessages.filter(m => m.role === 'assistant')
|
|
.map((m, i, msgs) => {
|
|
const swipe = MessageTools.getSwipe(m);
|
|
if (!swipe) return '';
|
|
|
|
let { content, summary } = swipe;
|
|
if (summary && i < msgs.length - MESSAGES_TO_KEEP) {
|
|
content = summary;
|
|
}
|
|
return content;
|
|
}).join('\n\n');
|
|
|
|
if (story.length > 0) {
|
|
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
|
templateMessages.push({ role: 'user', content: userPromptTemplate.render({ prompt, isStart: true }) });
|
|
templateMessages.push({ role: 'assistant', content: story });
|
|
}
|
|
|
|
let userPrompt = MessageTools.getSwipe(lastUserMessage)?.content;
|
|
if (!lastUserMessage?.technical && !isContinue && userPrompt) {
|
|
userPrompt = userPromptTemplate.render({ prompt: userPrompt, isStart: story.length === 0 });
|
|
}
|
|
|
|
if (userPrompt) {
|
|
templateMessages.push({ role: 'user', content: userPrompt });
|
|
}
|
|
}
|
|
|
|
if (templateMessages[1]?.role !== 'user') {
|
|
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
|
|
|
templateMessages.splice(1, 0, {
|
|
role: 'user',
|
|
content: userPromptTemplate.render({ prompt, isStart: true }),
|
|
});
|
|
}
|
|
|
|
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
|
|
|
const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
|
|
return {
|
|
prompt,
|
|
isContinue,
|
|
isRegen,
|
|
};
|
|
},
|
|
generate: async function* (prompt, extraSettings = {}) {
|
|
try {
|
|
generating.setTrue();
|
|
console.log('[LLM.generate]', prompt);
|
|
|
|
yield* Connection.generate(connection, prompt, {
|
|
...extraSettings,
|
|
banned_tokens: bannedWords.filter(w => w.trim()),
|
|
});
|
|
} finally {
|
|
generating.setFalse();
|
|
}
|
|
},
|
|
summarize: async (message) => {
|
|
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
|
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
|
|
|
|
const tokens = await Array.fromAsync(actions.generate(prompt));
|
|
|
|
return MessageTools.trimSentence(tokens.join(''));
|
|
},
|
|
countTokens: async (prompt) => {
|
|
return await Connection.countTokens(connection, prompt);
|
|
},
|
|
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
|
|
|
useEffect(() => void (async () => {
|
|
if (triggerNext && !generating.value) {
|
|
setTriggerNext(false);
|
|
|
|
let messageId = messages.length - 1;
|
|
let text: string = '';
|
|
|
|
const { prompt, isRegen } = await actions.compilePrompt(messages);
|
|
|
|
if (!isRegen) {
|
|
addMessage('', 'assistant');
|
|
messageId++;
|
|
}
|
|
|
|
editSummary(messageId, 'Generating...');
|
|
for await (const chunk of actions.generate(prompt)) {
|
|
text += chunk;
|
|
setPromptTokens(promptTokens + approximateTokens(text));
|
|
editMessage(messageId, text.trim());
|
|
}
|
|
|
|
text = MessageTools.trimSentence(text);
|
|
editMessage(messageId, text);
|
|
editSummary(messageId, '');
|
|
|
|
MessageTools.playReady();
|
|
}
|
|
})(), [triggerNext]);
|
|
|
|
useEffect(() => void (async () => {
|
|
if (summaryEnabled && !generating.value && !processing.summarizing) {
|
|
try {
|
|
processing.summarizing = true;
|
|
for (let id = 0; id < messages.length; id++) {
|
|
const message = messages[id];
|
|
const swipe = MessageTools.getSwipe(message);
|
|
if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) {
|
|
const summary = await actions.summarize(swipe.content);
|
|
editSummary(id, summary);
|
|
}
|
|
}
|
|
} catch (e) {
|
|
console.error(`Could not summarize`, e)
|
|
} finally {
|
|
processing.summarizing = false;
|
|
}
|
|
}
|
|
})(), [messages]);
|
|
|
|
useEffect(() => {
|
|
if (!blockConnection.value) {
|
|
setPromptTokens(0);
|
|
setContextLength(0);
|
|
setModelName('');
|
|
|
|
getContextLength().then(setContextLength);
|
|
getModelName().then(normalizeModel).then(setModelName);
|
|
}
|
|
}, [connection, blockConnection.value]);
|
|
|
|
useEffect(() => {
|
|
setModelTemplate('');
|
|
if (modelName) {
|
|
Huggingface.findModelTemplate(modelName)
|
|
.then((template) => {
|
|
if (template) {
|
|
setModelTemplate(template);
|
|
setInstruct(template);
|
|
} else {
|
|
setInstruct(Instruct.CHATML);
|
|
}
|
|
});
|
|
}
|
|
}, [modelName]);
|
|
|
|
const calculateTokens = useCallback(async () => {
|
|
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
|
|
try {
|
|
processing.tokenizing = true;
|
|
const { prompt } = await actions.compilePrompt(messages);
|
|
const tokens = await actions.countTokens(prompt);
|
|
setPromptTokens(tokens);
|
|
} catch (e) {
|
|
console.error(`Could not count tokens`, e)
|
|
} finally {
|
|
processing.tokenizing = false;
|
|
}
|
|
}
|
|
}, [actions, messages, blockConnection.value]);
|
|
|
|
useEffect(() => {
|
|
calculateTokens();
|
|
}, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]);
|
|
|
|
useEffect(() => {
|
|
try {
|
|
const hasTools = Huggingface.testToolCalls(connection.instruct);
|
|
setHasToolCalls(hasTools);
|
|
} catch {
|
|
setHasToolCalls(false);
|
|
}
|
|
}, [connection.instruct]);
|
|
|
|
const rawContext: IContext = {
|
|
generating: generating.value,
|
|
blockConnection,
|
|
modelName,
|
|
modelTemplate,
|
|
hasToolCalls,
|
|
promptTokens,
|
|
contextLength,
|
|
};
|
|
|
|
const context = useMemo(() => rawContext, Object.values(rawContext));
|
|
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
|
|
|
|
return (
|
|
<LLMContext.Provider value={value}>
|
|
{children}
|
|
</LLMContext.Provider>
|
|
);
|
|
} |