1
0
Fork 0
tsgames/src/games/ai-story/tools/huggingface.ts

328 lines
12 KiB
TypeScript

import { gguf } from '@huggingface/gguf';
import * as hub from '@huggingface/hub';
import { Template } from '@huggingface/jinja';
import { AutoTokenizer, PreTrainedTokenizer } from '@huggingface/transformers';
import { normalizeModel } from './model';
import { loadObject, saveObject } from './storage';
export namespace Huggingface {
export interface ITemplateMessage {
role: 'user' | 'assistant' | 'system';
content: string;
}
interface INumberParameter {
type: 'number';
enum?: number[];
description?: string;
}
interface IStringParameter {
type: 'string';
enum?: string[];
description?: string;
}
interface IArrayParameter {
type: 'array';
description?: string;
items: IParameter;
}
interface IObjectParameter {
type: 'object';
description?: string;
properties: Record<string, IParameter>;
required?: string[];
}
type IParameter = INumberParameter | IStringParameter | IArrayParameter | IObjectParameter;
interface ITool {
type: 'function',
function: {
name: string;
description?: string;
parameters?: IObjectParameter;
}
}
export interface IFunction {
name: string;
description?: string;
parameters?: Record<string, IParameter>;
}
interface TokenizerConfig {
chat_template: string;
bos_token?: string;
eos_token?: string;
}
const TEMPLATE_CACHE_KEY = 'ai_game_template_cache';
const templateCache: Record<string, string> = {};
loadObject(TEMPLATE_CACHE_KEY, {}).then(c => Object.assign(templateCache, c));
const compiledTemplates = new Map<string, Template>();
const tokenizerCache = new Map<string, PreTrainedTokenizer | null>();
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
obj != null && typeof obj === 'object' && (field in obj)
);
const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => (
hasField(obj, 'chat_template') && (typeof obj.chat_template === 'string')
&& (!hasField(obj, 'eos_token') || !obj.eos_token || typeof obj.eos_token === 'string')
&& (!hasField(obj, 'bos_token') || !obj.bos_token || typeof obj.bos_token === 'string')
);
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
modelName = normalizeModel(modelName);
console.log(`[huggingface] searching config for '${modelName}'`);
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
const models = hubModels.filter(m => {
if (m.gated) return false;
if (!normalizeModel(m.name).includes(modelName)) return false;
return true;
}).sort((a, b) => b.downloads - a.downloads);
let tokenizerConfig: TokenizerConfig | null = null;
let foundName = '';
for (const model of models) {
const { config, name } = model;
if (name.toLowerCase().endsWith('-gguf')) continue;
if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) {
tokenizerConfig = config.tokenizer_config;
foundName = name;
break;
}
try {
console.log(`[huggingface] searching config in '${name}/tokenizer_config.json'`);
const fileResponse = await hub.downloadFile({
repo: name,
path: 'tokenizer_config.json',
});
if (fileResponse) {
const maybeConfig = JSON.parse(await fileResponse.text());
if (isTokenizerConfig(maybeConfig)) {
tokenizerConfig = maybeConfig;
foundName = `${name}/tokenizer_config.json`;
break;
}
}
} catch { }
}
if (!tokenizerConfig) {
for (const model of models) {
try {
for await (const file of hub.listFiles({ repo: model.name, recursive: true })) {
if (file.type !== 'file' || !file.path.endsWith('.gguf')) continue;
try {
console.log(`[huggingface] searching config in '${model.name}/${file.path}'`);
const fileInfo = await hub.fileDownloadInfo({ repo: model.name, path: file.path });
if (fileInfo?.url) {
const { metadata } = await gguf(fileInfo.url);
if ('tokenizer.chat_template' in metadata) {
const chat_template = metadata['tokenizer.chat_template'];
const tokens = metadata['tokenizer.ggml.tokens'];
const bos_token = tokens[metadata['tokenizer.ggml.bos_token_id']];
const eos_token = tokens[metadata['tokenizer.ggml.eos_token_id']];
const maybeConfig = {
chat_template,
bos_token,
eos_token,
}
if (isTokenizerConfig(maybeConfig)) {
tokenizerConfig = maybeConfig;
foundName = `${model.name}/${file.path}`;
break;
}
} else if ('tokenizer.ggml.model' in metadata) {
break; // no reason to touch different quants
}
}
} catch { }
}
} catch { }
if (tokenizerConfig) {
break;
}
}
}
if (tokenizerConfig) {
console.log(`[huggingface] found config for '${modelName}' in '${foundName}'`);
return {
chat_template: tokenizerConfig.chat_template,
eos_token: tokenizerConfig.eos_token,
bos_token: tokenizerConfig.bos_token,
};
}
console.log(`[huggingface] not found config for '${modelName}'`);
return null;
};
function updateRequired<T extends IParameter>(param: T): T {
if ('items' in param) {
updateRequired(param.items);
} else if ('properties' in param) {
for (const prop of Object.values(param.properties)) {
updateRequired(prop);
}
param.required = Object.keys(param.properties);
}
return param;
}
const convertFunctionToTool = (fn: IFunction): ITool => ({
type: 'function',
function: {
name: fn.name,
description: fn.description,
parameters: updateRequired({
type: 'object',
properties: fn.parameters ?? {},
})
}
})
export const testToolCalls = (template: string): boolean => {
const history: ITemplateMessage[] = [
{ role: 'system', content: 'You are calculator.' },
{ role: 'user', content: 'Calculate 2 + 2.' },
];
const needle = '___AWOORWA_NEEDLE__';
const tools: IFunction[] = [{
name: 'add',
description: 'Test function',
parameters: {
a: { type: 'number' },
b: { type: 'number' },
c: { type: 'array', items: { type: 'number' } },
d: { type: 'object', properties: { inside: { type: 'number', description: needle } } },
}
}];
const text = applyChatTemplate(template, history, tools);
return text.includes(needle);
}
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
modelName = normalizeModel(modelName);
if (!modelName) return '';
let template = templateCache[modelName] ?? null;
if (template) {
console.log(`[huggingface] found cached template for '${modelName}'`);
} else {
const config = await loadHuggingfaceTokenizerConfig(modelName);
if (config?.chat_template?.trim()) {
template = config.chat_template.trim()
.replace(/raise_exception\(('[^')]+'|"[^")]+")\)/g, `''`)
.replaceAll('eos_token', `'${config.eos_token ?? ''}'`)
.replaceAll('bos_token', `''`)
.replace(/\{\{ ?(''|"") ?\}\}/g, '')
.replace(/\n'/g, `\\n'`)
.replace(/\n"/g, `\\n"`)
.replace(/'\s*\+\s*'/g, '')
.replace(/"\s*\+\s*"/g, '')
.replace(/\{%\s*else\s*%\}\{%\s*endif\s*%\}/gi, '{% endif %}')
.replace(/\{%\s*elif[^}]+%\}\{%\s*endif\s*%\}/gi, '{% endif %}')
.replace(/\{%\s*if[^}]+%\}\{%\s*endif\s*%\}/gi, '');
}
}
templateCache[modelName] = template;
saveObject(TEMPLATE_CACHE_KEY, templateCache);
return template;
}
export const findTokenizer = async (modelName: string): Promise<PreTrainedTokenizer | null> => {
modelName = normalizeModel(modelName);
if (!modelName) return null;
let tokenizer = tokenizerCache.get(modelName) ?? null;
let foundName = '';
if (tokenizer) {
return tokenizer;
} else if (!tokenizerCache.has(modelName)) {
console.log(`[huggingface] searching tokenizer for '${modelName}'`);
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName } }));
const models = hubModels.filter(m => {
if (m.gated) return false;
if (m.name.toLowerCase().includes('gguf')) return false;
if (!normalizeModel(m.name).includes(modelName)) return false;
return true;
});
for (const model of models) {
const { name } = model;
try {
console.log(`[huggingface] searching tokenizer in '${name}'`);
tokenizer = await AutoTokenizer.from_pretrained(name);
foundName = name;
break;
} catch { }
}
}
if (tokenizer) {
tokenizerCache.set(modelName, tokenizer);
console.log(`[huggingface] found tokenizer for '${modelName}' in '${foundName}'`);
} else {
console.log(`[huggingface] not found tokenizer for '${modelName}'`);
}
return tokenizer;
}
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => (
applyTemplate(templateString, {
messages,
add_generation_prompt: true,
tools: functions?.map(convertFunctionToTool),
})
);
export const applyTemplate = (templateString: string, args: Record<string, any>): string => {
try {
let template = compiledTemplates.get(templateString);
if (!template) {
template = new Template(templateString);
compiledTemplates.set(templateString, template);
}
const result = template.render(args);
return result;
} catch (e) {
console.error('[applyTemplate] error:', e);
}
return '';
}
}