1
0
Fork 0

Compare commits

..

No commits in common. "c480f5a7d16efc424ff1d3352e786b322a27a844" and "9c4cc615732e2d8516e83ec00db26da73a304af7" have entirely different histories.

18 changed files with 178 additions and 303 deletions

BIN
bun.lockb

Binary file not shown.

View File

@ -11,7 +11,6 @@
"@huggingface/gguf": "0.1.12",
"@huggingface/hub": "0.19.0",
"@huggingface/jinja": "0.3.1",
"@huggingface/transformers": "3.0.2",
"@inquirer/select": "2.3.10",
"ace-builds": "1.36.3",
"classnames": "2.5.1",

View File

@ -7,8 +7,6 @@
--green: #AFAFAF;
--red: #7F0000;
--green: #007F00;
--brightRed: #DD0000;
--brightGreen: #00DD00;
--shadeColor: rgba(0, 128, 128, 0.3);
--border: 1px solid var(--color);

View File

@ -2,7 +2,7 @@ import { useEffect, useRef } from "preact/hooks";
import type { JSX } from "preact/jsx-runtime"
import { useIsVisible } from '@common/hooks/useIsVisible';
import { DOMTools } from "../tools/dom";
import { DOMTools } from "../dom";
export const AutoTextarea = (props: JSX.HTMLAttributes<HTMLTextAreaElement>) => {
const { value } = props;

View File

@ -1,8 +1,8 @@
import { useCallback, useContext, useEffect, useRef } from "preact/hooks";
import { StateContext } from "../contexts/state";
import { Message } from "./message/message";
import { MessageTools } from "../tools/messages";
import { DOMTools } from "../tools/dom";
import { MessageTools } from "../messages";
import { DOMTools } from "../dom";
export const Chat = () => {
const { messages } = useContext(StateContext);

View File

@ -1,10 +1,11 @@
import { useCallback, useEffect, useMemo, useState } from 'preact/hooks';
import styles from './header.module.css';
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../tools/connection';
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection';
import { Instruct } from '../../contexts/state';
import { useInputState } from '@common/hooks/useInputState';
import { useInputCallback } from '@common/hooks/useInputCallback';
import { Huggingface } from '../../tools/huggingface';
import { Huggingface } from '../../huggingface';
interface IProps {
connection: IConnection;
@ -12,13 +13,10 @@ interface IProps {
}
export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
// kobold
const [connectionUrl, setConnectionUrl] = useInputState('');
// horde
const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY);
const [modelName, setModelName] = useInputState('');
const [instruct, setInstruct] = useInputState('');
const [modelTemplate, setModelTemplate] = useInputState('');
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
const [contextLength, setContextLength] = useState<number>(0);
@ -29,14 +27,11 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
return 'unknown';
}, [connection]);
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
useEffect(() => {
setInstruct(connection.instruct);
if (isKoboldConnection(connection)) {
setConnectionUrl(connection.url);
Connection.getContextLength(connection).then(setContextLength);
} else if (isHordeConnection(connection)) {
setModelName(connection.model);
setApiKey(connection.apiKey || HORDE_ANON_KEY);
@ -44,6 +39,9 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
Connection.getHordeModels()
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
}
Connection.getContextLength(connection).then(setContextLength);
Connection.getModelName(connection).then(setModelName);
}, [connection]);
useEffect(() => {
@ -52,44 +50,47 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
.then(template => {
if (template) {
setModelTemplate(template);
setInstruct(template);
}
});
}
}, [modelName]);
const setInstruct = useInputCallback((instruct) => {
setConnection({ ...connection, instruct });
}, [connection, setConnection]);
const setBackendType = useInputCallback((type) => {
if (type === 'kobold') {
setConnection({
instruct,
instruct: connection.instruct,
url: connectionUrl,
});
} else if (type === 'horde') {
setConnection({
instruct,
instruct: connection.instruct,
apiKey,
model: modelName,
});
}
}, [setConnection, connectionUrl, apiKey, modelName, instruct]);
}, [connection, setConnection, connectionUrl, apiKey, modelName]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i;
const url = connectionUrl.replace(regex, 'http$1://$2');
setConnection({
instruct,
instruct: connection.instruct,
url,
});
}, [connectionUrl, instruct, setConnection]);
}, [connection, connectionUrl, setConnection]);
const handleBlurHorde = useCallback(() => {
setConnection({
instruct,
instruct: connection.instruct,
apiKey,
model: modelName,
});
}, [apiKey, modelName, instruct, setConnection]);
}, [connection, apiKey, modelName, setConnection]);
return (
<div class={styles.connectionEditor}>
@ -97,7 +98,7 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
<option value='kobold'>Kobold CPP</option>
<option value='horde'>Horde</option>
</select>
<select value={instruct} onChange={setInstruct} title='Instruct template'>
<select value={connection.instruct} onChange={setInstruct} title='Instruct template'>
{modelName && modelTemplate && <optgroup label='Native model template'>
<option value={modelTemplate} title='Native for model'>{modelName}</option>
</optgroup>}
@ -108,15 +109,15 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
</option>
))}
</optgroup>
{instruct !== modelTemplate && <optgroup label='Custom'>
<optgroup label='Custom'>
<option value={connection.instruct}>Custom</option>
</optgroup>}
</optgroup>
</select>
{isKoboldConnection(connection) && <input
value={connectionUrl}
onInput={setConnectionUrl}
onBlur={handleBlurUrl}
class={isOnline ? styles.valid : styles.invalid}
class={urlValid ? styles.valid : styles.invalid}
/>}
{isHordeConnection(connection) && <>
<input

View File

@ -29,14 +29,6 @@
flex-direction: row;
gap: 8px;
padding: 0 8px;
.online {
color: var(--brightGreen);
}
.offline {
color: var(--brightRed);
}
}
}

