AIStory: continue button
This commit is contained in:
parent
9c4cc61573
commit
a213e0407c
|
|
@ -16,7 +16,7 @@ interface IProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScroll }: IProps) => {
|
export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScroll }: IProps) => {
|
||||||
const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages } = useContext(StateContext);
|
const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages, continueMessage } = useContext(StateContext);
|
||||||
const [editing, setEditing] = useState(false);
|
const [editing, setEditing] = useState(false);
|
||||||
const [editedMessage, setEditedMessage] = useInputState('');
|
const [editedMessage, setEditedMessage] = useInputState('');
|
||||||
const textRef = useRef<HTMLDivElement>(null);
|
const textRef = useRef<HTMLDivElement>(null);
|
||||||
|
|
@ -70,6 +70,10 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr
|
||||||
DOMTools.animate(textRef.current, 'swipe-from-right');
|
DOMTools.animate(textRef.current, 'swipe-from-right');
|
||||||
}, [setCurrentSwipe, index, message]);
|
}, [setCurrentSwipe, index, message]);
|
||||||
|
|
||||||
|
const handleContinueMessage = useCallback(() => {
|
||||||
|
continueMessage(true);
|
||||||
|
}, [continueMessage]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div class={`${styles.message} ${styles[message.role]} ${isLastUser ? styles.lastUser : ''}`}>
|
<div class={`${styles.message} ${styles[message.role]} ${isLastUser ? styles.lastUser : ''}`}>
|
||||||
<div class={styles.content}>
|
<div class={styles.content}>
|
||||||
|
|
@ -89,13 +93,14 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr
|
||||||
<button class='icon' onClick={handleCancelEdit} title='Cancel'>❌</button>
|
<button class='icon' onClick={handleCancelEdit} title='Cancel'>❌</button>
|
||||||
</>
|
</>
|
||||||
: <>
|
: <>
|
||||||
{isLastAssistant &&
|
{isLastAssistant && <>
|
||||||
<div class={styles.swipes}>
|
<div class={styles.swipes}>
|
||||||
<div onClick={handleSwipeLeft}>◀</div>
|
<div onClick={handleSwipeLeft}>◀</div>
|
||||||
<div>{message.currentSwipe + 1}/{message.swipes.length}</div>
|
<div>{message.currentSwipe + 1}/{message.swipes.length}</div>
|
||||||
<div onClick={handleSwipeRight}>▶</div>
|
<div onClick={handleSwipeRight}>▶</div>
|
||||||
</div>
|
</div>
|
||||||
}
|
<button class='icon' onClick={handleContinueMessage} title="Continue">▶</button>
|
||||||
|
</>}
|
||||||
<button class='icon' onClick={handleEnableEdit} title="Edit">🖊</button>
|
<button class='icon' onClick={handleEnableEdit} title="Edit">🖊</button>
|
||||||
</>
|
</>
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import styles from './minichat.module.css';
|
||||||
import { LLMContext } from "../../contexts/llm";
|
import { LLMContext } from "../../contexts/llm";
|
||||||
import { FormattedMessage } from "../message/formattedMessage";
|
import { FormattedMessage } from "../message/formattedMessage";
|
||||||
import { AutoTextarea } from "../autoTextarea";
|
import { AutoTextarea } from "../autoTextarea";
|
||||||
|
import { useBool } from "@common/hooks/useBool";
|
||||||
|
|
||||||
interface IProps {
|
interface IProps {
|
||||||
open: boolean;
|
open: boolean;
|
||||||
|
|
@ -16,9 +17,10 @@ interface IProps {
|
||||||
}
|
}
|
||||||
|
|
||||||
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => {
|
||||||
const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext);
|
const { stopGeneration, generate, compilePrompt } = useContext(LLMContext);
|
||||||
const [messages, setMessages] = useState<IMessage[]>([]);
|
const [messages, setMessages] = useState<IMessage[]>([]);
|
||||||
const ref = useRef<HTMLDivElement>(null);
|
const ref = useRef<HTMLDivElement>(null);
|
||||||
|
const generating = useBool();
|
||||||
|
|
||||||
const answer = useMemo(() =>
|
const answer = useMemo(() =>
|
||||||
MessageTools.getSwipe(messages.filter(m => m.role === 'assistant').at(-1))?.content,
|
MessageTools.getSwipe(messages.filter(m => m.role === 'assistant').at(-1))?.content,
|
||||||
|
|
@ -33,7 +35,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setTimeout(() => DOMTools.scrollDown(ref.current, false), 100);
|
setTimeout(() => DOMTools.scrollDown(ref.current, false), 100);
|
||||||
}, [generating, open]);
|
}, [generating.value, open]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
DOMTools.scrollDown(ref.current, false);
|
DOMTools.scrollDown(ref.current, false);
|
||||||
|
|
@ -47,19 +49,21 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
}, [messages.length, handleInit]);
|
}, [messages.length, handleInit]);
|
||||||
|
|
||||||
const handleGenerate = useCallback(async () => {
|
const handleGenerate = useCallback(async () => {
|
||||||
if (messages.length > 0 && !generating) {
|
if (messages.length > 0 && !generating.value) {
|
||||||
const promptMessages: IMessage[] = [...history, ...messages];
|
const promptMessages: IMessage[] = [...history, ...messages];
|
||||||
const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1 });
|
const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1, continueLast: true });
|
||||||
|
|
||||||
let text = '';
|
let text = '';
|
||||||
const messageId = messages.length;
|
const messageId = messages.length;
|
||||||
const newMessages = [...messages, MessageTools.create('', 'assistant', true)];
|
const newMessages = [...messages, MessageTools.create('', 'assistant', true)];
|
||||||
setMessages(newMessages);
|
setMessages(newMessages);
|
||||||
|
|
||||||
|
generating.setTrue();
|
||||||
for await (const chunk of generate(prompt)) {
|
for await (const chunk of generate(prompt)) {
|
||||||
text += chunk;
|
text += chunk;
|
||||||
setMessages(MessageTools.updateSwipe(newMessages, messageId, { content: text.trim() }));
|
setMessages(MessageTools.updateSwipe(newMessages, messageId, { content: text.trim() }));
|
||||||
}
|
}
|
||||||
|
generating.setFalse();
|
||||||
|
|
||||||
setMessages([
|
setMessages([
|
||||||
...MessageTools.updateSwipe(newMessages, messageId, { content: MessageTools.trimSentence(text) }),
|
...MessageTools.updateSwipe(newMessages, messageId, { content: MessageTools.trimSentence(text) }),
|
||||||
|
|
@ -90,7 +94,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
<div class={styles.minichat} ref={ref}>
|
<div class={styles.minichat} ref={ref}>
|
||||||
<div class={styles.messages}>
|
<div class={styles.messages}>
|
||||||
{messages.map((m, i) => (
|
{messages.map((m, i) => (
|
||||||
generating
|
generating.value
|
||||||
? <FormattedMessage key={i} class={`${styles[m.role]} ${styles.message}`}>
|
? <FormattedMessage key={i} class={`${styles[m.role]} ${styles.message}`}>
|
||||||
{MessageTools.getSwipe(m)?.content ?? ''}
|
{MessageTools.getSwipe(m)?.content ?? ''}
|
||||||
</FormattedMessage>
|
</FormattedMessage>
|
||||||
|
|
@ -105,18 +109,18 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
{generating
|
{generating.value
|
||||||
? <button onClick={stopGeneration}>Stop</button>
|
? <button onClick={stopGeneration}>Stop</button>
|
||||||
: <button onClick={handleGenerate}>Generate</button>
|
: <button onClick={handleGenerate}>Generate</button>
|
||||||
}
|
}
|
||||||
<button onClick={() => handleInit()} class={`${generating ? 'disabled' : ''}`}>
|
<button onClick={() => handleInit()} class={`${generating.value ? 'disabled' : ''}`}>
|
||||||
Clear
|
Clear
|
||||||
</button>
|
</button>
|
||||||
{Object.entries(buttons).map(([label, onClick], i) => (
|
{Object.entries(buttons).map(([label, onClick], i) => (
|
||||||
<button
|
<button
|
||||||
key={i}
|
key={i}
|
||||||
onClick={() => onClick(answer ?? '')}
|
onClick={() => onClick(answer ?? '')}
|
||||||
class={`${(generating || !answer) ? 'disabled' : ''}`}
|
class={`${(generating.value || !answer) ? 'disabled' : ''}`}
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
</button>
|
</button>
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import Lock from "@common/lock";
|
import Lock from "@common/lock";
|
||||||
import SSE from "@common/sse";
|
import SSE from "@common/sse";
|
||||||
import { throttle } from "@common/utils";
|
import { throttle } from "@common/utils";
|
||||||
import delay, { clearDelay } from "delay";
|
import delay from "delay";
|
||||||
|
|
||||||
interface IBaseConnection {
|
interface IBaseConnection {
|
||||||
instruct: string;
|
instruct: string;
|
||||||
|
|
@ -105,7 +105,7 @@ export const normalizeModel = (model: string) => {
|
||||||
.trim();
|
.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length;
|
export const approximateTokens = (prompt: string): number => Math.round(prompt.length / 4);
|
||||||
|
|
||||||
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
export type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import { useAsyncEffect } from "@common/hooks/useAsyncEffect";
|
||||||
|
|
||||||
interface ICompileArgs {
|
interface ICompileArgs {
|
||||||
keepUsers?: number;
|
keepUsers?: number;
|
||||||
raw?: boolean;
|
continueLast?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ICompiledPrompt {
|
interface ICompiledPrompt {
|
||||||
|
|
@ -48,8 +48,8 @@ const processing = {
|
||||||
|
|
||||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const {
|
const {
|
||||||
connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled,
|
||||||
setTriggerNext, addMessage, editMessage, editSummary,
|
setTriggerNext, setContinueLast, addMessage, editMessage, editSummary,
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
|
|
||||||
const generating = useBool(false);
|
const generating = useBool(false);
|
||||||
|
|
@ -69,13 +69,22 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}, [userPrompt]);
|
}, [userPrompt]);
|
||||||
|
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
compilePrompt: async (messages, { keepUsers } = {}) => {
|
compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => {
|
||||||
const promptMessages = messages.slice();
|
const lastMessage = messages.at(-1);
|
||||||
const lastMessage = promptMessages.at(-1);
|
const lastMessageContent = MessageTools.getSwipe(lastMessage)?.content;
|
||||||
const isAssistantLast = lastMessage?.role === 'assistant';
|
const isAssistantLast = lastMessage?.role === 'assistant';
|
||||||
const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content;
|
let isRegen = continueLast;
|
||||||
|
|
||||||
|
if (!isAssistantLast) {
|
||||||
|
isRegen = false;
|
||||||
|
} else if (!lastMessageContent) {
|
||||||
|
isRegen = true;
|
||||||
|
}
|
||||||
|
|
||||||
const isContinue = isAssistantLast && !isRegen;
|
const isContinue = isAssistantLast && !isRegen;
|
||||||
|
|
||||||
|
const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice();
|
||||||
|
|
||||||
if (isContinue) {
|
if (isContinue) {
|
||||||
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
|
promptMessages.push(MessageTools.create(userPromptTemplate.render({})));
|
||||||
}
|
}
|
||||||
|
|
@ -153,7 +162,12 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`;
|
||||||
|
|
||||||
const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
|
let prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages);
|
||||||
|
|
||||||
|
if (isRegen) {
|
||||||
|
prompt += lastMessageContent;
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
prompt,
|
prompt,
|
||||||
isContinue,
|
isContinue,
|
||||||
|
|
@ -194,20 +208,23 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
return await Connection.countTokens(connection, prompt);
|
return await Connection.countTokens(connection, prompt);
|
||||||
},
|
},
|
||||||
stopGeneration: () => {
|
stopGeneration: () => {
|
||||||
Connection.stopGeneration();
|
Connection.stopGeneration();
|
||||||
},
|
},
|
||||||
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
}), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]);
|
||||||
|
|
||||||
useAsyncEffect(async () => {
|
useAsyncEffect(async () => {
|
||||||
if (triggerNext && !generating.value) {
|
if (triggerNext && !generating.value) {
|
||||||
setTriggerNext(false);
|
setTriggerNext(false);
|
||||||
|
setContinueLast(false);
|
||||||
|
|
||||||
let messageId = messages.length - 1;
|
let messageId = messages.length - 1;
|
||||||
let text: string = '';
|
let text = '';
|
||||||
|
|
||||||
const { prompt, isRegen } = await actions.compilePrompt(messages);
|
const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast });
|
||||||
|
|
||||||
if (!isRegen) {
|
if (isRegen) {
|
||||||
|
text = MessageTools.getSwipe(messages.at(-1))?.content ?? '';
|
||||||
|
} else {
|
||||||
addMessage('', 'assistant');
|
addMessage('', 'assistant');
|
||||||
messageId++;
|
messageId++;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ interface IContext {
|
||||||
summaryEnabled: boolean;
|
summaryEnabled: boolean;
|
||||||
bannedWords: string[];
|
bannedWords: string[];
|
||||||
messages: IMessage[];
|
messages: IMessage[];
|
||||||
|
//
|
||||||
triggerNext: boolean;
|
triggerNext: boolean;
|
||||||
|
continueLast: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface IComputableContext {
|
interface IComputableContext {
|
||||||
|
|
@ -33,9 +35,11 @@ interface IActions {
|
||||||
setUserPrompt: (prompt: string | Event) => void;
|
setUserPrompt: (prompt: string | Event) => void;
|
||||||
setSummarizePrompt: (prompt: string | Event) => void;
|
setSummarizePrompt: (prompt: string | Event) => void;
|
||||||
setBannedWords: (words: string[]) => void;
|
setBannedWords: (words: string[]) => void;
|
||||||
setTriggerNext: (triggerNext: boolean) => void;
|
|
||||||
setSummaryEnabled: (summaryEnabled: boolean) => void;
|
setSummaryEnabled: (summaryEnabled: boolean) => void;
|
||||||
|
|
||||||
|
setTriggerNext: (triggerNext: boolean) => void;
|
||||||
|
setContinueLast: (continueLast: boolean) => void;
|
||||||
|
|
||||||
setMessages: (messages: IMessage[]) => void;
|
setMessages: (messages: IMessage[]) => void;
|
||||||
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
|
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
|
||||||
editMessage: (index: number, content: string) => void;
|
editMessage: (index: number, content: string) => void;
|
||||||
|
|
@ -44,7 +48,7 @@ interface IActions {
|
||||||
setCurrentSwipe: (index: number, swipe: number) => void;
|
setCurrentSwipe: (index: number, swipe: number) => void;
|
||||||
addSwipe: (index: number, content: string) => void;
|
addSwipe: (index: number, content: string) => void;
|
||||||
|
|
||||||
continueMessage: () => void;
|
continueMessage: (continueLast?: boolean) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const SAVE_KEY = 'ai_game_save_state';
|
const SAVE_KEY = 'ai_game_save_state';
|
||||||
|
|
@ -88,11 +92,13 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers
|
||||||
bannedWords: [],
|
bannedWords: [],
|
||||||
messages: [],
|
messages: [],
|
||||||
triggerNext: false,
|
triggerNext: false,
|
||||||
|
continueLast: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
export const saveContext = (context: IContext) => {
|
export const saveContext = (context: IContext) => {
|
||||||
const contextToSave: Partial<IContext> = { ...context };
|
const contextToSave: Partial<IContext> = { ...context };
|
||||||
delete contextToSave.triggerNext;
|
delete contextToSave.triggerNext;
|
||||||
|
delete contextToSave.continueLast;
|
||||||
|
|
||||||
localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
|
localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave));
|
||||||
}
|
}
|
||||||
|
|
@ -130,6 +136,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
|
const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0];
|
||||||
|
|
||||||
const [triggerNext, setTriggerNext] = useState(false);
|
const [triggerNext, setTriggerNext] = useState(false);
|
||||||
|
const [continueLast, setContinueLast] = useState(false);
|
||||||
const [instruct, setInstruct] = useInputState(connection.instruct);
|
const [instruct, setInstruct] = useInputState(connection.instruct);
|
||||||
|
|
||||||
const setConnection = useCallback((c: IConnection) => {
|
const setConnection = useCallback((c: IConnection) => {
|
||||||
|
|
@ -153,8 +160,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
setUserPrompt,
|
setUserPrompt,
|
||||||
setSummarizePrompt,
|
setSummarizePrompt,
|
||||||
setLore,
|
setLore,
|
||||||
setTriggerNext,
|
|
||||||
setSummaryEnabled,
|
setSummaryEnabled,
|
||||||
|
|
||||||
|
setTriggerNext,
|
||||||
|
setContinueLast,
|
||||||
|
|
||||||
setBannedWords: (words) => setBannedWords(words.slice()),
|
setBannedWords: (words) => setBannedWords(words.slice()),
|
||||||
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
|
setAvailableConnections: (connections) => setAvailableConnections(connections.slice()),
|
||||||
|
|
||||||
|
|
@ -224,7 +234,10 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
continueMessage: () => setTriggerNext(true),
|
continueMessage: (c = false) => {
|
||||||
|
setTriggerNext(true);
|
||||||
|
setContinueLast(c);
|
||||||
|
},
|
||||||
}), []);
|
}), []);
|
||||||
|
|
||||||
const rawContext: IContext & IComputableContext = {
|
const rawContext: IContext & IComputableContext = {
|
||||||
|
|
@ -239,7 +252,9 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
summaryEnabled,
|
summaryEnabled,
|
||||||
bannedWords,
|
bannedWords,
|
||||||
messages,
|
messages,
|
||||||
|
//
|
||||||
triggerNext,
|
triggerNext,
|
||||||
|
continueLast,
|
||||||
};
|
};
|
||||||
|
|
||||||
const context = useMemo(() => rawContext, Object.values(rawContext));
|
const context = useMemo(() => rawContext, Object.values(rawContext));
|
||||||
|
|
|
||||||
|
|
@ -249,7 +249,6 @@ export namespace Huggingface {
|
||||||
|
|
||||||
if (config.bos_token) {
|
if (config.bos_token) {
|
||||||
template = template
|
template = template
|
||||||
.replaceAll(config.bos_token, '')
|
|
||||||
.replace(/\{\{ ?(''|"") ?\}\}/g, '');
|
.replace(/\{\{ ?(''|"") ?\}\}/g, '');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue