import { gguf } from '@huggingface/gguf'; import * as hub from '@huggingface/hub'; import { Template } from '@huggingface/jinja'; import { AutoTokenizer, PreTrainedTokenizer } from '@huggingface/transformers'; import { normalizeModel } from './model'; import { loadObject, saveObject } from './storage'; export namespace Huggingface { export interface ITemplateMessage { role: 'user' | 'assistant' | 'system'; content: string; } interface INumberParameter { type: 'number'; enum?: number[]; description?: string; } interface IStringParameter { type: 'string'; enum?: string[]; description?: string; } interface IArrayParameter { type: 'array'; description?: string; items: IParameter; } interface IObjectParameter { type: 'object'; description?: string; properties: Record; required?: string[]; } type IParameter = INumberParameter | IStringParameter | IArrayParameter | IObjectParameter; interface ITool { type: 'function', function: { name: string; description?: string; parameters?: IObjectParameter; } } export interface IFunction { name: string; description?: string; parameters?: Record; } interface TokenizerConfig { chat_template: string; bos_token?: string; eos_token?: string; } const TEMPLATE_CACHE_KEY = 'ai_game_template_cache'; const templateCache: Record = {}; loadObject(TEMPLATE_CACHE_KEY, {}).then(c => Object.assign(templateCache, c)); const compiledTemplates = new Map(); const tokenizerCache = new Map(); const hasField = (obj: unknown, field: T): obj is Record => ( obj != null && typeof obj === 'object' && (field in obj) ); const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => ( hasField(obj, 'chat_template') && (typeof obj.chat_template === 'string') && (!hasField(obj, 'eos_token') || !obj.eos_token || typeof obj.eos_token === 'string') && (!hasField(obj, 'bos_token') || !obj.bos_token || typeof obj.bos_token === 'string') ); const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise => { modelName = normalizeModel(modelName); console.log(`[huggingface] searching config for '${modelName}'`); const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] })); const models = hubModels.filter(m => { if (m.gated) return false; if (!normalizeModel(m.name).includes(modelName)) return false; return true; }).sort((a, b) => b.downloads - a.downloads); let tokenizerConfig: TokenizerConfig | null = null; let foundName = ''; for (const model of models) { const { config, name } = model; if (name.toLowerCase().endsWith('-gguf')) continue; if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) { tokenizerConfig = config.tokenizer_config; foundName = name; break; } try { console.log(`[huggingface] searching config in '${name}/tokenizer_config.json'`); const fileResponse = await hub.downloadFile({ repo: name, path: 'tokenizer_config.json', }); if (fileResponse) { const maybeConfig = JSON.parse(await fileResponse.text()); if (isTokenizerConfig(maybeConfig)) { tokenizerConfig = maybeConfig; foundName = `${name}/tokenizer_config.json`; break; } } } catch { } } if (!tokenizerConfig) { for (const model of models) { try { for await (const file of hub.listFiles({ repo: model.name, recursive: true })) { if (file.type !== 'file' || !file.path.endsWith('.gguf')) continue; try { console.log(`[huggingface] searching config in '${model.name}/${file.path}'`); const fileInfo = await hub.fileDownloadInfo({ repo: model.name, path: file.path }); if (fileInfo?.url) { const { metadata } = await gguf(fileInfo.url); if ('tokenizer.chat_template' in metadata) { const chat_template = metadata['tokenizer.chat_template']; const tokens = metadata['tokenizer.ggml.tokens']; const bos_token = tokens[metadata['tokenizer.ggml.bos_token_id']]; const eos_token = tokens[metadata['tokenizer.ggml.eos_token_id']]; const maybeConfig = { chat_template, bos_token, eos_token, } if (isTokenizerConfig(maybeConfig)) { tokenizerConfig = maybeConfig; foundName = `${model.name}/${file.path}`; break; } } else if ('tokenizer.ggml.model' in metadata) { break; // no reason to touch different quants } } } catch { } } } catch { } if (tokenizerConfig) { break; } } } if (tokenizerConfig) { console.log(`[huggingface] found config for '${modelName}' in '${foundName}'`); return { chat_template: tokenizerConfig.chat_template, eos_token: tokenizerConfig.eos_token, bos_token: tokenizerConfig.bos_token, }; } console.log(`[huggingface] not found config for '${modelName}'`); return null; }; function updateRequired(param: T): T { if ('items' in param) { updateRequired(param.items); } else if ('properties' in param) { for (const prop of Object.values(param.properties)) { updateRequired(prop); } param.required = Object.keys(param.properties); } return param; } const convertFunctionToTool = (fn: IFunction): ITool => ({ type: 'function', function: { name: fn.name, description: fn.description, parameters: updateRequired({ type: 'object', properties: fn.parameters ?? {}, }) } }) export const testToolCalls = (template: string): boolean => { const history: ITemplateMessage[] = [ { role: 'system', content: 'You are calculator.' }, { role: 'user', content: 'Calculate 2 + 2.' }, ]; const needle = '___AWOORWA_NEEDLE__'; const tools: IFunction[] = [{ name: 'add', description: 'Test function', parameters: { a: { type: 'number' }, b: { type: 'number' }, c: { type: 'array', items: { type: 'number' } }, d: { type: 'object', properties: { inside: { type: 'number', description: needle } } }, } }]; const text = applyChatTemplate(template, history, tools); return text.includes(needle); } export const findModelTemplate = async (modelName: string): Promise => { modelName = normalizeModel(modelName); if (!modelName) return ''; let template = templateCache[modelName] ?? null; if (template) { console.log(`[huggingface] found cached template for '${modelName}'`); } else { const config = await loadHuggingfaceTokenizerConfig(modelName); if (config?.chat_template?.trim()) { template = config.chat_template.trim() .replace(/raise_exception\(('[^')]+'|"[^")]+")\)/g, `''`) .replaceAll('eos_token', `'${config.eos_token ?? ''}'`) .replaceAll('bos_token', `''`) .replace(/\{\{ ?(''|"") ?\}\}/g, '') .replace(/\n'/g, `\\n'`) .replace(/\n"/g, `\\n"`) .replace(/'\s*\+\s*'/g, '') .replace(/"\s*\+\s*"/g, '') .replace(/\{%\s*else\s*%\}\{%\s*endif\s*%\}/gi, '{% endif %}') .replace(/\{%\s*elif[^}]+%\}\{%\s*endif\s*%\}/gi, '{% endif %}') .replace(/\{%\s*if[^}]+%\}\{%\s*endif\s*%\}/gi, ''); } } templateCache[modelName] = template; saveObject(TEMPLATE_CACHE_KEY, templateCache); return template; } export const findTokenizer = async (modelName: string): Promise => { modelName = normalizeModel(modelName); if (!modelName) return null; let tokenizer = tokenizerCache.get(modelName) ?? null; let foundName = ''; if (tokenizer) { return tokenizer; } else if (!tokenizerCache.has(modelName)) { console.log(`[huggingface] searching tokenizer for '${modelName}'`); const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName } })); const models = hubModels.filter(m => { if (m.gated) return false; if (m.name.toLowerCase().includes('gguf')) return false; if (!normalizeModel(m.name).includes(modelName)) return false; return true; }); for (const model of models) { const { name } = model; try { console.log(`[huggingface] searching tokenizer in '${name}'`); tokenizer = await AutoTokenizer.from_pretrained(name); foundName = name; break; } catch { } } } if (tokenizer) { tokenizerCache.set(modelName, tokenizer); console.log(`[huggingface] found tokenizer for '${modelName}' in '${foundName}'`); } else { console.log(`[huggingface] not found tokenizer for '${modelName}'`); } return tokenizer; } export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => ( applyTemplate(templateString, { messages, add_generation_prompt: true, tools: functions?.map(convertFunctionToTool), }) ); export const applyTemplate = (templateString: string, args: Record): string => { try { let template = compiledTemplates.get(templateString); if (!template) { template = new Template(templateString); compiledTemplates.set(templateString, template); } const result = template.render(args); return result; } catch (e) { console.error('[applyTemplate] error:', e); } return ''; } }