Compare commits
2 Commits
9c4cc61573
...
c480f5a7d1
| Author | SHA1 | Date |
|---|---|---|
|
|
c480f5a7d1 | |
|
|
a213e0407c |
|
|
@ -11,6 +11,7 @@
|
||||||
"@huggingface/gguf": "0.1.12",
|
"@huggingface/gguf": "0.1.12",
|
||||||
"@huggingface/hub": "0.19.0",
|
"@huggingface/hub": "0.19.0",
|
||||||
"@huggingface/jinja": "0.3.1",
|
"@huggingface/jinja": "0.3.1",
|
||||||
|
"@huggingface/transformers": "3.0.2",
|
||||||
"@inquirer/select": "2.3.10",
|
"@inquirer/select": "2.3.10",
|
||||||
"ace-builds": "1.36.3",
|
"ace-builds": "1.36.3",
|
||||||
"classnames": "2.5.1",
|
"classnames": "2.5.1",
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@
|
||||||
--green: #AFAFAF;
|
--green: #AFAFAF;
|
||||||
--red: #7F0000;
|
--red: #7F0000;
|
||||||
--green: #007F00;
|
--green: #007F00;
|
||||||
|
--brightRed: #DD0000;
|
||||||
|
--brightGreen: #00DD00;
|
||||||
--shadeColor: rgba(0, 128, 128, 0.3);
|
--shadeColor: rgba(0, 128, 128, 0.3);
|
||||||
|
|
||||||
--border: 1px solid var(--color);
|
--border: 1px solid var(--color);
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import { useEffect, useRef } from "preact/hooks";
|
||||||
import type { JSX } from "preact/jsx-runtime"
|
import type { JSX } from "preact/jsx-runtime"
|
||||||
|
|
||||||
import { useIsVisible } from '@common/hooks/useIsVisible';
|
import { useIsVisible } from '@common/hooks/useIsVisible';
|
||||||
import { DOMTools } from "../dom";
|
import { DOMTools } from "../tools/dom";
|
||||||
|
|
||||||
export const AutoTextarea = (props: JSX.HTMLAttributes<HTMLTextAreaElement>) => {
|
export const AutoTextarea = (props: JSX.HTMLAttributes<HTMLTextAreaElement>) => {
|
||||||
const { value } = props;
|
const { value } = props;
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import { useCallback, useContext, useEffect, useRef } from "preact/hooks";
|
import { useCallback, useContext, useEffect, useRef } from "preact/hooks";
|
||||||
import { StateContext } from "../contexts/state";
|
import { StateContext } from "../contexts/state";
|
||||||
import { Message } from "./message/message";
|
import { Message } from "./message/message";
|
||||||
import { MessageTools } from "../messages";
|
import { MessageTools } from "../tools/messages";
|
||||||
import { DOMTools } from "../dom";
|
import { DOMTools } from "../tools/dom";
|
||||||
|
|
||||||
export const Chat = () => {
|
export const Chat = () => {
|
||||||
const { messages } = useContext(StateContext);
|
const { messages } = useContext(StateContext);
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
import { useCallback, useEffect, useMemo, useState } from 'preact/hooks';
|
import { useCallback, useEffect, useMemo, useState } from 'preact/hooks';
|
||||||
|
|
||||||
import styles from './header.module.css';
|
import styles from './header.module.css';
|
||||||
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../connection';
|
import { Connection, HORDE_ANON_KEY, isHordeConnection, isKoboldConnection, type IConnection, type IHordeModel } from '../../tools/connection';
|
||||||
import { Instruct } from '../../contexts/state';
|
import { Instruct } from '../../contexts/state';
|
||||||
import { useInputState } from '@common/hooks/useInputState';
|
import { useInputState } from '@common/hooks/useInputState';
|
||||||
import { useInputCallback } from '@common/hooks/useInputCallback';
|
import { useInputCallback } from '@common/hooks/useInputCallback';
|
||||||
import { Huggingface } from '../../huggingface';
|
import { Huggingface } from '../../tools/huggingface';
|
||||||
|
|
||||||
interface IProps {
|
interface IProps {
|
||||||
connection: IConnection;
|
connection: IConnection;
|
||||||
|
|
@ -13,10 +12,13 @@ interface IProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
||||||
|
// kobold
|
||||||
const [connectionUrl, setConnectionUrl] = useInputState('');
|
const [connectionUrl, setConnectionUrl] = useInputState('');
|
||||||
|
// horde
|
||||||
const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY);
|
const [apiKey, setApiKey] = useInputState(HORDE_ANON_KEY);
|
||||||
const [modelName, setModelName] = useInputState('');
|
const [modelName, setModelName] = useInputState('');
|
||||||
|
|
||||||
|
const [instruct, setInstruct] = useInputState('');
|
||||||
const [modelTemplate, setModelTemplate] = useInputState('');
|
const [modelTemplate, setModelTemplate] = useInputState('');
|
||||||
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
|
const [hordeModels, setHordeModels] = useState<IHordeModel[]>([]);
|
||||||
const [contextLength, setContextLength] = useState<number>(0);
|
const [contextLength, setContextLength] = useState<number>(0);
|
||||||
|
|
@ -27,11 +29,14 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
||||||
return 'unknown';
|
return 'unknown';
|
||||||
}, [connection]);
|
}, [connection]);
|
||||||
|
|
||||||
const urlValid = useMemo(() => contextLength > 0, [contextLength]);
|
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
setInstruct(connection.instruct);
|
||||||
|
|
||||||
if (isKoboldConnection(connection)) {
|
if (isKoboldConnection(connection)) {
|
||||||
setConnectionUrl(connection.url);
|
setConnectionUrl(connection.url);
|
||||||
|
Connection.getContextLength(connection).then(setContextLength);
|
||||||
} else if (isHordeConnection(connection)) {
|
} else if (isHordeConnection(connection)) {
|
||||||
setModelName(connection.model);
|
setModelName(connection.model);
|
||||||
setApiKey(connection.apiKey || HORDE_ANON_KEY);
|
setApiKey(connection.apiKey || HORDE_ANON_KEY);
|
||||||
|
|
@ -39,9 +44,6 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
||||||
Connection.getHordeModels()
|
Connection.getHordeModels()
|
||||||
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
|
.then(m => setHordeModels(Array.from(m.values()).sort((a, b) => a.name.localeCompare(b.name))));
|
||||||
}
|
}
|
||||||
|
|
||||||
Connection.getContextLength(connection).then(setContextLength);
|
|
||||||
Connection.getModelName(connection).then(setModelName);
|
|
||||||
}, [connection]);
|
}, [connection]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|
@ -50,47 +52,44 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
||||||
.then(template => {
|
.then(template => {
|
||||||
if (template) {
|
if (template) {
|
||||||
setModelTemplate(template);
|
setModelTemplate(template);
|
||||||
|
setInstruct(template);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [modelName]);
|
}, [modelName]);
|
||||||
|
|
||||||
const setInstruct = useInputCallback((instruct) => {
|
|
||||||
setConnection({ ...connection, instruct });
|
|
||||||
}, [connection, setConnection]);
|
|
||||||
|
|
||||||
const setBackendType = useInputCallback((type) => {
|
const setBackendType = useInputCallback((type) => {
|
||||||
if (type === 'kobold') {
|
if (type === 'kobold') {
|
||||||
setConnection({
|
setConnection({
|
||||||
instruct: connection.instruct,
|
instruct,
|
||||||
url: connectionUrl,
|
url: connectionUrl,
|
||||||
});
|
});
|
||||||
} else if (type === 'horde') {
|
} else if (type === 'horde') {
|
||||||
setConnection({
|
setConnection({
|
||||||
instruct: connection.instruct,
|
instruct,
|
||||||
apiKey,
|
apiKey,
|
||||||
model: modelName,
|
model: modelName,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [connection, setConnection, connectionUrl, apiKey, modelName]);
|
}, [setConnection, connectionUrl, apiKey, modelName, instruct]);
|
||||||
|
|
||||||
const handleBlurUrl = useCallback(() => {
|
const handleBlurUrl = useCallback(() => {
|
||||||
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i;
|
const regex = /^(?:http(s?):\/\/)?(.*?)\/?$/i;
|
||||||
const url = connectionUrl.replace(regex, 'http$1://$2');
|
const url = connectionUrl.replace(regex, 'http$1://$2');
|
||||||
|
|
||||||
setConnection({
|
setConnection({
|
||||||
instruct: connection.instruct,
|
instruct,
|
||||||
url,
|
url,
|
||||||
});
|
});
|
||||||
}, [connection, connectionUrl, setConnection]);
|
}, [connectionUrl, instruct, setConnection]);
|
||||||
|
|
||||||
const handleBlurHorde = useCallback(() => {
|
const handleBlurHorde = useCallback(() => {
|
||||||
setConnection({
|
setConnection({
|
||||||
instruct: connection.instruct,
|
instruct,
|
||||||
apiKey,
|
apiKey,
|
||||||
model: modelName,
|
model: modelName,
|
||||||
});
|
});
|
||||||
}, [connection, apiKey, modelName, setConnection]);
|
}, [apiKey, modelName, instruct, setConnection]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div class={styles.connectionEditor}>
|
<div class={styles.connectionEditor}>
|
||||||
|
|
@ -98,7 +97,7 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
||||||
<option value='kobold'>Kobold CPP</option>
|
<option value='kobold'>Kobold CPP</option>
|
||||||
<option value='horde'>Horde</option>
|
<option value='horde'>Horde</option>
|
||||||
</select>
|
</select>
|
||||||
<select value={connection.instruct} onChange={setInstruct} title='Instruct template'>
|
<select value={instruct} onChange={setInstruct} title='Instruct template'>
|
||||||
{modelName && modelTemplate && <optgroup label='Native model template'>
|
{modelName && modelTemplate && <optgroup label='Native model template'>
|
||||||
<option value={modelTemplate} title='Native for model'>{modelName}</option>
|
<option value={modelTemplate} title='Native for model'>{modelName}</option>
|
||||||
</optgroup>}
|
</optgroup>}
|
||||||
|
|
@ -109,15 +108,15 @@ export const ConnectionEditor = ({ connection, setConnection }: IProps) => {
|
||||||
</option>
|
</option>
|
||||||
))}
|
))}
|
||||||
</optgroup>
|
</optgroup>
|
||||||
<optgroup label='Custom'>
|
{instruct !== modelTemplate && <optgroup label='Custom'>
|
||||||
<option value={connection.instruct}>Custom</option>
|
<option value={connection.instruct}>Custom</option>
|
||||||
</optgroup>
|
</optgroup>}
|
||||||
</select>
|
</select>
|
||||||
{isKoboldConnection(connection) && <input
|
{isKoboldConnection(connection) && <input
|
||||||
value={connectionUrl}
|
value={connectionUrl}
|
||||||
onInput={setConnectionUrl}
|
onInput={setConnectionUrl}
|
||||||
onBlur={handleBlurUrl}
|
onBlur={handleBlurUrl}
|
||||||
class={urlValid ? styles.valid : styles.invalid}
|
class={isOnline ? styles.valid : styles.invalid}
|
||||||
/>}
|
/>}
|
||||||
{isHordeConnection(connection) && <>
|
{isHordeConnection(connection) && <>
|
||||||
<input
|
<input
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,14 @@
|
||||||
flex-direction: row;
|
flex-direction: row;
|
||||||
gap: 8px;
|
gap: 8px;
|
||||||
padding: 0 8px;
|
padding: 0 8px;
|
||||||
|
|
||||||
|
.online {
|
||||||
|
color: var(--brightGreen);
|
||||||
|
}
|
||||||
|
|
||||||
|
.offline {
|
||||||
|
color: var(--brightRed);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ export const Header = () => {
|
||||||
const promptsOpen = useBool();
|
const promptsOpen = useBool();
|
||||||
const genparamsOpen = useBool();
|
const genparamsOpen = useBool();
|
||||||
const assistantOpen = useBool();
|
const assistantOpen = useBool();
|
||||||
|
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
|
||||||
|
|
||||||
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
|
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
|
||||||
|
|
||||||
|
|
@ -56,7 +57,7 @@ export const Header = () => {
|
||||||
<div class={styles.header}>
|
<div class={styles.header}>
|
||||||
<div class={styles.inputs}>
|
<div class={styles.inputs}>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
<button class='icon' onClick={connectionsOpen.setTrue} title='Connection settings'>
|
<button class={`icon ${isOnline ? styles.online: styles.offline}`} onClick={connectionsOpen.setTrue} title='Connection settings'>
|
||||||
🔌
|
🔌
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import { useMemo } from "preact/hooks";
|
import { useMemo } from "preact/hooks";
|
||||||
import { MessageTools } from "../../messages";
|
import { MessageTools } from "../../tools/messages";
|
||||||
|
|
||||||
import styles from './message.module.css';
|
import styles from './message.module.css';
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks";
|
import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../../messages";
|
import { MessageTools, type IMessage } from "../../tools/messages";
|
||||||
import { StateContext } from "../../contexts/state";
|
import { StateContext } from "../../contexts/state";
|
||||||
import { DOMTools } from "../../dom";
|
import { DOMTools } from "../../tools/dom";
|
||||||
|
|
||||||
import styles from './message.module.css';
|
import styles from './message.module.css';
|
||||||
import { AutoTextarea } from "../autoTextarea";
|
import { AutoTextarea } from "../autoTextarea";
|
||||||
|
|
@ -16,7 +16,7 @@ interface IProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScroll }: IProps) => {
|
export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScroll }: IProps) => {
|
||||||
const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages } = useContext(StateContext);
|
const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages, continueMessage } = useContext(StateContext);
|
||||||
const [editing, setEditing] = useState(false);
|
const [editing, setEditing] = useState(false);
|
||||||
const [editedMessage, setEditedMessage] = useInputState('');
|
const [editedMessage, setEditedMessage] = useInputState('');
|
||||||
const textRef = useRef<HTMLDivElement>(null);
|
const textRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
@ -70,6 +70,10 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr
|
||||||
DOMTools.animate(textRef.current, 'swipe-from-right');
|
DOMTools.animate(textRef.current, 'swipe-from-right');
|
||||||
}, [setCurrentSwipe, index, message]);
|
}, [setCurrentSwipe, index, message]);
|
||||||
|
|
||||||
|
const handleContinueMessage = useCallback(() => {
|
||||||
|
continueMessage(true);
|
||||||
|
}, [continueMessage]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div class={`${styles.message} ${styles[message.role]} ${isLastUser ? styles.lastUser : ''}`}>
|
<div class={`${styles.message} ${styles[message.role]} ${isLastUser ? styles.lastUser : ''}`}>
|
||||||
<div class={styles.content}>
|
<div class={styles.content}>
|
||||||
|
|
@ -89,13 +93,14 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr
|
||||||
<button class='icon' onClick={handleCancelEdit} title='Cancel'>❌</button>
|
<button class='icon' onClick={handleCancelEdit} title='Cancel'>❌</button>
|
||||||
</>
|
</>
|
||||||
: <>
|
: <>
|
||||||
{isLastAssistant &&
|
{isLastAssistant && <>
|
||||||
<div class={styles.swipes}>
|
<div class={styles.swipes}>
|
||||||
<div onClick={handleSwipeLeft}>◀</div>
|
<div onClick={handleSwipeLeft}>◀</div>
|
||||||
<div>{message.currentSwipe + 1}/{message.swipes.length}</div>
|
<div>{message.currentSwipe + 1}/{message.swipes.length}</div>
|
||||||
<div onClick={handleSwipeRight}>▶</div>
|
<div onClick={handleSwipeRight}>▶</div>
|
||||||
</div>
|
</div>
|
||||||
}
|
<button class='icon' onClick={handleContinueMessage} title="Continue">▶</button>
|
||||||
|
</>}
|
||||||
<button class='icon' onClick={handleEnableEdit} title="Edit">🖊</button>
|
<button class='icon' onClick={handleEnableEdit} title="Edit">🖊</button>
|
||||||
</>
|
</>
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import { MessageTools, type IMessage } from "../../messages"
|
import { MessageTools, type IMessage } from "../../tools/messages"
|
||||||
import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks";
|
import { useCallback, useContext, useEffect, useMemo, useRef, useState } from "preact/hooks";
|
||||||
import { Modal } from "@common/components/modal/modal";
|
import { Modal } from "@common/components/modal/modal";
|
||||||
import { DOMTools } from "../../dom";
|
import { DOMTools } from "../../tools/dom";
|
||||||
|
|
||||||
import styles from './minichat.module.css';
|
import styles from './minichat.module.css';
|
||||||
import { LLMContext } from "../../contexts/llm";
|
import { LLMContext } from "../../contexts/llm";
|
||||||
import { FormattedMessage } from "../message/formattedMessage";
|
import { FormattedMessage } from "../message/formattedMessage";
|
||||||
import { AutoTextarea } from "../autoTextarea";
|
import { AutoTextarea } from "../autoTextarea";
|
||||||
|
import { useBool } from "@common/hooks/useBool";
|
||||||
|
|
||||||
interface IProps {
|
interface IProps {
|
||||||
open: boolean;
|
open: boolean;
|
||||||
|
|
@ -16,9 +17,10 @@ interface IProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
||||||
const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
|
const { stopGeneration, generate, compilePrompt } = useContext(LLMContext);
|
||||||
const [messages, setMessages] = useState<IMessage[]>([]);
|
const [messages, setMessages] = useState<IMessage[]>([]);
|
||||||
const ref = useRef<HTMLDivElement>(null);
|
const ref = useRef<HTMLDivElement>(null);
|
||||||
|
const generating = useBool();
|
||||||
|
|
||||||
const answer = useMemo(() =>
|
const answer = useMemo(() =>
|
||||||
MessageTools.getSwipe(messages.filter(m => m.role === 'assistant').at(-1))?.content,
|
MessageTools.getSwipe(messages.filter(m => m.role === 'assistant').at(-1))?.content,
|
||||||
|
|
@ -33,7 +35,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setTimeout(() => DOMTools.scrollDown(ref.current, false), 100);
|
setTimeout(() => DOMTools.scrollDown(ref.current, false), 100);
|
||||||
}, [generating, open]);
|
}, [generating.value, open]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
DOMTools.scrollDown(ref.current, false);
|
DOMTools.scrollDown(ref.current, false);
|
||||||
|
|
@ -47,19 +49,21 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
}, [messages.length, handleInit]);
|
}, [messages.length, handleInit]);
|
||||||
|
|
||||||
const handleGenerate = useCallback(async () => {
|
const handleGenerate = useCallback(async () => {
|
||||||
if (messages.length > 0 && !generating) {
|
if (messages.length > 0 && !generating.value) {
|
||||||
const promptMessages: IMessage[] = [...history, ...messages];
|
const promptMessages: IMessage[] = [...history, ...messages];
|
||||||
const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1 });
|
const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1, continueLast: true });
|
||||||
|
|
||||||
let text = '';
|
let text = '';
|
||||||
const messageId = messages.length;
|
const messageId = messages.length;
|
||||||
const newMessages = [...messages, MessageTools.create('', 'assistant', true)];
|
const newMessages = [...messages, MessageTools.create('', 'assistant', true)];
|
||||||
setMessages(newMessages);
|
setMessages(newMessages);
|
||||||
|
|
||||||
|
generating.setTrue();
|
||||||
for await (const chunk of generate(prompt)) {
|
for await (const chunk of generate(prompt)) {
|
||||||
text += chunk;
|
text += chunk;
|
||||||
setMessages(MessageTools.updateSwipe(newMessages, messageId, { content: text.trim() }));
|
setMessages(MessageTools.updateSwipe(newMessages, messageId, { content: text.trim() }));
|
||||||
}
|
}
|
||||||
|
generating.setFalse();
|
||||||
|
|
||||||
setMessages([
|
setMessages([
|
||||||
...MessageTools.updateSwipe(newMessages, messageId, { content: MessageTools.trimSentence(text) }),
|
...MessageTools.updateSwipe(newMessages, messageId, { content: MessageTools.trimSentence(text) }),
|
||||||
|
|
@ -90,7 +94,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
<div class={styles.minichat} ref={ref}>
|
<div class={styles.minichat} ref={ref}>
|
||||||
<div class={styles.messages}>
|
<div class={styles.messages}>
|
||||||
{messages.map((m, i) => (
|
{messages.map((m, i) => (
|
||||||
generating
|
generating.value
|
||||||
? <FormattedMessage key={i} class={`${styles[m.role]} ${styles.message}`}>
|
? <FormattedMessage key={i} class={`${styles[m.role]} ${styles.message}`}>
|
||||||
{MessageTools.getSwipe(m)?.content ?? ''}
|
{MessageTools.getSwipe(m)?.content ?? ''}
|
||||||
</FormattedMessage>
|
</FormattedMessage>
|
||||||
|
|
@ -105,18 +109,18 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
{generating
|
{generating.value
|
||||||
? <button onClick={stopGeneration}>Stop</button>
|
? <button onClick={stopGeneration}>Stop</button>
|
||||||
: <button onClick={handleGenerate}>Generate</button>
|
: <button onClick={handleGenerate}>Generate</button>
|
||||||
}
|
}
|
||||||
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
|
<button onClick={() => handleInit()} class={`${generating.value ? 'disabled' : ''}`}>
|
||||||
Clear
|
Clear
|
||||||
</button>
|
</button>
|
||||||
{Object.entries(buttons).map(([label, onClick], i) => (
|
{Object.entries(buttons).map(([label, onClick], i) => (
|
||||||
<button
|
<button
|
||||||
key={i}
|
key={i}
|
||||||
onClick={() => onClick(answer ?? '')}
|
onClick={() => onClick(answer ?? '')}
|
||||||
class={`${(generating || !answer) ? 'disabled' : ''}`}
|
class={`${(generating.value || !answer) ? 'disabled' : ''}`}
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
</button>
|
</button>
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
import { createContext } from "preact";
|
import { createContext } from "preact";
|
||||||
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../messages";
|
import { MessageTools, type IMessage } from "../tools/messages";
|
||||||
import { StateContext } from "./state";
|
import { StateContext } from "./state";
|
||||||
import { useBool } from "@common/hooks/useBool";
|
import { useBool } from "@common/hooks/useBool";
|
||||||
import { Template } from "@huggingface/jinja";
|
import { Huggingface } from "../tools/huggingface";
|
||||||
import { Huggingface } from "../huggingface";
|
import { Connection, type IGenerationSettings } from "../tools/connection";
|
||||||
import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection";
|
|
||||||
import { throttle } from "@common/utils";
|
import { throttle } from "@common/utils";
|
||||||
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
|
import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
|
||||||
|
import { approximateTokens, normalizeModel } from "../tools/model";
|
||||||
|
|
||||||
interface ICompileArgs {
|
interface ICompileArgs {
|
||||||
keepUsers?: number;
|
keepUsers?: number;
|
||||||
raw?: boolean;
|
continueLast?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ICompiledPrompt {
|
interface ICompiledPrompt {
|
||||||
|
|
@ -48,8 +48,8 @@ const processing = {
|
||||||
|
|
||||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const {
|
const {
|
||||||
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
||||||
setTriggerNext, addMessage, editMessage, editSummary,
|
setTriggerNext, setContinueLast, addMessage, editMessage, editSummary,
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
|
|
||||||
const generating = useBool(false);
|
const generating = useBool(false);
|
||||||
|
|
@ -58,26 +58,27 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const [modelName, setModelName] = useState('');
|
const [modelName, setModelName] = useState('');
|
||||||
const [hasToolCalls, setHasToolCalls] = useState(false);
|
const [hasToolCalls, setHasToolCalls] = useState(false);
|
||||||
|
|
||||||
const userPromptTemplate = useMemo(() => {
|
const isOnline = useMemo(() => contextLength > 0, [contextLength]);
|
||||||
try {
|
|
||||||
return new Template(userPrompt)
|
|
||||||
} catch {
|
|
||||||
return {
|
|
||||||
render: () => userPrompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [userPrompt]);
|
|
||||||
|
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
compilePrompt: async (messages, { keepUsers } = {}) => {
|
compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => {
|
||||||
const promptMessages = messages.slice();
|
const lastMessage = messages.at(-1);
|
||||||
const lastMessage = promptMessages.at(-1);
|
const lastMessageContent = MessageTools.getSwipe(lastMessage)?.content;
|
||||||
const isAssistantLast = lastMessage?.role === 'assistant';
|
const isAssistantLast = lastMessage?.role === 'assistant';
|
||||||
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
|
let isRegen = continueLast;
|
||||||
|
|
||||||
|
if (!isAssistantLast) {
|
||||||
|
isRegen = false;
|
||||||
|
} else if (!lastMessageContent) {
|
||||||
|
isRegen = true;
|
||||||
|
}
|
||||||
|
|
||||||
const isContinue = isAssistantLast && !isRegen;
|
const isContinue = isAssistantLast && !isRegen;
|
||||||
|
|
||||||
|
const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice();
|
||||||
|
|
||||||
if (isContinue) {
|
if (isContinue) {
|
||||||
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
|
promptMessages.push(MessageTools.create(Huggingface.applyTemplate(userPrompt, {})));
|
||||||
}
|
}
|
||||||
|
|
||||||
const userMessages = promptMessages.filter(m => m.role === 'user');
|
const userMessages = promptMessages.filter(m => m.role === 'user');
|
||||||
|
|
@ -104,7 +105,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
} else if (role === 'user' && !message.technical) {
|
} else if (role === 'user' && !message.technical) {
|
||||||
templateMessages.push({
|
templateMessages.push({
|
||||||
role: message.role,
|
role: message.role,
|
||||||
content: userPromptTemplate.render({ prompt: content, isStart: !wasStory }),
|
content: Huggingface.applyTemplate(userPrompt, { prompt: content, isStart: !wasStory }),
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
if (role === 'assistant') {
|
if (role === 'assistant') {
|
||||||
|
|
@ -128,17 +129,17 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
if (story.length > 0) {
|
if (story.length > 0) {
|
||||||
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
const prompt = MessageTools.getSwipe(firstUserMessage)?.content;
|
||||||
templateMessages.push({ role: 'user', content: userPromptTemplate.render({ prompt, isStart: true }) });
|
templateMessages.push({ role: 'user', content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }) });
|
||||||
templateMessages.push({ role: 'assistant', content: story });
|
templateMessages.push({ role: 'assistant', content: story });
|
||||||
}
|
}
|
||||||
|
|
||||||
let userPrompt = MessageTools.getSwipe(lastUserMessage)?.content;
|
let userMessage = MessageTools.getSwipe(lastUserMessage)?.content;
|
||||||
if (!lastUserMessage?.technical && !isContinue && userPrompt) {
|
if (!lastUserMessage?.technical && !isContinue && userMessage) {
|
||||||
userPrompt = userPromptTemplate.render({ prompt: userPrompt, isStart: story.length === 0 });
|
userMessage = Huggingface.applyTemplate(userPrompt, { prompt: userMessage, isStart: story.length === 0 });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (userPrompt) {
|
if (userMessage) {
|
||||||
templateMessages.push({ role: 'user', content: userPrompt });
|
templateMessages.push({ role: 'user', content: userMessage });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -147,13 +148,18 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
templateMessages.splice(1, 0, {
|
templateMessages.splice(1, 0, {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: userPromptTemplate.render({ prompt, isStart: true }),
|
content: Huggingface.applyTemplate(userPrompt, { prompt, isStart: true }),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
||||||
|
|
||||||
const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
|
let prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
|
||||||
|
|
||||||
|
if (isRegen) {
|
||||||
|
prompt += lastMessageContent;
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
prompt,
|
prompt,
|
||||||
isContinue,
|
isContinue,
|
||||||
|
|
@ -194,20 +200,23 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
return await Connection.countTokens(connection, prompt);
|
return await Connection.countTokens(connection, prompt);
|
||||||
},
|
},
|
||||||
stopGeneration: () => {
|
stopGeneration: () => {
|
||||||
Connection.stopGeneration();
|
Connection.stopGeneration();
|
||||||
},
|
},
|
||||||
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
}), [connection, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt]);
|
||||||
|
|
||||||
useAsyncEffect(async () => {
|
useAsyncEffect(async () => {
|
||||||
if (triggerNext && !generating.value) {
|
if (isOnline && triggerNext && !generating.value) {
|
||||||
setTriggerNext(false);
|
setTriggerNext(false);
|
||||||
|
setContinueLast(false);
|
||||||
|
|
||||||
let messageId = messages.length - 1;
|
let messageId = messages.length - 1;
|
||||||
let text: string = '';
|
let text = '';
|
||||||
|
|
||||||
const { prompt, isRegen } = await actions.compilePrompt(messages);
|
const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast });
|
||||||
|
|
||||||
if (!isRegen) {
|
if (isRegen) {
|
||||||
|
text = MessageTools.getSwipe(messages.at(-1))?.content ?? '';
|
||||||
|
} else {
|
||||||
addMessage('', 'assistant');
|
addMessage('', 'assistant');
|
||||||
messageId++;
|
messageId++;
|
||||||
}
|
}
|
||||||
|
|
@ -227,10 +236,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
MessageTools.playReady();
|
MessageTools.playReady();
|
||||||
}
|
}
|
||||||
}, [triggerNext]);
|
}, [triggerNext, isOnline]);
|
||||||
|
|
||||||
useAsyncEffect(async () => {
|
useAsyncEffect(async () => {
|
||||||
if (summaryEnabled && !processing.summarizing) {
|
if (isOnline && summaryEnabled && !processing.summarizing) {
|
||||||
try {
|
try {
|
||||||
processing.summarizing = true;
|
processing.summarizing = true;
|
||||||
for (let id = 0; id < messages.length; id++) {
|
for (let id = 0; id < messages.length; id++) {
|
||||||
|
|
@ -247,7 +256,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
processing.summarizing = false;
|
processing.summarizing = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [messages, summaryEnabled]);
|
}, [messages, summaryEnabled, isOnline]);
|
||||||
|
|
||||||
useEffect(throttle(() => {
|
useEffect(throttle(() => {
|
||||||
Connection.getContextLength(connection).then(setContextLength);
|
Connection.getContextLength(connection).then(setContextLength);
|
||||||
|
|
@ -255,7 +264,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}, 1000, true), [connection]);
|
}, 1000, true), [connection]);
|
||||||
|
|
||||||
const calculateTokens = useCallback(throttle(async () => {
|
const calculateTokens = useCallback(throttle(async () => {
|
||||||
if (!processing.tokenizing && !generating.value) {
|
if (isOnline && !processing.tokenizing && !generating.value) {
|
||||||
try {
|
try {
|
||||||
processing.tokenizing = true;
|
processing.tokenizing = true;
|
||||||
const { prompt } = await actions.compilePrompt(messages);
|
const { prompt } = await actions.compilePrompt(messages);
|
||||||
|
|
@ -267,11 +276,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
processing.tokenizing = false;
|
processing.tokenizing = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, 1000, true), [actions, messages]);
|
}, 1000, true), [actions, messages, isOnline]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
calculateTokens();
|
calculateTokens();
|
||||||
}, [messages, connection, systemPrompt, lore, userPrompt]);
|
}, [messages, connection, systemPrompt, lore, userPrompt, isOnline]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
try {
|
try {
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import { createContext } from "preact";
|
import { createContext } from "preact";
|
||||||
import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
|
import { useCallback, useEffect, useMemo, useState } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../messages";
|
import { MessageTools, type IMessage } from "../tools/messages";
|
||||||
import { useInputState } from "@common/hooks/useInputState";
|
import { useInputState } from "@common/hooks/useInputState";
|
||||||
import { type IConnection } from "../connection";
|
import { type IConnection } from "../tools/connection";
|
||||||
|
|
||||||
interface IContext {
|
interface IContext {
|
||||||
currentConnection: number;
|
currentConnection: number;
|
||||||
|
|
@ -15,7 +15,9 @@ interface IContext {
|
||||||
summaryEnabled: boolean;
|
summaryEnabled: boolean;
|
||||||
bannedWords: string[];
|
bannedWords: string[];
|
||||||
messages: IMessage[];
|
messages: IMessage[];
|
||||||
|
//
|
||||||
triggerNext: boolean;
|
triggerNext: boolean;
|
||||||
|
continueLast: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface IComputableContext {
|
interface IComputableContext {
|
||||||
|
|
@ -33,9 +35,11 @@ interface IActions {
|
||||||
setUserPrompt: (prompt: string | Event) => void;
|
setUserPrompt: (prompt: string | Event) => void;
|
||||||
setSummarizePrompt: (prompt: string | Event) => void;
|
setSummarizePrompt: (prompt: string | Event) => void;
|
||||||
setBannedWords: (words: string[]) => void;
|
setBannedWords: (words: string[]) => void;
|
||||||
setTriggerNext: (triggerNext: boolean) => void;
|
|
||||||
setSummaryEnabled: (summaryEnabled: boolean) => void;
|
setSummaryEnabled: (summaryEnabled: boolean) => void;
|
||||||
|
|
||||||
|
setTriggerNext: (triggerNext: boolean) => void;
|
||||||
|
setContinueLast: (continueLast: boolean) => void;
|
||||||
|
|
||||||
setMessages: (messages: IMessage[]) => void;
|
setMessages: (messages: IMessage[]) => void;
|
||||||
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
|
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
|
||||||
editMessage: (index: number, content: string) => void;
|
editMessage: (index: number, content: string) => void;
|
||||||
|
|
@ -44,7 +48,7 @@ interface IActions {
|
||||||
setCurrentSwipe: (index: number, swipe: number) => void;
|
setCurrentSwipe: (index: number, swipe: number) => void;
|
||||||
addSwipe: (index: number, content: string) => void;
|
addSwipe: (index: number, content: string) => void;
|
||||||
|
|
||||||
continueMessage: () => void;
|
continueMessage: (continueLast?: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const SAVE_KEY = 'ai_game_save_state';
|
const SAVE_KEY = 'ai_game_save_state';
|
||||||
|
|
@ -79,7 +83,7 @@ Continue the story forward.
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
|
|
||||||
{% if prompt -%}
|
{% if prompt -%}
|
||||||
This is the description of What should happen next in your answer: {{ prompt | trim }}
|
This is the description of what should happen next in your answer: {{ prompt | trim }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
Remember that this story should be infinite and go forever.
|
Remember that this story should be infinite and go forever.
|
||||||
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
|
Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`,
|
||||||
|
|
@ -88,11 +92,13 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
|
||||||
bannedWords: [],
|
bannedWords: [],
|
||||||
messages: [],
|
messages: [],
|
||||||
triggerNext: false,
|
triggerNext: false,
|
||||||
|
continueLast: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const saveContext = (context: IContext) => {
|
export const saveContext = (context: IContext) => {
|
||||||
const contextToSave: Partial<IContext> = { ...context };
|
const contextToSave: Partial<IContext> = { ...context };
|
||||||
delete contextToSave.triggerNext;
|
delete contextToSave.triggerNext;
|
||||||
|
delete contextToSave.continueLast;
|
||||||
|
|
||||||
localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
|
localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
|
||||||
}
|
}
|
||||||
|
|
@ -130,6 +136,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
|
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
|
||||||
|
|
||||||
const [triggerNext, setTriggerNext] = useState(false);
|
const [triggerNext, setTriggerNext] = useState(false);
|
||||||
|
const [continueLast, setContinueLast] = useState(false);
|
||||||
const [instruct, setInstruct] = useInputState(connection.instruct);
|
const [instruct, setInstruct] = useInputState(connection.instruct);
|
||||||
|
|
||||||
const setConnection = useCallback((c: IConnection) => {
|
const setConnection = useCallback((c: IConnection) => {
|
||||||
|
|
@ -153,8 +160,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
setUserPrompt,
|
setUserPrompt,
|
||||||
setSummarizePrompt,
|
setSummarizePrompt,
|
||||||
setLore,
|
setLore,
|
||||||
setTriggerNext,
|
|
||||||
setSummaryEnabled,
|
setSummaryEnabled,
|
||||||
|
|
||||||
|
setTriggerNext,
|
||||||
|
setContinueLast,
|
||||||
|
|
||||||
setBannedWords: (words) => setBannedWords(words.slice()),
|
setBannedWords: (words) => setBannedWords(words.slice()),
|
||||||
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
|
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
|
||||||
|
|
||||||
|
|
@ -224,7 +234,10 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
continueMessage: () => setTriggerNext(true),
|
continueMessage: (c = false) => {
|
||||||
|
setTriggerNext(true);
|
||||||
|
setContinueLast(c);
|
||||||
|
},
|
||||||
}), []);
|
}), []);
|
||||||
|
|
||||||
const rawContext: IContext & IComputableContext = {
|
const rawContext: IContext & IComputableContext = {
|
||||||
|
|
@ -239,7 +252,9 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
summaryEnabled,
|
summaryEnabled,
|
||||||
bannedWords,
|
bannedWords,
|
||||||
messages,
|
messages,
|
||||||
|
//
|
||||||
triggerNext,
|
triggerNext,
|
||||||
|
continueLast,
|
||||||
};
|
};
|
||||||
|
|
||||||
const context = useMemo(() => rawContext, Object.values(rawContext));
|
const context = useMemo(() => rawContext, Object.values(rawContext));
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
import Lock from "@common/lock";
|
import Lock from "@common/lock";
|
||||||
import SSE from "@common/sse";
|
import SSE from "@common/sse";
|
||||||
import { throttle } from "@common/utils";
|
import { throttle } from "@common/utils";
|
||||||
import delay, { clearDelay } from "delay";
|
import delay from "delay";
|
||||||
|
import { Huggingface } from "./huggingface";
|
||||||
|
import { approximateTokens, normalizeModel } from "./model";
|
||||||
|
|
||||||
interface IBaseConnection {
|
interface IBaseConnection {
|
||||||
instruct: string;
|
instruct: string;
|
||||||
|
|
@ -79,34 +81,6 @@ const MAX_HORDE_LENGTH = 512;
|
||||||
const MAX_HORDE_CONTEXT = 32000;
|
const MAX_HORDE_CONTEXT = 32000;
|
||||||
export const HORDE_ANON_KEY = '0000000000';
|
export const HORDE_ANON_KEY = '0000000000';
|
||||||
|
|
||||||
export const normalizeModel = (model: string) => {
|
|
||||||
let currentModel = model.split(/[\\\/]/).at(-1);
|
|
||||||
currentModel = currentModel.split('::').at(0);
|
|
||||||
let normalizedModel: string;
|
|
||||||
|
|
||||||
do {
|
|
||||||
normalizedModel = currentModel;
|
|
||||||
|
|
||||||
currentModel = currentModel
|
|
||||||
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
|
||||||
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
|
|
||||||
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
|
||||||
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
|
||||||
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
|
||||||
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
|
|
||||||
.replace(/^(debug-?)+/i, '')
|
|
||||||
.trim();
|
|
||||||
} while (normalizedModel !== currentModel);
|
|
||||||
|
|
||||||
return normalizedModel
|
|
||||||
.replace(/[ _-]+/ig, '-')
|
|
||||||
.replace(/\.{2,}/, '-')
|
|
||||||
.replace(/[ ._-]+$/ig, '')
|
|
||||||
.trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
|
|
||||||
|
|
||||||
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
||||||
|
|
||||||
export namespace Connection {
|
export namespace Connection {
|
||||||
|
|
@ -171,7 +145,11 @@ export namespace Connection {
|
||||||
sse.close();
|
sse.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
async function generateHorde(connection: Omit<IHordeConnection, keyof IBaseConnection>, prompt: string, extraSettings: IGenerationSettings = {}): Promise<string> {
|
async function* generateHorde(connection: IHordeConnection, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator<string> {
|
||||||
|
if (!connection.model) {
|
||||||
|
throw new Error('Horde not connected');
|
||||||
|
}
|
||||||
|
|
||||||
const models = await getHordeModels();
|
const models = await getHordeModels();
|
||||||
const model = models.get(connection.model);
|
const model = models.get(connection.model);
|
||||||
if (model) {
|
if (model) {
|
||||||
|
|
@ -192,54 +170,78 @@ export namespace Connection {
|
||||||
models: model.hordeNames,
|
models: model.hordeNames,
|
||||||
workers: model.workers,
|
workers: model.workers,
|
||||||
};
|
};
|
||||||
|
const bannedTokens = requestData.params.banned_tokens ?? [];
|
||||||
|
|
||||||
const { signal } = abortController;
|
const { signal } = abortController;
|
||||||
|
|
||||||
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
|
while (true) {
|
||||||
method: 'POST',
|
const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, {
|
||||||
body: JSON.stringify(requestData),
|
method: 'POST',
|
||||||
headers: {
|
body: JSON.stringify(requestData),
|
||||||
'Content-Type': 'application/json',
|
headers: {
|
||||||
apikey: connection.apiKey || HORDE_ANON_KEY,
|
'Content-Type': 'application/json',
|
||||||
},
|
apikey: connection.apiKey || HORDE_ANON_KEY,
|
||||||
signal,
|
},
|
||||||
});
|
signal,
|
||||||
|
});
|
||||||
|
|
||||||
if (!generateResponse.ok || generateResponse.status >= 400) {
|
if (!generateResponse.ok || generateResponse.status >= 400) {
|
||||||
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
|
throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`);
|
||||||
}
|
|
||||||
|
|
||||||
const { id } = await generateResponse.json() as { id: string };
|
|
||||||
const request = async (method = 'GET'): Promise<string | null> => {
|
|
||||||
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
|
|
||||||
if (response.ok && response.status < 400) {
|
|
||||||
const result: IHordeResult = await response.json();
|
|
||||||
if (result.generations?.length === 1) {
|
|
||||||
const { text } = result.generations[0];
|
|
||||||
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw new Error(await response.text());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
const { id } = await generateResponse.json() as { id: string };
|
||||||
};
|
const request = async (method = 'GET'): Promise<string | null> => {
|
||||||
|
const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method });
|
||||||
|
if (response.ok && response.status < 400) {
|
||||||
|
const result: IHordeResult = await response.json();
|
||||||
|
if (result.generations?.length === 1) {
|
||||||
|
const { text } = result.generations[0];
|
||||||
|
|
||||||
const deleteRequest = async () => (await request('DELETE')) ?? '';
|
return text;
|
||||||
|
}
|
||||||
while (true) {
|
} else {
|
||||||
try {
|
throw new Error(await response.text());
|
||||||
await delay(2500, { signal });
|
}
|
||||||
|
|
||||||
const text = await request();
|
return null;
|
||||||
|
};
|
||||||
if (text) {
|
|
||||||
return text;
|
const deleteRequest = async () => (await request('DELETE')) ?? '';
|
||||||
|
let text: string | null = null;
|
||||||
|
|
||||||
|
while (!text) {
|
||||||
|
try {
|
||||||
|
await delay(2500, { signal });
|
||||||
|
|
||||||
|
text = await request();
|
||||||
|
|
||||||
|
if (text) {
|
||||||
|
const locaseText = text.toLowerCase();
|
||||||
|
let unsloppedText = text;
|
||||||
|
for (const ban of bannedTokens) {
|
||||||
|
const slopIdx = locaseText.indexOf(ban.toLowerCase());
|
||||||
|
if (slopIdx >= 0) {
|
||||||
|
console.log(`[horde] slop '${ban}' detected at ${slopIdx}`);
|
||||||
|
unsloppedText = unsloppedText.slice(0, slopIdx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
yield unsloppedText;
|
||||||
|
|
||||||
|
requestData.prompt += unsloppedText;
|
||||||
|
|
||||||
|
if (unsloppedText === text) {
|
||||||
|
return; // we are finished
|
||||||
|
}
|
||||||
|
|
||||||
|
if (unsloppedText.length === 0) {
|
||||||
|
requestData.params.temperature += 0.05;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error in horde generation:', e);
|
||||||
|
return yield deleteRequest();
|
||||||
}
|
}
|
||||||
} catch (e) {
|
|
||||||
console.error('Error in horde generation:', e);
|
|
||||||
return deleteRequest();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -251,7 +253,7 @@ export namespace Connection {
|
||||||
if (isKoboldConnection(connection)) {
|
if (isKoboldConnection(connection)) {
|
||||||
yield* generateKobold(connection.url, prompt, extraSettings);
|
yield* generateKobold(connection.url, prompt, extraSettings);
|
||||||
} else if (isHordeConnection(connection)) {
|
} else if (isHordeConnection(connection)) {
|
||||||
yield await generateHorde(connection, prompt, extraSettings);
|
yield* generateHorde(connection, prompt, extraSettings);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -277,7 +279,7 @@ export namespace Connection {
|
||||||
|
|
||||||
for (const worker of goodWorkers) {
|
for (const worker of goodWorkers) {
|
||||||
for (const modelName of worker.models) {
|
for (const modelName of worker.models) {
|
||||||
const normName = normalizeModel(modelName.toLowerCase());
|
const normName = normalizeModel(modelName);
|
||||||
let model = models.get(normName);
|
let model = models.get(normName);
|
||||||
if (!model) {
|
if (!model) {
|
||||||
model = {
|
model = {
|
||||||
|
|
@ -343,7 +345,7 @@ export namespace Connection {
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error('Error getting max tokens', e);
|
console.error('Error getting max tokens', e);
|
||||||
}
|
}
|
||||||
} else if (isHordeConnection(connection)) {
|
} else if (isHordeConnection(connection) && connection.model) {
|
||||||
const models = await getHordeModels();
|
const models = await getHordeModels();
|
||||||
const model = models.get(connection.model);
|
const model = models.get(connection.model);
|
||||||
if (model) {
|
if (model) {
|
||||||
|
|
@ -367,7 +369,18 @@ export namespace Connection {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error('Error counting tokens', e);
|
console.error('Error counting tokens:', e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const model = await getModelName(connection);
|
||||||
|
const tokenizer = await Huggingface.findTokenizer(model);
|
||||||
|
if (tokenizer) {
|
||||||
|
try {
|
||||||
|
const { input_ids } = await tokenizer(prompt);
|
||||||
|
return input_ids.data.length;
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Error counting tokens with tokenizer:', e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import { gguf } from '@huggingface/gguf';
|
import { gguf } from '@huggingface/gguf';
|
||||||
import * as hub from '@huggingface/hub';
|
import * as hub from '@huggingface/hub';
|
||||||
import { Template } from '@huggingface/jinja';
|
import { Template } from '@huggingface/jinja';
|
||||||
import { normalizeModel } from './connection';
|
import { AutoTokenizer, PreTrainedTokenizer } from '@huggingface/transformers';
|
||||||
|
import { normalizeModel } from './model';
|
||||||
|
|
||||||
export namespace Huggingface {
|
export namespace Huggingface {
|
||||||
export interface ITemplateMessage {
|
export interface ITemplateMessage {
|
||||||
|
|
@ -81,6 +82,7 @@ export namespace Huggingface {
|
||||||
|
|
||||||
const templateCache: Record<string, string> = loadCache();
|
const templateCache: Record<string, string> = loadCache();
|
||||||
const compiledTemplates = new Map<string, Template>();
|
const compiledTemplates = new Map<string, Template>();
|
||||||
|
const tokenizerCache = new Map<string, PreTrainedTokenizer | null>();
|
||||||
|
|
||||||
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
|
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
|
||||||
obj != null && typeof obj === 'object' && (field in obj)
|
obj != null && typeof obj === 'object' && (field in obj)
|
||||||
|
|
@ -92,13 +94,13 @@ export namespace Huggingface {
|
||||||
);
|
);
|
||||||
|
|
||||||
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise<TokenizerConfig | null> => {
|
||||||
|
modelName = normalizeModel(modelName);
|
||||||
console.log(`[huggingface] searching config for '${modelName}'`);
|
console.log(`[huggingface] searching config for '${modelName}'`);
|
||||||
const searchModel = normalizeModel(modelName);
|
|
||||||
|
|
||||||
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] }));
|
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] }));
|
||||||
const models = hubModels.filter(m => {
|
const models = hubModels.filter(m => {
|
||||||
if (m.gated) return false;
|
if (m.gated) return false;
|
||||||
if (!normalizeModel(m.name).includes(searchModel)) return false;
|
if (!normalizeModel(m.name).includes(modelName)) return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}).sort((a, b) => b.downloads - a.downloads);
|
}).sort((a, b) => b.downloads - a.downloads);
|
||||||
|
|
@ -116,8 +118,8 @@ export namespace Huggingface {
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
console.log(`[huggingface] searching config in '${model.name}/tokenizer_config.json'`);
|
console.log(`[huggingface] searching config in '${name}/tokenizer_config.json'`);
|
||||||
const fileResponse = await hub.downloadFile({ repo: model.name, path: 'tokenizer_config.json' });
|
const fileResponse = await hub.downloadFile({ repo: name, path: 'tokenizer_config.json' });
|
||||||
if (fileResponse?.ok) {
|
if (fileResponse?.ok) {
|
||||||
const maybeConfig = await fileResponse.json();
|
const maybeConfig = await fileResponse.json();
|
||||||
if (isTokenizerConfig(maybeConfig)) {
|
if (isTokenizerConfig(maybeConfig)) {
|
||||||
|
|
@ -232,10 +234,10 @@ export namespace Huggingface {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
export const findModelTemplate = async (modelName: string): Promise<string | null> => {
|
||||||
const modelKey = modelName.toLowerCase().trim();
|
modelName = normalizeModel(modelName);
|
||||||
if (!modelKey) return '';
|
if (!modelName) return '';
|
||||||
|
|
||||||
let template = templateCache[modelKey] ?? null;
|
let template = templateCache[modelName] ?? null;
|
||||||
|
|
||||||
if (template) {
|
if (template) {
|
||||||
console.log(`[huggingface] found cached template for '${modelName}'`);
|
console.log(`[huggingface] found cached template for '${modelName}'`);
|
||||||
|
|
@ -249,18 +251,58 @@ export namespace Huggingface {
|
||||||
|
|
||||||
if (config.bos_token) {
|
if (config.bos_token) {
|
||||||
template = template
|
template = template
|
||||||
.replaceAll(config.bos_token, '')
|
|
||||||
.replace(/\{\{ ?(''|"") ?\}\}/g, '');
|
.replace(/\{\{ ?(''|"") ?\}\}/g, '');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
templateCache[modelKey] = template;
|
templateCache[modelName] = template;
|
||||||
saveCache(templateCache);
|
saveCache(templateCache);
|
||||||
|
|
||||||
return template;
|
return template;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const findTokenizer = async (modelName: string): Promise<PreTrainedTokenizer | null> => {
|
||||||
|
modelName = normalizeModel(modelName);
|
||||||
|
|
||||||
|
let tokenizer = tokenizerCache.get(modelName) ?? null;
|
||||||
|
|
||||||
|
if (tokenizer) {
|
||||||
|
return tokenizer;
|
||||||
|
} else if (!tokenizerCache.has(modelName)) {
|
||||||
|
console.log(`[huggingface] searching tokenizer for '${modelName}'`);
|
||||||
|
|
||||||
|
const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName } }));
|
||||||
|
const models = hubModels.filter(m => {
|
||||||
|
if (m.gated) return false;
|
||||||
|
if (m.name.toLowerCase().includes('gguf')) return false;
|
||||||
|
if (!normalizeModel(m.name).includes(modelName)) return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const model of models) {
|
||||||
|
const { name } = model;
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log(`[huggingface] searching tokenizer in '${name}'`);
|
||||||
|
tokenizer = await AutoTokenizer.from_pretrained(name);
|
||||||
|
break;
|
||||||
|
} catch { }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizerCache.set(modelName, tokenizer);
|
||||||
|
|
||||||
|
if (tokenizer) {
|
||||||
|
console.log(`[huggingface] found tokenizer for '${modelName}'`);
|
||||||
|
} else {
|
||||||
|
console.log(`[huggingface] not found tokenizer for '${modelName}'`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenizer;
|
||||||
|
}
|
||||||
|
|
||||||
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => (
|
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => (
|
||||||
applyTemplate(templateString, {
|
applyTemplate(templateString, {
|
||||||
messages,
|
messages,
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
import { Template } from "@huggingface/jinja";
|
import messageSound from '../assets/message.mp3';
|
||||||
import messageSound from './assets/message.mp3';
|
|
||||||
|
|
||||||
export interface ISwipe {
|
export interface ISwipe {
|
||||||
content: string;
|
content: string;
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
export const normalizeModel = (model: string) => {
|
||||||
|
let currentModel = model.split(/[\\\/]/).at(-1);
|
||||||
|
currentModel = currentModel.split('::').at(0).toLowerCase();
|
||||||
|
let normalizedModel: string;
|
||||||
|
|
||||||
|
do {
|
||||||
|
normalizedModel = currentModel;
|
||||||
|
|
||||||
|
currentModel = currentModel
|
||||||
|
.replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k
|
||||||
|
.replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name
|
||||||
|
.replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc
|
||||||
|
.replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size
|
||||||
|
.replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw
|
||||||
|
.replace(/[ ._-]f(p|loat)?(8|16|32)/i, '')
|
||||||
|
.replace(/^(debug-?)+/i, '')
|
||||||
|
.trim();
|
||||||
|
} while (normalizedModel !== currentModel);
|
||||||
|
|
||||||
|
return normalizedModel
|
||||||
|
.replace(/[ _-]+/ig, '-')
|
||||||
|
.replace(/\.{2,}/, '-')
|
||||||
|
.replace(/[ ._-]+$/ig, '')
|
||||||
|
.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
export const approximateTokens = (prompt: string): number => Math.round(prompt.length / 4);
|
||||||
Loading…
Reference in New Issue