1
0
Fork 0

Compare commits

...

2 Commits

Author SHA1 Message Date
Pabloader 277b315795 AIStory: stopping 2024-11-12 16:32:52 +00:00
Pabloader 017ef7aaa5 AIStory: add basic horde support 2024-11-12 13:32:35 +00:00
16 changed files with 737 additions and 297 deletions

BIN
bun.lockb

Binary file not shown.

View File

@ -14,6 +14,7 @@
"@inquirer/select": "2.3.10", "@inquirer/select": "2.3.10",
"ace-builds": "1.36.3", "ace-builds": "1.36.3",
"classnames": "2.5.1", "classnames": "2.5.1",
"delay": "6.0.0",
"preact": "10.22.0" "preact": "10.22.0"
}, },
"devDependencies": { "devDependencies": {

View File

@ -0,0 +1,4 @@
import { useEffect } from "preact/hooks";
export const useAsyncEffect = (fx: () => any, deps: any[]) =>
useEffect(() => void fx(), deps);

View File

@ -0,0 +1,16 @@
import { useCallback } from "preact/hooks";
export function useInputCallback<T>(callback: (value: string) => T, deps: any[]): ((value: string | Event) => T) {
return useCallback((e: Event | string) => {
if (typeof e === 'string') {
return callback(e);
} else {
const { target } = e;
if (target && 'value' in target && typeof target.value === 'string') {
return callback(target.value);
}
}
return callback('');
}, deps);
}

View File

@ -1,4 +1,3 @@
export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve)); export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random()); export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
@ -50,3 +49,32 @@ export const intHash = (seed: number, ...parts: number[]) => {
return h1; return h1;
}; };
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number, trailing = false): F {
let isThrottled = false;
let savedResult: R;
let savedThis: T;
let savedArgs: A | undefined;
const wrapper: F = function (...args: A) {
if (isThrottled) {
savedThis = this;
savedArgs = args;
} else {
savedResult = func.apply(this, args);
savedArgs = undefined;
isThrottled = true;
setTimeout(function () {
isThrottled = false;
if (trailing && savedArgs) {
savedResult = wrapper.apply(savedThis, savedArgs);
}
}, ms);
}
return savedResult;
} as F;
return wrapper;
}

View File

