Compare commits
2 Commits
ece1621e73
...
277b315795
| Author | SHA1 | Date |
|---|---|---|
|
|
277b315795 | |
|
|
017ef7aaa5 |
|
|
@ -14,6 +14,7 @@
|
||||||
"@inquirer/select": "2.3.10",
|
"@inquirer/select": "2.3.10",
|
||||||
"ace-builds": "1.36.3",
|
"ace-builds": "1.36.3",
|
||||||
"classnames": "2.5.1",
|
"classnames": "2.5.1",
|
||||||
|
"delay": "6.0.0",
|
||||||
"preact": "10.22.0"
|
"preact": "10.22.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
import { useEffect } from "preact/hooks";
|
||||||
|
|
||||||
|
export const useAsyncEffect = (fx: () => any, deps: any[]) =>
|
||||||
|
useEffect(() => void fx(), deps);
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
import { useCallback } from "preact/hooks";
|
||||||
|
|
||||||
|
export function useInputCallback<T>(callback: (value: string) => T, deps: any[]): ((value: string | Event) => T) {
|
||||||
|
return useCallback((e: Event | string) => {
|
||||||
|
if (typeof e === 'string') {
|
||||||
|
return callback(e);
|
||||||
|
} else {
|
||||||
|
const { target } = e;
|
||||||
|
if (target && 'value' in target && typeof target.value === 'string') {
|
||||||
|
return callback(target.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return callback('');
|
||||||
|
}, deps);
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
|
||||||
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
|
export const nextFrame = async (): Promise<number> => new Promise((resolve) => requestAnimationFrame(resolve));
|
||||||
|
|
||||||
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
|
export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random());
|
||||||
|
|
@ -50,3 +49,32 @@ export const intHash = (seed: number, ...parts: number[]) => {
|
||||||
return h1;
|
return h1;
|
||||||
};
|
};
|
||||||
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
|
export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5;
|
||||||
|
export const throttle = function <T, A extends unknown[], R, F extends (this: T, ...args: A) => R>(func: F, ms: number, trailing = false): F {
|
||||||
|
let isThrottled = false;
|
||||||
|
let savedResult: R;
|
||||||
|
let savedThis: T;
|
||||||
|
let savedArgs: A | undefined;
|
||||||
|
|
||||||
|
const wrapper: F = function (...args: A) {
|
||||||
|
if (isThrottled) {
|
||||||
|
savedThis = this;
|
||||||
|
savedArgs = args;
|
||||||
|
} else {
|
||||||
|
savedResult = func.apply(this, args);
|
||||||
|
savedArgs = undefined;
|
||||||
|
|
||||||
|
isThrottled = true;
|
||||||
|
|
||||||
|
setTimeout(function () {
|
||||||
|
isThrottled = false;
|
||||||
|
if (trailing && savedArgs) {
|
||||||
|
savedResult = wrapper.apply(savedThis, savedArgs);
|
||||||
|
}
|
||||||
|
}, ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
return savedResult;
|
||||||
|
} as F;
|
||||||
|
|
||||||
|
return wrapper;
|
||||||
|
}
|
||||||
|
|
@ -32,6 +32,10 @@ select {
|
||||||
outline: none;
|
outline: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
option, optgroup {
|
||||||
|
background-color: var(--backgroundColor);
|
||||||
|
}
|
||||||
|
|
||||||
textarea {
|
textarea {
|
||||||
resize: vertical;
|
resize: vertical;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ export const Chat = () => {
|
||||||
const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant');
|
const lastAssistantId = messages.findLastIndex(m => m.role === 'assistant');
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
DOMTools.scrollDown(chatRef.current);
|
setTimeout(() => DOMTools.scrollDown(chatRef.current, false), 100);
|
||||||
}, [messages.length, lastMessageContent]);
|
}, [messages.length, lastMessageContent]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
import { useCallback, useContext, useEffect, useMemo, useState } from 'preact/hooks';
|
||||||
|
|
||||||
|
import styles from './header.module.css';
|
||||||
|
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection';
|
||||||
|
import { Instruct, StateContext } from '../../contexts/state';
|
||||||
|
import { useInputState } from '@common/hooks/useInputState';
|
||||||
|
import { useInputCallback } from '@common/hooks/useInputCallback';
|
||||||
|
import { Huggingface } from '../../huggingface';
|
||||||
|
|
||||||
|
interface IProps {
|
||||||
|
connection: IConnection;
|
||||||
|
setConnection: (c: IConnection) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
||||||
|
const [connectionUrl, setConnectionUrl] = useInputState('');
|
||||||
|
const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY);
|
||||||
|
const [modelName, setModelName] = useInputState('');
|
||||||
|
|
||||||
|
const [modelTemplate, setModelTemplate] = useInputState(Instruct.CHATML);
|
||||||
|
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
|
||||||
|
const [contextLength, setContextLength] = useState<number>(0);
|
||||||
|
|
||||||
|
const backendType = useMemo(() => {
|
||||||
|
if (isKoboldConnection(connection)) return 'kobold';
|
||||||
|
if (isHordeConnection(connection)) return 'horde';
|
||||||
|
return 'unknown';
|
||||||
|
}, [connection]);
|
||||||
|
|
||||||
|
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isKoboldConnection(connection)) {
|
||||||
|
setConnectionUrl(connection.url);
|
||||||
|
} else if (isHordeConnection(connection)) {
|
||||||
|
setModelName(connection.model);
|
||||||
|
setApiKey(connection.apiKey || HORDE_ANON_KEY);
|
||||||
|
|
||||||
|
Connection.getHordeModels()
|
||||||
|
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
|
||||||
|
}
|
||||||
|
|
||||||
|
Connection.getContextLength(connection).then(setContextLength);
|
||||||
|
Connection.getModelName(connection).then(setModelName);
|
||||||
|
}, [connection]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (modelName) {
|
||||||
|
Huggingface.findModelTemplate(modelName)
|
||||||
|
.then(template => {
|
||||||
|
if (template) {
|
||||||
|
setModelTemplate(template);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [modelName]);
|
||||||
|
|
||||||
|
const setInstruct = useInputCallback((instruct) => {
|
||||||
|
setConnection({ ...connection, instruct });
|
||||||
|
}, [connection, setConnection]);
|
||||||
|
|
||||||
|
const setBackendType = useInputCallback((type) => {
|
||||||
|
if (type === 'kobold') {
|
||||||
|
setConnection({
|
||||||
|
instruct: connection.instruct,
|
||||||
|
url: connectionUrl,
|
||||||
|
});
|
||||||
|
} else if (type === 'horde') {
|
||||||
|
setConnection({
|
||||||
|
instruct: connection.instruct,
|
||||||
|
apiKey,
|
||||||
|
model: modelName,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [connection, setConnection, connectionUrl, apiKey, modelName]);
|
||||||
|
|
||||||
|
const handleBlurUrl = useCallback(() => {
|
||||||
|
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i;
|
||||||
|
const url = connectionUrl.replace(regex, 'http$1://$2');
|
||||||
|
|
||||||
|
setConnection({
|
||||||
|
instruct: connection.instruct,
|
||||||
|
url,
|
||||||
|
});
|
||||||
|
}, [connection, connectionUrl, setConnection]);
|
||||||
|
|
||||||
|
const handleBlurHorde = useCallback(() => {
|
||||||
|
setConnection({
|
||||||
|
instruct: connection.instruct,
|
||||||
|
apiKey,
|
||||||
|
model: modelName,
|
||||||
|
});
|
||||||
|
}, [connection, apiKey, modelName, setConnection]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div class={styles.connectionEditor}>
|
||||||
|
<select value={backendType} onChange={setBackendType}>
|
||||||
|
<option value='kobold'>Kobold CPP</option>
|
||||||
|
<option value='horde'>Horde</option>
|
||||||
|
</select>
|
||||||
|
<select value={connection.instruct} onChange={setInstruct} title='Instruct template'>
|
||||||
|
{modelName && modelTemplate && <optgroup label='Native model template'>
|
||||||
|
<option value={modelTemplate} title='Native for model'>{modelName}</option>
|
||||||
|
</optgroup>}
|
||||||
|
<optgroup label='Manual templates'>
|
||||||
|
{Object.entries(Instruct).map(([label, value]) => (
|
||||||
|
<option value={value} key={value}>
|
||||||
|
{label.toLowerCase()}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</optgroup>
|
||||||
|
<optgroup label='Custom'>
|
||||||
|
<option value={connection.instruct}>Custom</option>
|
||||||
|
</optgroup>
|
||||||
|
</select>
|
||||||
|
{isKoboldConnection(connection) && <input
|
||||||
|
value={connectionUrl}
|
||||||
|
onInput={setConnectionUrl}
|
||||||
|
onBlur={handleBlurUrl}
|
||||||
|
class={urlValid ? styles.valid : styles.invalid}
|
||||||
|
/>}
|
||||||
|
{isHordeConnection(connection) && <>
|
||||||
|
<input
|
||||||
|
placeholder='Horde API key'
|
||||||
|
title='Horde API key'
|
||||||
|
value={apiKey}
|
||||||
|
onInput={setApiKey}
|
||||||
|
onBlur={handleBlurHorde}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<select
|
||||||
|
value={modelName}
|
||||||
|
onChange={setModelName}
|
||||||
|
onBlur={handleBlurHorde}
|
||||||
|
title='Horde model'
|
||||||
|
>
|
||||||
|
{hordeModels.map((m) => (
|
||||||
|
<option value={m.name} key={m.name}>
|
||||||
|
{m.name} ({m.maxLength}/{m.maxContext})
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</>}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
@ -45,3 +45,10 @@
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.connectionEditor {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
gap: 8px;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
@ -2,35 +2,29 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
|
||||||
import { useBool } from "@common/hooks/useBool";
|
import { useBool } from "@common/hooks/useBool";
|
||||||
import { Modal } from "@common/components/modal/modal";
|
import { Modal } from "@common/components/modal/modal";
|
||||||
|
|
||||||
import { Instruct, StateContext } from "../../contexts/state";
|
import { StateContext } from "../../contexts/state";
|
||||||
import { LLMContext } from "../../contexts/llm";
|
import { LLMContext } from "../../contexts/llm";
|
||||||
import { MiniChat } from "../minichat/minichat";
|
import { MiniChat } from "../minichat/minichat";
|
||||||
import { AutoTextarea } from "../autoTextarea";
|
import { AutoTextarea } from "../autoTextarea";
|
||||||
|
import { Ace } from "../ace";
|
||||||
|
import { ConnectionEditor } from "./connectionEditor";
|
||||||
|
|
||||||
import styles from './header.module.css';
|
import styles from './header.module.css';
|
||||||
import { Ace } from "../ace";
|
|
||||||
|
|
||||||
export const Header = () => {
|
export const Header = () => {
|
||||||
const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext);
|
const { contextLength, promptTokens, modelName } = useContext(LLMContext);
|
||||||
const {
|
const {
|
||||||
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled,
|
messages, connection, systemPrompt, lore, userPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
||||||
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled,
|
setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setConnection,
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
|
|
||||||
|
const connectionsOpen = useBool();
|
||||||
const loreOpen = useBool();
|
const loreOpen = useBool();
|
||||||
const promptsOpen = useBool();
|
const promptsOpen = useBool();
|
||||||
const genparamsOpen = useBool();
|
const genparamsOpen = useBool();
|
||||||
const assistantOpen = useBool();
|
const assistantOpen = useBool();
|
||||||
|
|
||||||
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
|
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
|
||||||
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
|
|
||||||
|
|
||||||
const handleBlurUrl = useCallback(() => {
|
|
||||||
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i
|
|
||||||
const normalizedConnectionUrl = connectionUrl.replace(regex, 'http$1://$2');
|
|
||||||
setConnectionUrl(normalizedConnectionUrl);
|
|
||||||
blockConnection.setFalse();
|
|
||||||
}, [connectionUrl, setConnectionUrl, blockConnection]);
|
|
||||||
|
|
||||||
const handleAssistantAddSwipe = useCallback((answer: string) => {
|
const handleAssistantAddSwipe = useCallback((answer: string) => {
|
||||||
const index = messages.findLastIndex(m => m.role === 'assistant');
|
const index = messages.findLastIndex(m => m.role === 'assistant');
|
||||||
|
|
@ -61,29 +55,13 @@ export const Header = () => {
|
||||||
return (
|
return (
|
||||||
<div class={styles.header}>
|
<div class={styles.header}>
|
||||||
<div class={styles.inputs}>
|
<div class={styles.inputs}>
|
||||||
<input value={connectionUrl}
|
<div class={styles.buttons}>
|
||||||
onInput={setConnectionUrl}
|
<button class='icon' onClick={connectionsOpen.setTrue} title='Connection settings'>
|
||||||
onFocus={blockConnection.setTrue}
|
🔌
|
||||||
onBlur={handleBlurUrl}
|
</button>
|
||||||
class={blockConnection.value ? '' : urlValid ? styles.valid : styles.invalid}
|
</div>
|
||||||
/>
|
|
||||||
<select value={instruct} onChange={setInstruct} title='Instruct template'>
|
|
||||||
{modelName && modelTemplate && <optgroup label='Native model template'>
|
|
||||||
<option value={modelTemplate} title='Native for model'>{modelName}</option>
|
|
||||||
</optgroup>}
|
|
||||||
<optgroup label='Manual templates'>
|
|
||||||
{Object.entries(Instruct).map(([label, value]) => (
|
|
||||||
<option value={value} key={value}>
|
|
||||||
{label.toLowerCase()}
|
|
||||||
</option>
|
|
||||||
))}
|
|
||||||
</optgroup>
|
|
||||||
<optgroup label='Custom'>
|
|
||||||
<option value={instruct}>Custom</option>
|
|
||||||
</optgroup>
|
|
||||||
</select>
|
|
||||||
<div class={styles.info}>
|
<div class={styles.info}>
|
||||||
{promptTokens} / {contextLength}
|
{modelName} - {promptTokens} / {contextLength}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
|
|
@ -102,6 +80,10 @@ export const Header = () => {
|
||||||
❓
|
❓
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
<Modal open={connectionsOpen.value} onClose={connectionsOpen.setFalse}>
|
||||||
|
<h3 class={styles.modalTitle}>Connection settings</h3>
|
||||||
|
<ConnectionEditor connection={connection} setConnection={setConnection} />
|
||||||
|
</Modal>
|
||||||
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
|
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
|
||||||
<h3 class={styles.modalTitle}>Lore Editor</h3>
|
<h3 class={styles.modalTitle}>Lore Editor</h3>
|
||||||
<AutoTextarea
|
<AutoTextarea
|
||||||
|
|
@ -135,12 +117,12 @@ export const Header = () => {
|
||||||
<h4 class={styles.modalTitle}>Summary template</h4>
|
<h4 class={styles.modalTitle}>Summary template</h4>
|
||||||
<Ace value={summarizePrompt} onInput={setSummarizePrompt} />
|
<Ace value={summarizePrompt} onInput={setSummarizePrompt} />
|
||||||
<label>
|
<label>
|
||||||
<input type='checkbox' checked={summaryEnabled} onChange={handleSetSummaryEnabled}/>
|
<input type='checkbox' checked={summaryEnabled} onChange={handleSetSummaryEnabled} />
|
||||||
Enable summarization
|
Enable summarization
|
||||||
</label>
|
</label>
|
||||||
<hr />
|
<hr />
|
||||||
<h4 class={styles.modalTitle}>Instruct template</h4>
|
<h4 class={styles.modalTitle}>Instruct template</h4>
|
||||||
<Ace value={instruct} onInput={setInstruct} />
|
<Ace value={connection.instruct} onInput={setInstruct} />
|
||||||
</div>
|
</div>
|
||||||
</Modal>
|
</Modal>
|
||||||
<MiniChat
|
<MiniChat
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea";
|
||||||
|
|
||||||
export const Input = () => {
|
export const Input = () => {
|
||||||
const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
|
const { input, setInput, addMessage, continueMessage } = useContext(StateContext);
|
||||||
const { generating } = useContext(LLMContext);
|
const { generating, stopGeneration } = useContext(LLMContext);
|
||||||
|
|
||||||
const handleSend = useCallback(async () => {
|
const handleSend = useCallback(async () => {
|
||||||
if (!generating) {
|
if (!generating) {
|
||||||
|
|
@ -29,7 +29,10 @@ export const Input = () => {
|
||||||
return (
|
return (
|
||||||
<div class="chat-input">
|
<div class="chat-input">
|
||||||
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
|
<AutoTextarea onInput={setInput} onKeyDown={handleKeyDown} value={input} />
|
||||||
<button onClick={handleSend} class={`${generating ? 'disabled' : ''}`}>{input ? 'Send' : 'Continue'}</button>
|
{generating
|
||||||
|
? <button onClick={stopGeneration}>Stop</button>
|
||||||
|
: <button onClick={handleSend}>{input ? 'Send' : 'Continue'}</button>
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -16,7 +16,7 @@ interface IProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
||||||
const { generating, generate, compilePrompt } = useContext(LLMContext);
|
const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
|
||||||
const [messages, setMessages] = useState<IMessage[]>([]);
|
const [messages, setMessages] = useState<IMessage[]>([]);
|
||||||
const ref = useRef<HTMLDivElement>(null);
|
const ref = useRef<HTMLDivElement>(null);
|
||||||
|
|
||||||
|
|
@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
<button onClick={handleGenerate} class={`${generating ? 'disabled' : ''}`}>
|
{generating
|
||||||
Generate
|
? <button onClick={stopGeneration}>Stop</button>
|
||||||
</button>
|
: <button onClick={handleGenerate}>Generate</button>
|
||||||
|
}
|
||||||
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
|
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
|
||||||
Clear
|
Clear
|
||||||
</button>
|
</button>
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,376 @@
|
||||||
|
import Lock from "@common/lock";
|
||||||
|
import SSE from "@common/sse";
|
||||||
|
import { throttle } from "@common/utils";
|
||||||
|
import delay, { clearDelay } from "delay";
|
||||||
|
|
||||||
|
interface IBaseConnection {
|
||||||
|
instruct: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IKoboldConnection extends IBaseConnection {
|
||||||
|
url: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IHordeConnection extends IBaseConnection {
|
||||||
|
apiKey?: string;
|
||||||
|
model: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const isKoboldConnection = (obj: unknown): obj is IKoboldConnection => (
|
||||||
|
obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string'
|
||||||
|
);
|
||||||
|
|
||||||
|
export const isHordeConnection = (obj: unknown): obj is IHordeConnection => (
|
||||||
|
obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string'
|
||||||
|
);
|
||||||
|
|
||||||
|
export type IConnection = IKoboldConnection | IHordeConnection;
|
||||||
|
|
||||||
|
interface IHordeWorker {
|
||||||
|
id: string;
|
||||||
|
models: string[];
|
||||||
|
flagged: boolean;
|
||||||
|
online: boolean;
|
||||||
|
maintenance_mode: boolean;
|
||||||
|
max_context_length: number;
|
||||||
|
max_length: number;
|
||||||
|
performance: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IHordeModel {
|
||||||
|
name: string;
|
||||||
|
hordeNames: string[];
|
||||||
|
maxLength: number;
|
||||||
|
maxContext: number;
|
||||||
|
workers: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IHordeResult {
|
||||||
|
faulted: boolean;
|
||||||
|
done: boolean;
|
||||||
|
finished: number;
|
||||||
|
generations?: {
|
||||||
|
text: string;
|
||||||
|
}[];
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_GENERATION_SETTINGS = {
|
||||||
|
temperature: 0.8,
|
||||||
|
min_p: 0.1,
|
||||||
|
rep_pen: 1.08,
|
||||||
|
rep_pen_range: -1,
|
||||||
|
rep_pen_slope: 0.7,
|
||||||
|
top_k: 100,
|
||||||
|
top_p: 0.92,
|
||||||
|
banned_tokens: ['anticipat'],
|
||||||
|
max_length: 300,
|
||||||
|
trim_stop: true,
|
||||||
|
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
|
||||||
|
dry_allowed_length: 5,
|
||||||
|
dry_multiplier: 0.8,
|
||||||
|
dry_base: 1,
|
||||||
|
dry_sequence_breakers: ["\n", ":", "\"", "*"],
|
||||||
|
dry_penalty_last_n: 0
|
||||||
|
}
|
||||||
|
|
||||||
|
const MIN_PERFORMANCE = 2.0;
|
||||||
|
const MIN_WORKER_CONTEXT = 8192;
|
||||||
|
const MAX_HORDE_LENGTH = 512;
|
||||||
|
const MAX_HORDE_CONTEXT = 32000;
|
||||||
|
export const HORDE_ANON_KEY = '0000000000';
|
||||||
|
|
||||||
|
export const normalizeModel = (model: string) => {
|
||||||
|
let currentModel = model.split(/[\\\/]/).at(-1);
|
||||||
|
currentModel = currentModel.split('::').at(0);
|
||||||
|
let normalizedModel: string;
|
||||||
|
|
||||||
|
do {
|
||||||
|
normalizedModel = currentModel;
|
||||||
|
|
||||||
|
currentModel = currentModel
|
||||||
|
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
||||||
|
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
|
||||||
|
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
||||||
|
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
||||||
|
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
||||||
|
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
|
||||||
|
.replace(/^(debug-?)+/i, '')
|
||||||
|
.trim();
|
||||||
|
} while (normalizedModel !== currentModel);
|
||||||
|
|
||||||
|
return normalizedModel
|
||||||
|
.replace(/[ _-]+/ig, '-')
|
||||||
|
.replace(/\.{2,}/, '-')
|
||||||
|
.replace(/[ ._-]+$/ig, '')
|
||||||
|
.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
|
||||||
|
|
||||||
|
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
||||||
|
|
||||||
|
export namespace Connection {
|
||||||
|
const AIHORDE = 'https://aihorde.net';
|
||||||
|
|
||||||
|
let abortController = new AbortController();
|
||||||
|
|
||||||
|
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
|
||||||
|
const sse = new SSE(`${url}/api/extra/generate/stream`, {
|
||||||
|
payload: JSON.stringify({
|
||||||
|
...DEFAULT_GENERATION_SETTINGS,
|
||||||
|
...extraSettings,
|
||||||
|
prompt,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const messages: string[] = [];
|
||||||
|
const messageLock = new Lock();
|
||||||
|
let end = false;
|
||||||
|
|
||||||
|
sse.addEventListener('message', (e) => {
|
||||||
|
if (e.data) {
|
||||||
|
{
|
||||||
|
const { token, finish_reason } = JSON.parse(e.data);
|
||||||
|
messages.push(token);
|
||||||
|
|
||||||
|
if (finish_reason && finish_reason !== 'null') {
|
||||||
|
end = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messageLock.release();
|
||||||
|
});
|
||||||
|
|
||||||
|
const handleEnd = () => {
|
||||||
|
end = true;
|
||||||
|
messageLock.release();
|
||||||
|
};
|
||||||
|
|
||||||
|
abortController.signal.addEventListener('abort', handleEnd);
|
||||||
|
sse.addEventListener('error', handleEnd);
|
||||||
|
sse.addEventListener('abort', handleEnd);
|
||||||
|
sse.addEventListener('readystatechange', (e) => {
|
||||||
|
if (e.readyState === SSE.CLOSED) handleEnd();
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
while (!end || messages.length) {
|
||||||
|
while (messages.length > 0) {
|
||||||
|
const message = messages.shift();
|
||||||
|
if (message != null) {
|
||||||
|
try {
|
||||||
|
yield message;
|
||||||
|
} catch { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!end) {
|
||||||
|
await messageLock.wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sse.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
|
||||||
|
const models = await getHordeModels();
|
||||||
|
const model = models.get(connection.model);
|
||||||
|
if (model) {
|
||||||
|
let maxLength = Math.min(model.maxLength, DEFAULT_GENERATION_SETTINGS.max_length);
|
||||||
|
if (extraSettings.max_length && extraSettings.max_length < maxLength) {
|
||||||
|
maxLength = extraSettings.max_length;
|
||||||
|
}
|
||||||
|
const requestData = {
|
||||||
|
prompt,
|
||||||
|
params: {
|
||||||
|
...DEFAULT_GENERATION_SETTINGS,
|
||||||
|
...extraSettings,
|
||||||
|
n: 1,
|
||||||
|
max_context_length: model.maxContext,
|
||||||
|
max_length: maxLength,
|
||||||
|
rep_pen_range: Math.min(model.maxContext, 4096),
|
||||||
|
},
|
||||||
|
models: model.hordeNames,
|
||||||
|
workers: model.workers,
|
||||||
|
};
|
||||||
|
|
||||||
|
const { signal } = abortController;
|
||||||
|
|
||||||
|
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify(requestData),
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
apikey: connection.apiKey || HORDE_ANON_KEY,
|
||||||
|
},
|
||||||
|
signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!generateResponse.ok || generateResponse.status >= 400) {
|
||||||
|
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const { id } = await generateResponse.json() as { id: string };
|
||||||
|
const request = async (method = 'GET'): Promise<string | null> => {
|
||||||
|
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
|
||||||
|
if (response.ok && response.status < 400) {
|
||||||
|
const result: IHordeResult = await response.json();
|
||||||
|
if (result.generations?.length === 1) {
|
||||||
|
const { text } = result.generations[0];
|
||||||
|
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error(await response.text());
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
const deleteRequest = async () => (await request('DELETE')) ?? '';
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
await delay(2500, { signal });
|
||||||
|
|
||||||
|
const text = await request();
|
||||||
|
|
||||||
|
if (text) {
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error in horde generation:', e);
|
||||||
|
return deleteRequest();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(`Model ${connection.model} is offline`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) {
|
||||||
|
if (isKoboldConnection(connection)) {
|
||||||
|
yield* generateKobold(connection.url, prompt, extraSettings);
|
||||||
|
} else if (isHordeConnection(connection)) {
|
||||||
|
yield await generateHorde(connection, prompt, extraSettings);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function stopGeneration() {
|
||||||
|
abortController.abort();
|
||||||
|
abortController = new AbortController(); // refresh
|
||||||
|
}
|
||||||
|
|
||||||
|
async function requestHordeModels(): Promise<Map<string, IHordeModel>> {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`);
|
||||||
|
if (response.ok) {
|
||||||
|
const workers: IHordeWorker[] = await response.json();
|
||||||
|
const goodWorkers = workers.filter(w =>
|
||||||
|
w.online
|
||||||
|
&& !w.maintenance_mode
|
||||||
|
&& !w.flagged
|
||||||
|
&& w.max_context_length >= MIN_WORKER_CONTEXT
|
||||||
|
&& parseFloat(w.performance) >= MIN_PERFORMANCE
|
||||||
|
);
|
||||||
|
|
||||||
|
const models = new Map<string, IHordeModel>();
|
||||||
|
|
||||||
|
for (const worker of goodWorkers) {
|
||||||
|
for (const modelName of worker.models) {
|
||||||
|
const normName = normalizeModel(modelName.toLowerCase());
|
||||||
|
let model = models.get(normName);
|
||||||
|
if (!model) {
|
||||||
|
model = {
|
||||||
|
hordeNames: [],
|
||||||
|
maxContext: MAX_HORDE_CONTEXT,
|
||||||
|
maxLength: MAX_HORDE_LENGTH,
|
||||||
|
name: normName,
|
||||||
|
workers: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!model.hordeNames.includes(modelName)) {
|
||||||
|
model.hordeNames.push(modelName);
|
||||||
|
}
|
||||||
|
if (!model.workers.includes(worker.id)) {
|
||||||
|
model.workers.push(worker.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
model.maxContext = Math.min(model.maxContext, worker.max_context_length);
|
||||||
|
model.maxLength = Math.min(model.maxLength, worker.max_length);
|
||||||
|
|
||||||
|
models.set(normName, model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Map();
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getHordeModels = throttle(requestHordeModels, 10000);
|
||||||
|
|
||||||
|
export async function getModelName(connection: IConnection): Promise<string> {
|
||||||
|
if (isKoboldConnection(connection)) {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${connection.url}/api/v1/model`);
|
||||||
|
if (response.ok) {
|
||||||
|
const { result } = await response.json();
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error getting max tokens', e);
|
||||||
|
}
|
||||||
|
} else if (isHordeConnection(connection)) {
|
||||||
|
return connection.model;
|
||||||
|
}
|
||||||
|
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getContextLength(connection: IConnection): Promise<number> {
|
||||||
|
if (isKoboldConnection(connection)) {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${connection.url}/api/extra/true_max_context_length`);
|
||||||
|
if (response.ok) {
|
||||||
|
const { value } = await response.json();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error getting max tokens', e);
|
||||||
|
}
|
||||||
|
} else if (isHordeConnection(connection)) {
|
||||||
|
const models = await getHordeModels();
|
||||||
|
const model = models.get(connection.model);
|
||||||
|
if (model) {
|
||||||
|
return model.maxContext;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function countTokens(connection: IConnection, prompt: string) {
|
||||||
|
if (isKoboldConnection(connection)) {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${connection.url}/api/extra/tokencount`, {
|
||||||
|
body: JSON.stringify({ prompt }),
|
||||||
|
headers: { 'Content-Type': 'applicarion/json' },
|
||||||
|
method: 'POST',
|
||||||
|
});
|
||||||
|
if (response.ok) {
|
||||||
|
const { value } = await response.json();
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error counting tokens', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return approximateTokens(prompt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import Lock from "@common/lock";
|
|
||||||
import SSE from "@common/sse";
|
|
||||||
import { createContext } from "preact";
|
import { createContext } from "preact";
|
||||||
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../messages";
|
import { MessageTools, type IMessage } from "../messages";
|
||||||
import { Instruct, StateContext } from "./state";
|
import { StateContext } from "./state";
|
||||||
import { useBool } from "@common/hooks/useBool";
|
import { useBool } from "@common/hooks/useBool";
|
||||||
import { Template } from "@huggingface/jinja";
|
import { Template } from "@huggingface/jinja";
|
||||||
import { Huggingface } from "../huggingface";
|
import { Huggingface } from "../huggingface";
|
||||||
|
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
|
||||||
|
import { throttle } from "@common/utils";
|
||||||
|
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
|
||||||
|
|
||||||
interface ICompileArgs {
|
interface ICompileArgs {
|
||||||
keepUsers?: number;
|
keepUsers?: number;
|
||||||
|
|
@ -21,71 +22,23 @@ interface ICompiledPrompt {
|
||||||
|
|
||||||
interface IContext {
|
interface IContext {
|
||||||
generating: boolean;
|
generating: boolean;
|
||||||
blockConnection: ReturnType<typeof useBool>;
|
|
||||||
modelName: string;
|
modelName: string;
|
||||||
modelTemplate: string;
|
|
||||||
hasToolCalls: boolean;
|
hasToolCalls: boolean;
|
||||||
promptTokens: number;
|
promptTokens: number;
|
||||||
contextLength: number;
|
contextLength: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_GENERATION_SETTINGS = {
|
|
||||||
temperature: 0.8,
|
|
||||||
min_p: 0.1,
|
|
||||||
rep_pen: 1.08,
|
|
||||||
rep_pen_range: -1,
|
|
||||||
rep_pen_slope: 0.7,
|
|
||||||
top_k: 100,
|
|
||||||
top_p: 0.92,
|
|
||||||
banned_tokens: [],
|
|
||||||
max_length: 300,
|
|
||||||
trim_stop: true,
|
|
||||||
stop_sequence: ['[INST]', '[/INST]', '</s>', '<|'],
|
|
||||||
dry_allowed_length: 5,
|
|
||||||
dry_multiplier: 0.8,
|
|
||||||
dry_base: 1,
|
|
||||||
dry_sequence_breakers: ["\n", ":", "\"", "*"],
|
|
||||||
dry_penalty_last_n: 0
|
|
||||||
}
|
|
||||||
|
|
||||||
const MESSAGES_TO_KEEP = 10;
|
const MESSAGES_TO_KEEP = 10;
|
||||||
|
|
||||||
type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
|
||||||
|
|
||||||
interface IActions {
|
interface IActions {
|
||||||
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
||||||
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
||||||
|
stopGeneration: () => void;
|
||||||
summarize: (content: string) => Promise<string>;
|
summarize: (content: string) => Promise<string>;
|
||||||
countTokens: (prompt: string) => Promise<number>;
|
countTokens: (prompt: string) => Promise<number>;
|
||||||
}
|
}
|
||||||
export type ILLMContext = IContext & IActions;
|
export type ILLMContext = IContext & IActions;
|
||||||
|
|
||||||
export const normalizeModel = (model: string) => {
|
|
||||||
let currentModel = model.split(/[\\\/]/).at(-1);
|
|
||||||
currentModel = currentModel.split('::').at(0);
|
|
||||||
let normalizedModel: string;
|
|
||||||
|
|
||||||
do {
|
|
||||||
normalizedModel = currentModel;
|
|
||||||
|
|
||||||
currentModel = currentModel
|
|
||||||
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
|
||||||
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name
|
|
||||||
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
|
||||||
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
|
||||||
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
|
||||||
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
|
|
||||||
.replace(/^(debug-?)+/i, '')
|
|
||||||
.trim();
|
|
||||||
} while (normalizedModel !== currentModel);
|
|
||||||
|
|
||||||
return normalizedModel
|
|
||||||
.replace(/[ _-]+/ig, '-')
|
|
||||||
.replace(/\.{2,}/, '-')
|
|
||||||
.replace(/[ ._-]+$/ig, '')
|
|
||||||
.trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
||||||
|
|
||||||
const processing = {
|
const processing = {
|
||||||
|
|
@ -95,16 +48,14 @@ const processing = {
|
||||||
|
|
||||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const {
|
const {
|
||||||
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled,
|
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
||||||
setTriggerNext, addMessage, editMessage, editSummary, setInstruct,
|
setTriggerNext, addMessage, editMessage, editSummary,
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
|
|
||||||
const generating = useBool(false);
|
const generating = useBool(false);
|
||||||
const blockConnection = useBool(false);
|
|
||||||
const [promptTokens, setPromptTokens] = useState(0);
|
const [promptTokens, setPromptTokens] = useState(0);
|
||||||
const [contextLength, setContextLength] = useState(0);
|
const [contextLength, setContextLength] = useState(0);
|
||||||
const [modelName, setModelName] = useState('');
|
const [modelName, setModelName] = useState('');
|
||||||
const [modelTemplate, setModelTemplate] = useState('');
|
|
||||||
const [hasToolCalls, setHasToolCalls] = useState(false);
|
const [hasToolCalls, setHasToolCalls] = useState(false);
|
||||||
|
|
||||||
const userPromptTemplate = useMemo(() => {
|
const userPromptTemplate = useMemo(() => {
|
||||||
|
|
@ -117,40 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}
|
}
|
||||||
}, [userPrompt]);
|
}, [userPrompt]);
|
||||||
|
|
||||||
const getContextLength = useCallback(async () => {
|
|
||||||
if (!connectionUrl || blockConnection.value) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${connectionUrl}/api/extra/true_max_context_length`);
|
|
||||||
if (response.ok) {
|
|
||||||
const { value } = await response.json();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.log('Error getting max tokens', e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}, [connectionUrl, blockConnection.value]);
|
|
||||||
|
|
||||||
const getModelName = useCallback(async () => {
|
|
||||||
if (!connectionUrl || blockConnection.value) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${connectionUrl}/api/v1/model`);
|
|
||||||
if (response.ok) {
|
|
||||||
const { result } = await response.json();
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.log('Error getting max tokens', e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return '';
|
|
||||||
}, [connectionUrl, blockConnection.value]);
|
|
||||||
|
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
compilePrompt: async (messages, { keepUsers } = {}) => {
|
compilePrompt: async (messages, { keepUsers } = {}) => {
|
||||||
const promptMessages = messages.slice();
|
const promptMessages = messages.slice();
|
||||||
|
|
@ -236,7 +153,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
||||||
|
|
||||||
const prompt = Huggingface.applyChatTemplate(instruct, templateMessages);
|
const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
|
||||||
return {
|
return {
|
||||||
prompt,
|
prompt,
|
||||||
isContinue,
|
isContinue,
|
||||||
|
|
@ -244,102 +161,44 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
generate: async function* (prompt, extraSettings = {}) {
|
generate: async function* (prompt, extraSettings = {}) {
|
||||||
if (!connectionUrl) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
generating.setTrue();
|
|
||||||
console.log('[LLM.generate]', prompt);
|
console.log('[LLM.generate]', prompt);
|
||||||
|
|
||||||
const sse = new SSE(`${connectionUrl}/api/extra/generate/stream`, {
|
yield* Connection.generate(connection, prompt, {
|
||||||
payload: JSON.stringify({
|
|
||||||
...DEFAULT_GENERATION_SETTINGS,
|
|
||||||
banned_tokens: bannedWords.filter(w => w.trim()),
|
|
||||||
...extraSettings,
|
...extraSettings,
|
||||||
prompt,
|
banned_tokens: bannedWords.filter(w => w.trim()),
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
} catch (e) {
|
||||||
const messages: string[] = [];
|
if (e instanceof Error && e.name !== 'AbortError') {
|
||||||
const messageLock = new Lock();
|
alert(e.message);
|
||||||
let end = false;
|
} else {
|
||||||
|
console.error('[LLM.generate]', e);
|
||||||
sse.addEventListener('message', (e) => {
|
|
||||||
if (e.data) {
|
|
||||||
{
|
|
||||||
const { token, finish_reason } = JSON.parse(e.data);
|
|
||||||
messages.push(token);
|
|
||||||
|
|
||||||
if (finish_reason && finish_reason !== 'null') {
|
|
||||||
end = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
messageLock.release();
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleEnd = () => {
|
|
||||||
end = true;
|
|
||||||
messageLock.release();
|
|
||||||
};
|
|
||||||
|
|
||||||
sse.addEventListener('error', handleEnd);
|
|
||||||
sse.addEventListener('abort', handleEnd);
|
|
||||||
sse.addEventListener('readystatechange', (e) => {
|
|
||||||
if (e.readyState === SSE.CLOSED) handleEnd();
|
|
||||||
});
|
|
||||||
|
|
||||||
while (!end || messages.length) {
|
|
||||||
while (messages.length > 0) {
|
|
||||||
const message = messages.shift();
|
|
||||||
if (message != null) {
|
|
||||||
try {
|
|
||||||
yield message;
|
|
||||||
} catch { }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!end) {
|
|
||||||
await messageLock.wait();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sse.close();
|
|
||||||
} finally {
|
|
||||||
generating.setFalse();
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
summarize: async (message) => {
|
summarize: async (message) => {
|
||||||
|
try {
|
||||||
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
||||||
const prompt = Huggingface.applyChatTemplate(instruct, [{ role: 'user', content }]);
|
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
|
||||||
|
console.log('[LLM.summarize]', prompt);
|
||||||
|
|
||||||
const tokens = await Array.fromAsync(actions.generate(prompt));
|
const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
|
||||||
|
|
||||||
return MessageTools.trimSentence(tokens.join(''));
|
return MessageTools.trimSentence(tokens.join(''));
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error summarizing:', e);
|
||||||
|
return '';
|
||||||
|
}
|
||||||
},
|
},
|
||||||
countTokens: async (prompt) => {
|
countTokens: async (prompt) => {
|
||||||
if (!connectionUrl) {
|
return await Connection.countTokens(connection, prompt);
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${connectionUrl}/api/extra/tokencount`, {
|
|
||||||
body: JSON.stringify({ prompt }),
|
|
||||||
headers: { 'Content-Type': 'applicarion/json' },
|
|
||||||
method: 'POST',
|
|
||||||
});
|
|
||||||
if (response.ok) {
|
|
||||||
const { value } = await response.json();
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.log('Error counting tokens', e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
},
|
},
|
||||||
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct, summarizePrompt]);
|
stopGeneration: () => {
|
||||||
|
Connection.stopGeneration();
|
||||||
|
},
|
||||||
|
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
||||||
|
|
||||||
useEffect(() => void (async () => {
|
useAsyncEffect(async () => {
|
||||||
if (triggerNext && !generating.value) {
|
if (triggerNext && !generating.value) {
|
||||||
setTriggerNext(false);
|
setTriggerNext(false);
|
||||||
|
|
||||||
|
|
@ -353,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
messageId++;
|
messageId++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
generating.setTrue();
|
||||||
editSummary(messageId, 'Generating...');
|
editSummary(messageId, 'Generating...');
|
||||||
for await (const chunk of actions.generate(prompt)) {
|
for await (const chunk of actions.generate(prompt)) {
|
||||||
text += chunk;
|
text += chunk;
|
||||||
setPromptTokens(promptTokens + Math.round(text.length * 0.25));
|
setPromptTokens(promptTokens + approximateTokens(text));
|
||||||
editMessage(messageId, text.trim());
|
editMessage(messageId, text.trim());
|
||||||
}
|
}
|
||||||
|
generating.setFalse();
|
||||||
|
|
||||||
text = MessageTools.trimSentence(text);
|
text = MessageTools.trimSentence(text);
|
||||||
editMessage(messageId, text);
|
editMessage(messageId, text);
|
||||||
|
|
@ -366,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
MessageTools.playReady();
|
MessageTools.playReady();
|
||||||
}
|
}
|
||||||
})(), [triggerNext]);
|
}, [triggerNext]);
|
||||||
|
|
||||||
useEffect(() => void (async () => {
|
useAsyncEffect(async () => {
|
||||||
if (summaryEnabled && !generating.value && !processing.summarizing) {
|
if (summaryEnabled && !processing.summarizing) {
|
||||||
try {
|
try {
|
||||||
processing.summarizing = true;
|
processing.summarizing = true;
|
||||||
for (let id = 0; id < messages.length; id++) {
|
for (let id = 0; id < messages.length; id++) {
|
||||||
|
|
@ -386,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
processing.summarizing = false;
|
processing.summarizing = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})(), [messages]);
|
}, [messages, summaryEnabled]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(throttle(() => {
|
||||||
if (!blockConnection.value) {
|
Connection.getContextLength(connection).then(setContextLength);
|
||||||
setPromptTokens(0);
|
Connection.getModelName(connection).then(normalizeModel).then(setModelName);
|
||||||
setContextLength(0);
|
}, 1000, true), [connection]);
|
||||||
setModelName('');
|
|
||||||
|
|
||||||
getContextLength().then(setContextLength);
|
const calculateTokens = useCallback(throttle(async () => {
|
||||||
getModelName().then(normalizeModel).then(setModelName);
|
if (!processing.tokenizing && !generating.value) {
|
||||||
}
|
|
||||||
}, [connectionUrl, blockConnection.value]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
setModelTemplate('');
|
|
||||||
if (modelName) {
|
|
||||||
Huggingface.findModelTemplate(modelName)
|
|
||||||
.then((template) => {
|
|
||||||
if (template) {
|
|
||||||
setModelTemplate(template);
|
|
||||||
setInstruct(template);
|
|
||||||
} else {
|
|
||||||
setInstruct(Instruct.CHATML);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [modelName]);
|
|
||||||
|
|
||||||
const calculateTokens = useCallback(async () => {
|
|
||||||
if (!processing.tokenizing && !blockConnection.value && !generating.value) {
|
|
||||||
try {
|
try {
|
||||||
processing.tokenizing = true;
|
processing.tokenizing = true;
|
||||||
const { prompt } = await actions.compilePrompt(messages);
|
const { prompt } = await actions.compilePrompt(messages);
|
||||||
|
|
@ -427,26 +267,24 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
processing.tokenizing = false;
|
processing.tokenizing = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [actions, messages, blockConnection.value]);
|
}, 1000, true), [actions, messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
calculateTokens();
|
calculateTokens();
|
||||||
}, [messages, connectionUrl, blockConnection.value, instruct, /* systemPrompt, lore, userPrompt TODO debounce*/]);
|
}, [messages, connection, systemPrompt, lore, userPrompt]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
try {
|
try {
|
||||||
const hasTools = Huggingface.testToolCalls(instruct);
|
const hasTools = Huggingface.testToolCalls(connection.instruct);
|
||||||
setHasToolCalls(hasTools);
|
setHasToolCalls(hasTools);
|
||||||
} catch {
|
} catch {
|
||||||
setHasToolCalls(false);
|
setHasToolCalls(false);
|
||||||
}
|
}
|
||||||
}, [instruct]);
|
}, [connection.instruct]);
|
||||||
|
|
||||||
const rawContext: IContext = {
|
const rawContext: IContext = {
|
||||||
generating: generating.value,
|
generating: generating.value,
|
||||||
blockConnection,
|
|
||||||
modelName,
|
modelName,
|
||||||
modelTemplate,
|
|
||||||
hasToolCalls,
|
hasToolCalls,
|
||||||
promptTokens,
|
promptTokens,
|
||||||
contextLength,
|
contextLength,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import { createContext } from "preact";
|
import { createContext } from "preact";
|
||||||
import { useEffect, useMemo, useState } from "preact/hooks";
|
import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../messages";
|
import { MessageTools, type IMessage } from "../messages";
|
||||||
import { useInputState } from "@common/hooks/useInputState";
|
import { useInputState } from "@common/hooks/useInputState";
|
||||||
|
import { type IConnection } from "../connection";
|
||||||
|
|
||||||
interface IContext {
|
interface IContext {
|
||||||
connectionUrl: string;
|
currentConnection: number;
|
||||||
|
availableConnections: IConnection[];
|
||||||
input: string;
|
input: string;
|
||||||
instruct: string;
|
|
||||||
systemPrompt: string;
|
systemPrompt: string;
|
||||||
lore: string;
|
lore: string;
|
||||||
userPrompt: string;
|
userPrompt: string;
|
||||||
|
|
@ -17,8 +18,14 @@ interface IContext {
|
||||||
triggerNext: boolean;
|
triggerNext: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface IComputableContext {
|
||||||
|
connection: IConnection;
|
||||||
|
}
|
||||||
|
|
||||||
interface IActions {
|
interface IActions {
|
||||||
setConnectionUrl: (url: string | Event) => void;
|
setConnection: (connection: IConnection) => void;
|
||||||
|
setAvailableConnections: (connections: IConnection[]) => void;
|
||||||
|
setCurrentConnection: (connection: number) => void;
|
||||||
setInput: (url: string | Event) => void;
|
setInput: (url: string | Event) => void;
|
||||||
setInstruct: (template: string | Event) => void;
|
setInstruct: (template: string | Event) => void;
|
||||||
setLore: (lore: string | Event) => void;
|
setLore: (lore: string | Event) => void;
|
||||||
|
|
@ -49,11 +56,40 @@ export enum Instruct {
|
||||||
|
|
||||||
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
||||||
|
|
||||||
|
METHARME = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>' + message['content'] }}{% elif message['role'] == 'user' %}{{'<|user|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{'<|model|>' + message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% endif %}`,
|
||||||
|
|
||||||
GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`,
|
GEMMA = `{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}`,
|
||||||
|
|
||||||
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`,
|
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\\n\\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\\n\\n' + message['content'] + '\\n\\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\\n\\n' + message['content'] + '\\n\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\\n\\n' }}{% endif %}`,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const DEFAULT_CONTEXT: IContext = {
|
||||||
|
currentConnection: 0,
|
||||||
|
availableConnections: [{
|
||||||
|
url: 'http://localhost:5001',
|
||||||
|
instruct: Instruct.CHATML,
|
||||||
|
}],
|
||||||
|
input: '',
|
||||||
|
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
|
||||||
|
lore: '',
|
||||||
|
userPrompt: `{% if isStart -%}
|
||||||
|
Write a novel using information above as a reference.
|
||||||
|
{%- else -%}
|
||||||
|
Continue the story forward.
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{% if prompt -%}
|
||||||
|
What should happen next in your answer: {{ prompt | trim }}
|
||||||
|
{% endif %}
|
||||||
|
Remember that this story should be infinite and go forever.
|
||||||
|
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
|
||||||
|
summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
|
||||||
|
summaryEnabled: false,
|
||||||
|
bannedWords: [],
|
||||||
|
messages: [],
|
||||||
|
triggerNext: false,
|
||||||
|
};
|
||||||
|
|
||||||
export const saveContext = (context: IContext) => {
|
export const saveContext = (context: IContext) => {
|
||||||
const contextToSave: Partial<IContext> = { ...context };
|
const contextToSave: Partial<IContext> = { ...context };
|
||||||
delete contextToSave.triggerNext;
|
delete contextToSave.triggerNext;
|
||||||
|
|
@ -62,30 +98,6 @@ export const saveContext = (context: IContext) => {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const loadContext = (): IContext => {
|
export const loadContext = (): IContext => {
|
||||||
const defaultContext: IContext = {
|
|
||||||
connectionUrl: 'http://localhost:5001',
|
|
||||||
input: '',
|
|
||||||
instruct: Instruct.CHATML,
|
|
||||||
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
|
|
||||||
lore: '',
|
|
||||||
userPrompt: `{% if isStart -%}
|
|
||||||
Write a novel using information above as a reference.
|
|
||||||
{%- else -%}
|
|
||||||
Continue the story forward.
|
|
||||||
{%- endif %}
|
|
||||||
|
|
||||||
{% if prompt -%}
|
|
||||||
This is the description of what I want to happen next: {{ prompt | trim }}
|
|
||||||
{% endif %}
|
|
||||||
Remember that this story should be infinite and go forever.
|
|
||||||
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
|
|
||||||
summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.',
|
|
||||||
summaryEnabled: false,
|
|
||||||
bannedWords: [],
|
|
||||||
messages: [],
|
|
||||||
triggerNext: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
let loadedContext: Partial<IContext> = {};
|
let loadedContext: Partial<IContext> = {};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
@ -95,18 +107,18 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
|
||||||
}
|
}
|
||||||
} catch { }
|
} catch { }
|
||||||
|
|
||||||
return { ...defaultContext, ...loadedContext };
|
return { ...DEFAULT_CONTEXT, ...loadedContext };
|
||||||
}
|
}
|
||||||
|
|
||||||
export type IStateContext = IContext & IActions;
|
export type IStateContext = IContext & IActions & IComputableContext;
|
||||||
|
|
||||||
export const StateContext = createContext<IStateContext>({} as IStateContext);
|
export const StateContext = createContext<IStateContext>({} as IStateContext);
|
||||||
|
|
||||||
export const StateContextProvider = ({ children }: { children?: any }) => {
|
export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
const loadedContext = useMemo(() => loadContext(), []);
|
const loadedContext = useMemo(() => loadContext(), []);
|
||||||
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl);
|
const [currentConnection, setCurrentConnection] = useState<number>(loadedContext.currentConnection);
|
||||||
|
const [availableConnections, setAvailableConnections] = useState<IConnection[]>(loadedContext.availableConnections);
|
||||||
const [input, setInput] = useInputState(loadedContext.input);
|
const [input, setInput] = useInputState(loadedContext.input);
|
||||||
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
|
|
||||||
const [lore, setLore] = useInputState(loadedContext.lore);
|
const [lore, setLore] = useInputState(loadedContext.lore);
|
||||||
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
|
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
|
||||||
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
|
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
|
||||||
|
|
@ -115,10 +127,26 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
const [messages, setMessages] = useState(loadedContext.messages);
|
const [messages, setMessages] = useState(loadedContext.messages);
|
||||||
const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled);
|
const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled);
|
||||||
|
|
||||||
|
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
|
||||||
|
|
||||||
const [triggerNext, setTriggerNext] = useState(false);
|
const [triggerNext, setTriggerNext] = useState(false);
|
||||||
|
const [instruct, setInstruct] = useInputState(connection.instruct);
|
||||||
|
|
||||||
|
const setConnection = useCallback((c: IConnection) => {
|
||||||
|
setAvailableConnections(availableConnections.map((ac, ai) => {
|
||||||
|
if (ai === currentConnection) {
|
||||||
|
return c;
|
||||||
|
} else {
|
||||||
|
return ac;
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}, [availableConnections, currentConnection]);
|
||||||
|
|
||||||
|
useEffect(() => setConnection({ ...connection, instruct }), [instruct]);
|
||||||
|
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
setConnectionUrl,
|
setConnection,
|
||||||
|
setCurrentConnection,
|
||||||
setInput,
|
setInput,
|
||||||
setInstruct,
|
setInstruct,
|
||||||
setSystemPrompt,
|
setSystemPrompt,
|
||||||
|
|
@ -127,7 +155,8 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
setLore,
|
setLore,
|
||||||
setTriggerNext,
|
setTriggerNext,
|
||||||
setSummaryEnabled,
|
setSummaryEnabled,
|
||||||
setBannedWords: (words) => setBannedWords([...words]),
|
setBannedWords: (words) => setBannedWords(words.slice()),
|
||||||
|
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
|
||||||
|
|
||||||
setMessages: (newMessages) => setMessages(newMessages.slice()),
|
setMessages: (newMessages) => setMessages(newMessages.slice()),
|
||||||
addMessage: (content, role, triggerNext = false) => {
|
addMessage: (content, role, triggerNext = false) => {
|
||||||
|
|
@ -198,10 +227,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
continueMessage: () => setTriggerNext(true),
|
continueMessage: () => setTriggerNext(true),
|
||||||
}), []);
|
}), []);
|
||||||
|
|
||||||
const rawContext: IContext = {
|
const rawContext: IContext & IComputableContext = {
|
||||||
connectionUrl,
|
connection,
|
||||||
|
currentConnection,
|
||||||
|
availableConnections,
|
||||||
input,
|
input,
|
||||||
instruct,
|
|
||||||
systemPrompt,
|
systemPrompt,
|
||||||
lore,
|
lore,
|
||||||
userPrompt,
|
userPrompt,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import { gguf } from '@huggingface/gguf';
|
import { gguf } from '@huggingface/gguf';
|
||||||
import * as hub from '@huggingface/hub';
|
import * as hub from '@huggingface/hub';
|
||||||
import { Template } from '@huggingface/jinja';
|
import { Template } from '@huggingface/jinja';
|
||||||
|
import { normalizeModel } from './connection';
|
||||||
|
|
||||||
export namespace Huggingface {
|
export namespace Huggingface {
|
||||||
export interface ITemplateMessage {
|
export interface ITemplateMessage {
|
||||||
|
|
@ -92,11 +93,12 @@ export namespace Huggingface {
|
||||||
|
|
||||||
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
||||||
console.log(`[huggingface] searching config for '${modelName}'`);
|
console.log(`[huggingface] searching config for '${modelName}'`);
|
||||||
|
const searchModel = normalizeModel(modelName);
|
||||||
|
|
||||||
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
|
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
|
||||||
const models = hubModels.filter(m => {
|
const models = hubModels.filter(m => {
|
||||||
if (m.gated) return false;
|
if (m.gated) return false;
|
||||||
if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false;
|
if (!normalizeModel(m.name).includes(searchModel)) return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}).sort((a, b) => b.downloads - a.downloads);
|
}).sort((a, b) => b.downloads - a.downloads);
|
||||||
|
|
@ -230,7 +232,9 @@ export namespace Huggingface {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
||||||
const modelKey = modelName.toLowerCase();
|
const modelKey = modelName.toLowerCase().trim();
|
||||||
|
if (!modelKey) return '';
|
||||||
|
|
||||||
let template = templateCache[modelKey] ?? null;
|
let template = templateCache[modelKey] ?? null;
|
||||||
|
|
||||||
if (template) {
|
if (template) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue