1
0
Fork 0

AI story: load model template from HF & count tokens

This commit is contained in:
Pabloader 2024-11-03 13:06:06 +00:00
parent a4ce47d9d8
commit 5805469581
7 changed files with 235 additions and 62 deletions

BIN
bun.lockb

Binary file not shown.

View File

@ -8,6 +8,7 @@
"bake": "bun build/build.ts"
},
"dependencies": {
"@huggingface/hub": "0.19.0",
"@huggingface/jinja": "0.3.1",
"@inquirer/select": "2.3.10",
"ace-builds": "1.36.3",

View File

@ -26,6 +26,8 @@ export const Ace = ({ value, onInput }: IAceProps) => {
displayIndentGuides: false,
fontSize: 16,
maxLines: Infinity,
tabSize: 2,
useSoftTabs: true,
wrap: "free",
});
return e;

View File

@ -23,6 +23,11 @@
}
}
.info {
margin: 0 8px;
line-height: 36px;
}
.buttons {
display: flex;
flex-direction: row;

View File

@ -11,38 +11,25 @@ import styles from './header.module.css';
import { Ace } from "../ace";
export const Header = () => {
const { getContextLength } = useContext(LLMContext);
const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext);
const {
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
} = useContext(StateContext);
const [urlValid, setUrlValid] = useState(false);
const [urlEditing, setUrlEditing] = useState(false);
const loreOpen = useBool();
const promptsOpen = useBool();
const assistantOpen = useBool();
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
const handleFocusUrl = useCallback(() => setUrlEditing(true), []);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i
const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2');
console.log({ connectionUrl, normalizedConnectionUrl })
setConnectionUrl(normalizedConnectionUrl);
setUrlEditing(false);
setUrlValid(false);
}, [connectionUrl, setConnectionUrl]);
useEffect(() => {
if (!urlEditing) {
getContextLength().then(length => {
setUrlValid(length > 0);
});
}
}, [connectionUrl, urlEditing]);
blockConnection.setFalse();
}, [connectionUrl, setConnectionUrl, blockConnection]);
const handleAssistantAddSwipe = useCallback((answer: string) => {
const index = messages.findLastIndex(m => m.role === 'assistant');
@ -69,17 +56,21 @@ export const Header = () => {
<div class={styles.inputs}>
<input value={connectionUrl}
onInput={setConnectionUrl}
onFocus={handleFocusUrl}
onFocus={blockConnection.setTrue}
onBlur={handleBlurUrl}
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid}
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid}
/>
<select value={instruct} onChange={setInstruct}>
{modelName && modelTemplate && <option value={modelTemplate}>{modelName}</option>}
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</select>
<div class={styles.info}>
{promptTokens} / {contextLength}
</div>
</div>
<div class={styles.buttons}>
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>

View File

@ -1,11 +1,12 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { createContext } from "preact";
import { useContext, useEffect, useMemo } from "preact/hooks";
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages";
import { StateContext } from "./state";
import { useBool } from "@common/hooks/useBool";
import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface";
interface ITemplateMessage {
@ -25,6 +26,11 @@ interface ICompiledPrompt {
interface IContext {
generating: boolean;
blockConnection: ReturnType<typeof useBool>;
modelName: string;
modelTemplate: string;
promptTokens: number;
contextLength: number;
}
const DEFAULT_GENERATION_SETTINGS = {
@ -44,23 +50,65 @@ const DEFAULT_GENERATION_SETTINGS = {
type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
interface IActions {
applyChatTemplate: (messages: ITemplateMessage[], templateString: string, eosToken?: string) => string;
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
countTokens(prompt: string): Promise<number>;
getContextLength(): Promise<number>;
countTokens: (prompt: string) => Promise<number>;
}
export type ILLMContext = IContext & IActions;
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d(\d*k|\d+)(-context|$)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const applyChatTemplate = (messages: ITemplateMessage[], templateString: string) => {
const template = new Template(templateString);
console.log(`Applying template:\n${templateString}`, messages);
const prompt = template.render({
messages,
add_generation_prompt: true,
});
return prompt;
};
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
export const LLMContextProvider = ({ children }: { children?: any }) => {
const {
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
setTriggerNext, addMessage, editMessage,
setTriggerNext, addMessage, editMessage, setInstruct,
} = useContext(StateContext);
const generating = useBool(false);
const blockConnection = useBool(false);
const [promptTokens, setPromptTokens] = useState(0);
const [contextLength, setContextLength] = useState(0);
const [modelName, setModelName] = useState('');
const [modelTemplate, setModelTemplate] = useState('');
const userPromptTemplate = useMemo(() => {
try {
@ -72,17 +120,41 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}
}, [userPrompt]);
const getContextLength = useCallback(async () => {
if (!connectionUrl || blockConnection.value) {
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return 0;
}, [connectionUrl, blockConnection.value]);
const getModelName = useCallback(async () => {
if (!connectionUrl || blockConnection.value) {
return '';
}
try {
const response = await fetch(`${connectionUrl}/api/v1/model`);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return '';
}, [connectionUrl, blockConnection.value]);
const actions: IActions = useMemo(() => ({
applyChatTemplate: (messages: ITemplateMessage[], templateString: string) => {
const template = new Template(templateString);
const prompt = template.render({
messages,
add_generation_prompt: true,
});
return prompt;
},
compilePrompt: async (messages, { keepUsers } = {}) => {
const promptMessages = messages.slice();
const lastMessage = promptMessages.at(-1);
@ -128,15 +200,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
templateMessages.push({ role, content });
}
}
if (templateMessages[1]?.role !== 'user') {
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
templateMessages.splice(1, 0, {
role: 'user',
content: userPromptTemplate.render({ prompt, isStart: true }),
});
}
} else {
const story = promptMessages.filter(m => m.role === 'assistant')
.map(m => MessageTools.getSwipe(m)?.content.trim()).join('\n\n');
@ -157,7 +220,16 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}
}
const prompt = actions.applyChatTemplate(templateMessages, instruct);
if (templateMessages[1]?.role !== 'user') {
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
templateMessages.splice(1, 0, {
role: 'user',
content: userPromptTemplate.render({ prompt, isStart: true }),
});
}
const prompt = applyChatTemplate(templateMessages, instruct);
return {
prompt,
isContinue,
@ -248,22 +320,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
console.log('Error counting tokens', e);
}
return 0;
},
getContextLength: async () => {
if (!connectionUrl) {
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return 0;
},
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
@ -276,6 +332,8 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
let text: string = '';
const { prompt, isRegen } = await actions.compilePrompt(messages);
const tokens = await actions.countTokens(prompt);
setPromptTokens(tokens);
if (!isRegen) {
addMessage('', 'assistant');
@ -284,18 +342,66 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
for await (const chunk of actions.generate(prompt)) {
text += chunk;
setPromptTokens(tokens + 1);
editMessage(messageId, text);
}
text = MessageTools.trimSentence(text);
editMessage(messageId, text);
const generatedTokens = await actions.countTokens(text);
setPromptTokens(tokens + generatedTokens);
MessageTools.playReady();
}
})(), [triggerNext, messages, generating.value]);
})(), [actions, triggerNext, messages, generating.value]);
useEffect(() => {
if (!blockConnection.value) {
setPromptTokens(0);
setContextLength(0);
getContextLength().then(setContextLength);
}
}, [connectionUrl, instruct, blockConnection.value]);
useEffect(() => {
if (!blockConnection.value) {
setModelName('');
getModelName().then(normalizeModel).then(setModelName);
}
}, [connectionUrl, blockConnection.value]);
useEffect(() => {
setModelTemplate('');
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then((template) => {
if (template) {
setModelTemplate(template);
setInstruct(template);
}
});
}
}, [modelName]);
useEffect(() => {
if (promptTokens === 0 && !blockConnection.value) {
actions.compilePrompt(messages)
.then(({ prompt }) => actions.countTokens(prompt))
.then(setPromptTokens)
.catch(e => console.error(`Could not count tokens`, e));
}
}, [actions, promptTokens, messages, blockConnection.value]);
const rawContext: IContext = {
generating: generating.value,
blockConnection,
modelName,
modelTemplate,
promptTokens,
contextLength,
};
const context = useMemo(() => rawContext, Object.values(rawContext));

View File

@ -0,0 +1,68 @@
import * as hub from '@huggingface/hub';
export namespace Huggingface {
interface TokenizerConfig {
chat_template: string;
eos_token: string;
}
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
obj != null && typeof obj === 'object' && (field in obj)
);
const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => (
hasField(obj, 'chat_template') && typeof obj.chat_template === 'string'
&& hasField(obj, 'eos_token') && typeof obj.eos_token === 'string'
);
const loadHuggingfaceTokenizerConfig = async (model: string): Promise<TokenizerConfig | null> => {
console.log(`Searching for model '${model}'`);
const models = hub.listModels({ search: { query: model }, additionalFields: ['config'] });
let tokenizerConfig: TokenizerConfig | null = null;
for await (const model of models) {
const { config } = model;
if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) {
tokenizerConfig = config.tokenizer_config;
break;
}
try {
const fileResponse = await hub.downloadFile({ repo: model.name, path: 'tokenizer_config.json' });
if (fileResponse?.ok) {
const maybeConfig = await fileResponse.json();
if (isTokenizerConfig(maybeConfig)) {
tokenizerConfig = maybeConfig;
break;
}
}
} catch { }
}
if (tokenizerConfig) {
console.log(`Huggingface config for '${model}' found.`);
return {
chat_template: tokenizerConfig.chat_template,
eos_token: tokenizerConfig.eos_token,
};
}
console.log(`Huggingface config for '${model}' not found.`);
return null;
};
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
const config = await loadHuggingfaceTokenizerConfig(modelName);
if (config?.chat_template?.trim()) {
const template = config.chat_template.trim()
.replace('eos_token', `'${config.eos_token}'`)
.replace('bos_token', `''`);
return template;
}
return null;
}
}