Compare commits
3 Commits
a4ce47d9d8
...
b95506a095
| Author | SHA1 | Date |
|---|---|---|
|
|
b95506a095 | |
|
|
25c3f5dc25 | |
|
|
5805469581 |
|
|
@ -8,6 +8,8 @@
|
||||||
"bake": "bun build/build.ts"
|
"bake": "bun build/build.ts"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@huggingface/gguf": "0.1.12",
|
||||||
|
"@huggingface/hub": "0.19.0",
|
||||||
"@huggingface/jinja": "0.3.1",
|
"@huggingface/jinja": "0.3.1",
|
||||||
"@inquirer/select": "2.3.10",
|
"@inquirer/select": "2.3.10",
|
||||||
"ace-builds": "1.36.3",
|
"ace-builds": "1.36.3",
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ export const Ace = ({ value, onInput }: IAceProps) => {
|
||||||
displayIndentGuides: false,
|
displayIndentGuides: false,
|
||||||
fontSize: 16,
|
fontSize: 16,
|
||||||
maxLines: Infinity,
|
maxLines: Infinity,
|
||||||
|
tabSize: 2,
|
||||||
|
useSoftTabs: true,
|
||||||
wrap: "free",
|
wrap: "free",
|
||||||
});
|
});
|
||||||
return e;
|
return e;
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,11 @@
|
||||||
.inputs {
|
.inputs {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: row;
|
flex-direction: row;
|
||||||
|
|
||||||
select {
|
|
||||||
text-transform: capitalize;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.info {
|
||||||
|
margin: 0 8px;
|
||||||
|
line-height: 36px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.buttons {
|
.buttons {
|
||||||
|
|
|
||||||
|
|
@ -11,38 +11,25 @@ import styles from './header.module.css';
|
||||||
import { Ace } from "../ace";
|
import { Ace } from "../ace";
|
||||||
|
|
||||||
export const Header = () => {
|
export const Header = () => {
|
||||||
const { getContextLength } = useContext(LLMContext);
|
const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext);
|
||||||
const {
|
const {
|
||||||
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
|
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
|
||||||
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
|
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
const [urlValid, setUrlValid] = useState(false);
|
|
||||||
const [urlEditing, setUrlEditing] = useState(false);
|
|
||||||
|
|
||||||
const loreOpen = useBool();
|
const loreOpen = useBool();
|
||||||
const promptsOpen = useBool();
|
const promptsOpen = useBool();
|
||||||
const assistantOpen = useBool();
|
const assistantOpen = useBool();
|
||||||
|
|
||||||
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
|
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
|
||||||
|
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
|
||||||
const handleFocusUrl = useCallback(() => setUrlEditing(true), []);
|
|
||||||
|
|
||||||
const handleBlurUrl = useCallback(() => {
|
const handleBlurUrl = useCallback(() => {
|
||||||
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i
|
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i
|
||||||
const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2');
|
const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2');
|
||||||
console.log({ connectionUrl, normalizedConnectionUrl })
|
|
||||||
setConnectionUrl(normalizedConnectionUrl);
|
setConnectionUrl(normalizedConnectionUrl);
|
||||||
setUrlEditing(false);
|
blockConnection.setFalse();
|
||||||
setUrlValid(false);
|
}, [connectionUrl, setConnectionUrl, blockConnection]);
|
||||||
}, [connectionUrl, setConnectionUrl]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!urlEditing) {
|
|
||||||
getContextLength().then(length => {
|
|
||||||
setUrlValid(length > 0);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [connectionUrl, urlEditing]);
|
|
||||||
|
|
||||||
const handleAssistantAddSwipe = useCallback((answer: string) => {
|
const handleAssistantAddSwipe = useCallback((answer: string) => {
|
||||||
const index = messages.findLastIndex(m => m.role === 'assistant');
|
const index = messages.findLastIndex(m => m.role === 'assistant');
|
||||||
|
|
@ -69,17 +56,28 @@ export const Header = () => {
|
||||||
<div class={styles.inputs}>
|
<div class={styles.inputs}>
|
||||||
<input value={connectionUrl}
|
<input value={connectionUrl}
|
||||||
onInput={setConnectionUrl}
|
onInput={setConnectionUrl}
|
||||||
onFocus={handleFocusUrl}
|
onFocus={blockConnection.setTrue}
|
||||||
onBlur={handleBlurUrl}
|
onBlur={handleBlurUrl}
|
||||||
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid}
|
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid}
|
||||||
/>
|
/>
|
||||||
<select value={instruct} onChange={setInstruct}>
|
<select value={instruct} onChange={setInstruct} title='Instruct template'>
|
||||||
|
{modelName && modelTemplate && <optgroup label='Native model template'>
|
||||||
|
<option value={modelTemplate} title='Native for model'>{modelName}</option>
|
||||||
|
</optgroup>}
|
||||||
|
<optgroup label='Manual templates'>
|
||||||
{Object.entries(Instruct).map(([label, value]) => (
|
{Object.entries(Instruct).map(([label, value]) => (
|
||||||
<option value={value} key={value}>
|
<option value={value} key={value}>
|
||||||
{label.toLowerCase()}
|
{label.toLowerCase()}
|
||||||
</option>
|
</option>
|
||||||
))}
|
))}
|
||||||
|
</optgroup>
|
||||||
|
<optgroup label='Custom'>
|
||||||
|
<option value={instruct}>Custom</option>
|
||||||
|
</optgroup>
|
||||||
</select>
|
</select>
|
||||||
|
<div class={styles.info}>
|
||||||
|
{promptTokens} / {contextLength}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
|
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
|
||||||
|
|
@ -111,6 +109,9 @@ export const Header = () => {
|
||||||
<h4 class={styles.modalTitle}>User prompt template</h4>
|
<h4 class={styles.modalTitle}>User prompt template</h4>
|
||||||
<Ace value={userPrompt} onInput={setUserPrompt} />
|
<Ace value={userPrompt} onInput={setUserPrompt} />
|
||||||
<hr />
|
<hr />
|
||||||
|
<h4 class={styles.modalTitle}>Instruct template</h4>
|
||||||
|
<Ace value={instruct} onInput={setInstruct} />
|
||||||
|
<hr />
|
||||||
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
||||||
<AutoTextarea
|
<AutoTextarea
|
||||||
placeholder="Each phrase on separate line"
|
placeholder="Each phrase on separate line"
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,12 @@
|
||||||
import Lock from "@common/lock";
|
import Lock from "@common/lock";
|
||||||
import SSE from "@common/sse";
|
import SSE from "@common/sse";
|
||||||
import { createContext } from "preact";
|
import { createContext } from "preact";
|
||||||
import { useContext, useEffect, useMemo } from "preact/hooks";
|
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../messages";
|
import { MessageTools, type IMessage } from "../messages";
|
||||||
import { StateContext } from "./state";
|
import { Instruct, StateContext } from "./state";
|
||||||
import { useBool } from "@common/hooks/useBool";
|
import { useBool } from "@common/hooks/useBool";
|
||||||
import { Template } from "@huggingface/jinja";
|
import { Template } from "@huggingface/jinja";
|
||||||
|
import { Huggingface } from "../huggingface";
|
||||||
|
|
||||||
interface ITemplateMessage {
|
|
||||||
role: 'user' | 'assistant' | 'system';
|
|
||||||
content: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ICompileArgs {
|
interface ICompileArgs {
|
||||||
keepUsers?: number;
|
keepUsers?: number;
|
||||||
|
|
@ -25,6 +20,12 @@ interface ICompiledPrompt {
|
||||||
|
|
||||||
interface IContext {
|
interface IContext {
|
||||||
generating: boolean;
|
generating: boolean;
|
||||||
|
blockConnection: ReturnType<typeof useBool>;
|
||||||
|
modelName: string;
|
||||||
|
modelTemplate: string;
|
||||||
|
hasToolCalls: boolean;
|
||||||
|
promptTokens: number;
|
||||||
|
contextLength: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_GENERATION_SETTINGS = {
|
const DEFAULT_GENERATION_SETTINGS = {
|
||||||
|
|
@ -44,23 +45,53 @@ const DEFAULT_GENERATION_SETTINGS = {
|
||||||
type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
||||||
|
|
||||||
interface IActions {
|
interface IActions {
|
||||||
applyChatTemplate: (messages: ITemplateMessage[], templateString: string, eosToken?: string) => string;
|
|
||||||
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
||||||
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
||||||
countTokens(prompt: string): Promise<number>;
|
countTokens: (prompt: string) => Promise<number>;
|
||||||
getContextLength(): Promise<number>;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
export type ILLMContext = IContext & IActions;
|
export type ILLMContext = IContext & IActions;
|
||||||
|
|
||||||
|
export const normalizeModel = (model: string) => {
|
||||||
|
let currentModel = model.split(/[\\\/]/).at(-1);
|
||||||
|
currentModel = currentModel.split('::').at(0);
|
||||||
|
let normalizedModel: string;
|
||||||
|
|
||||||
|
do {
|
||||||
|
normalizedModel = currentModel;
|
||||||
|
|
||||||
|
currentModel = currentModel
|
||||||
|
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
||||||
|
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
|
||||||
|
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
||||||
|
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
||||||
|
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
||||||
|
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
|
||||||
|
.replace(/^(debug-?)+/i, '')
|
||||||
|
.trim();
|
||||||
|
} while (normalizedModel !== currentModel);
|
||||||
|
|
||||||
|
return normalizedModel
|
||||||
|
.replace(/[ _-]+/ig, '-')
|
||||||
|
.replace(/\.{2,}/, '-')
|
||||||
|
.replace(/[ ._-]+$/ig, '')
|
||||||
|
.trim();
|
||||||
|
}
|
||||||
|
|
||||||
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
||||||
|
|
||||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const {
|
const {
|
||||||
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
|
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
|
||||||
setTriggerNext, addMessage, editMessage,
|
setTriggerNext, addMessage, editMessage, setInstruct,
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
|
|
||||||
const generating = useBool(false);
|
const generating = useBool(false);
|
||||||
|
const blockConnection = useBool(false);
|
||||||
|
const [promptTokens, setPromptTokens] = useState(0);
|
||||||
|
const [contextLength, setContextLength] = useState(0);
|
||||||
|
const [modelName, setModelName] = useState('');
|
||||||
|
const [modelTemplate, setModelTemplate] = useState('');
|
||||||
|
const [hasToolCalls, setHasToolCalls] = useState(false);
|
||||||
|
|
||||||
const userPromptTemplate = useMemo(() => {
|
const userPromptTemplate = useMemo(() => {
|
||||||
try {
|
try {
|
||||||
|
|
@ -72,17 +103,41 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}
|
}
|
||||||
}, [userPrompt]);
|
}, [userPrompt]);
|
||||||
|
|
||||||
|
const getContextLength = useCallback(async () => {
|
||||||
|
if (!connectionUrl || blockConnection.value) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
|
||||||
|
if (response.ok) {
|
||||||
|
const { value } = await response.json();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.log('Error getting max tokens', e);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}, [connectionUrl, blockConnection.value]);
|
||||||
|
|
||||||
|
const getModelName = useCallback(async () => {
|
||||||
|
if (!connectionUrl || blockConnection.value) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${connectionUrl}/api/v1/model`);
|
||||||
|
if (response.ok) {
|
||||||
|
const { result } = await response.json();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.log('Error getting max tokens', e);
|
||||||
|
}
|
||||||
|
|
||||||
|
return '';
|
||||||
|
}, [connectionUrl, blockConnection.value]);
|
||||||
|
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
applyChatTemplate: (messages: ITemplateMessage[], templateString: string) => {
|
|
||||||
const template = new Template(templateString);
|
|
||||||
|
|
||||||
const prompt = template.render({
|
|
||||||
messages,
|
|
||||||
add_generation_prompt: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
return prompt;
|
|
||||||
},
|
|
||||||
compilePrompt: async (messages, { keepUsers } = {}) => {
|
compilePrompt: async (messages, { keepUsers } = {}) => {
|
||||||
const promptMessages = messages.slice();
|
const promptMessages = messages.slice();
|
||||||
const lastMessage = promptMessages.at(-1);
|
const lastMessage = promptMessages.at(-1);
|
||||||
|
|
@ -90,17 +145,17 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
|
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
|
||||||
const isContinue = isAssistantLast && !isRegen;
|
const isContinue = isAssistantLast && !isRegen;
|
||||||
|
|
||||||
const userMessages = promptMessages.filter(m => m.role === 'user');
|
|
||||||
const lastUserMessage = userMessages.at(-1);
|
|
||||||
const firstUserMessage = userMessages.at(0);
|
|
||||||
|
|
||||||
if (isContinue) {
|
if (isContinue) {
|
||||||
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
|
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const userMessages = promptMessages.filter(m => m.role === 'user');
|
||||||
|
const lastUserMessage = userMessages.at(-1);
|
||||||
|
const firstUserMessage = userMessages.at(0);
|
||||||
|
|
||||||
const system = `${systemPrompt}\n\n${lore}`.trim();
|
const system = `${systemPrompt}\n\n${lore}`.trim();
|
||||||
|
|
||||||
const templateMessages: ITemplateMessage[] = [
|
const templateMessages: Huggingface.ITemplateMessage[] = [
|
||||||
{ role: 'system', content: system },
|
{ role: 'system', content: system },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
@ -128,15 +183,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
templateMessages.push({ role, content });
|
templateMessages.push({ role, content });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (templateMessages[1]?.role !== 'user') {
|
|
||||||
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
|
||||||
|
|
||||||
templateMessages.splice(1, 0, {
|
|
||||||
role: 'user',
|
|
||||||
content: userPromptTemplate.render({ prompt, isStart: true }),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
const story = promptMessages.filter(m => m.role === 'assistant')
|
const story = promptMessages.filter(m => m.role === 'assistant')
|
||||||
.map(m => MessageTools.getSwipe(m)?.content.trim()).join('\n\n');
|
.map(m => MessageTools.getSwipe(m)?.content.trim()).join('\n\n');
|
||||||
|
|
@ -157,7 +203,16 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const prompt = actions.applyChatTemplate(templateMessages, instruct);
|
if (templateMessages[1]?.role !== 'user') {
|
||||||
|
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
||||||
|
|
||||||
|
templateMessages.splice(1, 0, {
|
||||||
|
role: 'user',
|
||||||
|
content: userPromptTemplate.render({ prompt, isStart: true }),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = Huggingface.applyChatTemplate(instruct, templateMessages);
|
||||||
return {
|
return {
|
||||||
prompt,
|
prompt,
|
||||||
isContinue,
|
isContinue,
|
||||||
|
|
@ -248,22 +303,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
console.log('Error counting tokens', e);
|
console.log('Error counting tokens', e);
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0;
|
|
||||||
},
|
|
||||||
getContextLength: async () => {
|
|
||||||
if (!connectionUrl) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
|
|
||||||
if (response.ok) {
|
|
||||||
const { value } = await response.json();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.log('Error getting max tokens', e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
},
|
},
|
||||||
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
|
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
|
||||||
|
|
@ -276,6 +315,8 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
let text: string = '';
|
let text: string = '';
|
||||||
|
|
||||||
const { prompt, isRegen } = await actions.compilePrompt(messages);
|
const { prompt, isRegen } = await actions.compilePrompt(messages);
|
||||||
|
const tokens = await actions.countTokens(prompt);
|
||||||
|
setPromptTokens(tokens);
|
||||||
|
|
||||||
if (!isRegen) {
|
if (!isRegen) {
|
||||||
addMessage('', 'assistant');
|
addMessage('', 'assistant');
|
||||||
|
|
@ -284,18 +325,76 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
for await (const chunk of actions.generate(prompt)) {
|
for await (const chunk of actions.generate(prompt)) {
|
||||||
text += chunk;
|
text += chunk;
|
||||||
|
setPromptTokens(tokens + Math.round(text.length * 0.25));
|
||||||
editMessage(messageId, text);
|
editMessage(messageId, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
text = MessageTools.trimSentence(text);
|
text = MessageTools.trimSentence(text);
|
||||||
editMessage(messageId, text);
|
editMessage(messageId, text);
|
||||||
|
|
||||||
|
setPromptTokens(0); // trigger calculation
|
||||||
|
|
||||||
MessageTools.playReady();
|
MessageTools.playReady();
|
||||||
}
|
}
|
||||||
})(), [triggerNext, messages, generating.value]);
|
})(), [actions, triggerNext, messages, generating.value]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!blockConnection.value) {
|
||||||
|
setPromptTokens(0);
|
||||||
|
setContextLength(0);
|
||||||
|
|
||||||
|
getContextLength().then(setContextLength);
|
||||||
|
}
|
||||||
|
}, [connectionUrl, instruct, blockConnection.value]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!blockConnection.value) {
|
||||||
|
setModelName('');
|
||||||
|
getModelName().then(normalizeModel).then(setModelName);
|
||||||
|
}
|
||||||
|
}, [connectionUrl, blockConnection.value]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setModelTemplate('');
|
||||||
|
if (modelName) {
|
||||||
|
Huggingface.findModelTemplate(modelName)
|
||||||
|
.then((template) => {
|
||||||
|
if (template) {
|
||||||
|
setModelTemplate(template);
|
||||||
|
setInstruct(template);
|
||||||
|
} else {
|
||||||
|
setInstruct(Instruct.CHATML);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [modelName]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (promptTokens === 0 && !blockConnection.value) {
|
||||||
|
actions.compilePrompt(messages)
|
||||||
|
.then(({ prompt }) => actions.countTokens(prompt))
|
||||||
|
.then(setPromptTokens)
|
||||||
|
.catch(e => console.error(`Could not count tokens`, e));
|
||||||
|
}
|
||||||
|
}, [actions, promptTokens, messages, blockConnection.value]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
try {
|
||||||
|
const hasTools = Huggingface.testToolCalls(instruct);
|
||||||
|
setHasToolCalls(hasTools);
|
||||||
|
} catch {
|
||||||
|
setHasToolCalls(false);
|
||||||
|
}
|
||||||
|
}, [instruct]);
|
||||||
|
|
||||||
const rawContext: IContext = {
|
const rawContext: IContext = {
|
||||||
generating: generating.value,
|
generating: generating.value,
|
||||||
|
blockConnection,
|
||||||
|
modelName,
|
||||||
|
modelTemplate,
|
||||||
|
hasToolCalls,
|
||||||
|
promptTokens,
|
||||||
|
contextLength,
|
||||||
};
|
};
|
||||||
|
|
||||||
const context = useMemo(() => rawContext, Object.values(rawContext));
|
const context = useMemo(() => rawContext, Object.values(rawContext));
|
||||||
|
|
|
||||||
|
|
@ -38,13 +38,13 @@ interface IActions {
|
||||||
const SAVE_KEY = 'ai_game_save_state';
|
const SAVE_KEY = 'ai_game_save_state';
|
||||||
|
|
||||||
export enum Instruct {
|
export enum Instruct {
|
||||||
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`,
|
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n\\n' }}{% endif %}`,
|
||||||
|
|
||||||
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}`,
|
||||||
|
|
||||||
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
|
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
||||||
|
|
||||||
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`,
|
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const saveContext = (context: IContext) => {
|
export const saveContext = (context: IContext) => {
|
||||||
|
|
@ -58,7 +58,7 @@ export const loadContext = (): IContext => {
|
||||||
const defaultContext: IContext = {
|
const defaultContext: IContext = {
|
||||||
connectionUrl: 'http://localhost:5001',
|
connectionUrl: 'http://localhost:5001',
|
||||||
input: '',
|
input: '',
|
||||||
instruct: Instruct.LLAMA,
|
instruct: Instruct.CHATML,
|
||||||
systemPrompt: 'You are creative writer. Write a story based on the world description below.',
|
systemPrompt: 'You are creative writer. Write a story based on the world description below.',
|
||||||
lore: '',
|
lore: '',
|
||||||
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
|
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,270 @@
|
||||||
|
import { gguf } from '@huggingface/gguf';
|
||||||
|
import * as hub from '@huggingface/hub';
|
||||||
|
import { Template } from '@huggingface/jinja';
|
||||||
|
|
||||||
|
export namespace Huggingface {
|
||||||
|
export interface ITemplateMessage {
|
||||||
|
role: 'user' | 'assistant' | 'system';
|
||||||
|
content: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface INumberParameter {
|
||||||
|
type: 'number';
|
||||||
|
enum?: number[];
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IStringParameter {
|
||||||
|
type: 'string';
|
||||||
|
enum?: string[];
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IArrayParameter {
|
||||||
|
type: 'array';
|
||||||
|
description?: string;
|
||||||
|
items: IParameter;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IObjectParameter {
|
||||||
|
type: 'object';
|
||||||
|
description?: string;
|
||||||
|
properties: Record<string, IParameter>;
|
||||||
|
required?: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
type IParameter = INumberParameter | IStringParameter | IArrayParameter | IObjectParameter;
|
||||||
|
|
||||||
|
interface ITool {
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: string;
|
||||||
|
description?: string;
|
||||||
|
parameters?: IObjectParameter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IFunction {
|
||||||
|
name: string;
|
||||||
|
description?: string;
|
||||||
|
parameters?: Record<string, IParameter>;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TokenizerConfig {
|
||||||
|
chat_template: string;
|
||||||
|
bos_token?: string;
|
||||||
|
eos_token?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TEMPLATE_CACHE_KEY = 'ai_game_template_cache';
|
||||||
|
|
||||||
|
const loadCache = (): Record<string, string> => {
|
||||||
|
const json = localStorage.getItem(TEMPLATE_CACHE_KEY);
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (json) {
|
||||||
|
const cache = JSON.parse(json);
|
||||||
|
if (cache && typeof cache === 'object') {
|
||||||
|
return cache
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch { }
|
||||||
|
|
||||||
|
return {};
|
||||||
|
};
|
||||||
|
|
||||||
|
const saveCache = (cache: Record<string, string>) => {
|
||||||
|
const json = JSON.stringify(cache);
|
||||||
|
localStorage.setItem(TEMPLATE_CACHE_KEY, json);
|
||||||
|
};
|
||||||
|
|
||||||
|
const templateCache: Record<string, string> = loadCache();
|
||||||
|
|
||||||
|
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
|
||||||
|
obj != null && typeof obj === 'object' && (field in obj)
|
||||||
|
);
|
||||||
|
const isTokenizerConfig = (obj: unknown): obj is TokenizerConfig => (
|
||||||
|
hasField(obj, 'chat_template') && (typeof obj.chat_template === 'string')
|
||||||
|
&& (!hasField(obj, 'eos_token') || !obj.eos_token || typeof obj.eos_token === 'string')
|
||||||
|
&& (!hasField(obj, 'bos_token') || !obj.bos_token || typeof obj.bos_token === 'string')
|
||||||
|
);
|
||||||
|
|
||||||
|
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
||||||
|
console.log(`[huggingface] searching config for '${modelName}'`);
|
||||||
|
|
||||||
|
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
|
||||||
|
const models = hubModels.filter(m => {
|
||||||
|
if (m.gated) return false;
|
||||||
|
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}).sort((a, b) => b.downloads - a.downloads);
|
||||||
|
|
||||||
|
let tokenizerConfig: TokenizerConfig | null = null;
|
||||||
|
|
||||||
|
for (const model of models) {
|
||||||
|
const { config, name } = model;
|
||||||
|
|
||||||
|
if (name.toLowerCase().endsWith('-gguf')) continue;
|
||||||
|
|
||||||
|
if (hasField(config, 'tokenizer_config') && isTokenizerConfig(config.tokenizer_config)) {
|
||||||
|
tokenizerConfig = config.tokenizer_config;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log(`[huggingface] searching config in '${model.name}/tokenizer_config.json'`);
|
||||||
|
const fileResponse = await hub.downloadFile({ repo: model.name, path: 'tokenizer_config.json' });
|
||||||
|
if (fileResponse?.ok) {
|
||||||
|
const maybeConfig = await fileResponse.json();
|
||||||
|
if (isTokenizerConfig(maybeConfig)) {
|
||||||
|
tokenizerConfig = maybeConfig;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch { }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tokenizerConfig) {
|
||||||
|
for (const model of models) {
|
||||||
|
try {
|
||||||
|
for await (const file of hub.listFiles({ repo: model.name, recursive: true })) {
|
||||||
|
if (file.type !== 'file' || !file.path.endsWith('.gguf')) continue;
|
||||||
|
try {
|
||||||
|
console.log(`[huggingface] searching config in '${model.name}/${file.path}'`);
|
||||||
|
const fileInfo = await hub.fileDownloadInfo({ repo: model.name, path: file.path });
|
||||||
|
if (fileInfo?.downloadLink) {
|
||||||
|
const { metadata } = await gguf(fileInfo.downloadLink);
|
||||||
|
if ('tokenizer.chat_template' in metadata) {
|
||||||
|
const chat_template = metadata['tokenizer.chat_template'];
|
||||||
|
const tokens = metadata['tokenizer.ggml.tokens'];
|
||||||
|
const bos_token = tokens[metadata['tokenizer.ggml.bos_token_id']];
|
||||||
|
const eos_token = tokens[metadata['tokenizer.ggml.eos_token_id']];
|
||||||
|
|
||||||
|
const maybeConfig = {
|
||||||
|
chat_template,
|
||||||
|
bos_token,
|
||||||
|
eos_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isTokenizerConfig(maybeConfig)) {
|
||||||
|
tokenizerConfig = maybeConfig;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else if ('tokenizer.ggml.model' in metadata) {
|
||||||
|
break; // no reason to touch different quants
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch { }
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch { }
|
||||||
|
|
||||||
|
if (tokenizerConfig) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tokenizerConfig) {
|
||||||
|
console.log(`[huggingface] found config for '${modelName}'`);
|
||||||
|
return {
|
||||||
|
chat_template: tokenizerConfig.chat_template,
|
||||||
|
eos_token: tokenizerConfig.eos_token,
|
||||||
|
bos_token: tokenizerConfig.bos_token,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`[huggingface] not found config for '${modelName}'`);
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
function updateRequired<T extends IParameter>(param: T): T {
|
||||||
|
if ('items' in param) {
|
||||||
|
updateRequired(param.items);
|
||||||
|
} else if ('properties' in param) {
|
||||||
|
for (const prop of Object.values(param.properties)) {
|
||||||
|
updateRequired(prop);
|
||||||
|
}
|
||||||
|
param.required = Object.keys(param.properties);
|
||||||
|
}
|
||||||
|
|
||||||
|
return param;
|
||||||
|
}
|
||||||
|
|
||||||
|
const convertFunctionToTool = (fn: IFunction): ITool => ({
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: fn.name,
|
||||||
|
description: fn.description,
|
||||||
|
parameters: updateRequired({
|
||||||
|
type: 'object',
|
||||||
|
properties: fn.parameters ?? {},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
export const testToolCalls = (template: string): boolean => {
|
||||||
|
const history: ITemplateMessage[] = [
|
||||||
|
{ role: 'system', content: 'You are calculator.' },
|
||||||
|
{ role: 'user', content: 'Calculate 2 + 2.' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const needle = '___AWOORWA_NEEDLE__';
|
||||||
|
|
||||||
|
const tools: IFunction[] = [{
|
||||||
|
name: 'add',
|
||||||
|
description: 'Test function',
|
||||||
|
parameters: {
|
||||||
|
a: { type: 'number' },
|
||||||
|
b: { type: 'number' },
|
||||||
|
c: { type: 'array', items: { type: 'number' } },
|
||||||
|
d: { type: 'object', properties: { inside: { type: 'number', description: needle } } },
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
|
||||||
|
const text = applyChatTemplate(template, history, tools);
|
||||||
|
|
||||||
|
return text.includes(needle);
|
||||||
|
}
|
||||||
|
|
||||||
|
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
||||||
|
const modelKey = modelName.toLowerCase();
|
||||||
|
let template = templateCache[modelKey] ?? null;
|
||||||
|
|
||||||
|
if (template) {
|
||||||
|
console.log(`[huggingface] found cached template for '${modelName}'`);
|
||||||
|
} else {
|
||||||
|
const config = await loadHuggingfaceTokenizerConfig(modelName);
|
||||||
|
|
||||||
|
if (config?.chat_template?.trim()) {
|
||||||
|
template = config.chat_template.trim()
|
||||||
|
.replaceAll('eos_token', `'${config.eos_token ?? ''}'`)
|
||||||
|
.replaceAll('bos_token', `''`);
|
||||||
|
|
||||||
|
if (config.bos_token) {
|
||||||
|
template = template
|
||||||
|
.replaceAll(config.bos_token, '')
|
||||||
|
.replace(/\{\{ ?(''|"") ?\}\}/g, '');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
templateCache[modelKey] = template;
|
||||||
|
saveCache(templateCache);
|
||||||
|
|
||||||
|
return template;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => {
|
||||||
|
const template = new Template(templateString);
|
||||||
|
|
||||||
|
const prompt = template.render({
|
||||||
|
messages,
|
||||||
|
add_generation_prompt: true,
|
||||||
|
tools: functions?.map(convertFunctionToTool),
|
||||||
|
});
|
||||||
|
|
||||||
|
return prompt;
|
||||||
|
};
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue