diff --git a/src/games/storywriter/components/chat-sidebar.tsx b/src/games/storywriter/components/chat-sidebar.tsx index 0eed2c0..b651623 100644 --- a/src/games/storywriter/components/chat-sidebar.tsx +++ b/src/games/storywriter/components/chat-sidebar.tsx @@ -5,6 +5,7 @@ import { useState, useRef, useEffect, useMemo, useCallback } from "preact/hooks" import LLM from "../utils/llm"; import { highlight } from "../utils/highlight"; import Prompt from "../utils/prompt"; +import { Tools } from "../utils/tools"; import clsx from "clsx"; export const ChatSidebar = () => { @@ -105,20 +106,19 @@ export const ChatSidebar = () => { if (tool_calls) { const toolMessages: ChatMessage[] = []; for (const tool of tool_calls) { - if (tool.function.name === 'test') { - const message: ChatMessage = { - id: crypto.randomUUID(), - role: 'tool', - content: `Test successful, received: ${JSON.stringify(tool.function.arguments)}`, - tool_call_id: tool.id, - }; - dispatch({ - type: 'ADD_CHAT_MESSAGE', - storyId: currentStory.id, - message, - }); - toolMessages.push(message); - } + const content = await Tools.executeTool(appState, tool); + const message: ChatMessage = { + id: crypto.randomUUID(), + role: 'tool', + content, + tool_call_id: tool.id, + }; + dispatch({ + type: 'ADD_CHAT_MESSAGE', + storyId: currentStory.id, + message, + }); + toolMessages.push(message); } return sendMessage([...newMessages, assistantMessage, ...toolMessages]); diff --git a/src/games/storywriter/utils/prompt.ts b/src/games/storywriter/utils/prompt.ts index d53f64e..9ca8922 100644 --- a/src/games/storywriter/utils/prompt.ts +++ b/src/games/storywriter/utils/prompt.ts @@ -1,27 +1,8 @@ import LLM from "./llm"; import type { AppState } from "../contexts/state"; +import { Tools } from "./tools"; namespace Prompt { - const tools: LLM.Tool[] = [ - { - type: 'function', - function: { - name: 'test', - description: 'A simple test function', - parameters: { - type: 'object', - properties: { - message: { - type: 'string', - description: 'The test message', - }, - }, - required: ['message'], - }, - }, - }, - ]; - export function compilePrompt(state: AppState, newMessages: LLM.ChatMessage[] = []): LLM.ChatCompletionRequest | null { const { currentStory, model } = state; @@ -40,7 +21,7 @@ namespace Prompt { return { model, messages, - tools, + tools: Tools.getTools(), // TODO banned_tokens }; } diff --git a/src/games/storywriter/utils/tools.ts b/src/games/storywriter/utils/tools.ts new file mode 100644 index 0000000..b75b645 --- /dev/null +++ b/src/games/storywriter/utils/tools.ts @@ -0,0 +1,65 @@ +import { formatError } from "@common/errors"; +import type { AppState } from "../contexts/state"; +import LLM from "./llm"; + +export namespace Tools { + interface Tool { + description: string; + parameters: LLM.ToolObjectParameter; + handler(args: string | Record, appState: AppState): unknown; + } + + const TOOLS: Record = { + 'test': { + handler: async (args) => ( + `Test successful, received: ${JSON.stringify(args)}` + ), + description: 'A simple test function', + parameters: { + type: 'object', + properties: { + message: { + type: 'string', + description: 'The test message', + }, + }, + required: ['message'], + }, + } + }; + + export function getTools(): LLM.Tool[] { + return Object.entries(TOOLS).map(([key, tool]) => { + return { + type: 'function', + function: { + name: key, + description: tool.description, + parameters: tool.parameters, + }, + }; + }); + } + + export async function executeTool(appState: AppState, toolCall: LLM.ToolCall): Promise { + const { function: fn } = toolCall; + let args = fn.arguments; + try { + if (typeof fn.arguments === 'string') { + args = JSON.parse(fn.arguments); + } + } catch { } + + const handler = TOOLS[fn.name]?.handler; + if (!handler) { + return `Unknown tool: ${fn.name}`; + } + + try { + const result = await handler(args, appState); + return JSON.stringify(result); + } catch (err) { + return formatError(err, 'Error executing tool'); + } + } +}