1
0
Fork 0

Working tool calls

This commit is contained in:
Pabloader 2026-03-21 22:10:31 +00:00
parent f4144b70c7
commit 7f3e628954
4 changed files with 186 additions and 44 deletions

View File

@ -1,7 +1,7 @@
import { Sidebar } from "./sidebar"; import { Sidebar } from "./sidebar";
import { useAppState } from "../contexts/state"; import { useAppState, type ChatMessage } from "../contexts/state";
import styles from '../assets/chat-sidebar.module.css'; import styles from '../assets/chat-sidebar.module.css';
import { useState, useRef, useEffect } from "preact/hooks"; import { useState, useRef, useEffect, useMemo, useCallback } from "preact/hooks";
import LLM from "../utils/llm"; import LLM from "../utils/llm";
import { highlight } from "../utils/highlight"; import { highlight } from "../utils/highlight";
import Prompt from "../utils/prompt"; import Prompt from "../utils/prompt";
@ -14,7 +14,7 @@ export const ChatSidebar = () => {
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const messagesRef = useRef<HTMLDivElement>(null); const messagesRef = useRef<HTMLDivElement>(null);
const abortControllerRef = useRef<AbortController | null>(null); const abortControllerRef = useRef<AbortController>(new AbortController());
useEffect(() => { useEffect(() => {
if (messagesRef.current) { if (messagesRef.current) {
@ -31,20 +31,16 @@ export const ChatSidebar = () => {
}; };
}, []); }, []);
const sendMessage = async () => { const sendMessage = useCallback(async (newMessages: ChatMessage[]) => {
if (!currentStory || !input.trim() || !connection || !model || isLoading) return; if (!currentStory || !connection || !model) return;
const userMessage = {
id: crypto.randomUUID(),
role: 'user' as const,
content: input.trim(),
};
for (const message of newMessages) {
dispatch({ dispatch({
type: 'ADD_CHAT_MESSAGE', type: 'ADD_CHAT_MESSAGE',
storyId: currentStory.id, storyId: currentStory.id,
message: userMessage, message,
}); });
}
const assistantMessageId = crypto.randomUUID(); const assistantMessageId = crypto.randomUUID();
dispatch({ dispatch({
@ -57,11 +53,7 @@ export const ChatSidebar = () => {
}, },
}); });
setInput(''); const request = Prompt.compilePrompt(appState, newMessages);
setIsLoading(true);
setError(null);
const request = Prompt.compilePrompt(appState, userMessage);
if (!request) { if (!request) {
setError('Failed to compile prompt'); setError('Failed to compile prompt');
@ -71,11 +63,18 @@ export const ChatSidebar = () => {
try { try {
let accumulatedContent = ''; let accumulatedContent = '';
let tool_calls: LLM.ToolCall[] | undefined;
for await (const chunk of LLM.generateStream(connection, request)) { for await (const chunk of LLM.generateStream(connection, request)) {
const delta = chunk.choices[0]?.delta?.content; const delta = chunk.choices[0]?.delta;
if (delta) {
accumulatedContent += delta; if (delta?.tool_calls) {
tool_calls = delta.tool_calls;
}
const content = delta?.content;
if (content) {
accumulatedContent += content;
dispatch({ dispatch({
type: 'ADD_CHAT_MESSAGE', type: 'ADD_CHAT_MESSAGE',
storyId: currentStory.id, storyId: currentStory.id,
@ -83,6 +82,7 @@ export const ChatSidebar = () => {
id: assistantMessageId, id: assistantMessageId,
role: 'assistant', role: 'assistant',
content: accumulatedContent, content: accumulatedContent,
tool_calls,
}, },
}); });
} }
@ -90,22 +90,73 @@ export const ChatSidebar = () => {
break; break;
} }
} }
const assistantMessage: ChatMessage = {
id: assistantMessageId,
role: 'assistant',
content: accumulatedContent,
tool_calls,
};
dispatch({
type: 'ADD_CHAT_MESSAGE',
storyId: currentStory.id,
message: assistantMessage,
});
if (tool_calls) {
const toolMessages: ChatMessage[] = [];
for (const tool of tool_calls) {
if (tool.function.name === 'test') {
const message: ChatMessage = {
id: crypto.randomUUID(),
role: 'tool',
content: `Test successful, received: ${JSON.stringify(tool.function.arguments)}`,
tool_call_id: tool.id,
};
dispatch({
type: 'ADD_CHAT_MESSAGE',
storyId: currentStory.id,
message,
});
toolMessages.push(message);
}
}
return sendMessage([...newMessages, assistantMessage, ...toolMessages]);
}
} catch (err) { } catch (err) {
const errorMessage = err instanceof Error const errorMessage = err instanceof Error
? err.message ? err.message
: 'Failed to generate response'; : 'Failed to generate response';
setError(errorMessage); setError(errorMessage);
}
}, [appState, currentStory, connection, model]);
const handleSendMessage = useCallback(async () => {
if (!currentStory || !input.trim() || !connection || !model || isLoading) return;
const userMessage = {
id: crypto.randomUUID(),
role: 'user' as const,
content: input.trim(),
};
setInput('');
setIsLoading(true);
setError(null);
try {
await sendMessage([userMessage]);
} finally { } finally {
setIsLoading(false); setIsLoading(false);
abortControllerRef.current = null; abortControllerRef.current = new AbortController();
} }
}; }, [currentStory, input, connection, model, isLoading]);
const handleKeyDown = (e: KeyboardEvent) => { const handleKeyDown = (e: KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) { if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault(); e.preventDefault();
sendMessage(); handleSendMessage();
} }
}; };
@ -169,7 +220,7 @@ export const ChatSidebar = () => {
/> />
<button <button
class={styles.sendButton} class={styles.sendButton}
onClick={sendMessage} onClick={handleSendMessage}
disabled={isDisabled || !input.trim()} disabled={isDisabled || !input.trim()}
> >
{isLoading ? 'Sending...' : 'Send'} {isLoading ? 'Sending...' : 'Send'}

View File

@ -6,10 +6,8 @@ import { useStoredReducer } from "@common/hooks/useStoredState";
// ─── Types ──────────────────────────────────────────────────────────────────── // ─── Types ────────────────────────────────────────────────────────────────────
export interface ChatMessage { export type ChatMessage = LLM.ChatMessage & {
id: string; id: string;
role: 'user' | 'assistant' | 'system';
content: string;
} }
export interface Story { export interface Story {

View File

@ -1,20 +1,95 @@
import { formatError } from '@common/errors'; import { formatError } from '@common/errors';
import Lock from '@common/lock'; import SSE from '@common/sse';
import SSE, { type SSEEvent } from '@common/sse';
namespace LLM { namespace LLM {
export interface Connection { export interface Connection {
url: string; url: string;
apiKey: string; apiKey: string;
} }
export interface ChatMessage {
role: 'system' | 'user' | 'assistant'; export interface ToolCall {
id: string;
type: 'function';
function: {
name: string;
arguments: string | Record<string, any>; // JSON string
};
}
interface ChatMessageUser {
role: 'user';
content: string; content: string;
} }
interface ChatMessageAssistant {
role: 'assistant';
content: string;
tool_calls?: ToolCall[];
}
interface ChatMessageSystem {
role: 'system';
content: string;
}
interface ChatMessageTool {
role: 'tool';
content: string;
tool_call_id?: string;
}
export type ChatMessage = ChatMessageUser | ChatMessageAssistant | ChatMessageSystem | ChatMessageTool;
export interface ToolStringParameter {
type: 'string';
enum?: string[];
description?: string;
}
export interface ToolNumberParameter {
type: 'number' | 'integer';
enum?: number[];
description?: string;
}
export interface ToolBooleanParameter {
type: 'boolean';
enum?: boolean[];
description?: string;
}
export interface ToolArrayParameter {
type: 'array';
description?: string;
items: ToolParameter;
}
export interface ToolObjectParameter {
type: 'object';
description?: string;
properties: Record<string, ToolParameter>;
required?: string[];
}
export type ToolParameter = ToolStringParameter | ToolNumberParameter | ToolBooleanParameter | ToolArrayParameter | ToolObjectParameter;
export interface Tool {
type: 'function';
function: {
name: string;
description?: string;
parameters: {
type: 'object';
properties: Record<string, ToolParameter>;
required?: string[];
};
};
}
export interface ChatCompletionRequest { export interface ChatCompletionRequest {
model: string; model: string;
messages: ChatMessage[]; messages: ChatMessage[];
tools?: Tool[];
temperature?: number; temperature?: number;
max_tokens?: number; max_tokens?: number;
stop?: string | string[]; stop?: string | string[];
@ -27,7 +102,7 @@ namespace LLM {
export interface ChatCompletionChoice { export interface ChatCompletionChoice {
index: number; index: number;
message: ChatMessage; message: ChatMessage;
finish_reason: 'stop' | 'length' | 'content_filter'; finish_reason: 'stop' | 'tool_calls';
} }
export interface ChatCompletionResponse { export interface ChatCompletionResponse {
@ -46,8 +121,8 @@ namespace LLM {
export interface ChatCompletionChunkChoice { export interface ChatCompletionChunkChoice {
index: number; index: number;
delta: { role?: string; content?: string }; delta: { role?: string; content?: string; tool_calls?: ToolCall[] };
finish_reason: 'stop' | 'length' | 'content_filter' | null; finish_reason: 'stop' | 'tool_calls' | null;
} }
export interface ChatCompletionChunk { export interface ChatCompletionChunk {
@ -63,6 +138,7 @@ namespace LLM {
object: 'model'; object: 'model';
created: number; created: number;
owned_by: string; owned_by: string;
support_tools: boolean;
max_context?: number; max_context?: number;
max_length?: number; max_length?: number;
} }
@ -134,8 +210,6 @@ namespace LLM {
if (closed) return; if (closed) return;
closed = true; closed = true;
controller.close(); controller.close();
console.log(formatError(e));
}; };
sse.addEventListener('error', handleEnd); sse.addEventListener('error', handleEnd);

View File

@ -2,7 +2,27 @@ import LLM from "./llm";
import type { AppState } from "../contexts/state"; import type { AppState } from "../contexts/state";
namespace Prompt { namespace Prompt {
export function compilePrompt(state: AppState, newMessage?: LLM.ChatMessage): LLM.ChatCompletionRequest | null { const tools: LLM.Tool[] = [
{
type: 'function',
function: {
name: 'test',
description: 'A simple test function',
parameters: {
type: 'object',
properties: {
message: {
type: 'string',
description: 'The test message',
},
},
required: ['message'],
},
},
},
];
export function compilePrompt(state: AppState, newMessages: LLM.ChatMessage[] = []): LLM.ChatCompletionRequest | null {
const { currentStory, model } = state; const { currentStory, model } = state;
if (!currentStory || !model) { if (!currentStory || !model) {
@ -15,13 +35,12 @@ namespace Prompt {
...currentStory.chatMessages, ...currentStory.chatMessages,
]; ];
if (newMessage) { messages.push(...newMessages);
messages.push(newMessage);
}
return { return {
model, model,
messages, messages,
tools,
// TODO banned_tokens // TODO banned_tokens
}; };
} }