View File

@ -23,7 +23,6 @@ export const Header = () => {
const promptsOpen = useBool();
const genparamsOpen = useBool();
const assistantOpen = useBool();
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
@ -57,7 +56,7 @@ export const Header = () => {
<div class={styles.header}>
<div class={styles.inputs}>
<div class={styles.buttons}>
<button class={`icon ${isOnline ? styles.online: styles.offline}`} onClick={connectionsOpen.setTrue} title='Connection settings'>
<button class='icon' onClick={connectionsOpen.setTrue} title='Connection settings'>
🔌
</button>
</div>

View File

@ -1,5 +1,5 @@
import { useMemo } from "preact/hooks";
import { MessageTools } from "../../tools/messages";
import { MessageTools } from "../../messages";
import styles from './message.module.css';

View File

@ -1,7 +1,7 @@
import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../../tools/messages";
import { MessageTools, type IMessage } from "../../messages";
import { StateContext } from "../../contexts/state";
import { DOMTools } from "../../tools/dom";
import { DOMTools } from "../../dom";
import styles from './message.module.css';
import { AutoTextarea } from "../autoTextarea";
@ -16,7 +16,7 @@ interface IProps {
}
export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScroll }: IProps) => {
const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages, continueMessage } = useContext(StateContext);
const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages } = useContext(StateContext);
const [editing, setEditing] = useState(false);
const [editedMessage, setEditedMessage] = useInputState('');
const textRef = useRef<HTMLDivElement>(null);
@ -70,10 +70,6 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr
DOMTools.animate(textRef.current, 'swipe-from-right');
}, [setCurrentSwipe, index, message]);
const handleContinueMessage = useCallback(() => {
continueMessage(true);
}, [continueMessage]);
return (
<div class={`${styles.message} ${styles[message.role]} ${isLastUser ? styles.lastUser : ''}`}>
<div class={styles.content}>
@ -93,14 +89,13 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr
<button class='icon' onClick={handleCancelEdit} title='Cancel'></button>
</>
: <>
{isLastAssistant && <>
{isLastAssistant &&
<div class={styles.swipes}>
<div onClick={handleSwipeLeft}></div>
<div>{message.currentSwipe + 1}/{message.swipes.length}</div>
<div onClick={handleSwipeRight}></div>
</div>
<button class='icon' onClick={handleContinueMessage} title="Continue"></button>
</>}
}
<button class='icon' onClick={handleEnableEdit} title="Edit">🖊</button>
</>
}

