AIStory: stopping
This commit is contained in:
parent
017ef7aaa5
commit
277b315795
|
|
@ -14,6 +14,7 @@
|
|||
"@inquirer/select": "2.3.10",
|
||||
"ace-builds": "1.36.3",
|
||||
"classnames": "2.5.1",
|
||||
"delay": "6.0.0",
|
||||
"preact": "10.22.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
import { useEffect } from "preact/hooks";
|
||||
|
||||
export const useAsyncEffect = (fx: () => any, deps: any[]) =>
|
||||
useEffect(() => void fx(), deps);
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
|
||||
|
||||
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
|
||||
|
|
@ -50,20 +49,30 @@ export const intHash = (seed: number, ...parts: number[]) => {
|
|||
return h1;
|
||||
};
|
||||
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
|
||||
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number): F {
|
||||
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number, trailing = false): F {
|
||||
let isThrottled = false;
|
||||
let savedResult: R;
|
||||
let savedThis: T;
|
||||
let savedArgs: A | undefined;
|
||||
|
||||
const wrapper: F = function (...args: A) {
|
||||
if (!isThrottled) {
|
||||
if (isThrottled) {
|
||||
savedThis = this;
|
||||
savedArgs = args;
|
||||
} else {
|
||||
savedResult = func.apply(this, args);
|
||||
savedArgs = undefined;
|
||||
|
||||
isThrottled = true;
|
||||
|
||||
setTimeout(function () {
|
||||
isThrottled = false;
|
||||
if (trailing && savedArgs) {
|
||||
savedResult = wrapper.apply(savedThis, savedArgs);
|
||||
}
|
||||
}, ms);
|
||||
}
|
||||
|
||||
return savedResult;
|
||||
} as F;
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea";
|
|||
|
||||
export const Input = () => {
|
||||
const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
|
||||
const { generating } = useContext(LLMContext);
|
||||
const { generating, stopGeneration } = useContext(LLMContext);
|
||||
|
||||
const handleSend = useCallback(async () => {
|
||||
if (!generating) {
|
||||
|
|
@ -29,7 +29,10 @@ export const Input = () => {
|
|||
return (
|
||||
<div class="chat-input">
|
||||
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
|
||||
<button onClick={handleSend} class={`${generating ? 'disabled' : ''}`}>{input ? 'Send' : 'Continue'}</button>
|
||||
{generating
|
||||
? <button onClick={stopGeneration}>Stop</button>
|
||||
: <button onClick={handleSend}>{input ? 'Send' : 'Continue'}</button>
|
||||
}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -16,7 +16,7 @@ interface IProps {
|
|||
}
|
||||
|
||||
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
||||
const { generating, generate, compilePrompt } = useContext(LLMContext);
|
||||
const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
|
||||
const [messages, setMessages] = useState<IMessage[]>([]);
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
|
||||
|
|
@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
|||
</div>
|
||||
</div>
|
||||
<div class={styles.buttons}>
|
||||
<button onClick={handleGenerate} class={`${generating ? 'disabled' : ''}`}>
|
||||
Generate
|
||||
</button>
|
||||
{generating
|
||||
? <button onClick={stopGeneration}>Stop</button>
|
||||
: <button onClick={handleGenerate}>Generate</button>
|
||||
}
|
||||
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
|
||||
Clear
|
||||
</button>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import Lock from "@common/lock";
|
||||
import SSE from "@common/sse";
|
||||
import { delay, throttle } from "@common/utils";
|
||||
import { throttle } from "@common/utils";
|
||||
import delay, { clearDelay } from "delay";
|
||||
|
||||
interface IBaseConnection {
|
||||
instruct: string;
|
||||
|
|
@ -72,7 +73,7 @@ const DEFAULT_GENERATION_SETTINGS = {
|
|||
dry_penalty_last_n: 0
|
||||
}
|
||||
|
||||
const MIN_PERFORMANCE = 5.0;
|
||||
const MIN_PERFORMANCE = 2.0;
|
||||
const MIN_WORKER_CONTEXT = 8192;
|
||||
const MAX_HORDE_LENGTH = 512;
|
||||
const MAX_HORDE_CONTEXT = 32000;
|
||||
|
|
@ -88,7 +89,7 @@ export const normalizeModel = (model: string) => {
|
|||
|
||||
currentModel = currentModel
|
||||
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
||||
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
|
||||
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
|
||||
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
||||
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
||||
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
||||
|
|
@ -104,14 +105,15 @@ export const normalizeModel = (model: string) => {
|
|||
.trim();
|
||||
}
|
||||
|
||||
export const approximateTokens = (prompt: string): number =>
|
||||
Math.round(prompt.split(/\s+/).length * 0.75);
|
||||
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
|
||||
|
||||
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
||||
|
||||
export namespace Connection {
|
||||
const AIHORDE = 'https://aihorde.net';
|
||||
|
||||
let abortController = new AbortController();
|
||||
|
||||
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
|
||||
const sse = new SSE(`${url}/api/extra/generate/stream`, {
|
||||
payload: JSON.stringify({
|
||||
|
|
@ -144,12 +146,14 @@ export namespace Connection {
|
|||
messageLock.release();
|
||||
};
|
||||
|
||||
abortController.signal.addEventListener('abort', handleEnd);
|
||||
sse.addEventListener('error', handleEnd);
|
||||
sse.addEventListener('abort', handleEnd);
|
||||
sse.addEventListener('readystatechange', (e) => {
|
||||
if (e.readyState === SSE.CLOSED) handleEnd();
|
||||
});
|
||||
|
||||
|
||||
while (!end || messages.length) {
|
||||
while (messages.length > 0) {
|
||||
const message = messages.shift();
|
||||
|
|
@ -189,6 +193,8 @@ export namespace Connection {
|
|||
workers: model.workers,
|
||||
};
|
||||
|
||||
const { signal } = abortController;
|
||||
|
||||
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(requestData),
|
||||
|
|
@ -196,32 +202,45 @@ export namespace Connection {
|
|||
'Content-Type': 'application/json',
|
||||
apikey: connection.apiKey || HORDE_ANON_KEY,
|
||||
},
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) {
|
||||
if (!generateResponse.ok || generateResponse.status >= 400) {
|
||||
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
|
||||
}
|
||||
|
||||
const { id } = await generateResponse.json() as { id: string };
|
||||
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' })
|
||||
.catch(e => console.error('Error deleting request', e));
|
||||
|
||||
while (true) {
|
||||
await delay(2500);
|
||||
|
||||
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`);
|
||||
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
|
||||
deleteRequest();
|
||||
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
|
||||
}
|
||||
|
||||
const result: IHordeResult = await retrieveResponse.json();
|
||||
|
||||
if (result.done && result.generations?.length === 1) {
|
||||
const request = async (method = 'GET'): Promise<string | null> => {
|
||||
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
|
||||
if (response.ok && response.status < 400) {
|
||||
const result: IHordeResult = await response.json();
|
||||
if (result.generations?.length === 1) {
|
||||
const { text } = result.generations[0];
|
||||
|
||||
return text;
|
||||
}
|
||||
} else {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const deleteRequest = async () => (await request('DELETE')) ?? '';
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
await delay(2500, { signal });
|
||||
|
||||
const text = await request();
|
||||
|
||||
if (text) {
|
||||
return text;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error in horde generation:', e);
|
||||
return deleteRequest();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -236,6 +255,11 @@ export namespace Connection {
|
|||
}
|
||||
}
|
||||
|
||||
export function stopGeneration() {
|
||||
abortController.abort();
|
||||
abortController = new AbortController(); // refresh
|
||||
}
|
||||
|
||||
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
|
||||
try {
|
||||
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
|
||||
|
|
@ -299,7 +323,7 @@ export namespace Connection {
|
|||
return result;
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('Error getting max tokens', e);
|
||||
console.error('Error getting max tokens', e);
|
||||
}
|
||||
} else if (isHordeConnection(connection)) {
|
||||
return connection.model;
|
||||
|
|
@ -317,7 +341,7 @@ export namespace Connection {
|
|||
return value;
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('Error getting max tokens', e);
|
||||
console.error('Error getting max tokens', e);
|
||||
}
|
||||
} else if (isHordeConnection(connection)) {
|
||||
const models = await getHordeModels();
|
||||
|
|
@ -343,7 +367,7 @@ export namespace Connection {
|
|||
return value;
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('Error counting tokens', e);
|
||||
console.error('Error counting tokens', e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
import Lock from "@common/lock";
|
||||
import SSE from "@common/sse";
|
||||
import { createContext } from "preact";
|
||||
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
||||
import { MessageTools, type IMessage } from "../messages";
|
||||
import { Instruct, StateContext } from "./state";
|
||||
import { StateContext } from "./state";
|
||||
import { useBool } from "@common/hooks/useBool";
|
||||
import { Template } from "@huggingface/jinja";
|
||||
import { Huggingface } from "../huggingface";
|
||||
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
|
||||
import { throttle } from "@common/utils";
|
||||
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
|
||||
|
||||
interface ICompileArgs {
|
||||
keepUsers?: number;
|
||||
|
|
@ -22,9 +22,7 @@ interface ICompiledPrompt {
|
|||
|
||||
interface IContext {
|
||||
generating: boolean;
|
||||
blockConnection: ReturnType<typeof useBool>;
|
||||
modelName: string;
|
||||
modelTemplate: string;
|
||||
hasToolCalls: boolean;
|
||||
promptTokens: number;
|
||||
contextLength: number;
|
||||
|
|
@ -35,6 +33,7 @@ const MESSAGES_TO_KEEP = 10;
|
|||
interface IActions {
|
||||
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
||||
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
||||
stopGeneration: () => void;
|
||||
summarize: (content: string) => Promise<string>;
|
||||
countTokens: (prompt: string) => Promise<number>;
|
||||
}
|
||||
|
|
@ -50,15 +49,13 @@ const processing = {
|
|||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||
const {
|
||||
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
||||
setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
|
||||
setTriggerNext, addMessage, editMessage, editSummary,
|
||||
} = useContext(StateContext);
|
||||
|
||||
const generating = useBool(false);
|
||||
const blockConnection = useBool(false);
|
||||
const [promptTokens, setPromptTokens] = useState(0);
|
||||
const [contextLength, setContextLength] = useState(0);
|
||||
const [modelName, setModelName] = useState('');
|
||||
const [modelTemplate, setModelTemplate] = useState('');
|
||||
const [hasToolCalls, setHasToolCalls] = useState(false);
|
||||
|
||||
const userPromptTemplate = useMemo(() => {
|
||||
|
|
@ -71,20 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
}
|
||||
}, [userPrompt]);
|
||||
|
||||
const getContextLength = useCallback(async () => {
|
||||
if (!connection || blockConnection.value) {
|
||||
return 0;
|
||||
}
|
||||
return Connection.getContextLength(connection);
|
||||
}, [connection, blockConnection.value]);
|
||||
|
||||
const getModelName = useCallback(async () => {
|
||||
if (!connection || blockConnection.value) {
|
||||
return '';
|
||||
}
|
||||
return Connection.getModelName(connection);
|
||||
}, [connection, blockConnection.value]);
|
||||
|
||||
const actions: IActions = useMemo(() => ({
|
||||
compilePrompt: async (messages, { keepUsers } = {}) => {
|
||||
const promptMessages = messages.slice();
|
||||
|
|
@ -179,31 +162,43 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
},
|
||||
generate: async function* (prompt, extraSettings = {}) {
|
||||
try {
|
||||
generating.setTrue();
|
||||
console.log('[LLM.generate]', prompt);
|
||||
|
||||
yield* Connection.generate(connection, prompt, {
|
||||
...extraSettings,
|
||||
banned_tokens: bannedWords.filter(w => w.trim()),
|
||||
});
|
||||
} finally {
|
||||
generating.setFalse();
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.name !== 'AbortError') {
|
||||
alert(e.message);
|
||||
} else {
|
||||
console.error('[LLM.generate]', e);
|
||||
}
|
||||
}
|
||||
},
|
||||
summarize: async (message) => {
|
||||
try {
|
||||
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
||||
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
|
||||
console.log('[LLM.summarize]', prompt);
|
||||
|
||||
const tokens = await Array.fromAsync(actions.generate(prompt));
|
||||
const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
|
||||
|
||||
return MessageTools.trimSentence(tokens.join(''));
|
||||
} catch (e) {
|
||||
console.error('Error summarizing:', e);
|
||||
return '';
|
||||
}
|
||||
},
|
||||
countTokens: async (prompt) => {
|
||||
return await Connection.countTokens(connection, prompt);
|
||||
},
|
||||
stopGeneration: () => {
|
||||
Connection.stopGeneration();
|
||||
},
|
||||
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
||||
|
||||
useEffect(() => void (async () => {
|
||||
useAsyncEffect(async () => {
|
||||
if (triggerNext && !generating.value) {
|
||||
setTriggerNext(false);
|
||||
|
||||
|
|
@ -217,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
messageId++;
|
||||
}
|
||||
|
||||
generating.setTrue();
|
||||
editSummary(messageId, 'Generating...');
|
||||
for await (const chunk of actions.generate(prompt)) {
|
||||
text += chunk;
|
||||
setPromptTokens(promptTokens + approximateTokens(text));
|
||||
editMessage(messageId, text.trim());
|
||||
}
|
||||
generating.setFalse();
|
||||
|
||||
text = MessageTools.trimSentence(text);
|
||||
editMessage(messageId, text);
|
||||
|
|
@ -230,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
|
||||
MessageTools.playReady();
|
||||
}
|
||||
})(), [triggerNext]);
|
||||
}, [triggerNext]);
|
||||
|
||||
useEffect(() => void (async () => {
|
||||
if (summaryEnabled && !generating.value && !processing.summarizing) {
|
||||
useAsyncEffect(async () => {
|
||||
if (summaryEnabled && !processing.summarizing) {
|
||||
try {
|
||||
processing.summarizing = true;
|
||||
for (let id = 0; id < messages.length; id++) {
|
||||
|
|
@ -250,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
processing.summarizing = false;
|
||||
}
|
||||
}
|
||||
})(), [messages]);
|
||||
}, [messages, summaryEnabled]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!blockConnection.value) {
|
||||
setPromptTokens(0);
|
||||
setContextLength(0);
|
||||
setModelName('');
|
||||
useEffect(throttle(() => {
|
||||
Connection.getContextLength(connection).then(setContextLength);
|
||||
Connection.getModelName(connection).then(normalizeModel).then(setModelName);
|
||||
}, 1000, true), [connection]);
|
||||
|
||||
getContextLength().then(setContextLength);
|
||||
getModelName().then(normalizeModel).then(setModelName);
|
||||
}
|
||||
}, [connection, blockConnection.value]);
|
||||
|
||||
useEffect(() => {
|
||||
setModelTemplate('');
|
||||
if (modelName) {
|
||||
Huggingface.findModelTemplate(modelName)
|
||||
.then((template) => {
|
||||
if (template) {
|
||||
setModelTemplate(template);
|
||||
setInstruct(template);
|
||||
} else {
|
||||
setInstruct(Instruct.CHATML);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [modelName]);
|
||||
|
||||
const calculateTokens = useCallback(async () => {
|
||||
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
|
||||
const calculateTokens = useCallback(throttle(async () => {
|
||||
if (!processing.tokenizing && !generating.value) {
|
||||
try {
|
||||
processing.tokenizing = true;
|
||||
const { prompt } = await actions.compilePrompt(messages);
|
||||
|
|
@ -291,11 +267,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
processing.tokenizing = false;
|
||||
}
|
||||
}
|
||||
}, [actions, messages, blockConnection.value]);
|
||||
}, 1000, true), [actions, messages]);
|
||||
|
||||
useEffect(() => {
|
||||
calculateTokens();
|
||||
}, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]);
|
||||
}, [messages, connection, systemPrompt, lore, userPrompt]);
|
||||
|
||||
useEffect(() => {
|
||||
try {
|
||||
|
|
@ -308,9 +284,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
|
||||
const rawContext: IContext = {
|
||||
generating: generating.value,
|
||||
blockConnection,
|
||||
modelName,
|
||||
modelTemplate,
|
||||
hasToolCalls,
|
||||
promptTokens,
|
||||
contextLength,
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ What should happen next in your answer: {{ prompt | trim }}
|
|||
{% endif %}
|
||||
Remember that this story should be infinite and go forever.
|
||||
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
|
||||
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
|
||||
summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
|
||||
summaryEnabled: false,
|
||||
bannedWords: [],
|
||||
messages: [],
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import { gguf } from '@huggingface/gguf';
|
||||
import * as hub from '@huggingface/hub';
|
||||
import { Template } from '@huggingface/jinja';
|
||||
import { normalizeModel } from './connection';
|
||||
|
||||
export namespace Huggingface {
|
||||
export interface ITemplateMessage {
|
||||
|
|
@ -92,11 +93,12 @@ export namespace Huggingface {
|
|||
|
||||
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
||||
console.log(`[huggingface] searching config for '${modelName}'`);
|
||||
const searchModel = normalizeModel(modelName);
|
||||
|
||||
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
|
||||
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
|
||||
const models = hubModels.filter(m => {
|
||||
if (m.gated) return false;
|
||||
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false;
|
||||
if (!normalizeModel(m.name).includes(searchModel)) return false;
|
||||
|
||||
return true;
|
||||
}).sort((a, b) => b.downloads - a.downloads);
|
||||
|
|
|
|||
Loading…
Reference in New Issue