diff --git a/bun.lockb b/bun.lockb index 40e1e24..888bd9a 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 404c108..7971592 100644 --- a/package.json +++ b/package.json @@ -11,6 +11,7 @@ "@huggingface/gguf": "0.1.12", "@huggingface/hub": "0.19.0", "@huggingface/jinja": "0.3.1", + "@huggingface/transformers": "3.0.2", "@inquirer/select": "2.3.10", "ace-builds": "1.36.3", "classnames": "2.5.1", diff --git a/src/games/ai-story/assets/style.css b/src/games/ai-story/assets/style.css index bdcc6e5..ab8f09f 100644 --- a/src/games/ai-story/assets/style.css +++ b/src/games/ai-story/assets/style.css @@ -7,6 +7,8 @@ --green: #AFAFAF; --red: #7F0000; --green: #007F00; + --brightRed: #DD0000; + --brightGreen: #00DD00; --shadeColor: rgba(0, 128, 128, 0.3); --border: 1px solid var(--color); diff --git a/src/games/ai-story/components/autoTextarea.tsx b/src/games/ai-story/components/autoTextarea.tsx index d0c592b..ac69a93 100644 --- a/src/games/ai-story/components/autoTextarea.tsx +++ b/src/games/ai-story/components/autoTextarea.tsx @@ -2,7 +2,7 @@ import { useEffect, useRef } from "preact/hooks"; import type { JSX } from "preact/jsx-runtime" import { useIsVisible } from '@common/hooks/useIsVisible'; -import { DOMTools } from "../dom"; +import { DOMTools } from "../tools/dom"; export const AutoTextarea = (props: JSX.HTMLAttributes) => { const { value } = props; diff --git a/src/games/ai-story/components/chat.tsx b/src/games/ai-story/components/chat.tsx index 7a07f75..aee10c8 100644 --- a/src/games/ai-story/components/chat.tsx +++ b/src/games/ai-story/components/chat.tsx @@ -1,8 +1,8 @@ import { useCallback, useContext, useEffect, useRef } from "preact/hooks"; import { StateContext } from "../contexts/state"; import { Message } from "./message/message"; -import { MessageTools } from "../messages"; -import { DOMTools } from "../dom"; +import { MessageTools } from "../tools/messages"; +import { DOMTools } from "../tools/dom"; export const Chat = () => { const { messages } = useContext(StateContext); diff --git a/src/games/ai-story/components/header/connectionEditor.tsx b/src/games/ai-story/components/header/connectionEditor.tsx index 5836e2e..81b3592 100644 --- a/src/games/ai-story/components/header/connectionEditor.tsx +++ b/src/games/ai-story/components/header/connectionEditor.tsx @@ -1,11 +1,10 @@ import { useCallback, useEffect, useMemo, useState } from 'preact/hooks'; - import styles from './header.module.css'; -import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection'; +import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../tools/connection'; import { Instruct } from '../../contexts/state'; import { useInputState } from '@common/hooks/useInputState'; import { useInputCallback } from '@common/hooks/useInputCallback'; -import { Huggingface } from '../../huggingface'; +import { Huggingface } from '../../tools/huggingface'; interface IProps { connection: IConnection; @@ -13,10 +12,13 @@ interface IProps { } export const ConnectionEditor = ({ connection, setConnection }: IProps) => { + // kobold const [connectionUrl, setConnectionUrl] = useInputState(''); + // horde const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY); const [modelName, setModelName] = useInputState(''); + const [instruct, setInstruct] = useInputState(''); const [modelTemplate, setModelTemplate] = useInputState(''); const [hordeModels, setHordeModels] = useState([]); const [contextLength, setContextLength] = useState(0); @@ -27,11 +29,14 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => { return 'unknown'; }, [connection]); - const urlValid = useMemo(() => contextLength > 0, [contextLength]); + const isOnline = useMemo(() => contextLength > 0, [contextLength]); useEffect(() => { + setInstruct(connection.instruct); + if (isKoboldConnection(connection)) { setConnectionUrl(connection.url); + Connection.getContextLength(connection).then(setContextLength); } else if (isHordeConnection(connection)) { setModelName(connection.model); setApiKey(connection.apiKey || HORDE_ANON_KEY); @@ -39,9 +44,6 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => { Connection.getHordeModels() .then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name)))); } - - Connection.getContextLength(connection).then(setContextLength); - Connection.getModelName(connection).then(setModelName); }, [connection]); useEffect(() => { @@ -50,47 +52,44 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => { .then(template => { if (template) { setModelTemplate(template); + setInstruct(template); } }); } }, [modelName]); - const setInstruct = useInputCallback((instruct) => { - setConnection({ ...connection, instruct }); - }, [connection, setConnection]); - const setBackendType = useInputCallback((type) => { if (type === 'kobold') { setConnection({ - instruct: connection.instruct, + instruct, url: connectionUrl, }); } else if (type === 'horde') { setConnection({ - instruct: connection.instruct, + instruct, apiKey, model: modelName, }); } - }, [connection, setConnection, connectionUrl, apiKey, modelName]); + }, [setConnection, connectionUrl, apiKey, modelName, instruct]); const handleBlurUrl = useCallback(() => { const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i; const url = connectionUrl.replace(regex, 'http$1://$2'); setConnection({ - instruct: connection.instruct, + instruct, url, }); - }, [connection, connectionUrl, setConnection]); + }, [connectionUrl, instruct, setConnection]); const handleBlurHorde = useCallback(() => { setConnection({ - instruct: connection.instruct, + instruct, apiKey, model: modelName, }); - }, [connection, apiKey, modelName, setConnection]); + }, [apiKey, modelName, instruct, setConnection]); return (
@@ -98,7 +97,7 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => { - {modelName && modelTemplate && } @@ -109,15 +108,15 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => { ))} - + {instruct !== modelTemplate && - + } {isKoboldConnection(connection) && } {isHordeConnection(connection) && <> { const promptsOpen = useBool(); const genparamsOpen = useBool(); const assistantOpen = useBool(); + const isOnline = useMemo(() => contextLength > 0, [contextLength]); const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]); @@ -56,7 +57,7 @@ export const Header = () => {
-
diff --git a/src/games/ai-story/components/message/formattedMessage.tsx b/src/games/ai-story/components/message/formattedMessage.tsx index 8834fa2..e63d419 100644 --- a/src/games/ai-story/components/message/formattedMessage.tsx +++ b/src/games/ai-story/components/message/formattedMessage.tsx @@ -1,5 +1,5 @@ import { useMemo } from "preact/hooks"; -import { MessageTools } from "../../messages"; +import { MessageTools } from "../../tools/messages"; import styles from './message.module.css'; diff --git a/src/games/ai-story/components/message/message.tsx b/src/games/ai-story/components/message/message.tsx index e8fd843..fee63b7 100644 --- a/src/games/ai-story/components/message/message.tsx +++ b/src/games/ai-story/components/message/message.tsx @@ -1,7 +1,7 @@ import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks"; -import { MessageTools, type IMessage } from "../../messages"; +import { MessageTools, type IMessage } from "../../tools/messages"; import { StateContext } from "../../contexts/state"; -import { DOMTools } from "../../dom"; +import { DOMTools } from "../../tools/dom"; import styles from './message.module.css'; import { AutoTextarea } from "../autoTextarea"; diff --git a/src/games/ai-story/components/minichat/minichat.tsx b/src/games/ai-story/components/minichat/minichat.tsx index 9e63e9c..5b421f1 100644 --- a/src/games/ai-story/components/minichat/minichat.tsx +++ b/src/games/ai-story/components/minichat/minichat.tsx @@ -1,7 +1,7 @@ -import { MessageTools, type IMessage } from "../../messages" +import { MessageTools, type IMessage } from "../../tools/messages" import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks"; import { Modal } from "@common/components/modal/modal"; -import { DOMTools } from "../../dom"; +import { DOMTools } from "../../tools/dom"; import styles from './minichat.module.css'; import { LLMContext } from "../../contexts/llm"; diff --git a/src/games/ai-story/contexts/llm.tsx b/src/games/ai-story/contexts/llm.tsx index d250b0b..078dfa2 100644 --- a/src/games/ai-story/contexts/llm.tsx +++ b/src/games/ai-story/contexts/llm.tsx @@ -1,13 +1,13 @@ import { createContext } from "preact"; import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks"; -import { MessageTools, type IMessage } from "../messages"; +import { MessageTools, type IMessage } from "../tools/messages"; import { StateContext } from "./state"; import { useBool } from "@common/hooks/useBool"; -import { Template } from "@huggingface/jinja"; -import { Huggingface } from "../huggingface"; -import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection"; +import { Huggingface } from "../tools/huggingface"; +import { Connection, type IGenerationSettings } from "../tools/connection"; import { throttle } from "@common/utils"; import { useAsyncEffect } from "@common/hooks/useAsyncEffect"; +import { approximateTokens, normalizeModel } from "../tools/model"; interface ICompileArgs { keepUsers?: number; @@ -58,15 +58,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { const [modelName, setModelName] = useState(''); const [hasToolCalls, setHasToolCalls] = useState(false); - const userPromptTemplate = useMemo(() => { - try { - return new Template(userPrompt) - } catch { - return { - render: () => userPrompt, - } - } - }, [userPrompt]); + const isOnline = useMemo(() => contextLength > 0, [contextLength]); const actions: IActions = useMemo(() => ({ compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => { @@ -86,7 +78,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice(); if (isContinue) { - promptMessages.push(MessageTools.create(userPromptTemplate.render({}))); + promptMessages.push(MessageTools.create(Huggingface.applyTemplate(userPrompt, {}))); } const userMessages = promptMessages.filter(m => m.role === 'user'); @@ -113,7 +105,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { } else if (role === 'user' && !message.technical) { templateMessages.push({ role: message.role, - content: userPromptTemplate.render({ prompt: content, isStart: !wasStory }), + content: Huggingface.applyTemplate(userPrompt, { prompt: content, isStart: !wasStory }), }); } else { if (role === 'assistant') { @@ -137,17 +129,17 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { if (story.length > 0) { const prompt = MessageTools.getSwipe(firstUserMessage)?.content; - templateMessages.push({ role: 'user', content: userPromptTemplate.render({ prompt, isStart: true }) }); + templateMessages.push({ role: 'user', content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }) }); templateMessages.push({ role: 'assistant', content: story }); } - let userPrompt = MessageTools.getSwipe(lastUserMessage)?.content; - if (!lastUserMessage?.technical && !isContinue && userPrompt) { - userPrompt = userPromptTemplate.render({ prompt: userPrompt, isStart: story.length === 0 }); + let userMessage = MessageTools.getSwipe(lastUserMessage)?.content; + if (!lastUserMessage?.technical && !isContinue && userMessage) { + userMessage = Huggingface.applyTemplate(userPrompt, { prompt: userMessage, isStart: story.length === 0 }); } - if (userPrompt) { - templateMessages.push({ role: 'user', content: userPrompt }); + if (userMessage) { + templateMessages.push({ role: 'user', content: userMessage }); } } @@ -156,7 +148,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { templateMessages.splice(1, 0, { role: 'user', - content: userPromptTemplate.render({ prompt, isStart: true }), + content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }), }); } @@ -210,10 +202,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { stopGeneration: () => { Connection.stopGeneration(); }, - }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]); + }), [connection, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt]); useAsyncEffect(async () => { - if (triggerNext && !generating.value) { + if (isOnline && triggerNext && !generating.value) { setTriggerNext(false); setContinueLast(false); @@ -244,10 +236,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { MessageTools.playReady(); } - }, [triggerNext]); + }, [triggerNext, isOnline]); useAsyncEffect(async () => { - if (summaryEnabled && !processing.summarizing) { + if (isOnline && summaryEnabled && !processing.summarizing) { try { processing.summarizing = true; for (let id = 0; id < messages.length; id++) { @@ -264,7 +256,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { processing.summarizing = false; } } - }, [messages, summaryEnabled]); + }, [messages, summaryEnabled, isOnline]); useEffect(throttle(() => { Connection.getContextLength(connection).then(setContextLength); @@ -272,7 +264,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }, 1000, true), [connection]); const calculateTokens = useCallback(throttle(async () => { - if (!processing.tokenizing && !generating.value) { + if (isOnline && !processing.tokenizing && !generating.value) { try { processing.tokenizing = true; const { prompt } = await actions.compilePrompt(messages); @@ -284,11 +276,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { processing.tokenizing = false; } } - }, 1000, true), [actions, messages]); + }, 1000, true), [actions, messages, isOnline]); useEffect(() => { calculateTokens(); - }, [messages, connection, systemPrompt, lore, userPrompt]); + }, [messages, connection, systemPrompt, lore, userPrompt, isOnline]); useEffect(() => { try { diff --git a/src/games/ai-story/contexts/state.tsx b/src/games/ai-story/contexts/state.tsx index 63f2936..2e4fa41 100644 --- a/src/games/ai-story/contexts/state.tsx +++ b/src/games/ai-story/contexts/state.tsx @@ -1,8 +1,8 @@ import { createContext } from "preact"; import { useCallback, useEffect, useMemo, useState } from "preact/hooks"; -import { MessageTools, type IMessage } from "../messages"; +import { MessageTools, type IMessage } from "../tools/messages"; import { useInputState } from "@common/hooks/useInputState"; -import { type IConnection } from "../connection"; +import { type IConnection } from "../tools/connection"; interface IContext { currentConnection: number; @@ -83,7 +83,7 @@ Continue the story forward. {%- endif %} {% if prompt -%} -This is the description of What should happen next in your answer: {{ prompt | trim }} +This is the description of what should happen next in your answer: {{ prompt | trim }} {% endif %} Remember that this story should be infinite and go forever. Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, diff --git a/src/games/ai-story/connection.ts b/src/games/ai-story/tools/connection.ts similarity index 70% rename from src/games/ai-story/connection.ts rename to src/games/ai-story/tools/connection.ts index 4892e05..7b8b571 100644 --- a/src/games/ai-story/connection.ts +++ b/src/games/ai-story/tools/connection.ts @@ -2,6 +2,8 @@ import Lock from "@common/lock"; import SSE from "@common/sse"; import { throttle } from "@common/utils"; import delay from "delay"; +import { Huggingface } from "./huggingface"; +import { approximateTokens, normalizeModel } from "./model"; interface IBaseConnection { instruct: string; @@ -79,34 +81,6 @@ const MAX_HORDE_LENGTH = 512; const MAX_HORDE_CONTEXT = 32000; export const HORDE_ANON_KEY = '0000000000'; -export const normalizeModel = (model: string) => { - let currentModel = model.split(/[\\\/]/).at(-1); - currentModel = currentModel.split('::').at(0); - let normalizedModel: string; - - do { - normalizedModel = currentModel; - - currentModel = currentModel - .replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k - .replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name - .replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc - .replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size - .replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw - .replace(/[ ._-]f(p|loat)?(8|16|32)/i, '') - .replace(/^(debug-?)+/i, '') - .trim(); - } while (normalizedModel !== currentModel); - - return normalizedModel - .replace(/[ _-]+/ig, '-') - .replace(/\.{2,}/, '-') - .replace(/[ ._-]+$/ig, '') - .trim(); -} - -export const approximateTokens = (prompt: string): number => Math.round(prompt.length / 4); - export type IGenerationSettings = Partial; export namespace Connection { @@ -171,7 +145,11 @@ export namespace Connection { sse.close(); } - async function generateHorde(connection: Omit, prompt: string, extraSettings: IGenerationSettings = {}): Promise { + async function* generateHorde(connection: IHordeConnection, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator { + if (!connection.model) { + throw new Error('Horde not connected'); + } + const models = await getHordeModels(); const model = models.get(connection.model); if (model) { @@ -192,54 +170,78 @@ export namespace Connection { models: model.hordeNames, workers: model.workers, }; + const bannedTokens = requestData.params.banned_tokens ?? []; const { signal } = abortController; - const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, { - method: 'POST', - body: JSON.stringify(requestData), - headers: { - 'Content-Type': 'application/json', - apikey: connection.apiKey || HORDE_ANON_KEY, - }, - signal, - }); + while (true) { + const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, { + method: 'POST', + body: JSON.stringify(requestData), + headers: { + 'Content-Type': 'application/json', + apikey: connection.apiKey || HORDE_ANON_KEY, + }, + signal, + }); - if (!generateResponse.ok || generateResponse.status >= 400) { - throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`); - } - - const { id } = await generateResponse.json() as { id: string }; - const request = async (method = 'GET'): Promise => { - const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method }); - if (response.ok && response.status < 400) { - const result: IHordeResult = await response.json(); - if (result.generations?.length === 1) { - const { text } = result.generations[0]; - - return text; - } - } else { - throw new Error(await response.text()); + if (!generateResponse.ok || generateResponse.status >= 400) { + throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`); } - return null; - }; + const { id } = await generateResponse.json() as { id: string }; + const request = async (method = 'GET'): Promise => { + const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method }); + if (response.ok && response.status < 400) { + const result: IHordeResult = await response.json(); + if (result.generations?.length === 1) { + const { text } = result.generations[0]; - const deleteRequest = async () => (await request('DELETE')) ?? ''; - - while (true) { - try { - await delay(2500, { signal }); - - const text = await request(); - - if (text) { - return text; + return text; + } + } else { + throw new Error(await response.text()); + } + + return null; + }; + + const deleteRequest = async () => (await request('DELETE')) ?? ''; + let text: string | null = null; + + while (!text) { + try { + await delay(2500, { signal }); + + text = await request(); + + if (text) { + const locaseText = text.toLowerCase(); + let unsloppedText = text; + for (const ban of bannedTokens) { + const slopIdx = locaseText.indexOf(ban.toLowerCase()); + if (slopIdx >= 0) { + console.log(`[horde] slop '${ban}' detected at ${slopIdx}`); + unsloppedText = unsloppedText.slice(0, slopIdx); + } + } + + yield unsloppedText; + + requestData.prompt += unsloppedText; + + if (unsloppedText === text) { + return; // we are finished + } + + if (unsloppedText.length === 0) { + requestData.params.temperature += 0.05; + } + } + } catch (e) { + console.error('Error in horde generation:', e); + return yield deleteRequest(); } - } catch (e) { - console.error('Error in horde generation:', e); - return deleteRequest(); } } } @@ -251,7 +253,7 @@ export namespace Connection { if (isKoboldConnection(connection)) { yield* generateKobold(connection.url, prompt, extraSettings); } else if (isHordeConnection(connection)) { - yield await generateHorde(connection, prompt, extraSettings); + yield* generateHorde(connection, prompt, extraSettings); } } @@ -277,7 +279,7 @@ export namespace Connection { for (const worker of goodWorkers) { for (const modelName of worker.models) { - const normName = normalizeModel(modelName.toLowerCase()); + const normName = normalizeModel(modelName); let model = models.get(normName); if (!model) { model = { @@ -343,7 +345,7 @@ export namespace Connection { } catch (e) { console.error('Error getting max tokens', e); } - } else if (isHordeConnection(connection)) { + } else if (isHordeConnection(connection) && connection.model) { const models = await getHordeModels(); const model = models.get(connection.model); if (model) { @@ -367,7 +369,18 @@ export namespace Connection { return value; } } catch (e) { - console.error('Error counting tokens', e); + console.error('Error counting tokens:', e); + } + } else { + const model = await getModelName(connection); + const tokenizer = await Huggingface.findTokenizer(model); + if (tokenizer) { + try { + const { input_ids } = await tokenizer(prompt); + return input_ids.data.length; + } catch (e) { + console.error('Error counting tokens with tokenizer:', e); + } } } diff --git a/src/games/ai-story/dom.ts b/src/games/ai-story/tools/dom.ts similarity index 100% rename from src/games/ai-story/dom.ts rename to src/games/ai-story/tools/dom.ts diff --git a/src/games/ai-story/huggingface.ts b/src/games/ai-story/tools/huggingface.ts similarity index 82% rename from src/games/ai-story/huggingface.ts rename to src/games/ai-story/tools/huggingface.ts index 5885045..80f22f2 100644 --- a/src/games/ai-story/huggingface.ts +++ b/src/games/ai-story/tools/huggingface.ts @@ -1,7 +1,8 @@ import { gguf } from '@huggingface/gguf'; import * as hub from '@huggingface/hub'; import { Template } from '@huggingface/jinja'; -import { normalizeModel } from './connection'; +import { AutoTokenizer, PreTrainedTokenizer } from '@huggingface/transformers'; +import { normalizeModel } from './model'; export namespace Huggingface { export interface ITemplateMessage { @@ -81,6 +82,7 @@ export namespace Huggingface { const templateCache: Record = loadCache(); const compiledTemplates = new Map(); + const tokenizerCache = new Map(); const hasField = (obj: unknown, field: T): obj is Record => ( obj != null && typeof obj === 'object' && (field in obj) @@ -92,13 +94,13 @@ export namespace Huggingface { ); const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise => { + modelName = normalizeModel(modelName); console.log(`[huggingface] searching config for '${modelName}'`); - const searchModel = normalizeModel(modelName); - const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] })); + const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] })); const models = hubModels.filter(m => { if (m.gated) return false; - if (!normalizeModel(m.name).includes(searchModel)) return false; + if (!normalizeModel(m.name).includes(modelName)) return false; return true; }).sort((a, b) => b.downloads - a.downloads); @@ -116,8 +118,8 @@ export namespace Huggingface { } try { - console.log(`[huggingface] searching config in '${model.name}/tokenizer_config.json'`); - const fileResponse = await hub.downloadFile({ repo: model.name, path: 'tokenizer_config.json' }); + console.log(`[huggingface] searching config in '${name}/tokenizer_config.json'`); + const fileResponse = await hub.downloadFile({ repo: name, path: 'tokenizer_config.json' }); if (fileResponse?.ok) { const maybeConfig = await fileResponse.json(); if (isTokenizerConfig(maybeConfig)) { @@ -232,10 +234,10 @@ export namespace Huggingface { } export const findModelTemplate = async (modelName: string): Promise => { - const modelKey = modelName.toLowerCase().trim(); - if (!modelKey) return ''; + modelName = normalizeModel(modelName); + if (!modelName) return ''; - let template = templateCache[modelKey] ?? null; + let template = templateCache[modelName] ?? null; if (template) { console.log(`[huggingface] found cached template for '${modelName}'`); @@ -254,12 +256,53 @@ export namespace Huggingface { } } - templateCache[modelKey] = template; + templateCache[modelName] = template; saveCache(templateCache); return template; } + export const findTokenizer = async (modelName: string): Promise => { + modelName = normalizeModel(modelName); + + let tokenizer = tokenizerCache.get(modelName) ?? null; + + if (tokenizer) { + return tokenizer; + } else if (!tokenizerCache.has(modelName)) { + console.log(`[huggingface] searching tokenizer for '${modelName}'`); + + const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName } })); + const models = hubModels.filter(m => { + if (m.gated) return false; + if (m.name.toLowerCase().includes('gguf')) return false; + if (!normalizeModel(m.name).includes(modelName)) return false; + + return true; + }); + + for (const model of models) { + const { name } = model; + + try { + console.log(`[huggingface] searching tokenizer in '${name}'`); + tokenizer = await AutoTokenizer.from_pretrained(name); + break; + } catch { } + } + } + + tokenizerCache.set(modelName, tokenizer); + + if (tokenizer) { + console.log(`[huggingface] found tokenizer for '${modelName}'`); + } else { + console.log(`[huggingface] not found tokenizer for '${modelName}'`); + } + + return tokenizer; + } + export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => ( applyTemplate(templateString, { messages, diff --git a/src/games/ai-story/messages.ts b/src/games/ai-story/tools/messages.ts similarity index 97% rename from src/games/ai-story/messages.ts rename to src/games/ai-story/tools/messages.ts index 19420e6..6d15636 100644 --- a/src/games/ai-story/messages.ts +++ b/src/games/ai-story/tools/messages.ts @@ -1,5 +1,4 @@ -import { Template } from "@huggingface/jinja"; -import messageSound from './assets/message.mp3'; +import messageSound from '../assets/message.mp3'; export interface ISwipe { content: string; diff --git a/src/games/ai-story/tools/model.ts b/src/games/ai-story/tools/model.ts new file mode 100644 index 0000000..1cb2a9b --- /dev/null +++ b/src/games/ai-story/tools/model.ts @@ -0,0 +1,27 @@ +export const normalizeModel = (model: string) => { + let currentModel = model.split(/[\\\/]/).at(-1); + currentModel = currentModel.split('::').at(0).toLowerCase(); + let normalizedModel: string; + + do { + normalizedModel = currentModel; + + currentModel = currentModel + .replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k + .replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name + .replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc + .replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size + .replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw + .replace(/[ ._-]f(p|loat)?(8|16|32)/i, '') + .replace(/^(debug-?)+/i, '') + .trim(); + } while (normalizedModel !== currentModel); + + return normalizedModel + .replace(/[ _-]+/ig, '-') + .replace(/\.{2,}/, '-') + .replace(/[ ._-]+$/ig, '') + .trim(); +} + +export const approximateTokens = (prompt: string): number => Math.round(prompt.length / 4); \ No newline at end of file