diff --git a/src/games/storywriter/components/chat-sidebar.tsx b/src/games/storywriter/components/chat-sidebar.tsx index 54835fe..0eed2c0 100644 --- a/src/games/storywriter/components/chat-sidebar.tsx +++ b/src/games/storywriter/components/chat-sidebar.tsx @@ -1,7 +1,7 @@ import { Sidebar } from "./sidebar"; -import { useAppState } from "../contexts/state"; +import { useAppState, type ChatMessage } from "../contexts/state"; import styles from '../assets/chat-sidebar.module.css'; -import { useState, useRef, useEffect } from "preact/hooks"; +import { useState, useRef, useEffect, useMemo, useCallback } from "preact/hooks"; import LLM from "../utils/llm"; import { highlight } from "../utils/highlight"; import Prompt from "../utils/prompt"; @@ -14,7 +14,7 @@ export const ChatSidebar = () => { const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); const messagesRef = useRef(null); - const abortControllerRef = useRef(null); + const abortControllerRef = useRef(new AbortController()); useEffect(() => { if (messagesRef.current) { @@ -31,20 +31,16 @@ export const ChatSidebar = () => { }; }, []); - const sendMessage = async () => { - if (!currentStory || !input.trim() || !connection || !model || isLoading) return; + const sendMessage = useCallback(async (newMessages: ChatMessage[]) => { + if (!currentStory || !connection || !model) return; - const userMessage = { - id: crypto.randomUUID(), - role: 'user' as const, - content: input.trim(), - }; - - dispatch({ - type: 'ADD_CHAT_MESSAGE', - storyId: currentStory.id, - message: userMessage, - }); + for (const message of newMessages) { + dispatch({ + type: 'ADD_CHAT_MESSAGE', + storyId: currentStory.id, + message, + }); + } const assistantMessageId = crypto.randomUUID(); dispatch({ @@ -57,11 +53,7 @@ export const ChatSidebar = () => { }, }); - setInput(''); - setIsLoading(true); - setError(null); - - const request = Prompt.compilePrompt(appState, userMessage); + const request = Prompt.compilePrompt(appState, newMessages); if (!request) { setError('Failed to compile prompt'); @@ -71,11 +63,18 @@ export const ChatSidebar = () => { try { let accumulatedContent = ''; + let tool_calls: LLM.ToolCall[] | undefined; for await (const chunk of LLM.generateStream(connection, request)) { - const delta = chunk.choices[0]?.delta?.content; - if (delta) { - accumulatedContent += delta; + const delta = chunk.choices[0]?.delta; + + if (delta?.tool_calls) { + tool_calls = delta.tool_calls; + } + + const content = delta?.content; + if (content) { + accumulatedContent += content; dispatch({ type: 'ADD_CHAT_MESSAGE', storyId: currentStory.id, @@ -83,6 +82,7 @@ export const ChatSidebar = () => { id: assistantMessageId, role: 'assistant', content: accumulatedContent, + tool_calls, }, }); } @@ -90,22 +90,73 @@ export const ChatSidebar = () => { break; } } + const assistantMessage: ChatMessage = { + id: assistantMessageId, + role: 'assistant', + content: accumulatedContent, + tool_calls, + }; + dispatch({ + type: 'ADD_CHAT_MESSAGE', + storyId: currentStory.id, + message: assistantMessage, + }); + + if (tool_calls) { + const toolMessages: ChatMessage[] = []; + for (const tool of tool_calls) { + if (tool.function.name === 'test') { + const message: ChatMessage = { + id: crypto.randomUUID(), + role: 'tool', + content: `Test successful, received: ${JSON.stringify(tool.function.arguments)}`, + tool_call_id: tool.id, + }; + dispatch({ + type: 'ADD_CHAT_MESSAGE', + storyId: currentStory.id, + message, + }); + toolMessages.push(message); + } + } + + return sendMessage([...newMessages, assistantMessage, ...toolMessages]); + } } catch (err) { const errorMessage = err instanceof Error ? err.message : 'Failed to generate response'; setError(errorMessage); + } + }, [appState, currentStory, connection, model]); + + const handleSendMessage = useCallback(async () => { + if (!currentStory || !input.trim() || !connection || !model || isLoading) return; + + const userMessage = { + id: crypto.randomUUID(), + role: 'user' as const, + content: input.trim(), + }; + + setInput(''); + setIsLoading(true); + setError(null); + + try { + await sendMessage([userMessage]); } finally { setIsLoading(false); - abortControllerRef.current = null; + abortControllerRef.current = new AbortController(); } - }; + }, [currentStory, input, connection, model, isLoading]); const handleKeyDown = (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); - sendMessage(); + handleSendMessage(); } }; @@ -169,7 +220,7 @@ export const ChatSidebar = () => { />