View File

@ -1,13 +1,12 @@
import { MessageTools, type IMessage } from "../../tools/messages"
import { MessageTools, type IMessage } from "../../messages"
import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks";
import { Modal } from "@common/components/modal/modal";
import { DOMTools } from "../../tools/dom";
import { DOMTools } from "../../dom";
import styles from './minichat.module.css';
import { LLMContext } from "../../contexts/llm";
import { FormattedMessage } from "../message/formattedMessage";
import { AutoTextarea } from "../autoTextarea";
import { useBool } from "@common/hooks/useBool";
interface IProps {
open: boolean;
@ -17,10 +16,9 @@ interface IProps {
}
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
const { stopGeneration, generate, compilePrompt } = useContext(LLMContext);
const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
const [messages, setMessages] = useState<IMessage[]>([]);
const ref = useRef<HTMLDivElement>(null);
const generating = useBool();
const answer = useMemo(() =>
MessageTools.getSwipe(messages.filter(m => m.role === 'assistant').at(-1))?.content,
@ -35,7 +33,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
useEffect(() => {
setTimeout(() => DOMTools.scrollDown(ref.current, false), 100);
}, [generating.value, open]);
}, [generating, open]);
useEffect(() => {
DOMTools.scrollDown(ref.current, false);
@ -49,21 +47,19 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
}, [messages.length, handleInit]);
const handleGenerate = useCallback(async () => {
if (messages.length > 0 && !generating.value) {
if (messages.length > 0 && !generating) {
const promptMessages: IMessage[] = [...history, ...messages];
const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1, continueLast: true });
const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1 });
let text = '';
const messageId = messages.length;
const newMessages = [...messages, MessageTools.create('', 'assistant', true)];
setMessages(newMessages);
generating.setTrue();
for await (const chunk of generate(prompt)) {
text += chunk;
setMessages(MessageTools.updateSwipe(newMessages, messageId, { content: text.trim() }));
}
generating.setFalse();
setMessages([
...MessageTools.updateSwipe(newMessages, messageId, { content: MessageTools.trimSentence(text) }),
@ -94,7 +90,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
<div class={styles.minichat} ref={ref}>
<div class={styles.messages}>
{messages.map((m, i) => (
generating.value
generating
? <FormattedMessage key={i} class={`${styles[m.role]} ${styles.message}`}>
{MessageTools.getSwipe(m)?.content ?? ''}
</FormattedMessage>
@ -109,18 +105,18 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
</div>
</div>
<div class={styles.buttons}>
{generating.value
{generating
? <button onClick={stopGeneration}>Stop</button>
: <button onClick={handleGenerate}>Generate</button>
}
<button onClick={() => handleInit()} class={`${generating.value ? 'disabled' : ''}`}>
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
Clear
</button>
{Object.entries(buttons).map(([label, onClick], i) => (
<button
key={i}
onClick={() => onClick(answer ?? '')}
class={`${(generating.value || !answer) ? 'disabled' : ''}`}
class={`${(generating || !answer) ? 'disabled' : ''}`}
>
{label}
</button>

View File

@ -1,9 +1,7 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { throttle } from "@common/utils";
import delay from "delay";
import { Huggingface } from "./huggingface";
import { approximateTokens, normalizeModel } from "./model";
import delay, { clearDelay } from "delay";
interface IBaseConnection {
instruct: string;
@ -81,6 +79,34 @@ const MAX_HORDE_LENGTH = 512;
const MAX_HORDE_CONTEXT = 32000;
export const HORDE_ANON_KEY = '0000000000';
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
export namespace Connection {
@ -145,11 +171,7 @@ export namespace Connection {
sse.close();
}
async function* generateHorde(connection: IHordeConnection, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
if (!connection.model) {
throw new Error('Horde not connected');
}
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
@ -170,11 +192,9 @@ export namespace Connection {
models: model.hordeNames,
workers: model.workers,
};
const bannedTokens = requestData.params.banned_tokens ?? [];
const { signal } = abortController;
while (true) {
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
method: 'POST',
body: JSON.stringify(requestData),
@ -207,41 +227,19 @@ export namespace Connection {
};
const deleteRequest = async () => (await request('DELETE')) ?? '';
let text: string | null = null;
while (!text) {
while (true) {
try {
await delay(2500, { signal });
text = await request();
const text = await request();
if (text) {
const locaseText = text.toLowerCase();
let unsloppedText = text;
for (const ban of bannedTokens) {
const slopIdx = locaseText.indexOf(ban.toLowerCase());
if (slopIdx >= 0) {
console.log(`[horde] slop '${ban}' detected at ${slopIdx}`);
unsloppedText = unsloppedText.slice(0, slopIdx);
}
}
yield unsloppedText;
requestData.prompt += unsloppedText;
if (unsloppedText === text) {
return; // we are finished
}
if (unsloppedText.length === 0) {
requestData.params.temperature += 0.05;
}
return text;
}
} catch (e) {
console.error('Error in horde generation:', e);
return yield deleteRequest();
}
return deleteRequest();
}
}
}
@ -253,7 +251,7 @@ export namespace Connection {
if (isKoboldConnection(connection)) {
yield* generateKobold(connection.url, prompt, extraSettings);
} else if (isHordeConnection(connection)) {
yield* generateHorde(connection, prompt, extraSettings);
yield await generateHorde(connection, prompt, extraSettings);
}
}
@ -279,7 +277,7 @@ export namespace Connection {
for (const worker of goodWorkers) {
for (const modelName of worker.models) {
const normName = normalizeModel(modelName);
const normName = normalizeModel(modelName.toLowerCase());
let model = models.get(normName);
if (!model) {
model = {
@ -345,7 +343,7 @@ export namespace Connection {
} catch (e) {
console.error('Error getting max tokens', e);
}
} else if (isHordeConnection(connection) && connection.model) {
} else if (isHordeConnection(connection)) {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
@ -369,18 +367,7 @@ export namespace Connection {
return value;
}
} catch (e) {
console.error('Error counting tokens:', e);
}
} else {
const model = await getModelName(connection);
const tokenizer = await Huggingface.findTokenizer(model);
if (tokenizer) {
try {
const { input_ids } = await tokenizer(prompt);
return input_ids.data.length;
} catch (e) {
console.error('Error counting tokens with tokenizer:', e);
}
console.error('Error counting tokens', e);
}
}

View File

@ -1,17 +1,17 @@
import { createContext } from "preact";
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../tools/messages";
import { MessageTools, type IMessage } from "../messages";
import { StateContext } from "./state";
import { useBool } from "@common/hooks/useBool";
import { Huggingface } from "../tools/huggingface";
import { Connection, type IGenerationSettings } from "../tools/connection";
import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface";
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
import { throttle } from "@common/utils";
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
import { approximateTokens, normalizeModel } from "../tools/model";
interface ICompileArgs {
keepUsers?: number;
continueLast?: boolean;
raw?: boolean;
}
interface ICompiledPrompt {
@ -48,8 +48,8 @@ const processing = {
export const LLMContextProvider = ({ children }: { children?: any }) => {
const {
connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, setContinueLast, addMessage, editMessage, editSummary,
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, addMessage, editMessage, editSummary,
} = useContext(StateContext);
const generating = useBool(false);
@ -58,27 +58,26 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
const [modelName, setModelName] = useState('');
const [hasToolCalls, setHasToolCalls] = useState(false);
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
const userPromptTemplate = useMemo(() => {
try {
return new Template(userPrompt)
} catch {
return {
render: () => userPrompt,
}
}
}, [userPrompt]);
const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => {
const lastMessage = messages.at(-1);
const lastMessageContent = MessageTools.getSwipe(lastMessage)?.content;
compilePrompt: async (messages, { keepUsers } = {}) => {
const promptMessages = messages.slice();
const lastMessage = promptMessages.at(-1);
const isAssistantLast = lastMessage?.role === 'assistant';
let isRegen = continueLast;
if (!isAssistantLast) {
isRegen = false;
} else if (!lastMessageContent) {
isRegen = true;
}
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
const isContinue = isAssistantLast && !isRegen;
const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice();
if (isContinue) {
promptMessages.push(MessageTools.create(Huggingface.applyTemplate(userPrompt, {})));
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
}
const userMessages = promptMessages.filter(m => m.role === 'user');
@ -105,7 +104,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
} else if (role === 'user' && !message.technical) {
templateMessages.push({
role: message.role,
content: Huggingface.applyTemplate(userPrompt, { prompt: content, isStart: !wasStory }),
content: userPromptTemplate.render({ prompt: content, isStart: !wasStory }),
});
} else {
if (role === 'assistant') {
@ -129,17 +128,17 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
if (story.length > 0) {
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
templateMessages.push({ role: 'user', content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }) });
templateMessages.push({ role: 'user', content: userPromptTemplate.render({ prompt, isStart: true }) });
templateMessages.push({ role: 'assistant', content: story });
}
let userMessage = MessageTools.getSwipe(lastUserMessage)?.content;
if (!lastUserMessage?.technical && !isContinue && userMessage) {
userMessage = Huggingface.applyTemplate(userPrompt, { prompt: userMessage, isStart: story.length === 0 });
let userPrompt = MessageTools.getSwipe(lastUserMessage)?.content;
if (!lastUserMessage?.technical && !isContinue && userPrompt) {
userPrompt = userPromptTemplate.render({ prompt: userPrompt, isStart: story.length === 0 });
}
if (userMessage) {
templateMessages.push({ role: 'user', content: userMessage });
if (userPrompt) {
templateMessages.push({ role: 'user', content: userPrompt });
}
}
@ -148,18 +147,13 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
templateMessages.splice(1, 0, {
role: 'user',
content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }),
content: userPromptTemplate.render({ prompt, isStart: true }),
});
}
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
let prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
if (isRegen) {
prompt += lastMessageContent;
}
const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
return {
prompt,
isContinue,
@ -202,21 +196,18 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
stopGeneration: () => {
Connection.stopGeneration();
},
}), [connection, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt]);
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
useAsyncEffect(async () => {
if (isOnline && triggerNext && !generating.value) {
if (triggerNext && !generating.value) {
setTriggerNext(false);
setContinueLast(false);
let messageId = messages.length - 1;
let text = '';
let text: string = '';
const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast });
const { prompt, isRegen } = await actions.compilePrompt(messages);
if (isRegen) {
text = MessageTools.getSwipe(messages.at(-1))?.content ?? '';
} else {
if (!isRegen) {
addMessage('', 'assistant');
messageId++;
}
@ -236,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
MessageTools.playReady();
}
}, [triggerNext, isOnline]);
}, [triggerNext]);
useAsyncEffect(async () => {
if (isOnline && summaryEnabled && !processing.summarizing) {
if (summaryEnabled && !processing.summarizing) {
try {
processing.summarizing = true;
for (let id = 0; id < messages.length; id++) {
@ -256,7 +247,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.summarizing = false;
}
}
}, [messages, summaryEnabled, isOnline]);
}, [messages, summaryEnabled]);
useEffect(throttle(() => {
Connection.getContextLength(connection).then(setContextLength);
@ -264,7 +255,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}, 1000, true), [connection]);
const calculateTokens = useCallback(throttle(async () => {
if (isOnline && !processing.tokenizing && !generating.value) {
if (!processing.tokenizing && !generating.value) {
try {
processing.tokenizing = true;
const { prompt } = await actions.compilePrompt(messages);
@ -276,11 +267,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.tokenizing = false;
}
}
}, 1000, true), [actions, messages, isOnline]);
}, 1000, true), [actions, messages]);
useEffect(() => {
calculateTokens();
}, [messages, connection, systemPrompt, lore, userPrompt, isOnline]);
}, [messages, connection, systemPrompt, lore, userPrompt]);
useEffect(() => {
try {

View File

@ -1,8 +1,8 @@
import { createContext } from "preact";
import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../tools/messages";
import { MessageTools, type IMessage } from "../messages";
import { useInputState } from "@common/hooks/useInputState";
import { type IConnection } from "../tools/connection";
import { type IConnection } from "../connection";
interface IContext {
currentConnection: number;
@ -15,9 +15,7 @@ interface IContext {
summaryEnabled: boolean;
bannedWords: string[];
messages: IMessage[];
//
triggerNext: boolean;
continueLast: boolean;
}
interface IComputableContext {
@ -35,10 +33,8 @@ interface IActions {
setUserPrompt: (prompt: string | Event) => void;
setSummarizePrompt: (prompt: string | Event) => void;
setBannedWords: (words: string[]) => void;
setSummaryEnabled: (summaryEnabled: boolean) => void;
setTriggerNext: (triggerNext: boolean) => void;
setContinueLast: (continueLast: boolean) => void;
setSummaryEnabled: (summaryEnabled: boolean) => void;
setMessages: (messages: IMessage[]) => void;
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
@ -48,7 +44,7 @@ interface IActions {
setCurrentSwipe: (index: number, swipe: number) => void;
addSwipe: (index: number, content: string) => void;
continueMessage: (continueLast?: boolean) => void;
continueMessage: () => void;
}
const SAVE_KEY = 'ai_game_save_state';
@ -83,7 +79,7 @@ Continue the story forward.
{%- endif %}
{% if prompt -%}
This is the description of what should happen next in your answer: {{ prompt | trim }}
This is the description of What should happen next in your answer: {{ prompt | trim }}
{% endif %}
Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
@ -92,13 +88,11 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
bannedWords: [],
messages: [],
triggerNext: false,
continueLast: false,
};
export const saveContext = (context: IContext) => {
const contextToSave: Partial<IContext> = { ...context };
delete contextToSave.triggerNext;
delete contextToSave.continueLast;
localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
}
@ -136,7 +130,6 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
const [triggerNext, setTriggerNext] = useState(false);
const [continueLast, setContinueLast] = useState(false);
const [instruct, setInstruct] = useInputState(connection.instruct);
const setConnection = useCallback((c: IConnection) => {
@ -160,11 +153,8 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
setUserPrompt,
setSummarizePrompt,
setLore,
setSummaryEnabled,
setTriggerNext,
setContinueLast,
setSummaryEnabled,
setBannedWords: (words) => setBannedWords(words.slice()),
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
@ -234,10 +224,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
}
)
),
continueMessage: (c = false) => {
setTriggerNext(true);
setContinueLast(c);
},
continueMessage: () => setTriggerNext(true),
}), []);
const rawContext: IContext & IComputableContext = {
@ -252,9 +239,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
summaryEnabled,
bannedWords,
messages,
//
triggerNext,
continueLast,
};
const context = useMemo(() => rawContext, Object.values(rawContext));

View File

@ -1,8 +1,7 @@
import { gguf } from '@huggingface/gguf';
import * as hub from '@huggingface/hub';
import { Template } from '@huggingface/jinja';
import { AutoTokenizer, PreTrainedTokenizer } from '@huggingface/transformers';
import { normalizeModel } from './model';
import { normalizeModel } from './connection';
export namespace Huggingface {
export interface ITemplateMessage {
@ -82,7 +81,6 @@ export namespace Huggingface {
const templateCache: Record<string, string> = loadCache();
const compiledTemplates = new Map<string, Template>();
const tokenizerCache = new Map<string, PreTrainedTokenizer | null>();
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
obj != null && typeof obj === 'object' && (field in obj)
@ -94,13 +92,13 @@ export namespace Huggingface {
);
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
modelName = normalizeModel(modelName);
console.log(`[huggingface] searching config for '${modelName}'`);
const searchModel = normalizeModel(modelName);
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
const models = hubModels.filter(m => {
if (m.gated) return false;
if (!normalizeModel(m.name).includes(modelName)) return false;
if (!normalizeModel(m.name).includes(searchModel)) return false;
return true;
}).sort((a, b) => b.downloads - a.downloads);
@ -118,8 +116,8 @@ export namespace Huggingface {
}
try {
console.log(`[huggingface] searching config in '${name}/tokenizer_config.json'`);
const fileResponse = await hub.downloadFile({ repo: name, path: 'tokenizer_config.json' });
console.log(`[huggingface] searching config in '${model.name}/tokenizer_config.json'`);
const fileResponse = await hub.downloadFile({ repo: model.name, path: 'tokenizer_config.json' });
if (fileResponse?.ok) {
const maybeConfig = await fileResponse.json();
if (isTokenizerConfig(maybeConfig)) {
@ -234,10 +232,10 @@ export namespace Huggingface {
}
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
modelName = normalizeModel(modelName);
if (!modelName) return '';
const modelKey = modelName.toLowerCase().trim();
if (!modelKey) return '';
let template = templateCache[modelName] ?? null;
let template = templateCache[modelKey] ?? null;
if (template) {
console.log(`[huggingface] found cached template for '${modelName}'`);
@ -251,58 +249,18 @@ export namespace Huggingface {
if (config.bos_token) {
template = template
.replaceAll(config.bos_token, '')
.replace(/\{\{ ?(''|"") ?\}\}/g, '');
}
}
}
templateCache[modelName] = template;
templateCache[modelKey] = template;
saveCache(templateCache);
return template;
}
export const findTokenizer = async (modelName: string): Promise<PreTrainedTokenizer | null> => {
modelName = normalizeModel(modelName);
let tokenizer = tokenizerCache.get(modelName) ?? null;
if (tokenizer) {
return tokenizer;
} else if (!tokenizerCache.has(modelName)) {
console.log(`[huggingface] searching tokenizer for '${modelName}'`);
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName } }));
const models = hubModels.filter(m => {
if (m.gated) return false;
if (m.name.toLowerCase().includes('gguf')) return false;
if (!normalizeModel(m.name).includes(modelName)) return false;
return true;
});
for (const model of models) {
const { name } = model;
try {
console.log(`[huggingface] searching tokenizer in '${name}'`);
tokenizer = await AutoTokenizer.from_pretrained(name);
break;
} catch { }
}
}
tokenizerCache.set(modelName, tokenizer);
if (tokenizer) {
console.log(`[huggingface] found tokenizer for '${modelName}'`);
} else {
console.log(`[huggingface] not found tokenizer for '${modelName}'`);
}
return tokenizer;
}
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => (
applyTemplate(templateString, {
messages,

View File

@ -1,4 +1,5 @@
import messageSound from '../assets/message.mp3';
import { Template } from "@huggingface/jinja";
import messageSound from './assets/message.mp3';
export interface ISwipe {
content: string;

View File

@ -1,27 +0,0 @@
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0).toLowerCase();
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const approximateTokens = (prompt: string): number => Math.round(prompt.length / 4);