1
0
Fork 0

AIStory: add basic horde support

This commit is contained in:
Pabloader 2024-11-12 13:32:35 +00:00
parent ece1621e73
commit 017ef7aaa5
11 changed files with 655 additions and 233 deletions

View File

@ -0,0 +1,16 @@
import { useCallback } from "preact/hooks";
export function useInputCallback<T>(callback: (value: string) => T, deps: any[]): ((value: string | Event) => T) {
return useCallback((e: Event | string) => {
if (typeof e === 'string') {
return callback(e);
} else {
const { target } = e;
if (target && 'value' in target && typeof target.value === 'string') {
return callback(target.value);
}
}
return callback('');
}, deps);
}

View File

@ -50,3 +50,22 @@ export const intHash = (seed: number, ...parts: number[]) => {
return h1; return h1;
}; };
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number): F {
let isThrottled = false;
let savedResult: R;
const wrapper: F = function (...args: A) {
if (!isThrottled) {
savedResult = func.apply(this, args);
isThrottled = true;
setTimeout(function () {
isThrottled = false;
}, ms);
}
return savedResult;
} as F;
return wrapper;
}

View File

@ -32,6 +32,10 @@ select {
outline: none; outline: none;
} }
option, optgroup {
background-color: var(--backgroundColor);
}
textarea { textarea {
resize: vertical; resize: vertical;
width: 100%; width: 100%;

View File

@ -15,7 +15,7 @@ export const Chat = () => {
const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant'); const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant');
useEffect(() => { useEffect(() => {
DOMTools.scrollDown(chatRef.current); setTimeout(() => DOMTools.scrollDown(chatRef.current, false), 100);
}, [messages.length, lastMessageContent]); }, [messages.length, lastMessageContent]);
return ( return (

View File

@ -0,0 +1,146 @@
import { useCallback, useContext, useEffect, useMemo, useState } from 'preact/hooks';
import styles from './header.module.css';
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection';
import { Instruct, StateContext } from '../../contexts/state';
import { useInputState } from '@common/hooks/useInputState';
import { useInputCallback } from '@common/hooks/useInputCallback';
import { Huggingface } from '../../huggingface';
interface IProps {
connection: IConnection;
setConnection: (c: IConnection) => void;
}
export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
const [connectionUrl, setConnectionUrl] = useInputState('');
const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY);
const [modelName, setModelName] = useInputState('');
const [modelTemplate, setModelTemplate] = useInputState(Instruct.CHATML);
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
const [contextLength, setContextLength] = useState<number>(0);
const backendType = useMemo(() => {
if (isKoboldConnection(connection)) return 'kobold';
if (isHordeConnection(connection)) return 'horde';
return 'unknown';
}, [connection]);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
useEffect(() => {
if (isKoboldConnection(connection)) {
setConnectionUrl(connection.url);
} else if (isHordeConnection(connection)) {
setModelName(connection.model);
setApiKey(connection.apiKey || HORDE_ANON_KEY);
Connection.getHordeModels()
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
}
Connection.getContextLength(connection).then(setContextLength);
Connection.getModelName(connection).then(setModelName);
}, [connection]);
useEffect(() => {
if (modelName) {
Huggingface.findModelTemplate(modelName)
.then(template => {
if (template) {
setModelTemplate(template);
}
});
}
}, [modelName]);
const setInstruct = useInputCallback((instruct) => {
setConnection({ ...connection, instruct });
}, [connection, setConnection]);
const setBackendType = useInputCallback((type) => {
if (type === 'kobold') {
setConnection({
instruct: connection.instruct,
url: connectionUrl,
});
} else if (type === 'horde') {
setConnection({
instruct: connection.instruct,
apiKey,
model: modelName,
});
}
}, [connection, setConnection, connectionUrl, apiKey, modelName]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i;
const url = connectionUrl.replace(regex, 'http$1://$2');
setConnection({
instruct: connection.instruct,
url,
});
}, [connection, connectionUrl, setConnection]);
const handleBlurHorde = useCallback(() => {
setConnection({
instruct: connection.instruct,
apiKey,
model: modelName,
});
}, [connection, apiKey, modelName, setConnection]);
return (
<div class={styles.connectionEditor}>
<select value={backendType} onChange={setBackendType}>
<option value='kobold'>Kobold CPP</option>
<option value='horde'>Horde</option>
</select>
<select value={connection.instruct} onChange={setInstruct} title='Instruct template'>
{modelName && modelTemplate && <optgroup label='Native model template'>
<option value={modelTemplate} title='Native for model'>{modelName}</option>
</optgroup>}
<optgroup label='Manual templates'>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</optgroup>
<optgroup label='Custom'>
<option value={connection.instruct}>Custom</option>
</optgroup>
</select>
{isKoboldConnection(connection) && <input
value={connectionUrl}
onInput={setConnectionUrl}
onBlur={handleBlurUrl}
class={urlValid ? styles.valid : styles.invalid}
/>}
{isHordeConnection(connection) && <>
<input
placeholder='Horde API key'
title='Horde API key'
value={apiKey}
onInput={setApiKey}
onBlur={handleBlurHorde}
/>
<select
value={modelName}
onChange={setModelName}
onBlur={handleBlurHorde}
title='Horde model'
>
{hordeModels.map((m) => (
<option value={m.name} key={m.name}>
{m.name} ({m.maxLength}/{m.maxContext})
</option>
))}
</select>
</>}
</div>
);
};

View File

@ -45,3 +45,10 @@
overflow: hidden; overflow: hidden;
} }
} }
.connectionEditor {
display: flex;
flex-direction: row;
gap: 8px;
flex-wrap: wrap;
}