@ -32,6 +32,10 @@ select {
outline: none; outline: none;
} }
option, optgroup {
background-color: var(--backgroundColor);
}
textarea { textarea {
resize: vertical; resize: vertical;
width: 100%; width: 100%;

View File

@ -15,7 +15,7 @@ export const Chat = () => {
const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant'); const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant');
useEffect(() => { useEffect(() => {
DOMTools.scrollDown(chatRef.current); setTimeout(() => DOMTools.scrollDown(chatRef.current, false), 100);
}, [messages.length, lastMessageContent]); }, [messages.length, lastMessageContent]);
return ( return (

View File

@ -0,0 +1,146 @@
import { useCallback, useContext, useEffect, useMemo, useState } from 'preact/hooks';
import styles from './header.module.css';
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection';
import { Instruct, StateContext } from '../../contexts/state';
import { useInputState } from '@common/hooks/useInputState';
import { useInputCallback } from '@common/hooks/useInputCallback';
import { Huggingface } from '../../huggingface';
interface IProps {
connection: IConnection;
setConnection: (c: IConnection) => void;
}
export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
const [connectionUrl, setConnectionUrl] = useInputState('');
const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY);
const [modelName, setModelName] = useInputState('');
const [modelTemplate, setModelTemplate] = useInputState(Instruct.CHATML);
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
const [contextLength, setContextLength] = useState<number>(0);
const backendType = useMemo(() => {
if (isKoboldConnection(connection)) return 'kobold';
if (isHordeConnection(connection)) return 'horde';
return 'unknown';
}, [connection]);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
useEffect(() => {
if (isKoboldConnection(connection)) {
setConnectionUrl(connection.url);
} else if (isHordeConnection(connection)) {
setModelName(connection.model);
setApiKey(connection.apiKey || HORDE_ANON_KEY);
Connection.getHordeModels()
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
}
Connection.getContextLength(connection).then(setContextLength);
Connection.getModelName(connection).then(setModelName);
}, [connection]);
useEffect(() => {
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then(template => {
if (template) {
setModelTemplate(template);
}
});
}
}, [modelName]);
const setInstruct = useInputCallback((instruct) => {
setConnection({ ...connection, instruct });
}, [connection, setConnection]);
const setBackendType = useInputCallback((type) => {
if (type === 'kobold') {
setConnection({
instruct: connection.instruct,
url: connectionUrl,
});
} else if (type === 'horde') {
setConnection({
instruct: connection.instruct,
apiKey,
model: modelName,
});
}
}, [connection, setConnection, connectionUrl, apiKey, modelName]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i;
const url = connectionUrl.replace(regex, 'http$1://$2');
setConnection({
instruct: connection.instruct,
url,
});
}, [connection, connectionUrl, setConnection]);
const handleBlurHorde = useCallback(() => {
setConnection({
instruct: connection.instruct,
apiKey,
model: modelName,
});
}, [connection, apiKey, modelName, setConnection]);
return (
<div class={styles.connectionEditor}>
<select value={backendType} onChange={setBackendType}>
<option value='kobold'>Kobold CPP</option>
<option value='horde'>Horde</option>
</select>
<select value={connection.instruct} onChange={setInstruct} title='Instruct template'>
{modelName && modelTemplate && <optgroup label='Native model template'>
<option value={modelTemplate} title='Native for model'>{modelName}</option>
</optgroup>}
<optgroup label='Manual templates'>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</optgroup>
<optgroup label='Custom'>
<option value={connection.instruct}>Custom</option>
</optgroup>
</select>
{isKoboldConnection(connection) && <input
value={connectionUrl}
onInput={setConnectionUrl}
onBlur={handleBlurUrl}
class={urlValid ? styles.valid : styles.invalid}
/>}
{isHordeConnection(connection) && <>
<input
placeholder='Horde API key'
title='Horde API key'
value={apiKey}
onInput={setApiKey}
onBlur={handleBlurHorde}
/>
<select
value={modelName}
onChange={setModelName}
onBlur={handleBlurHorde}
title='Horde model'
>
{hordeModels.map((m) => (
<option value={m.name} key={m.name}>
{m.name} ({m.maxLength}/{m.maxContext})
</option>
))}
</select>
</>}
</div>
);
};

View File

@ -45,3 +45,10 @@
overflow: hidden; overflow: hidden;
} }
} }
.connectionEditor {
display: flex;
flex-direction: row;
gap: 8px;
flex-wrap: wrap;
}

View File

