1
0
Fork 0

AIStory: stopping

This commit is contained in:
Pabloader 2024-11-12 16:32:52 +00:00
parent 017ef7aaa5
commit 277b315795
10 changed files with 122 additions and 104 deletions

BIN
bun.lockb

Binary file not shown.

View File

@ -14,6 +14,7 @@
"@inquirer/select": "2.3.10",
"ace-builds": "1.36.3",
"classnames": "2.5.1",
"delay": "6.0.0",
"preact": "10.22.0"
},
"devDependencies": {

View File

@ -0,0 +1,4 @@
import { useEffect } from "preact/hooks";
export const useAsyncEffect = (fx: () => any, deps: any[]) =>
useEffect(() => void fx(), deps);

View File

@ -1,4 +1,3 @@
export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
@ -50,20 +49,30 @@ export const intHash = (seed: number, ...parts: number[]) => {
return h1;
};
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number): F {
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number, trailing = false): F {
let isThrottled = false;
let savedResult: R;
let savedThis: T;
let savedArgs: A | undefined;
const wrapper: F = function (...args: A) {
if (!isThrottled) {
if (isThrottled) {
savedThis = this;
savedArgs = args;
} else {
savedResult = func.apply(this, args);
savedArgs = undefined;
isThrottled = true;
setTimeout(function () {
isThrottled = false;
if (trailing && savedArgs) {
savedResult = wrapper.apply(savedThis, savedArgs);
}
}, ms);
}
return savedResult;
} as F;

View File

@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea";
export const Input = () => {
const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
const { generating } = useContext(LLMContext);
const { generating, stopGeneration } = useContext(LLMContext);
const handleSend = useCallback(async () => {
if (!generating) {
@ -29,7 +29,10 @@ export const Input = () => {
return (
<div class="chat-input">
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
<button onClick={handleSend} class={`${generating ? 'disabled' : ''}`}>{input ? 'Send' : 'Continue'}</button>
{generating
? <button onClick={stopGeneration}>Stop</button>
: <button onClick={handleSend}>{input ? 'Send' : 'Continue'}</button>
}
</div>
);
}

View File

@ -16,7 +16,7 @@ interface IProps {
}
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
const { generating, generate, compilePrompt } = useContext(LLMContext);
const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
const [messages, setMessages] = useState<IMessage[]>([]);
const ref = useRef<HTMLDivElement>(null);
@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
</div>
</div>
<div class={styles.buttons}>
<button onClick={handleGenerate} class={`${generating ? 'disabled' : ''}`}>
Generate
</button>
{generating
? <button onClick={stopGeneration}>Stop</button>
: <button onClick={handleGenerate}>Generate</button>
}
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
Clear
</button>

View File

@ -1,6 +1,7 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { delay, throttle } from "@common/utils";
import { throttle } from "@common/utils";
import delay, { clearDelay } from "delay";
interface IBaseConnection {
instruct: string;
@ -72,7 +73,7 @@ const DEFAULT_GENERATION_SETTINGS = {
dry_penalty_last_n: 0
}
const MIN_PERFORMANCE = 5.0;
const MIN_PERFORMANCE = 2.0;
const MIN_WORKER_CONTEXT = 8192;
const MAX_HORDE_LENGTH = 512;
const MAX_HORDE_CONTEXT = 32000;
@ -88,7 +89,7 @@ export const normalizeModel = (model: string) => {
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
@ -104,14 +105,15 @@ export const normalizeModel = (model: string) => {
.trim();
}
export const approximateTokens = (prompt: string): number =>
Math.round(prompt.split(/\s+/).length * 0.75);
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
export namespace Connection {
const AIHORDE = 'https://aihorde.net';
let abortController = new AbortController();
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
const sse = new SSE(`${url}/api/extra/generate/stream`, {
payload: JSON.stringify({
@ -144,12 +146,14 @@ export namespace Connection {
messageLock.release();
};
abortController.signal.addEventListener('abort', handleEnd);
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
@ -189,6 +193,8 @@ export namespace Connection {
workers: model.workers,
};
const { signal } = abortController;
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
method: 'POST',
body: JSON.stringify(requestData),
@ -196,31 +202,44 @@ export namespace Connection {
'Content-Type': 'application/json',
apikey: connection.apiKey || HORDE_ANON_KEY,
},
signal,
});
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) {
if (!generateResponse.ok || generateResponse.status >= 400) {
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
}
const { id } = await generateResponse.json() as { id: string };
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' })
.catch(e => console.error('Error deleting request', e));
const request = async (method = 'GET'): Promise<string | null> => {
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
if (response.ok && response.status < 400) {
const result: IHordeResult = await response.json();
if (result.generations?.length === 1) {
const { text } = result.generations[0];
while (true) {
await delay(2500);
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`);
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
deleteRequest();
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
return text;
}
} else {
throw new Error(await response.text());
}
const result: IHordeResult = await retrieveResponse.json();
return null;
};
if (result.done && result.generations?.length === 1) {
const { text } = result.generations[0];
const deleteRequest = async () => (await request('DELETE')) ?? '';
return text;
while (true) {
try {
await delay(2500, { signal });
const text = await request();
if (text) {
return text;
}
} catch (e) {
console.error('Error in horde generation:', e);
return deleteRequest();
}
}
}
@ -236,15 +255,20 @@ export namespace Connection {
}
}
export function stopGeneration() {
abortController.abort();
abortController = new AbortController(); // refresh
}
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
try {
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
if (response.ok) {
const workers: IHordeWorker[] = await response.json();
const goodWorkers = workers.filter(w =>
w.online
&& !w.maintenance_mode
&& !w.flagged
const goodWorkers = workers.filter(w =>
w.online
&& !w.maintenance_mode
&& !w.flagged
&& w.max_context_length >= MIN_WORKER_CONTEXT
&& parseFloat(w.performance) >= MIN_PERFORMANCE
);
@ -299,7 +323,7 @@ export namespace Connection {
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
console.error('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
return connection.model;
@ -317,7 +341,7 @@ export namespace Connection {
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
console.error('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
const models = await getHordeModels();
@ -343,7 +367,7 @@ export namespace Connection {
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
console.error('Error counting tokens', e);
}
}

View File

@ -1,13 +1,13 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { createContext } from "preact";
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages";
import { Instruct, StateContext } from "./state";
import { StateContext } from "./state";
import { useBool } from "@common/hooks/useBool";
import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface";
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
import { throttle } from "@common/utils";
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
interface ICompileArgs {
keepUsers?: number;
@ -22,9 +22,7 @@ interface ICompiledPrompt {
interface IContext {
generating: boolean;
blockConnection: ReturnType<typeof useBool>;
modelName: string;
modelTemplate: string;
hasToolCalls: boolean;
promptTokens: number;
contextLength: number;
@ -35,6 +33,7 @@ const MESSAGES_TO_KEEP = 10;
interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
stopGeneration: () => void;
summarize: (content: string) => Promise<string>;
countTokens: (prompt: string) => Promise<number>;
}
@ -50,15 +49,13 @@ const processing = {
export const LLMContextProvider = ({ children }: { children?: any }) => {
const {
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
setTriggerNext, addMessage, editMessage, editSummary,
} = useContext(StateContext);
const generating = useBool(false);
const blockConnection = useBool(false);
const [promptTokens, setPromptTokens] = useState(0);
const [contextLength, setContextLength] = useState(0);
const [modelName, setModelName] = useState('');
const [modelTemplate, setModelTemplate] = useState('');
const [hasToolCalls, setHasToolCalls] = useState(false);
const userPromptTemplate = useMemo(() => {
@ -71,20 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}
}, [userPrompt]);
const getContextLength = useCallback(async () => {
if (!connection || blockConnection.value) {
return 0;
}
return Connection.getContextLength(connection);
}, [connection, blockConnection.value]);
const getModelName = useCallback(async () => {
if (!connection || blockConnection.value) {
return '';
}
return Connection.getModelName(connection);
}, [connection, blockConnection.value]);
const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers } = {}) => {
const promptMessages = messages.slice();
@ -179,31 +162,43 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
},
generate: async function* (prompt, extraSettings = {}) {
try {
generating.setTrue();
console.log('[LLM.generate]', prompt);
yield* Connection.generate(connection, prompt, {
...extraSettings,
...extraSettings,
banned_tokens: bannedWords.filter(w => w.trim()),
});
} finally {
generating.setFalse();
} catch (e) {
if (e instanceof Error && e.name !== 'AbortError') {
alert(e.message);
} else {
console.error('[LLM.generate]', e);
}
}
},
summarize: async (message) => {
const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
try {
const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
console.log('[LLM.summarize]', prompt);
const tokens = await Array.fromAsync(actions.generate(prompt));
const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
return MessageTools.trimSentence(tokens.join(''));
return MessageTools.trimSentence(tokens.join(''));
} catch (e) {
console.error('Error summarizing:', e);
return '';
}
},
countTokens: async (prompt) => {
return await Connection.countTokens(connection, prompt);
},
stopGeneration: () => {
Connection.stopGeneration();
},
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
useEffect(() => void (async () => {
useAsyncEffect(async () => {
if (triggerNext && !generating.value) {
setTriggerNext(false);
@ -217,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
messageId++;
}
generating.setTrue();
editSummary(messageId, 'Generating...');
for await (const chunk of actions.generate(prompt)) {
text += chunk;
setPromptTokens(promptTokens + approximateTokens(text));
editMessage(messageId, text.trim());
}
generating.setFalse();
text = MessageTools.trimSentence(text);
editMessage(messageId, text);
@ -230,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
MessageTools.playReady();
}
})(), [triggerNext]);
}, [triggerNext]);
useEffect(() => void (async () => {
if (summaryEnabled && !generating.value && !processing.summarizing) {
useAsyncEffect(async () => {
if (summaryEnabled && !processing.summarizing) {
try {
processing.summarizing = true;
for (let id = 0; id < messages.length; id++) {
@ -250,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.summarizing = false;
}
}
})(), [messages]);
}, [messages, summaryEnabled]);
useEffect(() => {
if (!blockConnection.value) {
setPromptTokens(0);
setContextLength(0);
setModelName('');
useEffect(throttle(() => {
Connection.getContextLength(connection).then(setContextLength);
Connection.getModelName(connection).then(normalizeModel).then(setModelName);
}, 1000, true), [connection]);
getContextLength().then(setContextLength);
getModelName().then(normalizeModel).then(setModelName);
}
}, [connection, blockConnection.value]);
useEffect(() => {
setModelTemplate('');
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then((template) => {
if (template) {
setModelTemplate(template);
setInstruct(template);
} else {
setInstruct(Instruct.CHATML);
}
});
}
}, [modelName]);
const calculateTokens = useCallback(async () => {
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
const calculateTokens = useCallback(throttle(async () => {
if (!processing.tokenizing && !generating.value) {
try {
processing.tokenizing = true;
const { prompt } = await actions.compilePrompt(messages);
@ -291,11 +267,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.tokenizing = false;
}
}
}, [actions, messages, blockConnection.value]);
}, 1000, true), [actions, messages]);
useEffect(() => {
calculateTokens();
}, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]);
}, [messages, connection, systemPrompt, lore, userPrompt]);
useEffect(() => {
try {
@ -308,9 +284,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
const rawContext: IContext = {
generating: generating.value,
blockConnection,
modelName,
modelTemplate,
hasToolCalls,
promptTokens,
contextLength,

View File

@ -83,7 +83,7 @@ What should happen next in your answer: {{ prompt | trim }}
{% endif %}
Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
summaryEnabled: false,
bannedWords: [],
messages: [],

View File

@ -1,6 +1,7 @@
import { gguf } from '@huggingface/gguf';
import * as hub from '@huggingface/hub';
import { Template } from '@huggingface/jinja';
import { normalizeModel } from './connection';
export namespace Huggingface {
export interface ITemplateMessage {
@ -92,11 +93,12 @@ export namespace Huggingface {
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
console.log(`[huggingface] searching config for '${modelName}'`);
const searchModel = normalizeModel(modelName);
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
const models = hubModels.filter(m => {
if (m.gated) return false;
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false;
if (!normalizeModel(m.name).includes(searchModel)) return false;
return true;
}).sort((a, b) => b.downloads - a.downloads);