1
0
Fork 0

AIStory: stopping

This commit is contained in:
Pabloader 2024-11-12 16:32:52 +00:00
parent 017ef7aaa5
commit 277b315795
10 changed files with 122 additions and 104 deletions

BIN
bun.lockb

Binary file not shown.

View File

@ -14,6 +14,7 @@
"@inquirer/select": "2.3.10", "@inquirer/select": "2.3.10",
"ace-builds": "1.36.3", "ace-builds": "1.36.3",
"classnames": "2.5.1", "classnames": "2.5.1",
"delay": "6.0.0",
"preact": "10.22.0" "preact": "10.22.0"
}, },
"devDependencies": { "devDependencies": {

View File

@ -0,0 +1,4 @@
import { useEffect } from "preact/hooks";
export const useAsyncEffect = (fx: () => any, deps: any[]) =>
useEffect(() => void fx(), deps);

View File

@ -1,4 +1,3 @@
export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve)); export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random()); export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
@ -50,20 +49,30 @@ export const intHash = (seed: number, ...parts: number[]) => {
return h1; return h1;
}; };
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number): F { export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number, trailing = false): F {
let isThrottled = false; let isThrottled = false;
let savedResult: R; let savedResult: R;
let savedThis: T;
let savedArgs: A | undefined;
const wrapper: F = function (...args: A) { const wrapper: F = function (...args: A) {
if (!isThrottled) { if (isThrottled) {
savedThis = this;
savedArgs = args;
} else {
savedResult = func.apply(this, args); savedResult = func.apply(this, args);
savedArgs = undefined;
isThrottled = true; isThrottled = true;
setTimeout(function () { setTimeout(function () {
isThrottled = false; isThrottled = false;
if (trailing && savedArgs) {
savedResult = wrapper.apply(savedThis, savedArgs);
}
}, ms); }, ms);
} }
return savedResult; return savedResult;
} as F; } as F;

View File

@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea";
export const Input = () => { export const Input = () => {
const { input, setInput, addMessage, continueMessage } = useContext(StateContext); const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
const { generating } = useContext(LLMContext); const { generating, stopGeneration } = useContext(LLMContext);
const handleSend = useCallback(async () => { const handleSend = useCallback(async () => {
if (!generating) { if (!generating) {
@ -29,7 +29,10 @@ export const Input = () => {
return ( return (
<div class="chat-input"> <div class="chat-input">
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} /> <AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
<button onClick={handleSend} class={`${generating ? 'disabled' : ''}`}>{input ? 'Send' : 'Continue'}</button> {generating
? <button onClick={stopGeneration}>Stop</button>
: <button onClick={handleSend}>{input ? 'Send' : 'Continue'}</button>
}
</div> </div>
); );
} }

View File

@ -16,7 +16,7 @@ interface IProps {
} }
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => { export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
const { generating, generate, compilePrompt } = useContext(LLMContext); const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
const [messages, setMessages] = useState<IMessage[]>([]); const [messages, setMessages] = useState<IMessage[]>([]);
const ref = useRef<HTMLDivElement>(null); const ref = useRef<HTMLDivElement>(null);
@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
</div> </div>
</div> </div>
<div class={styles.buttons}> <div class={styles.buttons}>
<button onClick={handleGenerate} class={`${generating ? 'disabled' : ''}`}> {generating
Generate ? <button onClick={stopGeneration}>Stop</button>
</button> : <button onClick={handleGenerate}>Generate</button>
}
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}> <button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
Clear Clear
</button> </button>

View File

@ -1,6 +1,7 @@
import Lock from "@common/lock"; import Lock from "@common/lock";
import SSE from "@common/sse"; import SSE from "@common/sse";
import { delay, throttle } from "@common/utils"; import { throttle } from "@common/utils";
import delay, { clearDelay } from "delay";
interface IBaseConnection { interface IBaseConnection {
instruct: string; instruct: string;
@ -72,7 +73,7 @@ const DEFAULT_GENERATION_SETTINGS = {
dry_penalty_last_n: 0 dry_penalty_last_n: 0
} }
const MIN_PERFORMANCE = 5.0; const MIN_PERFORMANCE = 2.0;
const MIN_WORKER_CONTEXT = 8192; const MIN_WORKER_CONTEXT = 8192;
const MAX_HORDE_LENGTH = 512; const MAX_HORDE_LENGTH = 512;
const MAX_HORDE_CONTEXT = 32000; const MAX_HORDE_CONTEXT = 32000;
@ -88,7 +89,7 @@ export const normalizeModel = (model: string) => {
currentModel = currentModel currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k .replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name .replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc .replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size .replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw .replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
@ -104,14 +105,15 @@ export const normalizeModel = (model: string) => {
.trim(); .trim();
} }
export const approximateTokens = (prompt: string): number => export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
Math.round(prompt.split(/\s+/).length * 0.75);
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>; export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
export namespace Connection { export namespace Connection {
const AIHORDE = 'https://aihorde.net'; const AIHORDE = 'https://aihorde.net';
let abortController = new AbortController();
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> { async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
const sse = new SSE(`${url}/api/extra/generate/stream`, { const sse = new SSE(`${url}/api/extra/generate/stream`, {
payload: JSON.stringify({ payload: JSON.stringify({
@ -144,12 +146,14 @@ export namespace Connection {
messageLock.release(); messageLock.release();
}; };
abortController.signal.addEventListener('abort', handleEnd);
sse.addEventListener('error', handleEnd); sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd); sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => { sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd(); if (e.readyState === SSE.CLOSED) handleEnd();
}); });
while (!end || messages.length) { while (!end || messages.length) {
while (messages.length > 0) { while (messages.length > 0) {
const message = messages.shift(); const message = messages.shift();
@ -189,6 +193,8 @@ export namespace Connection {
workers: model.workers, workers: model.workers,
}; };
const { signal } = abortController;
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, { const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
method: 'POST', method: 'POST',
body: JSON.stringify(requestData), body: JSON.stringify(requestData),
@ -196,31 +202,44 @@ export namespace Connection {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
apikey: connection.apiKey || HORDE_ANON_KEY, apikey: connection.apiKey || HORDE_ANON_KEY,
}, },
signal,
}); });
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) { if (!generateResponse.ok || generateResponse.status >= 400) {
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`); throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
} }
const { id } = await generateResponse.json() as { id: string }; const { id } = await generateResponse.json() as { id: string };
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' }) const request = async (method = 'GET'): Promise<string | null> => {
.catch(e => console.error('Error deleting request', e)); const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
if (response.ok && response.status < 400) {
const result: IHordeResult = await response.json();
if (result.generations?.length === 1) {
const { text } = result.generations[0];
while (true) { return text;
await delay(2500); }
} else {
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`); throw new Error(await response.text());
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
deleteRequest();
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
} }
const result: IHordeResult = await retrieveResponse.json(); return null;
};
if (result.done && result.generations?.length === 1) { const deleteRequest = async () => (await request('DELETE')) ?? '';
const { text } = result.generations[0];
return text; while (true) {
try {
await delay(2500, { signal });
const text = await request();
if (text) {
return text;
}
} catch (e) {
console.error('Error in horde generation:', e);
return deleteRequest();
} }
} }
} }
@ -236,15 +255,20 @@ export namespace Connection {
} }
} }
export function stopGeneration() {
abortController.abort();
abortController = new AbortController(); // refresh
}
async function requestHordeModels(): Promise<Map<string, IHordeModel>> { async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
try { try {
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`); const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
if (response.ok) { if (response.ok) {
const workers: IHordeWorker[] = await response.json(); const workers: IHordeWorker[] = await response.json();
const goodWorkers = workers.filter(w => const goodWorkers = workers.filter(w =>
w.online w.online
&& !w.maintenance_mode && !w.maintenance_mode
&& !w.flagged && !w.flagged
&& w.max_context_length >= MIN_WORKER_CONTEXT && w.max_context_length >= MIN_WORKER_CONTEXT
&& parseFloat(w.performance) >= MIN_PERFORMANCE && parseFloat(w.performance) >= MIN_PERFORMANCE
); );
@ -299,7 +323,7 @@ export namespace Connection {
return result; return result;
} }
} catch (e) { } catch (e) {
console.log('Error getting max tokens', e); console.error('Error getting max tokens', e);
} }
} else if (isHordeConnection(connection)) { } else if (isHordeConnection(connection)) {
return connection.model; return connection.model;
@ -317,7 +341,7 @@ export namespace Connection {
return value; return value;
} }
} catch (e) { } catch (e) {
console.log('Error getting max tokens', e); console.error('Error getting max tokens', e);
} }
} else if (isHordeConnection(connection)) { } else if (isHordeConnection(connection)) {
const models = await getHordeModels(); const models = await getHordeModels();
@ -343,7 +367,7 @@ export namespace Connection {
return value; return value;
} }
} catch (e) { } catch (e) {
console.log('Error counting tokens', e); console.error('Error counting tokens', e);
} }
} }

View File

@ -1,13 +1,13 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { createContext } from "preact"; import { createContext } from "preact";
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks"; import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages"; import { MessageTools, type IMessage } from "../messages";
import { Instruct, StateContext } from "./state"; import { StateContext } from "./state";
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Template } from "@huggingface/jinja"; import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface"; import { Huggingface } from "../huggingface";
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection"; import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
import { throttle } from "@common/utils";
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
interface ICompileArgs { interface ICompileArgs {
keepUsers?: number; keepUsers?: number;
@ -22,9 +22,7 @@ interface ICompiledPrompt {
interface IContext { interface IContext {
generating: boolean; generating: boolean;
blockConnection: ReturnType<typeof useBool>;
modelName: string; modelName: string;
modelTemplate: string;
hasToolCalls: boolean; hasToolCalls: boolean;
promptTokens: number; promptTokens: number;
contextLength: number; contextLength: number;
@ -35,6 +33,7 @@ const MESSAGES_TO_KEEP = 10;
interface IActions { interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>; compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
stopGeneration: () => void;
summarize: (content: string) => Promise<string>; summarize: (content: string) => Promise<string>;
countTokens: (prompt: string) => Promise<number>; countTokens: (prompt: string) => Promise<number>;
} }
@ -50,15 +49,13 @@ const processing = {
export const LLMContextProvider = ({ children }: { children?: any }) => { export const LLMContextProvider = ({ children }: { children?: any }) => {
const { const {
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, addMessage, editMessage, editSummary, setInstruct, setTriggerNext, addMessage, editMessage, editSummary,
} = useContext(StateContext); } = useContext(StateContext);
const generating = useBool(false); const generating = useBool(false);
const blockConnection = useBool(false);
const [promptTokens, setPromptTokens] = useState(0); const [promptTokens, setPromptTokens] = useState(0);
const [contextLength, setContextLength] = useState(0); const [contextLength, setContextLength] = useState(0);
const [modelName, setModelName] = useState(''); const [modelName, setModelName] = useState('');
const [modelTemplate, setModelTemplate] = useState('');
const [hasToolCalls, setHasToolCalls] = useState(false); const [hasToolCalls, setHasToolCalls] = useState(false);
const userPromptTemplate = useMemo(() => { const userPromptTemplate = useMemo(() => {
@ -71,20 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
} }
}, [userPrompt]); }, [userPrompt]);
const getContextLength = useCallback(async () => {
if (!connection || blockConnection.value) {
return 0;
}
return Connection.getContextLength(connection);
}, [connection, blockConnection.value]);
const getModelName = useCallback(async () => {
if (!connection || blockConnection.value) {
return '';
}
return Connection.getModelName(connection);
}, [connection, blockConnection.value]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers } = {}) => { compilePrompt: async (messages, { keepUsers } = {}) => {
const promptMessages = messages.slice(); const promptMessages = messages.slice();
@ -179,31 +162,43 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}, },
generate: async function* (prompt, extraSettings = {}) { generate: async function* (prompt, extraSettings = {}) {
try { try {
generating.setTrue();
console.log('[LLM.generate]', prompt); console.log('[LLM.generate]', prompt);
yield* Connection.generate(connection, prompt, { yield* Connection.generate(connection, prompt, {
...extraSettings, ...extraSettings,
banned_tokens: bannedWords.filter(w => w.trim()), banned_tokens: bannedWords.filter(w => w.trim()),
}); });
} finally { } catch (e) {
generating.setFalse(); if (e instanceof Error && e.name !== 'AbortError') {
alert(e.message);
} else {
console.error('[LLM.generate]', e);
}
} }
}, },
summarize: async (message) => { summarize: async (message) => {
const content = Huggingface.applyTemplate(summarizePrompt, { message }); try {
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
console.log('[LLM.summarize]', prompt);
const tokens = await Array.fromAsync(actions.generate(prompt)); const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
return MessageTools.trimSentence(tokens.join('')); return MessageTools.trimSentence(tokens.join(''));
} catch (e) {
console.error('Error summarizing:', e);
return '';
}
}, },
countTokens: async (prompt) => { countTokens: async (prompt) => {
return await Connection.countTokens(connection, prompt); return await Connection.countTokens(connection, prompt);
}, },
stopGeneration: () => {
Connection.stopGeneration();
},
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]); }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
useEffect(() => void (async () => { useAsyncEffect(async () => {
if (triggerNext && !generating.value) { if (triggerNext && !generating.value) {
setTriggerNext(false); setTriggerNext(false);
@ -217,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
messageId++; messageId++;
} }
generating.setTrue();
editSummary(messageId, 'Generating...'); editSummary(messageId, 'Generating...');
for await (const chunk of actions.generate(prompt)) { for await (const chunk of actions.generate(prompt)) {
text += chunk; text += chunk;
setPromptTokens(promptTokens + approximateTokens(text)); setPromptTokens(promptTokens + approximateTokens(text));
editMessage(messageId, text.trim()); editMessage(messageId, text.trim());
} }
generating.setFalse();
text = MessageTools.trimSentence(text); text = MessageTools.trimSentence(text);
editMessage(messageId, text); editMessage(messageId, text);
@ -230,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
MessageTools.playReady(); MessageTools.playReady();
} }
})(), [triggerNext]); }, [triggerNext]);
useEffect(() => void (async () => { useAsyncEffect(async () => {
if (summaryEnabled && !generating.value && !processing.summarizing) { if (summaryEnabled && !processing.summarizing) {
try { try {
processing.summarizing = true; processing.summarizing = true;
for (let id = 0; id < messages.length; id++) { for (let id = 0; id < messages.length; id++) {
@ -250,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.summarizing = false; processing.summarizing = false;
} }
} }
})(), [messages]); }, [messages, summaryEnabled]);
useEffect(() => { useEffect(throttle(() => {
if (!blockConnection.value) { Connection.getContextLength(connection).then(setContextLength);
setPromptTokens(0); Connection.getModelName(connection).then(normalizeModel).then(setModelName);
setContextLength(0); }, 1000, true), [connection]);
setModelName('');
getContextLength().then(setContextLength); const calculateTokens = useCallback(throttle(async () => {
getModelName().then(normalizeModel).then(setModelName); if (!processing.tokenizing && !generating.value) {
}
}, [connection, blockConnection.value]);
useEffect(() => {
setModelTemplate('');
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then((template) => {
if (template) {
setModelTemplate(template);
setInstruct(template);
} else {
setInstruct(Instruct.CHATML);
}
});
}
}, [modelName]);
const calculateTokens = useCallback(async () => {
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
try { try {
processing.tokenizing = true; processing.tokenizing = true;
const { prompt } = await actions.compilePrompt(messages); const { prompt } = await actions.compilePrompt(messages);
@ -291,11 +267,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.tokenizing = false; processing.tokenizing = false;
} }
} }
}, [actions, messages, blockConnection.value]); }, 1000, true), [actions, messages]);
useEffect(() => { useEffect(() => {
calculateTokens(); calculateTokens();
}, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]); }, [messages, connection, systemPrompt, lore, userPrompt]);
useEffect(() => { useEffect(() => {
try { try {
@ -308,9 +284,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
const rawContext: IContext = { const rawContext: IContext = {
generating: generating.value, generating: generating.value,
blockConnection,
modelName, modelName,
modelTemplate,
hasToolCalls, hasToolCalls,
promptTokens, promptTokens,
contextLength, contextLength,

View File

@ -83,7 +83,7 @@ What should happen next in your answer: {{ prompt | trim }}
{% endif %} {% endif %}
Remember that this story should be infinite and go forever. Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
summaryEnabled: false, summaryEnabled: false,
bannedWords: [], bannedWords: [],
messages: [], messages: [],

View File

@ -1,6 +1,7 @@
import { gguf } from '@huggingface/gguf'; import { gguf } from '@huggingface/gguf';
import * as hub from '@huggingface/hub'; import * as hub from '@huggingface/hub';
import { Template } from '@huggingface/jinja'; import { Template } from '@huggingface/jinja';
import { normalizeModel } from './connection';
export namespace Huggingface { export namespace Huggingface {
export interface ITemplateMessage { export interface ITemplateMessage {
@ -92,11 +93,12 @@ export namespace Huggingface {
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => { const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
console.log(`[huggingface] searching config for '${modelName}'`); console.log(`[huggingface] searching config for '${modelName}'`);
const searchModel = normalizeModel(modelName);
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] })); const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
const models = hubModels.filter(m => { const models = hubModels.filter(m => {
if (m.gated) return false; if (m.gated) return false;
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false; if (!normalizeModel(m.name).includes(searchModel)) return false;
return true; return true;
}).sort((a, b) => b.downloads - a.downloads); }).sort((a, b) => b.downloads - a.downloads);