@ -2,35 +2,29 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Modal } from "@common/components/modal/modal"; import { Modal } from "@common/components/modal/modal";
import { Instruct, StateContext } from "../../contexts/state"; import { StateContext } from "../../contexts/state";
import { LLMContext } from "../../contexts/llm"; import { LLMContext } from "../../contexts/llm";
import { MiniChat } from "../minichat/minichat"; import { MiniChat } from "../minichat/minichat";
import { AutoTextarea } from "../autoTextarea"; import { AutoTextarea } from "../autoTextarea";
import { Ace } from "../ace";
import { ConnectionEditor } from "./connectionEditor";
import styles from './header.module.css'; import styles from './header.module.css';
import { Ace } from "../ace";
export const Header = () => { export const Header = () => {
const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext); const { contextLength, promptTokens, modelName } = useContext(LLMContext);
const { const {
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, messages, connection, systemPrompt, lore, userPrompt, bannedWords, summarizePrompt, summaryEnabled,
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setConnection,
} = useContext(StateContext); } = useContext(StateContext);
const connectionsOpen = useBool();
const loreOpen = useBool(); const loreOpen = useBool();
const promptsOpen = useBool(); const promptsOpen = useBool();
const genparamsOpen = useBool(); const genparamsOpen = useBool();
const assistantOpen = useBool(); const assistantOpen = useBool();
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]); const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i
const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2');
setConnectionUrl(normalizedConnectionUrl);
blockConnection.setFalse();
}, [connectionUrl, setConnectionUrl, blockConnection]);
const handleAssistantAddSwipe = useCallback((answer: string) => { const handleAssistantAddSwipe = useCallback((answer: string) => {
const index = messages.findLastIndex(m => m.role === 'assistant'); const index = messages.findLastIndex(m => m.role === 'assistant');
@ -61,29 +55,13 @@ export const Header = () => {
return ( return (
<div class={styles.header}> <div class={styles.header}>
<div class={styles.inputs}> <div class={styles.inputs}>
<input value={connectionUrl} <div class={styles.buttons}>
onInput={setConnectionUrl} <button class='icon' onClick={connectionsOpen.setTrue} title='Connection settings'>
onFocus={blockConnection.setTrue} 🔌
onBlur={handleBlurUrl} </button>
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid} </div>
/>
<select value={instruct} onChange={setInstruct} title='Instruct template'>
{modelName && modelTemplate && <optgroup label='Native model template'>
<option value={modelTemplate} title='Native for model'>{modelName}</option>
</optgroup>}
<optgroup label='Manual templates'>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</optgroup>
<optgroup label='Custom'>
<option value={instruct}>Custom</option>
</optgroup>
</select>
<div class={styles.info}> <div class={styles.info}>
{promptTokens} / {contextLength} {modelName} - {promptTokens} / {contextLength}
</div> </div>
</div> </div>
<div class={styles.buttons}> <div class={styles.buttons}>
@ -102,6 +80,10 @@ export const Header = () => {
</button> </button>
</div> </div>
<Modal open={connectionsOpen.value} onClose={connectionsOpen.setFalse}>
<h3 class={styles.modalTitle}>Connection settings</h3>
<ConnectionEditor connection={connection} setConnection={setConnection} />
</Modal>
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}> <Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
<h3 class={styles.modalTitle}>Lore Editor</h3> <h3 class={styles.modalTitle}>Lore Editor</h3>
<AutoTextarea <AutoTextarea
@ -140,7 +122,7 @@ export const Header = () => {
</label> </label>
<hr /> <hr />
<h4 class={styles.modalTitle}>Instruct template</h4> <h4 class={styles.modalTitle}>Instruct template</h4>
<Ace value={instruct} onInput={setInstruct} /> <Ace value={connection.instruct} onInput={setInstruct} />
</div> </div>
</Modal> </Modal>
<MiniChat <MiniChat

View File

@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea";
export const Input = () => { export const Input = () => {
const { input, setInput, addMessage, continueMessage } = useContext(StateContext); const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
const { generating } = useContext(LLMContext); const { generating, stopGeneration } = useContext(LLMContext);
const handleSend = useCallback(async () => { const handleSend = useCallback(async () => {
if (!generating) { if (!generating) {
@ -29,7 +29,10 @@ export const Input = () => {
return ( return (
<div class="chat-input"> <div class="chat-input">
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} /> <AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
<button onClick={handleSend} class={`${generating ? 'disabled' : ''}`}>{input ? 'Send' : 'Continue'}</button> {generating
? <button onClick={stopGeneration}>Stop</button>
: <button onClick={handleSend}>{input ? 'Send' : 'Continue'}</button>
}
</div> </div>
); );
} }

View File

@ -16,7 +16,7 @@ interface IProps {
} }
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => { export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
const { generating, generate, compilePrompt } = useContext(LLMContext); const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
const [messages, setMessages] = useState<IMessage[]>([]); const [messages, setMessages] = useState<IMessage[]>([]);
const ref = useRef<HTMLDivElement>(null); const ref = useRef<HTMLDivElement>(null);
@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
</div> </div>
</div> </div>
<div class={styles.buttons}> <div class={styles.buttons}>
<button onClick={handleGenerate} class={`${generating ? 'disabled' : ''}`}> {generating
Generate ? <button onClick={stopGeneration}>Stop</button>
</button> : <button onClick={handleGenerate}>Generate</button>
}
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}> <button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
Clear Clear
</button> </button>

376
src/games/ai/connection.ts Normal file
View File

@ -0,0 +1,376 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { throttle } from "@common/utils";
import delay, { clearDelay } from "delay";
interface IBaseConnection {
instruct: string;
}
interface IKoboldConnection extends IBaseConnection {
url: string;
}
interface IHordeConnection extends IBaseConnection {
apiKey?: string;
model: string;
}
export const isKoboldConnection = (obj: unknown): obj is IKoboldConnection => (
obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string'
);
export const isHordeConnection = (obj: unknown): obj is IHordeConnection => (
obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string'
);
export type IConnection = IKoboldConnection | IHordeConnection;
interface IHordeWorker {
id: string;
models: string[];
flagged: boolean;
online: boolean;
maintenance_mode: boolean;
max_context_length: number;
max_length: number;
performance: string;
}
export interface IHordeModel {
name: string;
hordeNames: string[];
maxLength: number;
maxContext: number;
workers: string[];
}
interface IHordeResult {
faulted: boolean;
done: boolean;
finished: number;
generations?: {
text: string;
}[];
}
const DEFAULT_GENERATION_SETTINGS = {
temperature: 0.8,
min_p: 0.1,
rep_pen: 1.08,
rep_pen_range: -1,
rep_pen_slope: 0.7,
top_k: 100,
top_p: 0.92,
banned_tokens: ['anticipat'],
max_length: 300,
trim_stop: true,
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
dry_allowed_length: 5,
dry_multiplier: 0.8,
dry_base: 1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
dry_penalty_last_n: 0
}
const MIN_PERFORMANCE = 2.0;
const MIN_WORKER_CONTEXT = 8192;
const MAX_HORDE_LENGTH = 512;
const MAX_HORDE_CONTEXT = 32000;
export const HORDE_ANON_KEY = '0000000000';
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
export namespace Connection {
const AIHORDE = 'https://aihorde.net';
let abortController = new AbortController();
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
const sse = new SSE(`${url}/api/extra/generate/stream`, {
payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
prompt,
}),
});
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
abortController.signal.addEventListener('abort', handleEnd);
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
}
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
let maxLength = Math.min(model.maxLength, DEFAULT_GENERATION_SETTINGS.max_length);
if (extraSettings.max_length && extraSettings.max_length < maxLength) {
maxLength = extraSettings.max_length;
}
const requestData = {
prompt,
params: {
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
n: 1,
max_context_length: model.maxContext,
max_length: maxLength,
rep_pen_range: Math.min(model.maxContext, 4096),
},
models: model.hordeNames,
workers: model.workers,
};
const { signal } = abortController;
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
method: 'POST',
body: JSON.stringify(requestData),
headers: {
'Content-Type': 'application/json',
apikey: connection.apiKey || HORDE_ANON_KEY,
},
signal,
});
if (!generateResponse.ok || generateResponse.status >= 400) {
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
}
const { id } = await generateResponse.json() as { id: string };
const request = async (method = 'GET'): Promise<string | null> => {
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
if (response.ok && response.status < 400) {
const result: IHordeResult = await response.json();
if (result.generations?.length === 1) {
const { text } = result.generations[0];
return text;
}
} else {
throw new Error(await response.text());
}
return null;
};
const deleteRequest = async () => (await request('DELETE')) ?? '';
while (true) {
try {
await delay(2500, { signal });
const text = await request();
if (text) {
return text;
}
} catch (e) {
console.error('Error in horde generation:', e);
return deleteRequest();
}
}
}
throw new Error(`Model ${connection.model} is offline`);
}
export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) {
if (isKoboldConnection(connection)) {
yield* generateKobold(connection.url, prompt, extraSettings);
} else if (isHordeConnection(connection)) {
yield await generateHorde(connection, prompt, extraSettings);
}
}
export function stopGeneration() {
abortController.abort();
abortController = new AbortController(); // refresh
}
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
try {
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
if (response.ok) {
const workers: IHordeWorker[] = await response.json();
const goodWorkers = workers.filter(w =>
w.online
&& !w.maintenance_mode
&& !w.flagged
&& w.max_context_length >= MIN_WORKER_CONTEXT
&& parseFloat(w.performance) >= MIN_PERFORMANCE
);
const models = new Map<string, IHordeModel>();
for (const worker of goodWorkers) {
for (const modelName of worker.models) {
const normName = normalizeModel(modelName.toLowerCase());
let model = models.get(normName);
if (!model) {
model = {
hordeNames: [],
maxContext: MAX_HORDE_CONTEXT,
maxLength: MAX_HORDE_LENGTH,
name: normName,
workers: []
}
}
if (!model.hordeNames.includes(modelName)) {
model.hordeNames.push(modelName);
}
if (!model.workers.includes(worker.id)) {
model.workers.push(worker.id);
}
model.maxContext = Math.min(model.maxContext, worker.max_context_length);
model.maxLength = Math.min(model.maxLength, worker.max_length);
models.set(normName, model);
}
}
return models;
}
} catch (e) {
console.error(e);
}
return new Map();
};
export const getHordeModels = throttle(requestHordeModels, 10000);
export async function getModelName(connection: IConnection): Promise<string> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/v1/model`);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.error('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
return connection.model;
}
return '';
}
export async function getContextLength(connection: IConnection): Promise<number> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.error('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
return model.maxContext;
}
}
return 0;
}
export async function countTokens(connection: IConnection, prompt: string) {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.error('Error counting tokens', e);
}
}
return approximateTokens(prompt);
}
}

View File

@ -1,12 +1,13 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { createContext } from "preact"; import { createContext } from "preact";
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks"; import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages"; import { MessageTools, type IMessage } from "../messages";
import { Instruct, StateContext } from "./state"; import { StateContext } from "./state";
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Template } from "@huggingface/jinja"; import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface"; import { Huggingface } from "../huggingface";
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
import { throttle } from "@common/utils";
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
interface ICompileArgs { interface ICompileArgs {
keepUsers?: number; keepUsers?: number;
@ -21,71 +22,23 @@ interface ICompiledPrompt {
interface IContext { interface IContext {
generating: boolean; generating: boolean;
blockConnection: ReturnType<typeof useBool>;
modelName: string; modelName: string;
modelTemplate: string;
hasToolCalls: boolean; hasToolCalls: boolean;
promptTokens: number; promptTokens: number;
contextLength: number; contextLength: number;
} }
const DEFAULT_GENERATION_SETTINGS = {
temperature: 0.8,
min_p: 0.1,
rep_pen: 1.08,
rep_pen_range: -1,
rep_pen_slope: 0.7,
top_k: 100,
top_p: 0.92,
banned_tokens: [],
max_length: 300,
trim_stop: true,
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
dry_allowed_length: 5,
dry_multiplier: 0.8,
dry_base: 1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
dry_penalty_last_n: 0
}
const MESSAGES_TO_KEEP = 10; const MESSAGES_TO_KEEP = 10;
type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
interface IActions { interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>; compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
stopGeneration: () => void;
summarize: (content: string) => Promise<string>; summarize: (content: string) => Promise<string>;
countTokens: (prompt: string) => Promise<number>; countTokens: (prompt: string) => Promise<number>;
} }
export type ILLMContext = IContext & IActions; export type ILLMContext = IContext & IActions;
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const LLMContext = createContext<ILLMContext>({} as ILLMContext); export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
const processing = { const processing = {
@ -95,16 +48,14 @@ const processing = {
export const LLMContextProvider = ({ children }: { children?: any }) => { export const LLMContextProvider = ({ children }: { children?: any }) => {
const { const {
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, addMessage, editMessage, editSummary, setInstruct, setTriggerNext, addMessage, editMessage, editSummary,
} = useContext(StateContext); } = useContext(StateContext);
const generating = useBool(false); const generating = useBool(false);
const blockConnection = useBool(false);
const [promptTokens, setPromptTokens] = useState(0); const [promptTokens, setPromptTokens] = useState(0);
const [contextLength, setContextLength] = useState(0); const [contextLength, setContextLength] = useState(0);
const [modelName, setModelName] = useState(''); const [modelName, setModelName] = useState('');
const [modelTemplate, setModelTemplate] = useState('');
const [hasToolCalls, setHasToolCalls] = useState(false); const [hasToolCalls, setHasToolCalls] = useState(false);
const userPromptTemplate = useMemo(() => { const userPromptTemplate = useMemo(() => {
@ -117,40 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
} }
}, [userPrompt]); }, [userPrompt]);
const getContextLength = useCallback(async () => {
if (!connectionUrl || blockConnection.value) {
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return 0;
}, [connectionUrl, blockConnection.value]);
const getModelName = useCallback(async () => {
if (!connectionUrl || blockConnection.value) {
return '';
}
try {
const response = await fetch(`${connectionUrl}/api/v1/model`);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return '';
}, [connectionUrl, blockConnection.value]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers } = {}) => { compilePrompt: async (messages, { keepUsers } = {}) => {
const promptMessages = messages.slice(); const promptMessages = messages.slice();
@ -236,7 +153,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
const prompt = Huggingface.applyChatTemplate(instruct, templateMessages); const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
return { return {
prompt, prompt,
isContinue, isContinue,
@ -244,102 +161,44 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}; };
}, },
generate: async function* (prompt, extraSettings = {}) { generate: async function* (prompt, extraSettings = {}) {
if (!connectionUrl) {
return;
}
try { try {
generating.setTrue();
console.log('[LLM.generate]', prompt); console.log('[LLM.generate]', prompt);
const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, { yield* Connection.generate(connection, prompt, {
payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS,
banned_tokens: bannedWords.filter(w => w.trim()),
...extraSettings, ...extraSettings,
prompt, banned_tokens: bannedWords.filter(w => w.trim()),
}),
}); });
} catch (e) {
const messages: string[] = []; if (e instanceof Error && e.name !== 'AbortError') {
const messageLock = new Lock(); alert(e.message);
let end = false; } else {
console.error('[LLM.generate]', e);
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
} }
} }
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
} finally {
generating.setFalse();
}
}, },
summarize: async (message) => { summarize: async (message) => {
try {
const content = Huggingface.applyTemplate(summarizePrompt, { message }); const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(instruct, [{ role: 'user', content }]); const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
console.log('[LLM.summarize]', prompt);
const tokens = await Array.fromAsync(actions.generate(prompt)); const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
return MessageTools.trimSentence(tokens.join('')); return MessageTools.trimSentence(tokens.join(''));
} catch (e) {
console.error('Error summarizing:', e);
return '';
}
}, },
countTokens: async (prompt) => { countTokens: async (prompt) => {
if (!connectionUrl) { return await Connection.countTokens(connection, prompt);
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
}
return 0;
}, },
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct, summarizePrompt]); stopGeneration: () => {
Connection.stopGeneration();
},
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
useEffect(() => void (async () => { useAsyncEffect(async () => {
if (triggerNext && !generating.value) { if (triggerNext && !generating.value) {
setTriggerNext(false); setTriggerNext(false);
@ -353,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
messageId++; messageId++;
} }
generating.setTrue();
editSummary(messageId, 'Generating...'); editSummary(messageId, 'Generating...');
for await (const chunk of actions.generate(prompt)) { for await (const chunk of actions.generate(prompt)) {
text += chunk; text += chunk;
setPromptTokens(promptTokens + Math.round(text.length * 0.25)); setPromptTokens(promptTokens + approximateTokens(text));
editMessage(messageId, text.trim()); editMessage(messageId, text.trim());
} }
generating.setFalse();
text = MessageTools.trimSentence(text); text = MessageTools.trimSentence(text);
editMessage(messageId, text); editMessage(messageId, text);
@ -366,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
MessageTools.playReady(); MessageTools.playReady();
} }
})(), [triggerNext]); }, [triggerNext]);
useEffect(() => void (async () => { useAsyncEffect(async () => {
if (summaryEnabled && !generating.value && !processing.summarizing) { if (summaryEnabled && !processing.summarizing) {
try { try {
processing.summarizing = true; processing.summarizing = true;
for (let id = 0; id < messages.length; id++) { for (let id = 0; id < messages.length; id++) {
@ -386,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.summarizing = false; processing.summarizing = false;
} }
} }
})(), [messages]); }, [messages, summaryEnabled]);
useEffect(() => { useEffect(throttle(() => {
if (!blockConnection.value) { Connection.getContextLength(connection).then(setContextLength);
setPromptTokens(0); Connection.getModelName(connection).then(normalizeModel).then(setModelName);
setContextLength(0); }, 1000, true), [connection]);
setModelName('');
getContextLength().then(setContextLength); const calculateTokens = useCallback(throttle(async () => {
getModelName().then(normalizeModel).then(setModelName); if (!processing.tokenizing && !generating.value) {
}
}, [connectionUrl, blockConnection.value]);
useEffect(() => {
setModelTemplate('');
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then((template) => {
if (template) {
setModelTemplate(template);
setInstruct(template);
} else {
setInstruct(Instruct.CHATML);
}
});
}
}, [modelName]);
const calculateTokens = useCallback(async () => {
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
try { try {
processing.tokenizing = true; processing.tokenizing = true;
const { prompt } = await actions.compilePrompt(messages); const { prompt } = await actions.compilePrompt(messages);
@ -427,26 +267,24 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
processing.tokenizing = false; processing.tokenizing = false;
} }
} }
}, [actions, messages, blockConnection.value]); }, 1000, true), [actions, messages]);
useEffect(() => { useEffect(() => {
calculateTokens(); calculateTokens();
}, [messages, connectionUrl, blockConnection.value, instruct, /* systemPrompt, lore, userPrompt TODO debounce*/]); }, [messages, connection, systemPrompt, lore, userPrompt]);
useEffect(() => { useEffect(() => {
try { try {
const hasTools = Huggingface.testToolCalls(instruct); const hasTools = Huggingface.testToolCalls(connection.instruct);
setHasToolCalls(hasTools); setHasToolCalls(hasTools);
} catch { } catch {
setHasToolCalls(false); setHasToolCalls(false);
} }
}, [instruct]); }, [connection.instruct]);
const rawContext: IContext = { const rawContext: IContext = {
generating: generating.value, generating: generating.value,
blockConnection,
modelName, modelName,
modelTemplate,
hasToolCalls, hasToolCalls,
promptTokens, promptTokens,
contextLength, contextLength,

View File

@ -1,12 +1,13 @@
import { createContext } from "preact"; import { createContext } from "preact";
import { useEffect, useMemo, useState } from "preact/hooks"; import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages"; import { MessageTools, type IMessage } from "../messages";
import { useInputState } from "@common/hooks/useInputState"; import { useInputState } from "@common/hooks/useInputState";
import { type IConnection } from "../connection";
interface IContext { interface IContext {
connectionUrl: string; currentConnection: number;
availableConnections: IConnection[];
input: string; input: string;
instruct: string;
systemPrompt: string; systemPrompt: string;
lore: string; lore: string;
userPrompt: string; userPrompt: string;
@ -17,8 +18,14 @@ interface IContext {
triggerNext: boolean; triggerNext: boolean;
} }
interface IComputableContext {
connection: IConnection;
}
interface IActions { interface IActions {
setConnectionUrl: (url: string | Event) => void; setConnection: (connection: IConnection) => void;
setAvailableConnections: (connections: IConnection[]) => void;
setCurrentConnection: (connection: number) => void;
setInput: (url: string | Event) => void; setInput: (url: string | Event) => void;
setInstruct: (template: string | Event) => void; setInstruct: (template: string | Event) => void;
setLore: (lore: string | Event) => void; setLore: (lore: string | Event) => void;
@ -49,11 +56,40 @@ export enum Instruct {
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`, MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
METHARME = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>' + message['content'] }}{% elif message['role'] == 'user' %}{{'<|user|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{'<|model|>' + message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% endif %}`,
GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`, GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`,
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`, ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`,
}; };
const DEFAULT_CONTEXT: IContext = {
currentConnection: 0,
availableConnections: [{
url: 'http://localhost:5001',
instruct: Instruct.CHATML,
}],
input: '',
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
lore: '',
userPrompt: `{% if isStart -%}
Write a novel using information above as a reference.
{%- else -%}
Continue the story forward.
{%- endif %}
{% if prompt -%}
What should happen next in your answer: {{ prompt | trim }}
{% endif %}
Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
summaryEnabled: false,
bannedWords: [],
messages: [],
triggerNext: false,
};
export const saveContext = (context: IContext) => { export const saveContext = (context: IContext) => {
const contextToSave: Partial<IContext> = { ...context }; const contextToSave: Partial<IContext> = { ...context };
delete contextToSave.triggerNext; delete contextToSave.triggerNext;
@ -62,30 +98,6 @@ export const saveContext = (context: IContext) => {
} }
export const loadContext = (): IContext => { export const loadContext = (): IContext => {
const defaultContext: IContext = {
connectionUrl: 'http://localhost:5001',
input: '',
instruct: Instruct.CHATML,
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
lore: '',
userPrompt: `{% if isStart -%}
Write a novel using information above as a reference.
{%- else -%}
Continue the story forward.
{%- endif %}
{% if prompt -%}
This is the description of what I want to happen next: {{ prompt | trim }}
{% endif %}
Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
summaryEnabled: false,
bannedWords: [],
messages: [],
triggerNext: false,
};
let loadedContext: Partial<IContext> = {}; let loadedContext: Partial<IContext> = {};
try { try {
@ -95,18 +107,18 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
} }
} catch { } } catch { }
return { ...defaultContext, ...loadedContext }; return { ...DEFAULT_CONTEXT, ...loadedContext };
} }
export type IStateContext = IContext & IActions; export type IStateContext = IContext & IActions & IComputableContext;
export const StateContext = createContext<IStateContext>({} as IStateContext); export const StateContext = createContext<IStateContext>({} as IStateContext);
export const StateContextProvider = ({ children }: { children?: any }) => { export const StateContextProvider = ({ children }: { children?: any }) => {
const loadedContext = useMemo(() => loadContext(), []); const loadedContext = useMemo(() => loadContext(), []);
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl); const [currentConnection, setCurrentConnection] = useState<number>(loadedContext.currentConnection);
const [availableConnections, setAvailableConnections] = useState<IConnection[]>(loadedContext.availableConnections);
const [input, setInput] = useInputState(loadedContext.input); const [input, setInput] = useInputState(loadedContext.input);
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
const [lore, setLore] = useInputState(loadedContext.lore); const [lore, setLore] = useInputState(loadedContext.lore);
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
@ -115,10 +127,26 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const [messages, setMessages] = useState(loadedContext.messages); const [messages, setMessages] = useState(loadedContext.messages);
const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled); const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled);
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
const [triggerNext, setTriggerNext] = useState(false); const [triggerNext, setTriggerNext] = useState(false);
const [instruct, setInstruct] = useInputState(connection.instruct);
const setConnection = useCallback((c: IConnection) => {
setAvailableConnections(availableConnections.map((ac, ai) => {
if (ai === currentConnection) {
return c;
} else {
return ac;
}
}));
}, [availableConnections, currentConnection]);
useEffect(() => setConnection({ ...connection, instruct }), [instruct]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
setConnectionUrl, setConnection,
setCurrentConnection,
setInput, setInput,
setInstruct, setInstruct,
setSystemPrompt, setSystemPrompt,
@ -127,7 +155,8 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
setLore, setLore,
setTriggerNext, setTriggerNext,
setSummaryEnabled, setSummaryEnabled,
setBannedWords: (words) => setBannedWords([...words]), setBannedWords: (words) => setBannedWords(words.slice()),
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
setMessages: (newMessages) => setMessages(newMessages.slice()), setMessages: (newMessages) => setMessages(newMessages.slice()),
addMessage: (content, role, triggerNext = false) => { addMessage: (content, role, triggerNext = false) => {
@ -198,10 +227,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
continueMessage: () => setTriggerNext(true), continueMessage: () => setTriggerNext(true),
}), []); }), []);
const rawContext: IContext = { const rawContext: IContext & IComputableContext = {
connectionUrl, connection,
currentConnection,
availableConnections,
input, input,
instruct,
systemPrompt, systemPrompt,
lore, lore,
userPrompt, userPrompt,

View File

@ -1,6 +1,7 @@
import { gguf } from '@huggingface/gguf'; import { gguf } from '@huggingface/gguf';
import * as hub from '@huggingface/hub'; import * as hub from '@huggingface/hub';
import { Template } from '@huggingface/jinja'; import { Template } from '@huggingface/jinja';
import { normalizeModel } from './connection';
export namespace Huggingface { export namespace Huggingface {
export interface ITemplateMessage { export interface ITemplateMessage {
@ -92,11 +93,12 @@ export namespace Huggingface {
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => { const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
console.log(`[huggingface] searching config for '${modelName}'`); console.log(`[huggingface] searching config for '${modelName}'`);
const searchModel = normalizeModel(modelName);
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] })); const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
const models = hubModels.filter(m => { const models = hubModels.filter(m => {
if (m.gated) return false; if (m.gated) return false;
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false; if (!normalizeModel(m.name).includes(searchModel)) return false;
return true; return true;
}).sort((a, b) => b.downloads - a.downloads); }).sort((a, b) => b.downloads - a.downloads);
@ -230,7 +232,9 @@ export namespace Huggingface {
} }
export const findModelTemplate = async (modelName: string): Promise<string | null> => { export const findModelTemplate = async (modelName: string): Promise<string | null> => {
const modelKey = modelName.toLowerCase(); const modelKey = modelName.toLowerCase().trim();
if (!modelKey) return '';
let template = templateCache[modelKey] ?? null; let template = templateCache[modelKey] ?? null;
if (template) { if (template) {