1
0
Fork 0

Compare commits

..

No commits in common. "77802634aef151e296af8ac4f0b2f4943213b4d8" and "0711d3b89ad1fa677e208c97066a2134319546c9" have entirely different histories.

8 changed files with 117 additions and 311 deletions

View File

@ -1,6 +1,6 @@
import { useCallback, useEffect, useMemo, useState } from 'preact/hooks'; import { useCallback, useEffect, useMemo, useState } from 'preact/hooks';
import styles from './header.module.css'; import styles from './header.module.css';
import { Connection, HORDE_ANON_KEY, type IConnection, type IHordeModel } from '../../tools/connection'; import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../tools/connection';
import { Instruct } from '../../contexts/state'; import { Instruct } from '../../contexts/state';
import { useInputState } from '@common/hooks/useInputState'; import { useInputState } from '@common/hooks/useInputState';
import { useInputCallback } from '@common/hooks/useInputCallback'; import { useInputCallback } from '@common/hooks/useInputCallback';
@ -23,17 +23,24 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]); const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
const [contextLength, setContextLength] = useState<number>(0); const [contextLength, setContextLength] = useState<number>(0);
const backendType = useMemo(() => {
if (isKoboldConnection(connection)) return 'kobold';
if (isHordeConnection(connection)) return 'horde';
return 'unknown';
}, [connection]);
const isOnline = useMemo(() => contextLength > 0, [contextLength]); const isOnline = useMemo(() => contextLength > 0, [contextLength]);
useEffect(() => { useEffect(() => {
setInstruct(connection.instruct); setInstruct(connection.instruct);
connection.url && setConnectionUrl(connection.url);
connection.model && setModelName(connection.model); if (isKoboldConnection(connection)) {
setConnectionUrl(connection.url);
Connection.getContextLength(connection).then(setContextLength);
} else if (isHordeConnection(connection)) {
setModelName(connection.model);
setApiKey(connection.apiKey || HORDE_ANON_KEY); setApiKey(connection.apiKey || HORDE_ANON_KEY);
if (connection.type === 'kobold') {
Connection.getContextLength(connection).then(setContextLength);
} else if (connection.type === 'horde') {
Connection.getHordeModels() Connection.getHordeModels()
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name)))); .then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
} }
@ -52,17 +59,17 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
}, [modelName]); }, [modelName]);
const setBackendType = useInputCallback((type) => { const setBackendType = useInputCallback((type) => {
switch (type) { if (type === 'kobold') {
case 'kobold':
case 'horde':
setConnection({ setConnection({
type,
instruct, instruct,
url: connectionUrl, url: connectionUrl,
});
} else if (type === 'horde') {
setConnection({
instruct,
apiKey, apiKey,
model: modelName, model: modelName,
}); });
break;
} }
}, [setConnection, connectionUrl, apiKey, modelName, instruct]); }, [setConnection, connectionUrl, apiKey, modelName, instruct]);
@ -75,19 +82,14 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
const url = connectionUrl.replace(regex, 'http$1://$2'); const url = connectionUrl.replace(regex, 'http$1://$2');
setConnection({ setConnection({
type: 'kobold',
instruct, instruct,
url, url,
apiKey,
model: modelName,
}); });
}, [connectionUrl, instruct, setConnection]); }, [connectionUrl, instruct, setConnection]);
const handleBlurHorde = useCallback(() => { const handleBlurHorde = useCallback(() => {
setConnection({ setConnection({
type: 'horde',
instruct, instruct,
url: connectionUrl,
apiKey, apiKey,
model: modelName, model: modelName,
}); });
@ -95,7 +97,7 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
return ( return (
<div class={styles.connectionEditor}> <div class={styles.connectionEditor}>
<select value={connection.type} onChange={setBackendType}> <select value={backendType} onChange={setBackendType}>
<option value='kobold'>Kobold CPP</option> <option value='kobold'>Kobold CPP</option>
<option value='horde'>Horde</option> <option value='horde'>Horde</option>
</select> </select>
@ -114,13 +116,13 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
<option value={connection.instruct}>Custom</option> <option value={connection.instruct}>Custom</option>
</optgroup>} </optgroup>}
</select> </select>
{connection.type === 'kobold' && <input {isKoboldConnection(connection) && <input
value={connectionUrl} value={connectionUrl}
onInput={setConnectionUrl} onInput={setConnectionUrl}
onBlur={handleBlurUrl} onBlur={handleBlurUrl}
class={isOnline ? styles.valid : styles.invalid} class={isOnline ? styles.valid : styles.invalid}
/>} />}
{connection.type === 'horde' && <> {isHordeConnection(connection) && <>
<input <input
placeholder='Horde API key' placeholder='Horde API key'
title='Horde API key' title='Horde API key'

View File

@ -22,11 +22,6 @@
.info { .info {
margin: 0 8px; margin: 0 8px;
line-height: 36px; line-height: 36px;
display: flex;
flex-direction: row;
justify-content: center;
align-items: center;
gap: 16px;
} }
.buttons { .buttons {
@ -65,27 +60,3 @@
gap: 8px; gap: 8px;
flex-wrap: wrap; flex-wrap: wrap;
} }
.lore {
display: flex;
flex-direction: column;
gap: 10px;
min-height: 80dvh;
.currentStory {
display: flex;
flex-direction: row;
justify-content: center;
align-items: center;
gap: 10px;
.storiesSelector {
height: 24px;
flex-grow: 1;
}
}
.loreText {
flex-grow: 1;
}
}

View File

@ -1,9 +1,8 @@
import { useCallback, useContext, useMemo } from "preact/hooks"; import { useCallback, useContext, useMemo } from "preact/hooks";
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Modal } from "@common/components/modal/Modal"; import { Modal } from "@common/components/modal/Modal";
import { useInputCallback } from "@common/hooks/useInputCallback";
import { DEFAULT_STORY, StateContext } from "../../contexts/state"; import { StateContext } from "../../contexts/state";
import { LLMContext } from "../../contexts/llm"; import { LLMContext } from "../../contexts/llm";
import { MiniChat } from "../minichat/minichat"; import { MiniChat } from "../minichat/minichat";
import { AutoTextarea } from "../autoTextarea"; import { AutoTextarea } from "../autoTextarea";
@ -13,31 +12,10 @@ import { ConnectionEditor } from "./connectionEditor";
import styles from './header.module.css'; import styles from './header.module.css';
export const Header = () => { export const Header = () => {
const { contextLength, promptTokens, modelName, spentKudos } = useContext(LLMContext); const { contextLength, promptTokens, modelName } = useContext(LLMContext);
const { const {
messages, messages, connection, systemPrompt, lore, userPrompt, bannedWords, summarizePrompt, summaryEnabled,
connection, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, setConnection,
systemPrompt,
lore,
userPrompt,
bannedWords,
summarizePrompt,
summaryEnabled,
totalSpentKudos,
stories,
currentStory,
setSystemPrompt,
setLore,
setUserPrompt,
addSwipe,
setBannedWords,
setInstruct,
setSummarizePrompt,
setSummaryEnabled,
setConnection,
setCurrentStory,
createStory,
deleteStory,
} = useContext(StateContext); } = useContext(StateContext);
const connectionsOpen = useBool(); const connectionsOpen = useBool();
@ -75,24 +53,6 @@ export const Header = () => {
} }
}, [setSummaryEnabled]); }, [setSummaryEnabled]);
const handleChangeStory = useInputCallback((story) => {
if (story === '@new') {
const id = prompt('Story id');
if (id) {
createStory(id);
setCurrentStory(id);
}
} else {
setCurrentStory(story);
}
}, []);
const handleDeleteStory = useCallback(() => {
if (confirm(`Delete story "${currentStory}"?`)) {
deleteStory(currentStory);
}
}, [currentStory]);
return ( return (
<div class={styles.header}> <div class={styles.header}>
<div class={styles.inputs}> <div class={styles.inputs}>
@ -102,12 +62,7 @@ export const Header = () => {
</button> </button>
</div> </div>
<div class={styles.info}> <div class={styles.info}>
<span>{modelName}</span> {modelName} - {promptTokens} / {contextLength}
<span>📃{promptTokens}/{contextLength}</span>
{connection.type === 'horde' ? <>
<span>💲{spentKudos}</span>
<span>💰{totalSpentKudos}</span>
</> : null}
</div> </div>
</div> </div>
<div class={styles.buttons}> <div class={styles.buttons}>
@ -130,26 +85,12 @@ export const Header = () => {
<h3 class={styles.modalTitle}>Connection settings</h3> <h3 class={styles.modalTitle}>Connection settings</h3>
<ConnectionEditor connection={connection} setConnection={setConnection} /> <ConnectionEditor connection={connection} setConnection={setConnection} />
</Modal> </Modal>
<Modal open={loreOpen.value} onClose={loreOpen.setFalse} class={styles.lore}> <Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
<h3 class={styles.modalTitle}>Lore Editor</h3> <h3 class={styles.modalTitle}>Lore Editor</h3>
<div class={styles.currentStory}>
<select value={currentStory} onChange={handleChangeStory} class={styles.storiesSelector}>
{Object.keys(stories).map((story) => (
<option key={story} value={story}>{story}</option>
))}
<option value='@new'>New Story...</option>
</select>
{currentStory !== DEFAULT_STORY
? <button class='icon' onClick={handleDeleteStory}>
🗑
</button>
: null}
</div>
<AutoTextarea <AutoTextarea
value={lore} value={lore}
onInput={setLore} onInput={setLore}
placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers." placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers."
class={styles.loreText}
/> />
</Modal> </Modal>
<Modal open={genparamsOpen.value} onClose={genparamsOpen.setFalse}> <Modal open={genparamsOpen.value} onClose={genparamsOpen.setFalse}>

View File

@ -26,7 +26,6 @@ interface IContext {
hasToolCalls: boolean; hasToolCalls: boolean;
promptTokens: number; promptTokens: number;
contextLength: number; contextLength: number;
spentKudos: number;
} }
const MESSAGES_TO_KEEP = 10; const MESSAGES_TO_KEEP = 10;
@ -50,7 +49,7 @@ const processing = {
export const LLMContextProvider = ({ children }: { children?: any }) => { export const LLMContextProvider = ({ children }: { children?: any }) => {
const { const {
connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
setTriggerNext, setContinueLast, addMessage, editMessage, editSummary, setTotalSpentKudos, setTriggerNext, setContinueLast, addMessage, editMessage, editSummary,
} = useContext(StateContext); } = useContext(StateContext);
const generating = useBool(false); const generating = useBool(false);
@ -58,7 +57,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
const [contextLength, setContextLength] = useState(0); const [contextLength, setContextLength] = useState(0);
const [modelName, setModelName] = useState(''); const [modelName, setModelName] = useState('');
const [hasToolCalls, setHasToolCalls] = useState(false); const [hasToolCalls, setHasToolCalls] = useState(false);
const [spentKudos, setSpentKudos] = useState(0);
const isOnline = useMemo(() => contextLength > 0, [contextLength]); const isOnline = useMemo(() => contextLength > 0, [contextLength]);
@ -172,15 +170,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
try { try {
console.log('[LLM.generate]', prompt); console.log('[LLM.generate]', prompt);
setSpentKudos(0); yield* Connection.generate(connection, prompt, {
for await (const { text, cost } of Connection.generate(connection, prompt, {
...extraSettings, ...extraSettings,
banned_tokens: bannedWords.filter(w => w.trim()), banned_tokens: bannedWords.filter(w => w.trim()),
})) { });
setSpentKudos(sk => sk + cost);
setTotalSpentKudos(sk => sk + cost);
yield text;
}
} catch (e) { } catch (e) {
if (e instanceof Error && e.name !== 'AbortError') { if (e instanceof Error && e.name !== 'AbortError') {
alert(e.message); alert(e.message);
@ -195,16 +188,9 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
console.log('[LLM.summarize]', prompt); console.log('[LLM.summarize]', prompt);
const tokens = await Array.fromAsync(Connection.generate(connection, prompt)); const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {}));
const summary = tokens.reduce((sum, token) => ({
text: sum.text + token.text,
cost: sum.cost + token.cost,
}), { text: '', cost: 0 });
setSpentKudos(sk => sk + summary.cost); return MessageTools.trimSentence(tokens.join(''));
setTotalSpentKudos(sk => sk + summary.cost);
return MessageTools.trimSentence(summary.text);
} catch (e) { } catch (e) {
console.error('Error summarizing:', e); console.error('Error summarizing:', e);
return ''; return '';
@ -311,7 +297,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
hasToolCalls, hasToolCalls,
promptTokens, promptTokens,
contextLength, contextLength,
spentKudos,
}; };
const context = useMemo(() => rawContext, Object.values(rawContext)); const context = useMemo(() => rawContext, Object.values(rawContext));

View File

@ -1,28 +1,20 @@
import { createContext } from "preact"; import { createContext } from "preact";
import { useCallback, useEffect, useMemo, useState, type Dispatch, type StateUpdater } from "preact/hooks"; import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
import { MessageTools, type IMessage } from "../tools/messages"; import { MessageTools, type IMessage } from "../tools/messages";
import { useInputState } from "@common/hooks/useInputState"; import { useInputState } from "@common/hooks/useInputState";
import { type IConnection } from "../tools/connection"; import { type IConnection } from "../tools/connection";
import { loadObject, saveObject } from "../tools/storage";
import { useInputCallback } from "@common/hooks/useInputCallback";
interface IStory {
lore: string;
messages: IMessage[];
}
interface IContext { interface IContext {
currentConnection: number; currentConnection: number;
availableConnections: IConnection[]; availableConnections: IConnection[];
input: string; input: string;
systemPrompt: string; systemPrompt: string;
lore: string;
userPrompt: string; userPrompt: string;
summarizePrompt: string; summarizePrompt: string;
summaryEnabled: boolean; summaryEnabled: boolean;
bannedWords: string[]; bannedWords: string[];
totalSpentKudos: number; messages: IMessage[];
stories: Record<string, IStory>;
currentStory: string;
// //
triggerNext: boolean; triggerNext: boolean;
continueLast: boolean; continueLast: boolean;
@ -30,8 +22,6 @@ interface IContext {
interface IComputableContext { interface IComputableContext {
connection: IConnection; connection: IConnection;
lore: string;
messages: IMessage[];
} }
interface IActions { interface IActions {
@ -46,7 +36,6 @@ interface IActions {
setSummarizePrompt: (prompt: string | Event) => void; setSummarizePrompt: (prompt: string | Event) => void;
setBannedWords: (words: string[]) => void; setBannedWords: (words: string[]) => void;
setSummaryEnabled: (summaryEnabled: boolean) => void; setSummaryEnabled: (summaryEnabled: boolean) => void;
setTotalSpentKudos: Dispatch<StateUpdater<number>>;
setTriggerNext: (triggerNext: boolean) => void; setTriggerNext: (triggerNext: boolean) => void;
setContinueLast: (continueLast: boolean) => void; setContinueLast: (continueLast: boolean) => void;
@ -60,14 +49,9 @@ interface IActions {
addSwipe: (index: number, content: string) => void; addSwipe: (index: number, content: string) => void;
continueMessage: (continueLast?: boolean) => void; continueMessage: (continueLast?: boolean) => void;
setCurrentStory: (id: string) => void;
createStory: (id: string) => void;
deleteStory: (id: string) => void;
} }
const SAVE_KEY = 'ai_game_save_state'; const SAVE_KEY = 'ai_game_save_state';
export const DEFAULT_STORY = 'default';
export enum Instruct { export enum Instruct {
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n\\n' }}{% endif %}`, CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n\\n' }}{% endif %}`,
@ -86,14 +70,12 @@ export enum Instruct {
const DEFAULT_CONTEXT: IContext = { const DEFAULT_CONTEXT: IContext = {
currentConnection: 0, currentConnection: 0,
availableConnections: [{ availableConnections: [{
type: 'kobold',
url: 'http://localhost:5001', url: 'http://localhost:5001',
instruct: Instruct.MISTRAL, instruct: Instruct.CHATML,
}], }],
input: '', input: '',
systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.', systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.',
stories: {}, lore: '',
currentStory: DEFAULT_STORY,
userPrompt: `{% if isStart -%} userPrompt: `{% if isStart -%}
Write a novel using information above as a reference. Write a novel using information above as a reference.
{%- else -%} {%- else -%}
@ -108,45 +90,48 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
summarizePrompt: 'Summarize following text in one paragraph:\n\n{{ message }}\n\nAnswer with shortened text only.', summarizePrompt: 'Summarize following text in one paragraph:\n\n{{ message }}\n\nAnswer with shortened text only.',
summaryEnabled: true, summaryEnabled: true,
bannedWords: [], bannedWords: [],
totalSpentKudos: 0, messages: [],
triggerNext: false, triggerNext: false,
continueLast: false, continueLast: false,
}; };
const EMPTY_STORY: IStory = { export const saveContext = (context: IContext) => {
lore: '', const contextToSave: Partial<IContext> = { ...context };
messages: [],
};
const saveContext = async (context: IContext & IComputableContext) => {
const contextToSave: Partial<IContext & IComputableContext> = { ...context };
delete contextToSave.connection;
delete contextToSave.triggerNext; delete contextToSave.triggerNext;
delete contextToSave.continueLast; delete contextToSave.continueLast;
delete contextToSave.lore;
delete contextToSave.messages;
return saveObject(SAVE_KEY, contextToSave); localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
}
export const loadContext = (): IContext => {
let loadedContext: Partial<IContext> = {};
try {
const json = localStorage.getItem(SAVE_KEY);
if (json) {
loadedContext = JSON.parse(json);
}
} catch { }
return { ...DEFAULT_CONTEXT, ...loadedContext };
} }
export type IStateContext = IContext & IActions & IComputableContext; export type IStateContext = IContext & IActions & IComputableContext;
export const StateContext = createContext<IStateContext>({} as IStateContext); export const StateContext = createContext<IStateContext>({} as IStateContext);
const loadedContext = await loadObject(SAVE_KEY, DEFAULT_CONTEXT);
export const StateContextProvider = ({ children }: { children?: any }) => { export const StateContextProvider = ({ children }: { children?: any }) => {
const loadedContext = useMemo(() => loadContext(), []);
const [currentConnection, setCurrentConnection] = useState<number>(loadedContext.currentConnection); const [currentConnection, setCurrentConnection] = useState<number>(loadedContext.currentConnection);
const [availableConnections, setAvailableConnections] = useState<IConnection[]>(loadedContext.availableConnections); const [availableConnections, setAvailableConnections] = useState<IConnection[]>(loadedContext.availableConnections);
const [input, setInput] = useInputState(loadedContext.input); const [input, setInput] = useInputState(loadedContext.input);
const [stories, setStories] = useState(loadedContext.stories); const [lore, setLore] = useInputState(loadedContext.lore);
const [currentStory, setCurrentStory] = useState(loadedContext.currentStory);
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
const [summarizePrompt, setSummarizePrompt] = useInputState(loadedContext.summarizePrompt); const [summarizePrompt, setSummarizePrompt] = useInputState(loadedContext.summarizePrompt);
const [bannedWords, setBannedWords] = useState<string[]>(loadedContext.bannedWords); const [bannedWords, setBannedWords] = useState<string[]>(loadedContext.bannedWords);
const [messages, setMessages] = useState(loadedContext.messages);
const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled); const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled);
const [totalSpentKudos, setTotalSpentKudos] = useState(loadedContext.totalSpentKudos);
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0]; const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
@ -166,35 +151,6 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
useEffect(() => setConnection({ ...connection, instruct }), [instruct]); useEffect(() => setConnection({ ...connection, instruct }), [instruct]);
const setLore = useInputCallback((lore) => {
if (!currentStory) return;
setStories(ss => ({
...ss,
[currentStory]: {
...EMPTY_STORY,
...stories[currentStory],
lore,
}
}));
}, [currentStory]);
const setMessages = useCallback((msg: StateUpdater<IMessage[]>) => {
if (!currentStory) return;
let messages = (typeof msg === 'function')
? msg(stories[currentStory]?.messages ?? EMPTY_STORY.messages)
: msg;
setStories(ss => ({
...ss,
[currentStory]: {
...EMPTY_STORY,
...stories[currentStory],
messages,
}
}));
}, [currentStory]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
setConnection, setConnection,
setCurrentConnection, setCurrentConnection,
@ -208,8 +164,6 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
setTriggerNext, setTriggerNext,
setContinueLast, setContinueLast,
setTotalSpentKudos,
setCurrentStory,
setBannedWords: (words) => setBannedWords(words.slice()), setBannedWords: (words) => setBannedWords(words.slice()),
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()), setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
@ -284,19 +238,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
setTriggerNext(true); setTriggerNext(true);
setContinueLast(c); setContinueLast(c);
}, },
createStory: (id: string) => { }), []);
setStories(ss => ({
...ss,
[id]: { ...EMPTY_STORY }
}))
},
deleteStory: (id: string) => {
if (id === DEFAULT_STORY) return;
setStories(ss => Object.fromEntries(Object.entries(ss).filter(([k]) => k !== id)));
setCurrentStory(cs => cs === id ? DEFAULT_STORY : cs);
}
}), [setLore, setMessages]);
const rawContext: IContext & IComputableContext = { const rawContext: IContext & IComputableContext = {
connection, connection,
@ -304,18 +246,15 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
availableConnections, availableConnections,
input, input,
systemPrompt, systemPrompt,
lore,
userPrompt, userPrompt,
summarizePrompt, summarizePrompt,
summaryEnabled, summaryEnabled,
bannedWords, bannedWords,
totalSpentKudos, messages,
stories,
currentStory,
// //
triggerNext, triggerNext,
continueLast, continueLast,
lore: stories[currentStory]?.lore ?? '',
messages: stories[currentStory]?.messages ?? [],
}; };
const context = useMemo(() => rawContext, Object.values(rawContext)); const context = useMemo(() => rawContext, Object.values(rawContext));

View File

@ -6,24 +6,26 @@ import { Huggingface } from "./huggingface";
import { approximateTokens, normalizeModel } from "./model"; import { approximateTokens, normalizeModel } from "./model";
interface IBaseConnection { interface IBaseConnection {
type: 'kobold' | 'horde';
instruct: string; instruct: string;
url?: string;
apiKey?: string;
model?: string;
} }
interface IKoboldConnection extends IBaseConnection { interface IKoboldConnection extends IBaseConnection {
type: 'kobold';
url: string; url: string;
} }
interface IHordeConnection extends IBaseConnection { interface IHordeConnection extends IBaseConnection {
type: 'horde';
apiKey?: string; apiKey?: string;
model: string; model: string;
} }
export const isKoboldConnection = (obj: unknown): obj is IKoboldConnection => (
obj != null && typeof obj === 'object' && 'url' in obj && typeof obj.url === 'string'
);
export const isHordeConnection = (obj: unknown): obj is IHordeConnection => (
obj != null && typeof obj === 'object' && 'model' in obj && typeof obj.model === 'string'
);
export type IConnection = IKoboldConnection | IHordeConnection; export type IConnection = IKoboldConnection | IHordeConnection;
interface IHordeWorker { interface IHordeWorker {
@ -49,7 +51,6 @@ interface IHordeResult {
faulted: boolean; faulted: boolean;
done: boolean; done: boolean;
finished: number; finished: number;
kudos: number;
generations?: { generations?: {
text: string; text: string;
}[]; }[];
@ -87,12 +88,7 @@ export namespace Connection {
let abortController = new AbortController(); let abortController = new AbortController();
export interface TextChunk { async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
text: string;
cost: number;
}
async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<TextChunk> {
const sse = new SSE(`${url}/api/extra/generate/stream`, { const sse = new SSE(`${url}/api/extra/generate/stream`, {
payload: JSON.stringify({ payload: JSON.stringify({
...DEFAULT_GENERATION_SETTINGS, ...DEFAULT_GENERATION_SETTINGS,
@ -134,10 +130,10 @@ export namespace Connection {
while (!end || messages.length) { while (!end || messages.length) {
while (messages.length > 0) { while (messages.length > 0) {
const text = messages.shift(); const message = messages.shift();
if (text != null) { if (message != null) {
try { try {
yield { text, cost: 0 }; yield message;
} catch { } } catch { }
} }
} }
@ -149,7 +145,7 @@ export namespace Connection {
sse.close(); sse.close();
} }
async function* generateHorde(connection: IHordeConnection, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<TextChunk> { async function* generateHorde(connection: IHordeConnection, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
if (!connection.model) { if (!connection.model) {
throw new Error('Horde not connected'); throw new Error('Horde not connected');
} }
@ -194,14 +190,14 @@ export namespace Connection {
} }
const { id } = await generateResponse.json() as { id: string }; const { id } = await generateResponse.json() as { id: string };
const request = async (method = 'GET'): Promise<TextChunk | null> => { const request = async (method = 'GET'): Promise<string | null> => {
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method }); const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
if (response.ok && response.status < 400) { if (response.ok && response.status < 400) {
const result: IHordeResult = await response.json(); const result: IHordeResult = await response.json();
if (result.generations?.length === 1) { if (result.generations?.length === 1) {
const { text } = result.generations[0]; const { text } = result.generations[0];
return { text, cost: result.kudos }; return text;
} }
} else { } else {
throw new Error(await response.text()); throw new Error(await response.text());
@ -210,17 +206,16 @@ export namespace Connection {
return null; return null;
}; };
const deleteRequest = async () => (await request('DELETE')) ?? { text: '', cost: 0 }; const deleteRequest = async () => (await request('DELETE')) ?? '';
let text: string | null = null; let text: string | null = null;
while (!text) { while (!text) {
try { try {
await delay(2500, { signal }); await delay(2500, { signal });
const response = await request(); text = await request();
if (response?.text) { if (text) {
text = response.text;
for (const sequence of requestData.params.stop_sequence) { for (const sequence of requestData.params.stop_sequence) {
const stopIdx = text.indexOf(sequence); const stopIdx = text.indexOf(sequence);
if (stopIdx >= 0) { if (stopIdx >= 0) {
@ -238,7 +233,7 @@ export namespace Connection {
} }
} }
yield { text: unsloppedText, cost: response.cost }; yield unsloppedText;
requestData.prompt += unsloppedText; requestData.prompt += unsloppedText;
@ -262,9 +257,9 @@ export namespace Connection {
} }
export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) { export async function* generate(connection: IConnection, prompt: string, extraSettings: IGenerationSettings = {}) {
if (connection.type === 'kobold') { if (isKoboldConnection(connection)) {
yield* generateKobold(connection.url, prompt, extraSettings); yield* generateKobold(connection.url, prompt, extraSettings);
} else if (connection.type === 'horde') { } else if (isHordeConnection(connection)) {
yield* generateHorde(connection, prompt, extraSettings); yield* generateHorde(connection, prompt, extraSettings);
} }
} }
@ -329,7 +324,7 @@ export namespace Connection {
export const getHordeModels = throttle(requestHordeModels, 10000); export const getHordeModels = throttle(requestHordeModels, 10000);
export async function getModelName(connection: IConnection): Promise<string> { export async function getModelName(connection: IConnection): Promise<string> {
if (connection.type === 'kobold') { if (isKoboldConnection(connection)) {
try { try {
const response = await fetch(`${connection.url}/api/v1/model`); const response = await fetch(`${connection.url}/api/v1/model`);
if (response.ok) { if (response.ok) {
@ -339,7 +334,7 @@ export namespace Connection {
} catch (e) { } catch (e) {
console.error('Error getting max tokens', e); console.error('Error getting max tokens', e);
} }
} else if (connection.type === 'horde') { } else if (isHordeConnection(connection)) {
return connection.model; return connection.model;
} }
@ -347,7 +342,7 @@ export namespace Connection {
} }
export async function getContextLength(connection: IConnection): Promise<number> { export async function getContextLength(connection: IConnection): Promise<number> {
if (connection.type === 'kobold') { if (isKoboldConnection(connection)) {
try { try {
const response = await fetch(`${connection.url}/api/extra/true_max_context_length`); const response = await fetch(`${connection.url}/api/extra/true_max_context_length`);
if (response.ok) { if (response.ok) {
@ -357,7 +352,7 @@ export namespace Connection {
} catch (e) { } catch (e) {
console.error('Error getting max tokens', e); console.error('Error getting max tokens', e);
} }
} else if (connection.type === 'horde' && connection.model) { } else if (isHordeConnection(connection) && connection.model) {
const models = await getHordeModels(); const models = await getHordeModels();
const model = models.get(connection.model); const model = models.get(connection.model);
if (model) { if (model) {
@ -369,7 +364,7 @@ export namespace Connection {
} }
export async function countTokens(connection: IConnection, prompt: string) { export async function countTokens(connection: IConnection, prompt: string) {
if (connection.type === 'kobold') { if (isKoboldConnection(connection)) {
try { try {
const response = await fetch(`${connection.url}/api/extra/tokencount`, { const response = await fetch(`${connection.url}/api/extra/tokencount`, {
body: JSON.stringify({ prompt }), body: JSON.stringify({ prompt }),
@ -383,7 +378,7 @@ export namespace Connection {
} catch (e) { } catch (e) {
console.error('Error counting tokens:', e); console.error('Error counting tokens:', e);
} }
} else if (connection.type === 'horde') { } else {
const model = await getModelName(connection); const model = await getModelName(connection);
const tokenizer = await Huggingface.findTokenizer(model); const tokenizer = await Huggingface.findTokenizer(model);
if (tokenizer) { if (tokenizer) {

View File

@ -3,7 +3,6 @@ import * as hub from '@huggingface/hub';
import { Template } from '@huggingface/jinja'; import { Template } from '@huggingface/jinja';
import { AutoTokenizer, PreTrainedTokenizer } from '@huggingface/transformers'; import { AutoTokenizer, PreTrainedTokenizer } from '@huggingface/transformers';
import { normalizeModel } from './model'; import { normalizeModel } from './model';
import { loadObject, saveObject } from './storage';
export namespace Huggingface { export namespace Huggingface {
export interface ITemplateMessage { export interface ITemplateMessage {
@ -61,9 +60,27 @@ export namespace Huggingface {
const TEMPLATE_CACHE_KEY = 'ai_game_template_cache'; const TEMPLATE_CACHE_KEY = 'ai_game_template_cache';
const templateCache: Record<string, string> = {}; const loadCache = (): Record<string, string> => {
loadObject(TEMPLATE_CACHE_KEY, {}).then(c => Object.assign(templateCache, c)); const json = localStorage.getItem(TEMPLATE_CACHE_KEY);
try {
if (json) {
const cache = JSON.parse(json);
if (cache && typeof cache === 'object') {
return cache
}
}
} catch { }
return {};
};
const saveCache = (cache: Record<string, string>) => {
const json = JSON.stringify(cache);
localStorage.setItem(TEMPLATE_CACHE_KEY, json);
};
const templateCache: Record<string, string> = loadCache();
const compiledTemplates = new Map<string, Template>(); const compiledTemplates = new Map<string, Template>();
const tokenizerCache = new Map<string, PreTrainedTokenizer | null>(); const tokenizerCache = new Map<string, PreTrainedTokenizer | null>();
@ -244,7 +261,7 @@ export namespace Huggingface {
} }
templateCache[modelName] = template; templateCache[modelName] = template;
saveObject(TEMPLATE_CACHE_KEY, templateCache); saveCache(templateCache);
return template; return template;
} }

View File

@ -1,44 +0,0 @@
const API_KEY = 'awoorwa32';
export const loadObject = async <T>(key: string, defaultObject: T): Promise<T> => {
let localObject: Partial<T> = {};
try {
const json = localStorage.getItem(key);
if (json) {
localObject = JSON.parse(json);
}
} catch { }
let remoteObject: Partial<T> = {};
try {
const response = await fetch(`https://demo.pabloader.ru/storage/${key}`);
if (response.ok) {
remoteObject = await response.json();
}
} catch { }
return { ...defaultObject, ...localObject, ...remoteObject };
}
export const saveObject = async <T>(key: string, obj: T) => {
const saveData = JSON.stringify(obj);
localStorage.setItem(key, saveData);
try {
const url = new URL('https://demo.pabloader.ru/storage/index.php');
url.searchParams.set('filename', key);
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${API_KEY}`,
},
body: saveData,
});
if (!response.ok) {
throw new Error('Failed to save context');
}
} catch {
}
}