1
0
Fork 0
tsgames/src/games/ai-story/tools/huggingface.ts

448 lines
17 KiB
TypeScript

import { gguf } from '@huggingface/gguf';
import * as hub from '@huggingface/hub';
import { Template } from '@huggingface/jinja';
import { Tokenizer } from '@huggingface/tokenizers';
import { normalizeModel } from './model';
import { loadObject, saveObject } from './storage';
export namespace Huggingface {
export interface ITemplateMessage {
role: 'user' | 'assistant' | 'system';
content: string;
}
interface INumberParameter {
type: 'number';
enum?: number[];
description?: string;
}
interface IStringParameter {
type: 'string';
enum?: string[];
description?: string;
}
interface IArrayParameter {
type: 'array';
description?: string;
items: IParameter;
}
interface IObjectParameter {
type: 'object';
description?: string;
properties: Record<string, IParameter>;
required?: string[];
}
type IParameter = INumberParameter | IStringParameter | IArrayParameter | IObjectParameter;
interface ITool {
type: 'function',
function: {
name: string;
description?: string;
parameters?: IObjectParameter;
}
}
export interface IFunction {
name: string;
description?: string;
parameters?: Record<string, IParameter>;
}
interface TokenizerConfig {
chat_template: string;
bos_token?: string;
eos_token?: string;
}
type TokenizerJson = any;
type TokenizerInfo = [TokenizerConfig | null, TokenizerJson | null];
const TEMPLATE_CACHE_KEY = 'ai_game_template_cache';
const TOKENIZER_CACHE_KEY = 'ai_game_tokenizer_cache';
const templateCache: Record<string, string> = {};
const tokenizerCache: Record<string, TokenizerInfo> = {};
const prevLoading: Promise<unknown> = Promise.all([
loadObject(TEMPLATE_CACHE_KEY, {}).then(c => Object.assign(templateCache, c)),
loadObject(TOKENIZER_CACHE_KEY, {}).then(c => Object.assign(tokenizerCache, c)),
]);
const compiledTemplates = new Map<string, Template>();
const compiledTokenizers = new Map<string, Tokenizer | null>();
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
obj != null && typeof obj === 'object' && (field in obj)
);
const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => (
hasField(obj, 'chat_template') && (typeof obj.chat_template === 'string')
&& (!hasField(obj, 'eos_token') || !obj.eos_token || typeof obj.eos_token === 'string')
&& (!hasField(obj, 'bos_token') || !obj.bos_token || typeof obj.bos_token === 'string')
);
const loadHuggingfaceTokenizer = async (modelName: string, configOnly = false): Promise<TokenizerInfo> => {
await prevLoading;
modelName = normalizeModel(modelName);
console.log(`[huggingface] searching config for '${modelName}'`);
const cachedConfig = tokenizerCache[modelName];
if (cachedConfig && cachedConfig[0] != null && cachedConfig[1] != null) {
console.log(`[huggingface] found cached config for '${modelName}'`);
return cachedConfig;
}
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
const models = hubModels.filter(m => {
if (m.gated) return false;
if (!normalizeModel(m.name).includes(modelName)) return false;
return true;
}).sort((a, b) => b.downloads - a.downloads);
let tokenizerConfig: TokenizerConfig | null = null;
let tokenizerJson: TokenizerJson | null = null;
for (const model of models) {
const { config, name } = model;
if (name.toLowerCase().includes('gguf')) continue;
if (!tokenizerJson && !configOnly) {
try {
console.log(`[huggingface] searching tokenizer in '${name}/tokenizer.json'`);
const fileResponse = await hub.downloadFile({
repo: name,
path: 'tokenizer.json',
});
if (fileResponse) {
tokenizerJson = JSON.parse(await fileResponse.text());
console.log(`[huggingface] found tokenizer in '${name}/tokenizer.json'`);
}
} catch { }
}
if (!tokenizerConfig) {
if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) {
tokenizerConfig = config.tokenizer_config;
console.log(`[huggingface] found config for '${modelName}' in '${name}'`);
}
}
if (!tokenizerConfig) {
try {
console.log(`[huggingface] searching config in '${name}/tokenizer_config.json'`);
const fileResponse = await hub.downloadFile({
repo: name,
path: 'tokenizer_config.json',
});
if (fileResponse) {
const maybeConfig = JSON.parse(await fileResponse.text());
if (!hasField(maybeConfig, 'chat_template') || !maybeConfig.chat_template) {
console.log(`[huggingface] searching template in '${name}/chat_template.jinja'`);
const templateResponse = await hub.downloadFile({
repo: name,
path: 'chat_template.jinja',
}).catch(() => null);
if (templateResponse) {
const template = await templateResponse.text().catch(() => null);
if (template) {
maybeConfig.chat_template = template;
}
}
}
if (isTokenizerConfig(maybeConfig)) {
tokenizerConfig = maybeConfig;
console.log(`[huggingface] found config for '${modelName}' in '${name}/tokenizer_config.json'`);
break;
}
}
} catch { }
}
}
if (!tokenizerConfig) {
for (const model of models) {
try {
for await (const file of hub.listFiles({ repo: model.name, recursive: true })) {
if (file.type !== 'file' || !file.path.endsWith('.gguf')) continue;
try {
console.log(`[huggingface] searching config in '${model.name}/${file.path}'`);
const fileInfo = await hub.fileDownloadInfo({ repo: model.name, path: file.path });
if (fileInfo?.url) {
const { metadata } = await gguf(fileInfo.url);
if ('tokenizer.chat_template' in metadata) {
const chat_template = metadata['tokenizer.chat_template'];
const tokens = metadata['tokenizer.ggml.tokens'];
const bos_token = tokens[metadata['tokenizer.ggml.bos_token_id']];
const eos_token = tokens[metadata['tokenizer.ggml.eos_token_id']];
const maybeConfig = {
chat_template,
bos_token,
eos_token,
}
if (isTokenizerConfig(maybeConfig)) {
tokenizerConfig = maybeConfig;
console.log(`[huggingface] found config for '${modelName}' in '${model.name}/${file.path}'`);
break;
}
} else if ('tokenizer.ggml.model' in metadata) {
break; // no reason to touch different quants
}
}
} catch { }
}
} catch { }
if (tokenizerConfig) {
break;
}
}
}
if (tokenizerConfig) {
if (tokenizerConfig.chat_template) {
tokenizerConfig.chat_template = formatTemplate(tokenizerConfig.chat_template, tokenizerConfig);
}
const info: TokenizerInfo = [tokenizerConfig, tokenizerJson];
if (!configOnly) {
tokenizerCache[modelName] = info;
saveObject(TOKENIZER_CACHE_KEY, tokenizerCache);
}
return info;
}
console.log(`[huggingface] not found config for '${modelName}'`);
return [null, null];
};
function updateRequired<T extends IParameter>(param: T): T {
if ('items' in param) {
updateRequired(param.items);
} else if ('properties' in param) {
for (const prop of Object.values(param.properties)) {
updateRequired(prop);
}
param.required = Object.keys(param.properties);
}
return param;
}
const convertFunctionToTool = (fn: IFunction): ITool => ({
type: 'function',
function: {
name: fn.name,
description: fn.description,
parameters: updateRequired({
type: 'object',
properties: fn.parameters ?? {},
})
}
})
export const testToolCalls = (template: string): boolean => {
const history: ITemplateMessage[] = [
{ role: 'system', content: 'You are calculator.' },
{ role: 'user', content: 'Calculate 2 + 2.' },
];
const needle = '___AWOORWA_NEEDLE__';
const tools: IFunction[] = [{
name: 'add',
description: 'Test function',
parameters: {
a: { type: 'number' },
b: { type: 'number' },
c: { type: 'array', items: { type: 'number' } },
d: { type: 'object', properties: { inside: { type: 'number', description: needle } } },
}
}];
const text = applyChatTemplate(template, history, tools);
return text.includes(needle);
}
export const minifyTemplate = (input: string, config?: TokenizerConfig) => {
let minified = input;
do {
input = minified;
minified = input.replace(/raise_exception\(('[^')]+'|"[^")]+")\)/g, `''`)
.replace(/(['"])\s*\+\s*bos_token/gi, `$1`)
.replace(/bos_token\s*\+\s*(['"])/gi, `$1`)
.replace(/(['"])\s*\+\s*eos_token/gi, `${config?.eos_token?.replace('$', '$$') ?? ''}$1`)
.replace(/eos_token\s*\+\s*(['"])/gi, `$1${config?.eos_token?.replace('$', '$$') ?? ''}`)
.replace(/\{#-?[^#]+-?#}/gi, '')
.replace(/\s*(\{[{%])-/gi, '$1')
.replace(/-([}%]\})\s*/gi, '$1')
.replace(/\{\{\s*(''|"")\s*\}\}/g, '')
.replace(/\s*\}\}\{\{\s*/, ' + ')
.replace(/\n+['"]/g, (match) => match.replace(/\n/gi, '\\n'))
.replace(/'\s*\+\s*'/g, '')
.replace(/"\s*\+\s*"/g, '')
.replace(/\{%\s*else\s*%\}\{%\s*endif\s*%\}/gi, '{% endif %}')
.replace(/\{%\s*elif[^}]+%\}\{%\s*endif\s*%\}/gi, '{% endif %}')
.replace(/\{%\s*if[^}]+%\}\{%\s*endif\s*%\}/gi, '')
.replaceAll('bos_token', `''`)
.replaceAll('eos_token', `'${config?.eos_token ?? ''}'`);
} while (minified !== input);
return minified;
}
export const formatTemplate = (input: string, config?: TokenizerConfig) => {
const minified = minifyTemplate(input, config);
type ParserState = 'none' | 'open_brace' | 'block' | 'block_end' | 'quote' | 'escaped';
let state: ParserState = 'none';
let currentBlock = '';
let blockStart = '';
let quoteStart = '';
let escaped = false;
const blocks: string[] = [];
for (const ch of minified) {
currentBlock += ch;
if (state === 'none') {
if (ch === '{') {
state = 'open_brace';
}
} else if (state === 'open_brace') {
if (ch === '{' || ch === '%') {
blockStart = ch;
state = 'block';
currentBlock += '-';
} else {
state = 'none';
}
} else if (state === 'block') {
if (ch === '"' || ch === "'") {
quoteStart = ch;
state = 'quote';
} else if (ch === blockStart || blockStart === '{' && ch === '}') {
currentBlock = currentBlock.slice(0, -1) + '-' + ch;
state = 'block_end';
}
} else if (state === 'block_end') {
if (ch === '}') {
state = 'none';
blocks.push(currentBlock);
currentBlock = '';
} else {
state = 'block';
}
} else if (state === 'quote') {
if (!escaped && ch === quoteStart) {
state = 'block';
} else if (!escaped && ch === '\\') {
escaped = true;
} else {
escaped = false;
}
}
}
if (currentBlock) {
blocks.push(currentBlock);
}
let indent = '';
for (let i = 0; i < blocks.length; i++) {
const line = blocks[i];
const content = line.slice(3).trim();
if (content.startsWith('if ') || content.startsWith('for ')) {
blocks[i] = indent + line;
indent += ' ';
} else if (content.startsWith('else ') || content.startsWith('elif ')) {
indent = indent.slice(2);
blocks[i] = indent + line;
indent += ' ';
} else if (content.startsWith("end")) {
indent = indent.slice(2);
blocks[i] = indent + line;
} else {
blocks[i] = indent + line;
}
}
return blocks.filter(b => b.trim()).join('\n');
}
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
modelName = normalizeModel(modelName);
if (!modelName) return '';
let template = templateCache[modelName] ?? null;
if (template) {
console.log(`[huggingface] found cached template for '${modelName}'`);
} else {
const [config] = await loadHuggingfaceTokenizer(modelName, true);
if (config?.chat_template?.trim()) {
template = config.chat_template;
}
}
templateCache[modelName] = template;
saveObject(TEMPLATE_CACHE_KEY, templateCache);
return template;
}
export const findTokenizer = async (modelName: string): Promise<Tokenizer | null> => {
modelName = normalizeModel(modelName);
if (!modelName) return null;
let tokenizer = compiledTokenizers.get(modelName) ?? null;
if (!tokenizer) {
const [tokenizerConfig, tokenizerJson] = await loadHuggingfaceTokenizer(modelName);
if (tokenizerConfig && tokenizerJson) {
tokenizer = new Tokenizer(tokenizerJson, tokenizerConfig);
compiledTokenizers.set(modelName, tokenizer);
}
}
return tokenizer;
}
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => (
applyTemplate(templateString, {
messages,
add_generation_prompt: true,
enable_thinking: false,
tools: functions?.map(convertFunctionToTool),
})
);
export const applyTemplate = (templateString: string, args: Record<string, any>): string => {
try {
let template = compiledTemplates.get(templateString);
if (!template) {
template = new Template(templateString);
compiledTemplates.set(templateString, template);
}
const result = template.render(args);
return result;
} catch (e) {
console.error('[applyTemplate] error:', e);
}
return '';
}
}