169 lines
5.4 KiB
TypeScript
169 lines
5.4 KiB
TypeScript
import Lock from "@common/lock";
|
|
import SSE from "@common/sse";
|
|
import { GENERATION_SETTINGS } from "../const";
|
|
import { createContext } from "preact";
|
|
import { useContext, useEffect, useMemo } from "preact/hooks";
|
|
import { MessageTools } from "../messages";
|
|
import { StateContext } from "./state";
|
|
import { useBool } from "@common/hooks/useBool";
|
|
|
|
interface IContext {
|
|
generating: boolean;
|
|
}
|
|
|
|
interface IActions {
|
|
generate: (prompt: string, extraSettings?: Partial<typeof GENERATION_SETTINGS>) => AsyncGenerator<string>;
|
|
countTokens(prompt: string): Promise<number>;
|
|
getContextLength(): Promise<number>;
|
|
}
|
|
export type ILLMContext = IContext & IActions;
|
|
|
|
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
|
|
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|
const { connectionUrl, messages, triggerNext, setTriggerNext, addMessage, editMessage } = useContext(StateContext);
|
|
const generating = useBool(false);
|
|
|
|
const actions: IActions = useMemo(() => ({
|
|
generate: async function* (prompt, extraSettings = {}) {
|
|
if (!connectionUrl) {
|
|
return;
|
|
}
|
|
|
|
try {
|
|
generating.setTrue();
|
|
console.log('[LLM.generate]', prompt);
|
|
|
|
const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, {
|
|
payload: JSON.stringify({
|
|
...GENERATION_SETTINGS,
|
|
...extraSettings,
|
|
prompt,
|
|
}),
|
|
});
|
|
|
|
const messages: string[] = [];
|
|
const messageLock = new Lock();
|
|
let end = false;
|
|
|
|
sse.addEventListener('message', (e) => {
|
|
if (e.data) {
|
|
{
|
|
const { token, finish_reason } = JSON.parse(e.data);
|
|
messages.push(token);
|
|
|
|
if (finish_reason && finish_reason !== 'null') {
|
|
end = true;
|
|
}
|
|
}
|
|
}
|
|
messageLock.release();
|
|
});
|
|
|
|
const handleEnd = () => {
|
|
end = true;
|
|
messageLock.release();
|
|
};
|
|
|
|
sse.addEventListener('error', handleEnd);
|
|
sse.addEventListener('abort', handleEnd);
|
|
sse.addEventListener('readystatechange', (e) => {
|
|
if (e.readyState === SSE.CLOSED) handleEnd();
|
|
});
|
|
|
|
while (!end || messages.length) {
|
|
while (messages.length > 0) {
|
|
const message = messages.shift();
|
|
if (message != null) {
|
|
try {
|
|
yield message;
|
|
} catch { }
|
|
}
|
|
}
|
|
if (!end) {
|
|
await messageLock.wait();
|
|
}
|
|
}
|
|
|
|
sse.close();
|
|
} finally {
|
|
generating.setFalse();
|
|
}
|
|
},
|
|
countTokens: async (prompt) => {
|
|
if (!connectionUrl) {
|
|
return 0;
|
|
}
|
|
try {
|
|
const response = await fetch(`${connectionUrl}/api/extra/tokencount`, {
|
|
body: JSON.stringify({ prompt }),
|
|
headers: { 'Content-Type': 'applicarion/json' },
|
|
method: 'POST',
|
|
});
|
|
if (response.ok) {
|
|
const { value } = await response.json();
|
|
return value;
|
|
}
|
|
} catch (e) {
|
|
console.log('Error counting tokens', e);
|
|
}
|
|
|
|
return 0;
|
|
},
|
|
getContextLength: async() => {
|
|
if (!connectionUrl) {
|
|
return 0;
|
|
}
|
|
try {
|
|
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
|
|
if (response.ok) {
|
|
const { value } = await response.json();
|
|
return value;
|
|
}
|
|
} catch (e) {
|
|
console.log('Error getting max tokens', e);
|
|
}
|
|
|
|
return 0;
|
|
},
|
|
}), [connectionUrl]);
|
|
|
|
useEffect(() => void (async () => {
|
|
if (triggerNext && !generating) {
|
|
setTriggerNext(false);
|
|
|
|
let messageId = messages.length - 1;
|
|
let text: string = '';
|
|
|
|
const { prompt, isRegen } = await MessageTools.compilePrompt(messages);
|
|
|
|
if (!isRegen) {
|
|
addMessage('', 'assistant');
|
|
messageId++;
|
|
}
|
|
|
|
for await (const chunk of actions.generate(prompt)) {
|
|
text += chunk;
|
|
editMessage(messageId, text);
|
|
}
|
|
|
|
text = MessageTools.trimSentence(text);
|
|
editMessage(messageId, text);
|
|
|
|
MessageTools.playReady();
|
|
}
|
|
})(), [triggerNext, messages, generating]);
|
|
|
|
const rawContext: IContext = {
|
|
generating: generating.value,
|
|
};
|
|
|
|
const context = useMemo(() => rawContext, Object.values(rawContext));
|
|
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
|
|
|
|
return (
|
|
<LLMContext.Provider value={value}>
|
|
{children}
|
|
</LLMContext.Provider>
|
|
);
|
|
} |