1
0
Fork 0
tsgames/src/games/ai/contexts/llm.tsx

169 lines
5.4 KiB
TypeScript

import Lock from "@common/lock";
import SSE from "@common/sse";
import { GENERATION_SETTINGS } from "../const";
import { createContext } from "preact";
import { useContext, useEffect, useMemo } from "preact/hooks";
import { MessageTools } from "../messages";
import { StateContext } from "./state";
import { useBool } from "@common/hooks/useBool";
interface IContext {
generating: boolean;
}
interface IActions {
generate: (prompt: string, extraSettings?: Partial<typeof GENERATION_SETTINGS>) => AsyncGenerator<string>;
countTokens(prompt: string): Promise<number>;
getContextLength(): Promise<number>;
}
export type ILLMContext = IContext & IActions;
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
export const LLMContextProvider = ({ children }: { children?: any }) => {
const { connectionUrl, messages, triggerNext, setTriggerNext, addMessage, editMessage } = useContext(StateContext);
const generating = useBool(false);
const actions: IActions = useMemo(() => ({
generate: async function* (prompt, extraSettings = {}) {
if (!connectionUrl) {
return;
}
try {
generating.setTrue();
console.log('[LLM.generate]', prompt);
const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, {
payload: JSON.stringify({
...GENERATION_SETTINGS,
...extraSettings,
prompt,
}),
});
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
} finally {
generating.setFalse();
}
},
countTokens: async (prompt) => {
if (!connectionUrl) {
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
}
return 0;
},
getContextLength: async() => {
if (!connectionUrl) {
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return 0;
},
}), [connectionUrl]);
useEffect(() => void (async () => {
if (triggerNext && !generating) {
setTriggerNext(false);
let messageId = messages.length - 1;
let text: string = '';
const { prompt, isRegen } = await MessageTools.compilePrompt(messages);
if (!isRegen) {
addMessage('', 'assistant');
messageId++;
}
for await (const chunk of actions.generate(prompt)) {
text += chunk;
editMessage(messageId, text);
}
text = MessageTools.trimSentence(text);
editMessage(messageId, text);
MessageTools.playReady();
}
})(), [triggerNext, messages, generating]);
const rawContext: IContext = {
generating: generating.value,
};
const context = useMemo(() => rawContext, Object.values(rawContext));
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
return (
<LLMContext.Provider value={value}>
{children}
</LLMContext.Provider>
);
}