View File

@ -2,35 +2,29 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Modal } from "@common/components/modal/modal"; import { Modal } from "@common/components/modal/modal";
import { Instruct, StateContext } from "../../contexts/state"; import { StateContext } from "../../contexts/state";
import { LLMContext } from "../../contexts/llm"; import { LLMContext } from "../../contexts/llm";
import { MiniChat } from "../minichat/minichat"; import { MiniChat } from "../minichat/minichat";
import { AutoTextarea } from "../autoTextarea"; import { AutoTextarea } from "../autoTextarea";
import { Ace } from "../ace";
import { ConnectionEditor } from "./connectionEditor";
import styles from './header.module.css'; import styles from './header.module.css';
import { Ace } from "../ace";
export const Header = () => { export const Header = () => {
const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext); const { contextLength, promptTokens, modelName } = useContext(LLMContext);
const { const {
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, messages, connection, systemPrompt, lore, userPrompt, bannedWords, summarizePrompt, summaryEnabled,
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setConnection,
} = useContext(StateContext); } = useContext(StateContext);
const connectionsOpen = useBool();
const loreOpen = useBool(); const loreOpen = useBool();
const promptsOpen = useBool(); const promptsOpen = useBool();
const genparamsOpen = useBool(); const genparamsOpen = useBool();
const assistantOpen = useBool(); const assistantOpen = useBool();
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]); const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
const handleBlurUrl = useCallback(() => {
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i
const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2');
setConnectionUrl(normalizedConnectionUrl);
blockConnection.setFalse();
}, [connectionUrl, setConnectionUrl, blockConnection]);
const handleAssistantAddSwipe = useCallback((answer: string) => { const handleAssistantAddSwipe = useCallback((answer: string) => {
const index = messages.findLastIndex(m => m.role === 'assistant'); const index = messages.findLastIndex(m => m.role === 'assistant');
@ -61,29 +55,13 @@ export const Header = () => {
return ( return (
<div class={styles.header}> <div class={styles.header}>
<div class={styles.inputs}> <div class={styles.inputs}>
<input value={connectionUrl} <div class={styles.buttons}>
onInput={setConnectionUrl} <button class='icon' onClick={connectionsOpen.setTrue} title='Connection settings'>
onFocus={blockConnection.setTrue} 🔌
onBlur={handleBlurUrl} </button>
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid} </div>
/>
<select value={instruct} onChange={setInstruct} title='Instruct template'>
{modelName && modelTemplate && <optgroup label='Native model template'>
<option value={modelTemplate} title='Native for model'>{modelName}</option>
</optgroup>}
<optgroup label='Manual templates'>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</optgroup>
<optgroup label='Custom'>
<option value={instruct}>Custom</option>
</optgroup>
</select>
<div class={styles.info}> <div class={styles.info}>
{promptTokens} / {contextLength} {modelName} - {promptTokens} / {contextLength}
</div> </div>
</div> </div>
<div class={styles.buttons}> <div class={styles.buttons}>
@ -102,6 +80,10 @@ export const Header = () => {
</button> </button>
</div> </div>
<Modal open={connectionsOpen.value} onClose={connectionsOpen.setFalse}>
<h3 class={styles.modalTitle}>Connection settings</h3>
<ConnectionEditor connection={connection} setConnection={setConnection} />
</Modal>
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}> <Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
<h3 class={styles.modalTitle}>Lore Editor</h3> <h3 class={styles.modalTitle}>Lore Editor</h3>
<AutoTextarea <AutoTextarea
@ -140,7 +122,7 @@ export const Header = () => {
</label> </label>
<hr /> <hr />
<h4 class={styles.modalTitle}>Instruct template</h4> <h4 class={styles.modalTitle}>Instruct template</h4>
<Ace value={instruct} onInput={setInstruct} /> <Ace value={connection.instruct} onInput={setInstruct} />
</div> </div>
</Modal> </Modal>
<MiniChat <MiniChat

352
src/games/ai/connection.ts Normal file
View File

@ -0,0 +1,352 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { delay, throttle } from "@common/utils";
interface IBaseConnection {
instruct: string;
}
interface IKoboldConnection extends IBaseConnection {
url: string;
}
interface IHordeConnection extends IBaseConnection {
apiKey?: string;
model: string;
}
export const isKoboldConnection = (obj: unknown): obj is IKoboldConnection => (
obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string'
);
export const isHordeConnection = (obj: unknown): obj is IHordeConnection => (
obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string'
);
export type IConnection = IKoboldConnection | IHordeConnection;
interface IHordeWorker {
id: string;
models: string[];
flagged: boolean;
online: boolean;
maintenance_mode: boolean;
max_context_length: number;
max_length: number;
performance: string;
}
export interface IHordeModel {
name: string;
hordeNames: string[];
maxLength: number;
maxContext: number;
workers: string[];
}
interface IHordeResult {
faulted: boolean;
done: boolean;
finished: number;
generations?: {
text: string;
}[];
}
const DEFAULT_GENERATION_SETTINGS = {
temperature: 0.8,
min_p: 0.1,
rep_pen: 1.08,
rep_pen_range: -1,
rep_pen_slope: 0.7,
top_k: 100,
top_p: 0.92,
banned_tokens: ['anticipat'],
max_length: 300,
trim_stop: true,
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
dry_allowed_length: 5,
dry_multiplier: 0.8,
dry_base: 1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
dry_penalty_last_n: 0
}
const MIN_PERFORMANCE = 5.0;
const MIN_WORKER_CONTEXT = 8192;
const MAX_HORDE_LENGTH = 512;
const MAX_HORDE_CONTEXT = 32000;
export const HORDE_ANON_KEY = '0000000000';
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const approximateTokens = (prompt: string): number =>
Math.round(prompt.split(/\s+/).length * 0.75);
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
export namespace Connection {
const AIHORDE = 'https://aihorde.net';
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
const sse = new SSE(`${url}/api/extra/generate/stream`, {
payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
prompt,
}),
});
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
}
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
let maxLength = Math.min(model.maxLength, DEFAULT_GENERATION_SETTINGS.max_length);
if (extraSettings.max_length && extraSettings.max_length < maxLength) {
maxLength = extraSettings.max_length;
}
const requestData = {
prompt,
params: {
...DEFAULT_GENERATION_SETTINGS,
...extraSettings,
n: 1,
max_context_length: model.maxContext,
max_length: maxLength,
rep_pen_range: Math.min(model.maxContext, 4096),
},
models: model.hordeNames,
workers: model.workers,
};
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
method: 'POST',
body: JSON.stringify(requestData),
headers: {
'Content-Type': 'application/json',
apikey: connection.apiKey || HORDE_ANON_KEY,
},
});
if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) {
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
}
const { id } = await generateResponse.json() as { id: string };
const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' })
.catch(e => console.error('Error deleting request', e));
while (true) {
await delay(2500);
const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`);
if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) {
deleteRequest();
throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`);
}
const result: IHordeResult = await retrieveResponse.json();
if (result.done && result.generations?.length === 1) {
const { text } = result.generations[0];
return text;
}
}
}
throw new Error(`Model ${connection.model} is offline`);
}
export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) {
if (isKoboldConnection(connection)) {
yield* generateKobold(connection.url, prompt, extraSettings);
} else if (isHordeConnection(connection)) {
yield await generateHorde(connection, prompt, extraSettings);
}
}
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
try {
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
if (response.ok) {
const workers: IHordeWorker[] = await response.json();
const goodWorkers = workers.filter(w =>
w.online
&& !w.maintenance_mode
&& !w.flagged
&& w.max_context_length >= MIN_WORKER_CONTEXT
&& parseFloat(w.performance) >= MIN_PERFORMANCE
);
const models = new Map<string, IHordeModel>();
for (const worker of goodWorkers) {
for (const modelName of worker.models) {
const normName = normalizeModel(modelName.toLowerCase());
let model = models.get(normName);
if (!model) {
model = {
hordeNames: [],
maxContext: MAX_HORDE_CONTEXT,
maxLength: MAX_HORDE_LENGTH,
name: normName,
workers: []
}
}
if (!model.hordeNames.includes(modelName)) {
model.hordeNames.push(modelName);
}
if (!model.workers.includes(worker.id)) {
model.workers.push(worker.id);
}
model.maxContext = Math.min(model.maxContext, worker.max_context_length);
model.maxLength = Math.min(model.maxLength, worker.max_length);
models.set(normName, model);
}
}
return models;
}
} catch (e) {
console.error(e);
}
return new Map();
};
export const getHordeModels = throttle(requestHordeModels, 10000);
export async function getModelName(connection: IConnection): Promise<string> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/v1/model`);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
return connection.model;
}
return '';
}
export async function getContextLength(connection: IConnection): Promise<number> {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/true_max_context_length`);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
} else if (isHordeConnection(connection)) {
const models = await getHordeModels();
const model = models.get(connection.model);
if (model) {
return model.maxContext;
}
}
return 0;
}
export async function countTokens(connection: IConnection, prompt: string) {
if (isKoboldConnection(connection)) {
try {
const response = await fetch(`${connection.url}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
}
}
return approximateTokens(prompt);
}
}

View File

@ -7,6 +7,7 @@ import { Instruct, StateContext } from "./state";
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Template } from "@huggingface/jinja"; import { Template } from "@huggingface/jinja";
import { Huggingface } from "../huggingface"; import { Huggingface } from "../huggingface";
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
interface ICompileArgs { interface ICompileArgs {
keepUsers?: number; keepUsers?: number;
@ -29,29 +30,8 @@ interface IContext {
contextLength: number; contextLength: number;
} }
const DEFAULT_GENERATION_SETTINGS = {
temperature: 0.8,
min_p: 0.1,
rep_pen: 1.08,
rep_pen_range: -1,
rep_pen_slope: 0.7,
top_k: 100,
top_p: 0.92,
banned_tokens: [],
max_length: 300,
trim_stop: true,
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
dry_allowed_length: 5,
dry_multiplier: 0.8,
dry_base: 1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
dry_penalty_last_n: 0
}
const MESSAGES_TO_KEEP = 10; const MESSAGES_TO_KEEP = 10;
type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
interface IActions { interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>; compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
@ -60,32 +40,6 @@ interface IActions {
} }
export type ILLMContext = IContext & IActions; export type ILLMContext = IContext & IActions;
export const normalizeModel = (model: string) => {
let currentModel = model.split(/[\\\/]/).at(-1);
currentModel = currentModel.split('::').at(0);
let normalizedModel: string;
do {
normalizedModel = currentModel;
currentModel = currentModel
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
.replace(/^(debug-?)+/i, '')
.trim();
} while (normalizedModel !== currentModel);
return normalizedModel
.replace(/[ _-]+/ig, '-')
.replace(/\.{2,}/, '-')
.replace(/[ ._-]+$/ig, '')
.trim();
}
export const LLMContext = createContext<ILLMContext>({} as ILLMContext); export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
const processing = { const processing = {
@ -95,7 +49,7 @@ const processing = {
export const LLMContextProvider = ({ children }: { children?: any }) => { export const LLMContextProvider = ({ children }: { children?: any }) => {
const { const {
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, addMessage, editMessage, editSummary, setInstruct, setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
} = useContext(StateContext); } = useContext(StateContext);
@ -118,38 +72,18 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}, [userPrompt]); }, [userPrompt]);
const getContextLength = useCallback(async () => { const getContextLength = useCallback(async () => {
if (!connectionUrl || blockConnection.value) { if (!connection || blockConnection.value) {
return 0; return 0;
} }
try { return Connection.getContextLength(connection);
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`); }, [connection, blockConnection.value]);
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return 0;
}, [connectionUrl, blockConnection.value]);
const getModelName = useCallback(async () => { const getModelName = useCallback(async () => {
if (!connectionUrl || blockConnection.value) { if (!connection || blockConnection.value) {
return ''; return '';
} }
try { return Connection.getModelName(connection);
const response = await fetch(`${connectionUrl}/api/v1/model`); }, [connection, blockConnection.value]);
if (response.ok) {
const { result } = await response.json();
return result;
}
} catch (e) {
console.log('Error getting max tokens', e);
}
return '';
}, [connectionUrl, blockConnection.value]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
compilePrompt: async (messages, { keepUsers } = {}) => { compilePrompt: async (messages, { keepUsers } = {}) => {
@ -236,7 +170,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
const prompt = Huggingface.applyChatTemplate(instruct, templateMessages); const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
return { return {
prompt, prompt,
isContinue, isContinue,
@ -244,100 +178,30 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}; };
}, },
generate: async function* (prompt, extraSettings = {}) { generate: async function* (prompt, extraSettings = {}) {
if (!connectionUrl) {
return;
}
try { try {
generating.setTrue(); generating.setTrue();
console.log('[LLM.generate]', prompt); console.log('[LLM.generate]', prompt);
const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, { yield* Connection.generate(connection, prompt, {
payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS,
banned_tokens: bannedWords.filter(w => w.trim()),
...extraSettings, ...extraSettings,
prompt, banned_tokens: bannedWords.filter(w => w.trim()),
}),
}); });
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
} finally { } finally {
generating.setFalse(); generating.setFalse();
} }
}, },
summarize: async (message) => { summarize: async (message) => {
const content = Huggingface.applyTemplate(summarizePrompt, { message }); const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(instruct, [{ role: 'user', content }]); const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
const tokens = await Array.fromAsync(actions.generate(prompt)); const tokens = await Array.fromAsync(actions.generate(prompt));
return MessageTools.trimSentence(tokens.join('')); return MessageTools.trimSentence(tokens.join(''));
}, },
countTokens: async (prompt) => { countTokens: async (prompt) => {
if (!connectionUrl) { return await Connection.countTokens(connection, prompt);
return 0;
}
try {
const response = await fetch(`${connectionUrl}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }),
headers: { 'Content-Type': 'applicarion/json' },
method: 'POST',
});
if (response.ok) {
const { value } = await response.json();
return value;
}
} catch (e) {
console.log('Error counting tokens', e);
}
return 0;
}, },
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct, summarizePrompt]); }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
useEffect(() => void (async () => { useEffect(() => void (async () => {
if (triggerNext && !generating.value) { if (triggerNext && !generating.value) {
@ -356,7 +220,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
editSummary(messageId, 'Generating...'); editSummary(messageId, 'Generating...');
for await (const chunk of actions.generate(prompt)) { for await (const chunk of actions.generate(prompt)) {
text += chunk; text += chunk;
setPromptTokens(promptTokens + Math.round(text.length * 0.25)); setPromptTokens(promptTokens + approximateTokens(text));
editMessage(messageId, text.trim()); editMessage(messageId, text.trim());
} }
@ -397,7 +261,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
getContextLength().then(setContextLength); getContextLength().then(setContextLength);
getModelName().then(normalizeModel).then(setModelName); getModelName().then(normalizeModel).then(setModelName);
} }
}, [connectionUrl, blockConnection.value]); }, [connection, blockConnection.value]);
useEffect(() => { useEffect(() => {
setModelTemplate(''); setModelTemplate('');
@ -431,16 +295,16 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
useEffect(() => { useEffect(() => {
calculateTokens(); calculateTokens();
}, [messages, connectionUrl, blockConnection.value, instruct, /* systemPrompt, lore, userPrompt TODO debounce*/]); }, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]);
useEffect(() => { useEffect(() => {
try { try {
const hasTools = Huggingface.testToolCalls(instruct); const hasTools = Huggingface.testToolCalls(connection.instruct);
setHasToolCalls(hasTools); setHasToolCalls(hasTools);
} catch { } catch {
setHasToolCalls(false); setHasToolCalls(false);
} }
}, [instruct]); }, [connection.instruct]);
const rawContext: IContext = { const rawContext: IContext = {
generating: generating.value, generating: generating.value,

View File

@ -1,12 +1,13 @@
import { createContext } from "preact"; import { createContext } from "preact";
import { useEffect, useMemo, useState } from "preact/hooks"; import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages"; import { MessageTools, type IMessage } from "../messages";
import { useInputState } from "@common/hooks/useInputState"; import { useInputState } from "@common/hooks/useInputState";
import { type IConnection } from "../connection";
interface IContext { interface IContext {
connectionUrl: string; currentConnection: number;
availableConnections: IConnection[];
input: string; input: string;
instruct: string;
systemPrompt: string; systemPrompt: string;
lore: string; lore: string;
userPrompt: string; userPrompt: string;
@ -17,8 +18,14 @@ interface IContext {
triggerNext: boolean; triggerNext: boolean;
} }
interface IComputableContext {
connection: IConnection;
}
interface IActions { interface IActions {
setConnectionUrl: (url: string | Event) => void; setConnection: (connection: IConnection) => void;
setAvailableConnections: (connections: IConnection[]) => void;
setCurrentConnection: (connection: number) => void;
setInput: (url: string | Event) => void; setInput: (url: string | Event) => void;
setInstruct: (template: string | Event) => void; setInstruct: (template: string | Event) => void;
setLore: (lore: string | Event) => void; setLore: (lore: string | Event) => void;
@ -49,23 +56,20 @@ export enum Instruct {
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`, MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
METHARME = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>' + message['content'] }}{% elif message['role'] == 'user' %}{{'<|user|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{'<|model|>' + message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% endif %}`,
GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`, GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`,
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`, ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`,
}; };
export const saveContext = (context: IContext) => { const DEFAULT_CONTEXT: IContext = {
const contextToSave: Partial<IContext> = { ...context }; currentConnection: 0,
delete contextToSave.triggerNext; availableConnections: [{
url: 'http://localhost:5001',
localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
}
export const loadContext = (): IContext => {
const defaultContext: IContext = {
connectionUrl: 'http://localhost:5001',
input: '',
instruct: Instruct.CHATML, instruct: Instruct.CHATML,
}],
input: '',
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.', systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
lore: '', lore: '',
userPrompt: `{% if isStart -%} userPrompt: `{% if isStart -%}
@ -75,7 +79,7 @@ export const loadContext = (): IContext => {
{%- endif %} {%- endif %}
{% if prompt -%} {% if prompt -%}
This is the description of what I want to happen next: {{ prompt | trim }} What should happen next in your answer: {{ prompt | trim }}
{% endif %} {% endif %}
Remember that this story should be infinite and go forever. Remember that this story should be infinite and go forever.
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
@ -86,6 +90,14 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
triggerNext: false, triggerNext: false,
}; };
export const saveContext = (context: IContext) => {
const contextToSave: Partial<IContext> = { ...context };
delete contextToSave.triggerNext;
localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
}
export const loadContext = (): IContext => {
let loadedContext: Partial<IContext> = {}; let loadedContext: Partial<IContext> = {};
try { try {
@ -95,18 +107,18 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
} }
} catch { } } catch { }
return { ...defaultContext, ...loadedContext }; return { ...DEFAULT_CONTEXT, ...loadedContext };
} }
export type IStateContext = IContext & IActions; export type IStateContext = IContext & IActions & IComputableContext;
export const StateContext = createContext<IStateContext>({} as IStateContext); export const StateContext = createContext<IStateContext>({} as IStateContext);
export const StateContextProvider = ({ children }: { children?: any }) => { export const StateContextProvider = ({ children }: { children?: any }) => {
const loadedContext = useMemo(() => loadContext(), []); const loadedContext = useMemo(() => loadContext(), []);
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl); const [currentConnection, setCurrentConnection] = useState<number>(loadedContext.currentConnection);
const [availableConnections, setAvailableConnections] = useState<IConnection[]>(loadedContext.availableConnections);
const [input, setInput] = useInputState(loadedContext.input); const [input, setInput] = useInputState(loadedContext.input);
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
const [lore, setLore] = useInputState(loadedContext.lore); const [lore, setLore] = useInputState(loadedContext.lore);
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
@ -115,10 +127,26 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const [messages, setMessages] = useState(loadedContext.messages); const [messages, setMessages] = useState(loadedContext.messages);
const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled); const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled);
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
const [triggerNext, setTriggerNext] = useState(false); const [triggerNext, setTriggerNext] = useState(false);
const [instruct, setInstruct] = useInputState(connection.instruct);
const setConnection = useCallback((c: IConnection) => {
setAvailableConnections(availableConnections.map((ac, ai) => {
if (ai === currentConnection) {
return c;
} else {
return ac;
}
}));
}, [availableConnections, currentConnection]);
useEffect(() => setConnection({ ...connection, instruct }), [instruct]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
setConnectionUrl, setConnection,
setCurrentConnection,
setInput, setInput,
setInstruct, setInstruct,
setSystemPrompt, setSystemPrompt,
@ -127,7 +155,8 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
setLore, setLore,
setTriggerNext, setTriggerNext,
setSummaryEnabled, setSummaryEnabled,
setBannedWords: (words) => setBannedWords([...words]), setBannedWords: (words) => setBannedWords(words.slice()),
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
setMessages: (newMessages) => setMessages(newMessages.slice()), setMessages: (newMessages) => setMessages(newMessages.slice()),
addMessage: (content, role, triggerNext = false) => { addMessage: (content, role, triggerNext = false) => {
@ -198,10 +227,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
continueMessage: () => setTriggerNext(true), continueMessage: () => setTriggerNext(true),
}), []); }), []);
const rawContext: IContext = { const rawContext: IContext & IComputableContext = {
connectionUrl, connection,
currentConnection,
availableConnections,
input, input,
instruct,
systemPrompt, systemPrompt,
lore, lore,
userPrompt, userPrompt,

View File

@ -230,7 +230,9 @@ export namespace Huggingface {
} }
export const findModelTemplate = async (modelName: string): Promise<string | null> => { export const findModelTemplate = async (modelName: string): Promise<string | null> => {
const modelKey = modelName.toLowerCase(); const modelKey = modelName.toLowerCase().trim();
if (!modelKey) return '';
let template = templateCache[modelKey] ?? null; let template = templateCache[modelKey] ?? null;
if (template) { if (template) {