From a213e0407c8f09b28f14ddd3fbf66a3539632f81 Mon Sep 17 00:00:00 2001 From: Pabloader Date: Thu, 14 Nov 2024 14:24:21 +0000 Subject: [PATCH] AIStory: continue button --- .../ai-story/components/message/message.tsx | 11 +++-- .../ai-story/components/minichat/minichat.tsx | 20 +++++---- src/games/ai-story/connection.ts | 4 +- src/games/ai-story/contexts/llm.tsx | 41 +++++++++++++------ src/games/ai-story/contexts/state.tsx | 23 +++++++++-- src/games/ai-story/huggingface.ts | 1 - 6 files changed, 70 insertions(+), 30 deletions(-) diff --git a/src/games/ai-story/components/message/message.tsx b/src/games/ai-story/components/message/message.tsx index 2240a77..e8fd843 100644 --- a/src/games/ai-story/components/message/message.tsx +++ b/src/games/ai-story/components/message/message.tsx @@ -16,7 +16,7 @@ interface IProps { } export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScroll }: IProps) => { - const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages } = useContext(StateContext); + const { messages, editMessage, editSummary, deleteMessage, setCurrentSwipe, setMessages, continueMessage } = useContext(StateContext); const [editing, setEditing] = useState(false); const [editedMessage, setEditedMessage] = useInputState(''); const textRef = useRef(null); @@ -70,6 +70,10 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr DOMTools.animate(textRef.current, 'swipe-from-right'); }, [setCurrentSwipe, index, message]); + const handleContinueMessage = useCallback(() => { + continueMessage(true); + }, [continueMessage]); + return (
@@ -89,13 +93,14 @@ export const Message = ({ message, index, isLastUser, isLastAssistant, onNeedScr : <> - {isLastAssistant && + {isLastAssistant && <>
{message.currentSwipe + 1}/{message.swipes.length}
- } + + } } diff --git a/src/games/ai-story/components/minichat/minichat.tsx b/src/games/ai-story/components/minichat/minichat.tsx index 7e5b59f..9e63e9c 100644 --- a/src/games/ai-story/components/minichat/minichat.tsx +++ b/src/games/ai-story/components/minichat/minichat.tsx @@ -7,6 +7,7 @@ import styles from './minichat.module.css'; import { LLMContext } from "../../contexts/llm"; import { FormattedMessage } from "../message/formattedMessage"; import { AutoTextarea } from "../autoTextarea"; +import { useBool } from "@common/hooks/useBool"; interface IProps { open: boolean; @@ -16,9 +17,10 @@ interface IProps { } export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => { - const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext); + const { stopGeneration, generate, compilePrompt } = useContext(LLMContext); const [messages, setMessages] = useState([]); const ref = useRef(null); + const generating = useBool(); const answer = useMemo(() => MessageTools.getSwipe(messages.filter(m => m.role === 'assistant').at(-1))?.content, @@ -33,7 +35,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) useEffect(() => { setTimeout(() => DOMTools.scrollDown(ref.current, false), 100); - }, [generating, open]); + }, [generating.value, open]); useEffect(() => { DOMTools.scrollDown(ref.current, false); @@ -47,19 +49,21 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) }, [messages.length, handleInit]); const handleGenerate = useCallback(async () => { - if (messages.length > 0 && !generating) { + if (messages.length > 0 && !generating.value) { const promptMessages: IMessage[] = [...history, ...messages]; - const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1 }); + const { prompt } = await compilePrompt(promptMessages, { keepUsers: messages.length + 1, continueLast: true }); let text = ''; const messageId = messages.length; const newMessages = [...messages, MessageTools.create('', 'assistant', true)]; setMessages(newMessages); + generating.setTrue(); for await (const chunk of generate(prompt)) { text += chunk; setMessages(MessageTools.updateSwipe(newMessages, messageId, { content: text.trim() })); } + generating.setFalse(); setMessages([ ...MessageTools.updateSwipe(newMessages, messageId, { content: MessageTools.trimSentence(text) }), @@ -90,7 +94,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
{messages.map((m, i) => ( - generating + generating.value ? {MessageTools.getSwipe(m)?.content ?? ''} @@ -105,18 +109,18 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
- {generating + {generating.value ? : } - {Object.entries(buttons).map(([label, onClick], i) => ( diff --git a/src/games/ai-story/connection.ts b/src/games/ai-story/connection.ts index dd111c8..4892e05 100644 --- a/src/games/ai-story/connection.ts +++ b/src/games/ai-story/connection.ts @@ -1,7 +1,7 @@ import Lock from "@common/lock"; import SSE from "@common/sse"; import { throttle } from "@common/utils"; -import delay, { clearDelay } from "delay"; +import delay from "delay"; interface IBaseConnection { instruct: string; @@ -105,7 +105,7 @@ export const normalizeModel = (model: string) => { .trim(); } -export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length; +export const approximateTokens = (prompt: string): number => Math.round(prompt.length / 4); export type IGenerationSettings = Partial; diff --git a/src/games/ai-story/contexts/llm.tsx b/src/games/ai-story/contexts/llm.tsx index e9f85a9..d250b0b 100644 --- a/src/games/ai-story/contexts/llm.tsx +++ b/src/games/ai-story/contexts/llm.tsx @@ -11,7 +11,7 @@ import { useAsyncEffect } from "@common/hooks/useAsyncEffect"; interface ICompileArgs { keepUsers?: number; - raw?: boolean; + continueLast?: boolean; } interface ICompiledPrompt { @@ -48,8 +48,8 @@ const processing = { export const LLMContextProvider = ({ children }: { children?: any }) => { const { - connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, - setTriggerNext, addMessage, editMessage, editSummary, + connection, messages, triggerNext, continueLast, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, + setTriggerNext, setContinueLast, addMessage, editMessage, editSummary, } = useContext(StateContext); const generating = useBool(false); @@ -69,13 +69,22 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }, [userPrompt]); const actions: IActions = useMemo(() => ({ - compilePrompt: async (messages, { keepUsers } = {}) => { - const promptMessages = messages.slice(); - const lastMessage = promptMessages.at(-1); + compilePrompt: async (messages, { keepUsers, continueLast = false } = {}) => { + const lastMessage = messages.at(-1); + const lastMessageContent = MessageTools.getSwipe(lastMessage)?.content; const isAssistantLast = lastMessage?.role === 'assistant'; - const isRegen = isAssistantLast && !MessageTools.getSwipe(lastMessage)?.content; + let isRegen = continueLast; + + if (!isAssistantLast) { + isRegen = false; + } else if (!lastMessageContent) { + isRegen = true; + } + const isContinue = isAssistantLast && !isRegen; + const promptMessages = continueLast ? messages.slice(0, -1) : messages.slice(); + if (isContinue) { promptMessages.push(MessageTools.create(userPromptTemplate.render({}))); } @@ -153,7 +162,12 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; - const prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages); + let prompt = Huggingface.applyChatTemplate(connection.instruct, templateMessages); + + if (isRegen) { + prompt += lastMessageContent; + } + return { prompt, isContinue, @@ -194,20 +208,23 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { return await Connection.countTokens(connection, prompt); }, stopGeneration: () => { - Connection.stopGeneration(); + Connection.stopGeneration(); }, }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]); useAsyncEffect(async () => { if (triggerNext && !generating.value) { setTriggerNext(false); + setContinueLast(false); let messageId = messages.length - 1; - let text: string = ''; + let text = ''; - const { prompt, isRegen } = await actions.compilePrompt(messages); + const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast }); - if (!isRegen) { + if (isRegen) { + text = MessageTools.getSwipe(messages.at(-1))?.content ?? ''; + } else { addMessage('', 'assistant'); messageId++; } diff --git a/src/games/ai-story/contexts/state.tsx b/src/games/ai-story/contexts/state.tsx index 09f1ce7..63f2936 100644 --- a/src/games/ai-story/contexts/state.tsx +++ b/src/games/ai-story/contexts/state.tsx @@ -15,7 +15,9 @@ interface IContext { summaryEnabled: boolean; bannedWords: string[]; messages: IMessage[]; + // triggerNext: boolean; + continueLast: boolean; } interface IComputableContext { @@ -33,9 +35,11 @@ interface IActions { setUserPrompt: (prompt: string | Event) => void; setSummarizePrompt: (prompt: string | Event) => void; setBannedWords: (words: string[]) => void; - setTriggerNext: (triggerNext: boolean) => void; setSummaryEnabled: (summaryEnabled: boolean) => void; + setTriggerNext: (triggerNext: boolean) => void; + setContinueLast: (continueLast: boolean) => void; + setMessages: (messages: IMessage[]) => void; addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void; editMessage: (index: number, content: string) => void; @@ -44,7 +48,7 @@ interface IActions { setCurrentSwipe: (index: number, swipe: number) => void; addSwipe: (index: number, content: string) => void; - continueMessage: () => void; + continueMessage: (continueLast?: boolean) => void; } const SAVE_KEY = 'ai_game_save_state'; @@ -88,11 +92,13 @@ Make sure to follow the world description and rules exactly. Avoid cliffhangers bannedWords: [], messages: [], triggerNext: false, + continueLast: false, }; export const saveContext = (context: IContext) => { const contextToSave: Partial = { ...context }; delete contextToSave.triggerNext; + delete contextToSave.continueLast; localStorage.setItem(SAVE_KEY, JSON.stringify(contextToSave)); } @@ -130,6 +136,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { const connection = availableConnections[currentConnection] ?? DEFAULT_CONTEXT.availableConnections[0]; const [triggerNext, setTriggerNext] = useState(false); + const [continueLast, setContinueLast] = useState(false); const [instruct, setInstruct] = useInputState(connection.instruct); const setConnection = useCallback((c: IConnection) => { @@ -153,8 +160,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => { setUserPrompt, setSummarizePrompt, setLore, - setTriggerNext, setSummaryEnabled, + + setTriggerNext, + setContinueLast, + setBannedWords: (words) => setBannedWords(words.slice()), setAvailableConnections: (connections) => setAvailableConnections(connections.slice()), @@ -224,7 +234,10 @@ export const StateContextProvider = ({ children }: { children?: any }) => { } ) ), - continueMessage: () => setTriggerNext(true), + continueMessage: (c = false) => { + setTriggerNext(true); + setContinueLast(c); + }, }), []); const rawContext: IContext & IComputableContext = { @@ -239,7 +252,9 @@ export const StateContextProvider = ({ children }: { children?: any }) => { summaryEnabled, bannedWords, messages, + // triggerNext, + continueLast, }; const context = useMemo(() => rawContext, Object.values(rawContext)); diff --git a/src/games/ai-story/huggingface.ts b/src/games/ai-story/huggingface.ts index 630504c..5885045 100644 --- a/src/games/ai-story/huggingface.ts +++ b/src/games/ai-story/huggingface.ts @@ -249,7 +249,6 @@ export namespace Huggingface { if (config.bos_token) { template = template - .replaceAll(config.bos_token, '') .replace(/\{\{ ?(''|"") ?\}\}/g, ''); } }