AIStory: stopping
This commit is contained in:
parent
017ef7aaa5
commit
277b315795
|
|
@ -14,6 +14,7 @@
|
||||||
"@inquirer/select": "2.3.10",
|
"@inquirer/select": "2.3.10",
|
||||||
"ace-builds": "1.36.3",
|
"ace-builds": "1.36.3",
|
||||||
"classnames": "2.5.1",
|
"classnames": "2.5.1",
|
||||||
|
"delay": "6.0.0",
|
||||||
"preact": "10.22.0"
|
"preact": "10.22.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
import { useEffect } from "preact/hooks";
|
||||||
|
|
||||||
|
export const useAsyncEffect = (fx: () => any, deps: any[]) =>
|
||||||
|
useEffect(() => void fx(), deps);
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
|
||||||
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
|
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
|
||||||
|
|
||||||
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
|
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
|
||||||
|
|
@ -50,20 +49,30 @@ export const intHash = (seed: number, ...parts: number[]) => {
|
||||||
return h1;
|
return h1;
|
||||||
};
|
};
|
||||||
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
|
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
|
||||||
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number): F {
|
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number, trailing = false): F {
|
||||||
let isThrottled = false;
|
let isThrottled = false;
|
||||||
let savedResult: R;
|
let savedResult: R;
|
||||||
|
let savedThis: T;
|
||||||
|
let savedArgs: A | undefined;
|
||||||
|
|
||||||
const wrapper: F = function (...args: A) {
|
const wrapper: F = function (...args: A) {
|
||||||
if (!isThrottled) {
|
if (isThrottled) {
|
||||||
|
savedThis = this;
|
||||||
|
savedArgs = args;
|
||||||
|
} else {
|
||||||
savedResult = func.apply(this, args);
|
savedResult = func.apply(this, args);
|
||||||
|
savedArgs = undefined;
|
||||||
|
|
||||||
isThrottled = true;
|
isThrottled = true;
|
||||||
|
|
||||||
setTimeout(function () {
|
setTimeout(function () {
|
||||||
isThrottled = false;
|
isThrottled = false;
|
||||||
|
if (trailing && savedArgs) {
|
||||||
|
savedResult = wrapper.apply(savedThis, savedArgs);
|
||||||
|
}
|
||||||
}, ms);
|
}, ms);
|
||||||
}
|
}
|
||||||
|
|
||||||
return savedResult;
|
return savedResult;
|
||||||
} as F;
|
} as F;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea";
|
||||||
|
|
||||||
export const Input = () => {
|
export const Input = () => {
|
||||||
const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
|
const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
|
||||||
const { generating } = useContext(LLMContext);
|
const { generating, stopGeneration } = useContext(LLMContext);
|
||||||
|
|
||||||
const handleSend = useCallback(async () => {
|
const handleSend = useCallback(async () => {
|
||||||
if (!generating) {
|
if (!generating) {
|
||||||
|
|
@ -29,7 +29,10 @@ export const Input = () => {
|
||||||
return (
|
return (
|
||||||
<div class="chat-input">
|
<div class="chat-input">
|
||||||
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
|
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
|
||||||
<button onClick={handleSend} class={`${generating ? 'disabled' : ''}`}>{input ? 'Send' : 'Continue'}</button>
|
{generating
|
||||||
|
? <button onClick={stopGeneration}>Stop</button>
|
||||||
|
: <button onClick={handleSend}>{input ? 'Send' : 'Continue'}</button>
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -16,7 +16,7 @@ interface IProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
||||||
const { generating, generate, compilePrompt } = useContext(LLMContext);
|
const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
|
||||||
const [messages, setMessages] = useState<IMessage[]>([]);
|
const [messages, setMessages] = useState<IMessage[]>([]);
|
||||||
const ref = useRef<HTMLDivElement>(null);
|
const ref = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
|
@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
<button onClick={handleGenerate} class={`${generating ? 'disabled' : ''}`}>
|
{generating
|
||||||
Generate
|
? <button onClick={stopGeneration}>Stop</button>
|
||||||
</button>
|
: <button onClick={handleGenerate}>Generate</button>
|
||||||
|
}
|
||||||
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
|
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
|
||||||
Clear
|
Clear
|
||||||
</button>
|
</button>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import Lock from "@common/lock";
|
import Lock from "@common/lock";
|
||||||
import SSE from "@common/sse";
|
import SSE from "@common/sse";
|
||||||
import { delay, throttle } from "@common/utils";
|
import { throttle } from "@common/utils";
|
||||||
|
import delay, { clearDelay } from "delay";
|
||||||
|
|
||||||
interface IBaseConnection {
|
interface IBaseConnection {
|
||||||
instruct: string;
|
instruct: string;
|
||||||
|
|
@ -72,7 +73,7 @@ const DEFAULT_GENERATION_SETTINGS = {
|
||||||
dry_penalty_last_n: 0
|
dry_penalty_last_n: 0
|
||||||
}
|
}
|
||||||
|
|
||||||
const MIN_PERFORMANCE = 5.0;
|
const MIN_PERFORMANCE = 2.0;
|
||||||
const MIN_WORKER_CONTEXT = 8192;
|
const MIN_WORKER_CONTEXT = 8192;
|
||||||
const MAX_HORDE_LENGTH = 512;
|
const MAX_HORDE_LENGTH = 512;
|
||||||
const MAX_HORDE_CONTEXT = 32000;
|
const MAX_HORDE_CONTEXT = 32000;
|
||||||
|
|
@ -88,7 +89,7 @@ export const normalizeModel = (model: string) => {
|
||||||
|
|
||||||
currentModel = currentModel
|
currentModel = currentModel
|
||||||
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
||||||
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
|
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
|
||||||
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
||||||
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
||||||
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
||||||
|
|
@ -104,14 +105,15 @@ export const normalizeModel = (model: string) => {
|
||||||
.trim();
|
.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
export const approximateTokens = (prompt: string): number =>
|
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
|
||||||
Math.round(prompt.split(/\s+/).length * 0.75);
|
|
||||||
|
|
||||||
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
||||||
|
|
||||||
export namespace Connection {
|
export namespace Connection {
|
||||||
const AIHORDE = 'https://aihorde.net';
|
const AIHORDE = 'https://aihorde.net';
|
||||||
|
|
||||||
|
let abortController = new AbortController();
|
||||||
|
|
||||||
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
|
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
|
||||||
const sse = new SSE(`${url}/api/extra/generate/stream`, {
|
const sse = new SSE(`${url}/api/extra/generate/stream`, {
|
||||||
payload: JSON.stringify({
|
payload: JSON.stringify({
|
||||||
|
|
@ -144,12 +146,14 @@ export namespace Connection {
|
||||||
messageLock.release();
|
messageLock.release();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
abortController.signal.addEventListener('abort', handleEnd);
|
||||||
sse.addEventListener('error', handleEnd);
|
sse.addEventListener('error', handleEnd);
|
||||||
sse.addEventListener('abort', handleEnd);
|
sse.addEventListener('abort', handleEnd);
|
||||||
sse.addEventListener('readystatechange', (e) => {
|
sse.addEventListener('readystatechange', (e) => {
|
||||||
if (e.readyState === SSE.CLOSED) handleEnd();
|
if (e.readyState === SSE.CLOSED) handleEnd();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
while (!end || messages.length) {
|
while (!end || messages.length) {
|
||||||
while (messages.length > 0) {
|
while (messages.length > 0) {
|
||||||
const message = messages.shift();
|
const message = messages.shift();
|
||||||
|
|
@ -189,6 +193,8 @@ export namespace Connection {
|
||||||
workers: model.workers,
|
workers: model.workers,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const { signal } = abortController;
|
||||||
|
|
||||||
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
|
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: JSON.stringify(requestData),
|
body: JSON.stringify(requestData),
|
||||||
|
|
@ -196,32 +202,45 @@ export namespace Connection {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
apikey: connection.apiKey || HORDE_ANON_KEY,
|
apikey: connection.apiKey || HORDE_ANON_KEY,
|
||||||
},
|
},
|
||||||
|
signal,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) {
|
if (!generateResponse.ok || generateResponse.status >= 400) {
|
||||||
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
|
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const { id } = await generateResponse.json() as { id: string };
|
const { id } = await generateResponse.json() as { id: string };
|
||||||
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' })
|
const request = async (method = 'GET'): Promise<string | null> => {
|
||||||
.catch(e => console.error('Error deleting request', e));
|
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
|
||||||
|
if (response.ok && response.status < 400) {
|
||||||
while (true) {
|
const result: IHordeResult = await response.json();
|
||||||
await delay(2500);
|
if (result.generations?.length === 1) {
|
||||||
|
|
||||||
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`);
|
|
||||||
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
|
|
||||||
deleteRequest();
|
|
||||||
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const result: IHordeResult = await retrieveResponse.json();
|
|
||||||
|
|
||||||
if (result.done && result.generations?.length === 1) {
|
|
||||||
const { text } = result.generations[0];
|
const { text } = result.generations[0];
|
||||||
|
|
||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error(await response.text());
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const deleteRequest = async () => (await request('DELETE')) ?? '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
await delay(2500, { signal });
|
||||||
|
|
||||||
|
const text = await request();
|
||||||
|
|
||||||
|
if (text) {
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error in horde generation:', e);
|
||||||
|
return deleteRequest();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -236,6 +255,11 @@ export namespace Connection {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function stopGeneration() {
|
||||||
|
abortController.abort();
|
||||||
|
abortController = new AbortController(); // refresh
|
||||||
|
}
|
||||||
|
|
||||||
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
|
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
|
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
|
||||||
|
|
@ -299,7 +323,7 @@ export namespace Connection {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('Error getting max tokens', e);
|
console.error('Error getting max tokens', e);
|
||||||
}
|
}
|
||||||
} else if (isHordeConnection(connection)) {
|
} else if (isHordeConnection(connection)) {
|
||||||
return connection.model;
|
return connection.model;
|
||||||
|
|
@ -317,7 +341,7 @@ export namespace Connection {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('Error getting max tokens', e);
|
console.error('Error getting max tokens', e);
|
||||||
}
|
}
|
||||||
} else if (isHordeConnection(connection)) {
|
} else if (isHordeConnection(connection)) {
|
||||||
const models = await getHordeModels();
|
const models = await getHordeModels();
|
||||||
|
|
@ -343,7 +367,7 @@ export namespace Connection {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('Error counting tokens', e);
|
console.error('Error counting tokens', e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
import Lock from "@common/lock";
|
|
||||||
import SSE from "@common/sse";
|
|
||||||
import { createContext } from "preact";
|
import { createContext } from "preact";
|
||||||
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../messages";
|
import { MessageTools, type IMessage } from "../messages";
|
||||||
import { Instruct, StateContext } from "./state";
|
import { StateContext } from "./state";
|
||||||
import { useBool } from "@common/hooks/useBool";
|
import { useBool } from "@common/hooks/useBool";
|
||||||
import { Template } from "@huggingface/jinja";
|
import { Template } from "@huggingface/jinja";
|
||||||
import { Huggingface } from "../huggingface";
|
import { Huggingface } from "../huggingface";
|
||||||
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
|
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
|
||||||
|
import { throttle } from "@common/utils";
|
||||||
|
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
|
||||||
|
|
||||||
interface ICompileArgs {
|
interface ICompileArgs {
|
||||||
keepUsers?: number;
|
keepUsers?: number;
|
||||||
|
|
@ -22,9 +22,7 @@ interface ICompiledPrompt {
|
||||||
|
|
||||||
interface IContext {
|
interface IContext {
|
||||||
generating: boolean;
|
generating: boolean;
|
||||||
blockConnection: ReturnType<typeof useBool>;
|
|
||||||
modelName: string;
|
modelName: string;
|
||||||
modelTemplate: string;
|
|
||||||
hasToolCalls: boolean;
|
hasToolCalls: boolean;
|
||||||
promptTokens: number;
|
promptTokens: number;
|
||||||
contextLength: number;
|
contextLength: number;
|
||||||
|
|
@ -35,6 +33,7 @@ const MESSAGES_TO_KEEP = 10;
|
||||||
interface IActions {
|
interface IActions {
|
||||||
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
||||||
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
||||||
|
stopGeneration: () => void;
|
||||||
summarize: (content: string) => Promise<string>;
|
summarize: (content: string) => Promise<string>;
|
||||||
countTokens: (prompt: string) => Promise<number>;
|
countTokens: (prompt: string) => Promise<number>;
|
||||||
}
|
}
|
||||||
|
|
@ -50,15 +49,13 @@ const processing = {
|
||||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const {
|
const {
|
||||||
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
||||||
setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
|
setTriggerNext, addMessage, editMessage, editSummary,
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
|
|
||||||
const generating = useBool(false);
|
const generating = useBool(false);
|
||||||
const blockConnection = useBool(false);
|
|
||||||
const [promptTokens, setPromptTokens] = useState(0);
|
const [promptTokens, setPromptTokens] = useState(0);
|
||||||
const [contextLength, setContextLength] = useState(0);
|
const [contextLength, setContextLength] = useState(0);
|
||||||
const [modelName, setModelName] = useState('');
|
const [modelName, setModelName] = useState('');
|
||||||
const [modelTemplate, setModelTemplate] = useState('');
|
|
||||||
const [hasToolCalls, setHasToolCalls] = useState(false);
|
const [hasToolCalls, setHasToolCalls] = useState(false);
|
||||||
|
|
||||||
const userPromptTemplate = useMemo(() => {
|
const userPromptTemplate = useMemo(() => {
|
||||||
|
|
@ -71,20 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}
|
}
|
||||||
}, [userPrompt]);
|
}, [userPrompt]);
|
||||||
|
|
||||||
const getContextLength = useCallback(async () => {
|
|
||||||
if (!connection || blockConnection.value) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return Connection.getContextLength(connection);
|
|
||||||
}, [connection, blockConnection.value]);
|
|
||||||
|
|
||||||
const getModelName = useCallback(async () => {
|
|
||||||
if (!connection || blockConnection.value) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
return Connection.getModelName(connection);
|
|
||||||
}, [connection, blockConnection.value]);
|
|
||||||
|
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
compilePrompt: async (messages, { keepUsers } = {}) => {
|
compilePrompt: async (messages, { keepUsers } = {}) => {
|
||||||
const promptMessages = messages.slice();
|
const promptMessages = messages.slice();
|
||||||
|
|
@ -179,31 +162,43 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
},
|
},
|
||||||
generate: async function* (prompt, extraSettings = {}) {
|
generate: async function* (prompt, extraSettings = {}) {
|
||||||
try {
|
try {
|
||||||
generating.setTrue();
|
|
||||||
console.log('[LLM.generate]', prompt);
|
console.log('[LLM.generate]', prompt);
|
||||||
|
|
||||||
yield* Connection.generate(connection, prompt, {
|
yield* Connection.generate(connection, prompt, {
|
||||||
...extraSettings,
|
...extraSettings,
|
||||||
banned_tokens: bannedWords.filter(w => w.trim()),
|
banned_tokens: bannedWords.filter(w => w.trim()),
|
||||||
});
|
});
|
||||||
} finally {
|
} catch (e) {
|
||||||
generating.setFalse();
|
if (e instanceof Error && e.name !== 'AbortError') {
|
||||||
|
alert(e.message);
|
||||||
|
} else {
|
||||||
|
console.error('[LLM.generate]', e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
summarize: async (message) => {
|
summarize: async (message) => {
|
||||||
|
try {
|
||||||
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
||||||
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
|
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
|
||||||
|
console.log('[LLM.summarize]', prompt);
|
||||||
|
|
||||||
const tokens = await Array.fromAsync(actions.generate(prompt));
|
const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
|
||||||
|
|
||||||
return MessageTools.trimSentence(tokens.join(''));
|
return MessageTools.trimSentence(tokens.join(''));
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error summarizing:', e);
|
||||||
|
return '';
|
||||||
|
}
|
||||||
},
|
},
|
||||||
countTokens: async (prompt) => {
|
countTokens: async (prompt) => {
|
||||||
return await Connection.countTokens(connection, prompt);
|
return await Connection.countTokens(connection, prompt);
|
||||||
},
|
},
|
||||||
|
stopGeneration: () => {
|
||||||
|
Connection.stopGeneration();
|
||||||
|
},
|
||||||
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
||||||
|
|
||||||
useEffect(() => void (async () => {
|
useAsyncEffect(async () => {
|
||||||
if (triggerNext && !generating.value) {
|
if (triggerNext && !generating.value) {
|
||||||
setTriggerNext(false);
|
setTriggerNext(false);
|
||||||
|
|
||||||
|
|
@ -217,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
messageId++;
|
messageId++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
generating.setTrue();
|
||||||
editSummary(messageId, 'Generating...');
|
editSummary(messageId, 'Generating...');
|
||||||
for await (const chunk of actions.generate(prompt)) {
|
for await (const chunk of actions.generate(prompt)) {
|
||||||
text += chunk;
|
text += chunk;
|
||||||
setPromptTokens(promptTokens + approximateTokens(text));
|
setPromptTokens(promptTokens + approximateTokens(text));
|
||||||
editMessage(messageId, text.trim());
|
editMessage(messageId, text.trim());
|
||||||
}
|
}
|
||||||
|
generating.setFalse();
|
||||||
|
|
||||||
text = MessageTools.trimSentence(text);
|
text = MessageTools.trimSentence(text);
|
||||||
editMessage(messageId, text);
|
editMessage(messageId, text);
|
||||||
|
|
@ -230,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
MessageTools.playReady();
|
MessageTools.playReady();
|
||||||
}
|
}
|
||||||
})(), [triggerNext]);
|
}, [triggerNext]);
|
||||||
|
|
||||||
useEffect(() => void (async () => {
|
useAsyncEffect(async () => {
|
||||||
if (summaryEnabled && !generating.value && !processing.summarizing) {
|
if (summaryEnabled && !processing.summarizing) {
|
||||||
try {
|
try {
|
||||||
processing.summarizing = true;
|
processing.summarizing = true;
|
||||||
for (let id = 0; id < messages.length; id++) {
|
for (let id = 0; id < messages.length; id++) {
|
||||||
|
|
@ -250,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
processing.summarizing = false;
|
processing.summarizing = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})(), [messages]);
|
}, [messages, summaryEnabled]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(throttle(() => {
|
||||||
if (!blockConnection.value) {
|
Connection.getContextLength(connection).then(setContextLength);
|
||||||
setPromptTokens(0);
|
Connection.getModelName(connection).then(normalizeModel).then(setModelName);
|
||||||
setContextLength(0);
|
}, 1000, true), [connection]);
|
||||||
setModelName('');
|
|
||||||
|
|
||||||
getContextLength().then(setContextLength);
|
const calculateTokens = useCallback(throttle(async () => {
|
||||||
getModelName().then(normalizeModel).then(setModelName);
|
if (!processing.tokenizing && !generating.value) {
|
||||||
}
|
|
||||||
}, [connection, blockConnection.value]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
setModelTemplate('');
|
|
||||||
if (modelName) {
|
|
||||||
Huggingface.findModelTemplate(modelName)
|
|
||||||
.then((template) => {
|
|
||||||
if (template) {
|
|
||||||
setModelTemplate(template);
|
|
||||||
setInstruct(template);
|
|
||||||
} else {
|
|
||||||
setInstruct(Instruct.CHATML);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [modelName]);
|
|
||||||
|
|
||||||
const calculateTokens = useCallback(async () => {
|
|
||||||
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
|
|
||||||
try {
|
try {
|
||||||
processing.tokenizing = true;
|
processing.tokenizing = true;
|
||||||
const { prompt } = await actions.compilePrompt(messages);
|
const { prompt } = await actions.compilePrompt(messages);
|
||||||
|
|
@ -291,11 +267,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
processing.tokenizing = false;
|
processing.tokenizing = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [actions, messages, blockConnection.value]);
|
}, 1000, true), [actions, messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
calculateTokens();
|
calculateTokens();
|
||||||
}, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]);
|
}, [messages, connection, systemPrompt, lore, userPrompt]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
try {
|
try {
|
||||||
|
|
@ -308,9 +284,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
const rawContext: IContext = {
|
const rawContext: IContext = {
|
||||||
generating: generating.value,
|
generating: generating.value,
|
||||||
blockConnection,
|
|
||||||
modelName,
|
modelName,
|
||||||
modelTemplate,
|
|
||||||
hasToolCalls,
|
hasToolCalls,
|
||||||
promptTokens,
|
promptTokens,
|
||||||
contextLength,
|
contextLength,
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ What should happen next in your answer: {{ prompt | trim }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
Remember that this story should be infinite and go forever.
|
Remember that this story should be infinite and go forever.
|
||||||
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
|
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
|
||||||
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
|
summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
|
||||||
summaryEnabled: false,
|
summaryEnabled: false,
|
||||||
bannedWords: [],
|
bannedWords: [],
|
||||||
messages: [],
|
messages: [],
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import { gguf } from '@huggingface/gguf';
|
import { gguf } from '@huggingface/gguf';
|
||||||
import * as hub from '@huggingface/hub';
|
import * as hub from '@huggingface/hub';
|
||||||
import { Template } from '@huggingface/jinja';
|
import { Template } from '@huggingface/jinja';
|
||||||
|
import { normalizeModel } from './connection';
|
||||||
|
|
||||||
export namespace Huggingface {
|
export namespace Huggingface {
|
||||||
export interface ITemplateMessage {
|
export interface ITemplateMessage {
|
||||||
|
|
@ -92,11 +93,12 @@ export namespace Huggingface {
|
||||||
|
|
||||||
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
||||||
console.log(`[huggingface] searching config for '${modelName}'`);
|
console.log(`[huggingface] searching config for '${modelName}'`);
|
||||||
|
const searchModel = normalizeModel(modelName);
|
||||||
|
|
||||||
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
|
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
|
||||||
const models = hubModels.filter(m => {
|
const models = hubModels.filter(m => {
|
||||||
if (m.gated) return false;
|
if (m.gated) return false;
|
||||||
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false;
|
if (!normalizeModel(m.name).includes(searchModel)) return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}).sort((a, b) => b.downloads - a.downloads);
|
}).sort((a, b) => b.downloads - a.downloads);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue