AI story: tool calling check
This commit is contained in:
parent
5805469581
commit
25c3f5dc25
|
|
@ -8,6 +8,7 @@
|
|||
"bake": "bun build/build.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@huggingface/gguf": "0.1.12",
|
||||
"@huggingface/hub": "0.19.0",
|
||||
"@huggingface/jinja": "0.3.1",
|
||||
"@inquirer/select": "2.3.10",
|
||||
|
|
|
|||
|
|
@ -60,8 +60,8 @@ export const Header = () => {
|
|||
onBlur={handleBlurUrl}
|
||||
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid}
|
||||
/>
|
||||
<select value={instruct} onChange={setInstruct}>
|
||||
{modelName && modelTemplate && <option value={modelTemplate}>{modelName}</option>}
|
||||
<select value={instruct} onChange={setInstruct} title='Instruct template'>
|
||||
{modelName && modelTemplate && <option value={modelTemplate} title='Native for model'>{modelName}</option>}
|
||||
{Object.entries(Instruct).map(([label, value]) => (
|
||||
<option value={value} key={value}>
|
||||
{label.toLowerCase()}
|
||||
|
|
|
|||
|
|
@ -3,17 +3,11 @@ import SSE from "@common/sse";
|
|||
import { createContext } from "preact";
|
||||
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
||||
import { MessageTools, type IMessage } from "../messages";
|
||||
import { StateContext } from "./state";
|
||||
import { Instruct, StateContext } from "./state";
|
||||
import { useBool } from "@common/hooks/useBool";
|
||||
import { Template } from "@huggingface/jinja";
|
||||
import { Huggingface } from "../huggingface";
|
||||
|
||||
|
||||
interface ITemplateMessage {
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface ICompileArgs {
|
||||
keepUsers?: number;
|
||||
}
|
||||
|
|
@ -29,6 +23,7 @@ interface IContext {
|
|||
blockConnection: ReturnType<typeof useBool>;
|
||||
modelName: string;
|
||||
modelTemplate: string;
|
||||
hasToolCalls: boolean;
|
||||
promptTokens: number;
|
||||
contextLength: number;
|
||||
}
|
||||
|
|
@ -65,7 +60,7 @@ export const normalizeModel = (model: string) => {
|
|||
normalizedModel = currentModel;
|
||||
|
||||
currentModel = currentModel
|
||||
.replace(/[ ._-]\d(\d*k|\d+)(-context|$)/i, '') // remove context length, i.e. -32k
|
||||
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
||||
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
|
||||
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
||||
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
||||
|
|
@ -82,19 +77,6 @@ export const normalizeModel = (model: string) => {
|
|||
.trim();
|
||||
}
|
||||
|
||||
export const applyChatTemplate = (messages: ITemplateMessage[], templateString: string) => {
|
||||
const template = new Template(templateString);
|
||||
|
||||
console.log(`Applying template:\n${templateString}`, messages);
|
||||
|
||||
const prompt = template.render({
|
||||
messages,
|
||||
add_generation_prompt: true,
|
||||
});
|
||||
|
||||
return prompt;
|
||||
};
|
||||
|
||||
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
||||
|
||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||
|
|
@ -109,6 +91,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
const [contextLength, setContextLength] = useState(0);
|
||||
const [modelName, setModelName] = useState('');
|
||||
const [modelTemplate, setModelTemplate] = useState('');
|
||||
const [hasToolCalls, setHasToolCalls] = useState(false);
|
||||
|
||||
const userPromptTemplate = useMemo(() => {
|
||||
try {
|
||||
|
|
@ -162,17 +145,17 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
|
||||
const isContinue = isAssistantLast && !isRegen;
|
||||
|
||||
const userMessages = promptMessages.filter(m => m.role === 'user');
|
||||
const lastUserMessage = userMessages.at(-1);
|
||||
const firstUserMessage = userMessages.at(0);
|
||||
|
||||
if (isContinue) {
|
||||
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
|
||||
}
|
||||
|
||||
const userMessages = promptMessages.filter(m => m.role === 'user');
|
||||
const lastUserMessage = userMessages.at(-1);
|
||||
const firstUserMessage = userMessages.at(0);
|
||||
|
||||
const system = `${systemPrompt}\n\n${lore}`.trim();
|
||||
|
||||
const templateMessages: ITemplateMessage[] = [
|
||||
const templateMessages: Huggingface.ITemplateMessage[] = [
|
||||
{ role: 'system', content: system },
|
||||
];
|
||||
|
||||
|
|
@ -229,7 +212,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
});
|
||||
}
|
||||
|
||||
const prompt = applyChatTemplate(templateMessages, instruct);
|
||||
const prompt = Huggingface.applyChatTemplate(instruct, templateMessages);
|
||||
return {
|
||||
prompt,
|
||||
isContinue,
|
||||
|
|
@ -342,16 +325,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
|
||||
for await (const chunk of actions.generate(prompt)) {
|
||||
text += chunk;
|
||||
setPromptTokens(tokens + 1);
|
||||
setPromptTokens(tokens + Math.round(text.length * 0.25));
|
||||
editMessage(messageId, text);
|
||||
}
|
||||
|
||||
text = MessageTools.trimSentence(text);
|
||||
editMessage(messageId, text);
|
||||
|
||||
const generatedTokens = await actions.countTokens(text);
|
||||
|
||||
setPromptTokens(tokens + generatedTokens);
|
||||
setPromptTokens(0); // trigger calculation
|
||||
|
||||
MessageTools.playReady();
|
||||
}
|
||||
|
|
@ -381,6 +362,8 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
if (template) {
|
||||
setModelTemplate(template);
|
||||
setInstruct(template);
|
||||
} else {
|
||||
setInstruct(Instruct.CHATML);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -395,11 +378,21 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
}
|
||||
}, [actions, promptTokens, messages, blockConnection.value]);
|
||||
|
||||
useEffect(() => {
|
||||
try {
|
||||
const hasTools = Huggingface.testToolCalls(instruct);
|
||||
setHasToolCalls(hasTools);
|
||||
} catch {
|
||||
setHasToolCalls(false);
|
||||
}
|
||||
}, [instruct]);
|
||||
|
||||
const rawContext: IContext = {
|
||||
generating: generating.value,
|
||||
blockConnection,
|
||||
modelName,
|
||||
modelTemplate,
|
||||
hasToolCalls,
|
||||
promptTokens,
|
||||
contextLength,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -38,12 +38,12 @@ interface IActions {
|
|||
const SAVE_KEY = 'ai_game_save_state';
|
||||
|
||||
export enum Instruct {
|
||||
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
|
||||
|
||||
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`,
|
||||
|
||||
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
||||
|
||||
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
|
||||
|
||||
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`,
|
||||
};
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ export const loadContext = (): IContext => {
|
|||
const defaultContext: IContext = {
|
||||
connectionUrl: 'http://localhost:5001',
|
||||
input: '',
|
||||
instruct: Instruct.LLAMA,
|
||||
instruct: Instruct.CHATML,
|
||||
systemPrompt: 'You are creative writer. Write a story based on the world description below.',
|
||||
lore: '',
|
||||
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
|
||||
|
|
|
|||
|
|
@ -1,27 +1,104 @@
|
|||
import { gguf } from '@huggingface/gguf';
|
||||
import * as hub from '@huggingface/hub';
|
||||
import { Template } from '@huggingface/jinja';
|
||||
|
||||
export namespace Huggingface {
|
||||
export interface ITemplateMessage {
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface INumberParameter {
|
||||
type: 'number';
|
||||
enum?: number[];
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface IStringParameter {
|
||||
type: 'string';
|
||||
enum?: string[];
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface IArrayParameter {
|
||||
type: 'array';
|
||||
description?: string;
|
||||
items: IParameter;
|
||||
}
|
||||
|
||||
interface IObjectParameter {
|
||||
type: 'object';
|
||||
description?: string;
|
||||
properties: Record<string, IParameter>;
|
||||
required?: string[];
|
||||
}
|
||||
|
||||
type IParameter = INumberParameter | IStringParameter | IArrayParameter | IObjectParameter;
|
||||
|
||||
interface ITool {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: IObjectParameter;
|
||||
}
|
||||
}
|
||||
|
||||
export interface IFunction {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: Record<string, IParameter>;
|
||||
}
|
||||
|
||||
interface TokenizerConfig {
|
||||
chat_template: string;
|
||||
eos_token: string;
|
||||
bos_token?: string;
|
||||
eos_token?: string;
|
||||
}
|
||||
|
||||
const TEMPLATE_CACHE_KEY = 'ai_game_template_cache';
|
||||
|
||||
const loadCache = (): Record<string, string> => {
|
||||
const json = localStorage.getItem(TEMPLATE_CACHE_KEY);
|
||||
|
||||
try {
|
||||
if (json) {
|
||||
const cache = JSON.parse(json);
|
||||
if (cache && typeof cache === 'object') {
|
||||
return cache
|
||||
}
|
||||
}
|
||||
} catch { }
|
||||
|
||||
return {};
|
||||
};
|
||||
|
||||
const saveCache = (cache: Record<string, string>) => {
|
||||
const json = JSON.stringify(cache);
|
||||
localStorage.setItem(TEMPLATE_CACHE_KEY, json);
|
||||
};
|
||||
|
||||
const templateCache: Record<string, string> = loadCache();
|
||||
|
||||
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
|
||||
obj != null && typeof obj === 'object' && (field in obj)
|
||||
);
|
||||
const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => (
|
||||
hasField(obj, 'chat_template') && typeof obj.chat_template === 'string'
|
||||
&& hasField(obj, 'eos_token') && typeof obj.eos_token === 'string'
|
||||
&& (!hasField(obj, 'eos_token') || !obj.eos_token || typeof obj.eos_token === 'string')
|
||||
&& (!hasField(obj, 'bos_token') || !obj.bos_token || typeof obj.bos_token === 'string')
|
||||
);
|
||||
|
||||
const loadHuggingfaceTokenizerConfig = async (model: string): Promise<TokenizerConfig | null> => {
|
||||
console.log(`Searching for model '${model}'`);
|
||||
console.log(`[huggingface] searching config for '${model}'`);
|
||||
|
||||
const models = hub.listModels({ search: { query: model }, additionalFields: ['config'] });
|
||||
const recheckModels: hub.ModelEntry[] = [];
|
||||
|
||||
let tokenizerConfig: TokenizerConfig | null = null;
|
||||
|
||||
for await (const model of models) {
|
||||
recheckModels.push(model);
|
||||
const { config } = model;
|
||||
if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) {
|
||||
tokenizerConfig = config.tokenizer_config;
|
||||
|
|
@ -29,6 +106,7 @@ export namespace Huggingface {
|
|||
}
|
||||
|
||||
try {
|
||||
console.log(`[huggingface] searching config in '${model.name}/tokenizer_config.json'`);
|
||||
const fileResponse = await hub.downloadFile({ repo: model.name, path: 'tokenizer_config.json' });
|
||||
if (fileResponse?.ok) {
|
||||
const maybeConfig = await fileResponse.json();
|
||||
|
|
@ -38,31 +116,149 @@ export namespace Huggingface {
|
|||
}
|
||||
}
|
||||
} catch { }
|
||||
|
||||
}
|
||||
if (!tokenizerConfig) {
|
||||
for (const model of recheckModels.slice(0, 10)) {
|
||||
try {
|
||||
for await (const file of hub.listFiles({ repo: model.name, recursive: true })) {
|
||||
if (file.type !== 'file' || !file.path.endsWith('.gguf')) continue;
|
||||
try {
|
||||
console.log(`[huggingface] searching config in '${model.name}/${file.path}'`);
|
||||
const fileInfo = await hub.fileDownloadInfo({ repo: model.name, path: file.path });
|
||||
if (fileInfo?.downloadLink) {
|
||||
const { metadata } = await gguf(fileInfo.downloadLink);
|
||||
if ('tokenizer.chat_template' in metadata) {
|
||||
const chat_template = metadata['tokenizer.chat_template'];
|
||||
const tokens = metadata['tokenizer.ggml.tokens'];
|
||||
const bos_token = tokens[metadata['tokenizer.ggml.bos_token_id']];
|
||||
const eos_token = tokens[metadata['tokenizer.ggml.eos_token_id']];
|
||||
|
||||
const maybeConfig = {
|
||||
chat_template,
|
||||
bos_token,
|
||||
eos_token,
|
||||
}
|
||||
|
||||
if (isTokenizerConfig(maybeConfig)) {
|
||||
tokenizerConfig = maybeConfig;
|
||||
break;
|
||||
}
|
||||
} else if ('tokenizer.ggml.model' in metadata) {
|
||||
break; // no reason to touch different quants
|
||||
}
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
} catch { }
|
||||
|
||||
if (tokenizerConfig) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tokenizerConfig) {
|
||||
console.log(`Huggingface config for '${model}' found.`);
|
||||
console.log(`[huggingface] found config for '${model}'`);
|
||||
return {
|
||||
chat_template: tokenizerConfig.chat_template,
|
||||
eos_token: tokenizerConfig.eos_token,
|
||||
bos_token: tokenizerConfig.bos_token,
|
||||
};
|
||||
}
|
||||
|
||||
console.log(`Huggingface config for '${model}' not found.`);
|
||||
console.log(`[huggingface] not found config for '${model}'`);
|
||||
return null;
|
||||
};
|
||||
|
||||
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
||||
const config = await loadHuggingfaceTokenizerConfig(modelName);
|
||||
|
||||
if (config?.chat_template?.trim()) {
|
||||
const template = config.chat_template.trim()
|
||||
.replace('eos_token', `'${config.eos_token}'`)
|
||||
.replace('bos_token', `''`);
|
||||
|
||||
return template;
|
||||
function updateRequired<T extends IParameter>(param: T): T {
|
||||
if ('items' in param) {
|
||||
updateRequired(param.items);
|
||||
} else if ('properties' in param) {
|
||||
for (const prop of Object.values(param.properties)) {
|
||||
updateRequired(prop);
|
||||
}
|
||||
param.required = Object.keys(param.properties);
|
||||
}
|
||||
|
||||
return null;
|
||||
return param;
|
||||
}
|
||||
|
||||
const convertFunctionToTool = (fn: IFunction): ITool => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: fn.name,
|
||||
description: fn.description,
|
||||
parameters: updateRequired({
|
||||
type: 'object',
|
||||
properties: fn.parameters ?? {},
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
export const testToolCalls = (template: string): boolean => {
|
||||
const history: ITemplateMessage[] = [
|
||||
{ role: 'system', content: 'You are calculator.' },
|
||||
{ role: 'user', content: 'Calculate 2 + 2.' },
|
||||
];
|
||||
|
||||
const needle = '___AWOORWA_NEEDLE__';
|
||||
|
||||
const tools: IFunction[] = [{
|
||||
name: 'add',
|
||||
description: 'Test function',
|
||||
parameters: {
|
||||
a: { type: 'number' },
|
||||
b: { type: 'number' },
|
||||
c: { type: 'array', items: { type: 'number' } },
|
||||
d: { type: 'object', properties: { inside: { type: 'number', description: needle } } },
|
||||
}
|
||||
}];
|
||||
|
||||
const text = applyChatTemplate(template, history, tools);
|
||||
|
||||
console.log(text);
|
||||
|
||||
return text.includes(needle);
|
||||
}
|
||||
|
||||
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
||||
let template = templateCache[modelName] ?? null;
|
||||
|
||||
if (template) {
|
||||
console.log(`[huggingface] found cached template for '${modelName}'`);
|
||||
} else {
|
||||
const config = await loadHuggingfaceTokenizerConfig(modelName);
|
||||
|
||||
if (config?.chat_template?.trim()) {
|
||||
template = config.chat_template.trim()
|
||||
.replaceAll('eos_token', `'${config.eos_token ?? ''}'`)
|
||||
.replaceAll('bos_token', `''`);
|
||||
|
||||
if (config.bos_token) {
|
||||
template = template
|
||||
.replaceAll(config.bos_token, '')
|
||||
.replace(/\{\{ ?(''|"") ?\}\}/g, '');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
templateCache[modelName] = template;
|
||||
saveCache(templateCache);
|
||||
|
||||
return template;
|
||||
}
|
||||
|
||||
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => {
|
||||
const template = new Template(templateString);
|
||||
|
||||
const prompt = template.render({
|
||||
messages,
|
||||
add_generation_prompt: true,
|
||||
tools: functions?.map(convertFunctionToTool),
|
||||
});
|
||||
|
||||
return prompt;
|
||||
};
|
||||
}
|
||||
Loading…
Reference in New Issue