AI story: improve hf model search
This commit is contained in:
parent
25c3f5dc25
commit
b95506a095
|
|
@ -17,10 +17,6 @@
|
||||||
.inputs {
|
.inputs {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: row;
|
flex-direction: row;
|
||||||
|
|
||||||
select {
|
|
||||||
text-transform: capitalize;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.info {
|
.info {
|
||||||
|
|
|
||||||
|
|
@ -61,12 +61,19 @@ export const Header = () => {
|
||||||
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid}
|
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid}
|
||||||
/>
|
/>
|
||||||
<select value={instruct} onChange={setInstruct} title='Instruct template'>
|
<select value={instruct} onChange={setInstruct} title='Instruct template'>
|
||||||
{modelName && modelTemplate && <option value={modelTemplate} title='Native for model'>{modelName}</option>}
|
{modelName && modelTemplate && <optgroup label='Native model template'>
|
||||||
{Object.entries(Instruct).map(([label, value]) => (
|
<option value={modelTemplate} title='Native for model'>{modelName}</option>
|
||||||
<option value={value} key={value}>
|
</optgroup>}
|
||||||
{label.toLowerCase()}
|
<optgroup label='Manual templates'>
|
||||||
</option>
|
{Object.entries(Instruct).map(([label, value]) => (
|
||||||
))}
|
<option value={value} key={value}>
|
||||||
|
{label.toLowerCase()}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</optgroup>
|
||||||
|
<optgroup label='Custom'>
|
||||||
|
<option value={instruct}>Custom</option>
|
||||||
|
</optgroup>
|
||||||
</select>
|
</select>
|
||||||
<div class={styles.info}>
|
<div class={styles.info}>
|
||||||
{promptTokens} / {contextLength}
|
{promptTokens} / {contextLength}
|
||||||
|
|
@ -102,6 +109,9 @@ export const Header = () => {
|
||||||
<h4 class={styles.modalTitle}>User prompt template</h4>
|
<h4 class={styles.modalTitle}>User prompt template</h4>
|
||||||
<Ace value={userPrompt} onInput={setUserPrompt} />
|
<Ace value={userPrompt} onInput={setUserPrompt} />
|
||||||
<hr />
|
<hr />
|
||||||
|
<h4 class={styles.modalTitle}>Instruct template</h4>
|
||||||
|
<Ace value={instruct} onInput={setInstruct} />
|
||||||
|
<hr />
|
||||||
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
||||||
<AutoTextarea
|
<AutoTextarea
|
||||||
placeholder="Each phrase on separate line"
|
placeholder="Each phrase on separate line"
|
||||||
|
|
|
||||||
|
|
@ -38,13 +38,13 @@ interface IActions {
|
||||||
const SAVE_KEY = 'ai_game_save_state';
|
const SAVE_KEY = 'ai_game_save_state';
|
||||||
|
|
||||||
export enum Instruct {
|
export enum Instruct {
|
||||||
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
|
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n\\n' }}{% endif %}`,
|
||||||
|
|
||||||
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`,
|
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}`,
|
||||||
|
|
||||||
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
||||||
|
|
||||||
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`,
|
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const saveContext = (context: IContext) => {
|
export const saveContext = (context: IContext) => {
|
||||||
|
|
|
||||||
|
|
@ -84,22 +84,29 @@ export namespace Huggingface {
|
||||||
obj != null && typeof obj === 'object' && (field in obj)
|
obj != null && typeof obj === 'object' && (field in obj)
|
||||||
);
|
);
|
||||||
const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => (
|
const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => (
|
||||||
hasField(obj, 'chat_template') && typeof obj.chat_template === 'string'
|
hasField(obj, 'chat_template') && (typeof obj.chat_template === 'string')
|
||||||
&& (!hasField(obj, 'eos_token') || !obj.eos_token || typeof obj.eos_token === 'string')
|
&& (!hasField(obj, 'eos_token') || !obj.eos_token || typeof obj.eos_token === 'string')
|
||||||
&& (!hasField(obj, 'bos_token') || !obj.bos_token || typeof obj.bos_token === 'string')
|
&& (!hasField(obj, 'bos_token') || !obj.bos_token || typeof obj.bos_token === 'string')
|
||||||
);
|
);
|
||||||
|
|
||||||
const loadHuggingfaceTokenizerConfig = async (model: string): Promise<TokenizerConfig | null> => {
|
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
||||||
console.log(`[huggingface] searching config for '${model}'`);
|
console.log(`[huggingface] searching config for '${modelName}'`);
|
||||||
|
|
||||||
const models = hub.listModels({ search: { query: model }, additionalFields: ['config'] });
|
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
|
||||||
const recheckModels: hub.ModelEntry[] = [];
|
const models = hubModels.filter(m => {
|
||||||
|
if (m.gated) return false;
|
||||||
|
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}).sort((a, b) => b.downloads - a.downloads);
|
||||||
|
|
||||||
let tokenizerConfig: TokenizerConfig | null = null;
|
let tokenizerConfig: TokenizerConfig | null = null;
|
||||||
|
|
||||||
for await (const model of models) {
|
for (const model of models) {
|
||||||
recheckModels.push(model);
|
const { config, name } = model;
|
||||||
const { config } = model;
|
|
||||||
|
if (name.toLowerCase().endsWith('-gguf')) continue;
|
||||||
|
|
||||||
if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) {
|
if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) {
|
||||||
tokenizerConfig = config.tokenizer_config;
|
tokenizerConfig = config.tokenizer_config;
|
||||||
break;
|
break;
|
||||||
|
|
@ -116,10 +123,10 @@ export namespace Huggingface {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch { }
|
} catch { }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tokenizerConfig) {
|
if (!tokenizerConfig) {
|
||||||
for (const model of recheckModels.slice(0, 10)) {
|
for (const model of models) {
|
||||||
try {
|
try {
|
||||||
for await (const file of hub.listFiles({ repo: model.name, recursive: true })) {
|
for await (const file of hub.listFiles({ repo: model.name, recursive: true })) {
|
||||||
if (file.type !== 'file' || !file.path.endsWith('.gguf')) continue;
|
if (file.type !== 'file' || !file.path.endsWith('.gguf')) continue;
|
||||||
|
|
@ -160,7 +167,7 @@ export namespace Huggingface {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tokenizerConfig) {
|
if (tokenizerConfig) {
|
||||||
console.log(`[huggingface] found config for '${model}'`);
|
console.log(`[huggingface] found config for '${modelName}'`);
|
||||||
return {
|
return {
|
||||||
chat_template: tokenizerConfig.chat_template,
|
chat_template: tokenizerConfig.chat_template,
|
||||||
eos_token: tokenizerConfig.eos_token,
|
eos_token: tokenizerConfig.eos_token,
|
||||||
|
|
@ -168,7 +175,7 @@ export namespace Huggingface {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log(`[huggingface] not found config for '${model}'`);
|
console.log(`[huggingface] not found config for '${modelName}'`);
|
||||||
return null;
|
return null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -218,13 +225,12 @@ export namespace Huggingface {
|
||||||
|
|
||||||
const text = applyChatTemplate(template, history, tools);
|
const text = applyChatTemplate(template, history, tools);
|
||||||
|
|
||||||
console.log(text);
|
|
||||||
|
|
||||||
return text.includes(needle);
|
return text.includes(needle);
|
||||||
}
|
}
|
||||||
|
|
||||||
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
||||||
let template = templateCache[modelName] ?? null;
|
const modelKey = modelName.toLowerCase();
|
||||||
|
let template = templateCache[modelKey] ?? null;
|
||||||
|
|
||||||
if (template) {
|
if (template) {
|
||||||
console.log(`[huggingface] found cached template for '${modelName}'`);
|
console.log(`[huggingface] found cached template for '${modelName}'`);
|
||||||
|
|
@ -244,7 +250,7 @@ export namespace Huggingface {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
templateCache[modelName] = template;
|
templateCache[modelKey] = template;
|
||||||
saveCache(templateCache);
|
saveCache(templateCache);
|
||||||
|
|
||||||
return template;
|
return template;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue