Working tool calls
This commit is contained in:
parent
f4144b70c7
commit
7f3e628954
|
|
@ -1,7 +1,7 @@
|
|||
import { Sidebar } from "./sidebar";
|
||||
import { useAppState } from "../contexts/state";
|
||||
import { useAppState, type ChatMessage } from "../contexts/state";
|
||||
import styles from '../assets/chat-sidebar.module.css';
|
||||
import { useState, useRef, useEffect } from "preact/hooks";
|
||||
import { useState, useRef, useEffect, useMemo, useCallback } from "preact/hooks";
|
||||
import LLM from "../utils/llm";
|
||||
import { highlight } from "../utils/highlight";
|
||||
import Prompt from "../utils/prompt";
|
||||
|
|
@ -14,7 +14,7 @@ export const ChatSidebar = () => {
|
|||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const messagesRef = useRef<HTMLDivElement>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const abortControllerRef = useRef<AbortController>(new AbortController());
|
||||
|
||||
useEffect(() => {
|
||||
if (messagesRef.current) {
|
||||
|
|
@ -31,20 +31,16 @@ export const ChatSidebar = () => {
|
|||
};
|
||||
}, []);
|
||||
|
||||
const sendMessage = async () => {
|
||||
if (!currentStory || !input.trim() || !connection || !model || isLoading) return;
|
||||
const sendMessage = useCallback(async (newMessages: ChatMessage[]) => {
|
||||
if (!currentStory || !connection || !model) return;
|
||||
|
||||
const userMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
role: 'user' as const,
|
||||
content: input.trim(),
|
||||
};
|
||||
|
||||
dispatch({
|
||||
type: 'ADD_CHAT_MESSAGE',
|
||||
storyId: currentStory.id,
|
||||
message: userMessage,
|
||||
});
|
||||
for (const message of newMessages) {
|
||||
dispatch({
|
||||
type: 'ADD_CHAT_MESSAGE',
|
||||
storyId: currentStory.id,
|
||||
message,
|
||||
});
|
||||
}
|
||||
|
||||
const assistantMessageId = crypto.randomUUID();
|
||||
dispatch({
|
||||
|
|
@ -57,11 +53,7 @@ export const ChatSidebar = () => {
|
|||
},
|
||||
});
|
||||
|
||||
setInput('');
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
const request = Prompt.compilePrompt(appState, userMessage);
|
||||
const request = Prompt.compilePrompt(appState, newMessages);
|
||||
|
||||
if (!request) {
|
||||
setError('Failed to compile prompt');
|
||||
|
|
@ -71,11 +63,18 @@ export const ChatSidebar = () => {
|
|||
|
||||
try {
|
||||
let accumulatedContent = '';
|
||||
let tool_calls: LLM.ToolCall[] | undefined;
|
||||
|
||||
for await (const chunk of LLM.generateStream(connection, request)) {
|
||||
const delta = chunk.choices[0]?.delta?.content;
|
||||
if (delta) {
|
||||
accumulatedContent += delta;
|
||||
const delta = chunk.choices[0]?.delta;
|
||||
|
||||
if (delta?.tool_calls) {
|
||||
tool_calls = delta.tool_calls;
|
||||
}
|
||||
|
||||
const content = delta?.content;
|
||||
if (content) {
|
||||
accumulatedContent += content;
|
||||
dispatch({
|
||||
type: 'ADD_CHAT_MESSAGE',
|
||||
storyId: currentStory.id,
|
||||
|
|
@ -83,6 +82,7 @@ export const ChatSidebar = () => {
|
|||
id: assistantMessageId,
|
||||
role: 'assistant',
|
||||
content: accumulatedContent,
|
||||
tool_calls,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -90,22 +90,73 @@ export const ChatSidebar = () => {
|
|||
break;
|
||||
}
|
||||
}
|
||||
const assistantMessage: ChatMessage = {
|
||||
id: assistantMessageId,
|
||||
role: 'assistant',
|
||||
content: accumulatedContent,
|
||||
tool_calls,
|
||||
};
|
||||
dispatch({
|
||||
type: 'ADD_CHAT_MESSAGE',
|
||||
storyId: currentStory.id,
|
||||
message: assistantMessage,
|
||||
});
|
||||
|
||||
if (tool_calls) {
|
||||
const toolMessages: ChatMessage[] = [];
|
||||
for (const tool of tool_calls) {
|
||||
if (tool.function.name === 'test') {
|
||||
const message: ChatMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
role: 'tool',
|
||||
content: `Test successful, received: ${JSON.stringify(tool.function.arguments)}`,
|
||||
tool_call_id: tool.id,
|
||||
};
|
||||
dispatch({
|
||||
type: 'ADD_CHAT_MESSAGE',
|
||||
storyId: currentStory.id,
|
||||
message,
|
||||
});
|
||||
toolMessages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
return sendMessage([...newMessages, assistantMessage, ...toolMessages]);
|
||||
}
|
||||
} catch (err) {
|
||||
const errorMessage = err instanceof Error
|
||||
? err.message
|
||||
: 'Failed to generate response';
|
||||
|
||||
setError(errorMessage);
|
||||
}
|
||||
}, [appState, currentStory, connection, model]);
|
||||
|
||||
const handleSendMessage = useCallback(async () => {
|
||||
if (!currentStory || !input.trim() || !connection || !model || isLoading) return;
|
||||
|
||||
const userMessage = {
|
||||
id: crypto.randomUUID(),
|
||||
role: 'user' as const,
|
||||
content: input.trim(),
|
||||
};
|
||||
|
||||
setInput('');
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
await sendMessage([userMessage]);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
abortControllerRef.current = null;
|
||||
abortControllerRef.current = new AbortController();
|
||||
}
|
||||
};
|
||||
}, [currentStory, input, connection, model, isLoading]);
|
||||
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
sendMessage();
|
||||
handleSendMessage();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -169,7 +220,7 @@ export const ChatSidebar = () => {
|
|||
/>
|
||||
<button
|
||||
class={styles.sendButton}
|
||||
onClick={sendMessage}
|
||||
onClick={handleSendMessage}
|
||||
disabled={isDisabled || !input.trim()}
|
||||
>
|
||||
{isLoading ? 'Sending...' : 'Send'}
|
||||
|
|
|
|||
|
|
@ -6,10 +6,8 @@ import { useStoredReducer } from "@common/hooks/useStoredState";
|
|||
|
||||
// ─── Types ────────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface ChatMessage {
|
||||
export type ChatMessage = LLM.ChatMessage & {
|
||||
id: string;
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface Story {
|
||||
|
|
|
|||
|
|
@ -1,20 +1,95 @@
|
|||
import { formatError } from '@common/errors';
|
||||
import Lock from '@common/lock';
|
||||
import SSE, { type SSEEvent } from '@common/sse';
|
||||
import SSE from '@common/sse';
|
||||
|
||||
namespace LLM {
|
||||
export interface Connection {
|
||||
url: string;
|
||||
apiKey: string;
|
||||
}
|
||||
export interface ChatMessage {
|
||||
role: 'system' | 'user' | 'assistant';
|
||||
|
||||
export interface ToolCall {
|
||||
id: string;
|
||||
type: 'function';
|
||||
function: {
|
||||
name: string;
|
||||
arguments: string | Record<string, any>; // JSON string
|
||||
};
|
||||
}
|
||||
|
||||
interface ChatMessageUser {
|
||||
role: 'user';
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface ChatMessageAssistant {
|
||||
role: 'assistant';
|
||||
content: string;
|
||||
tool_calls?: ToolCall[];
|
||||
}
|
||||
|
||||
interface ChatMessageSystem {
|
||||
role: 'system';
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface ChatMessageTool {
|
||||
role: 'tool';
|
||||
content: string;
|
||||
tool_call_id?: string;
|
||||
}
|
||||
|
||||
export type ChatMessage = ChatMessageUser | ChatMessageAssistant | ChatMessageSystem | ChatMessageTool;
|
||||
|
||||
export interface ToolStringParameter {
|
||||
type: 'string';
|
||||
enum?: string[];
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface ToolNumberParameter {
|
||||
type: 'number' | 'integer';
|
||||
enum?: number[];
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface ToolBooleanParameter {
|
||||
type: 'boolean';
|
||||
enum?: boolean[];
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface ToolArrayParameter {
|
||||
type: 'array';
|
||||
description?: string;
|
||||
items: ToolParameter;
|
||||
}
|
||||
|
||||
export interface ToolObjectParameter {
|
||||
type: 'object';
|
||||
description?: string;
|
||||
properties: Record<string, ToolParameter>;
|
||||
required?: string[];
|
||||
}
|
||||
|
||||
export type ToolParameter = ToolStringParameter | ToolNumberParameter | ToolBooleanParameter | ToolArrayParameter | ToolObjectParameter;
|
||||
|
||||
export interface Tool {
|
||||
type: 'function';
|
||||
function: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters: {
|
||||
type: 'object';
|
||||
properties: Record<string, ToolParameter>;
|
||||
required?: string[];
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
export interface ChatCompletionRequest {
|
||||
model: string;
|
||||
messages: ChatMessage[];
|
||||
tools?: Tool[];
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
stop?: string | string[];
|
||||
|
|
@ -27,7 +102,7 @@ namespace LLM {
|
|||
export interface ChatCompletionChoice {
|
||||
index: number;
|
||||
message: ChatMessage;
|
||||
finish_reason: 'stop' | 'length' | 'content_filter';
|
||||
finish_reason: 'stop' | 'tool_calls';
|
||||
}
|
||||
|
||||
export interface ChatCompletionResponse {
|
||||
|
|
@ -46,8 +121,8 @@ namespace LLM {
|
|||
|
||||
export interface ChatCompletionChunkChoice {
|
||||
index: number;
|
||||
delta: { role?: string; content?: string };
|
||||
finish_reason: 'stop' | 'length' | 'content_filter' | null;
|
||||
delta: { role?: string; content?: string; tool_calls?: ToolCall[] };
|
||||
finish_reason: 'stop' | 'tool_calls' | null;
|
||||
}
|
||||
|
||||
export interface ChatCompletionChunk {
|
||||
|
|
@ -63,6 +138,7 @@ namespace LLM {
|
|||
object: 'model';
|
||||
created: number;
|
||||
owned_by: string;
|
||||
support_tools: boolean;
|
||||
max_context?: number;
|
||||
max_length?: number;
|
||||
}
|
||||
|
|
@ -134,8 +210,6 @@ namespace LLM {
|
|||
if (closed) return;
|
||||
closed = true;
|
||||
controller.close();
|
||||
|
||||
console.log(formatError(e));
|
||||
};
|
||||
|
||||
sse.addEventListener('error', handleEnd);
|
||||
|
|
|
|||
|
|
@ -2,7 +2,27 @@ import LLM from "./llm";
|
|||
import type { AppState } from "../contexts/state";
|
||||
|
||||
namespace Prompt {
|
||||
export function compilePrompt(state: AppState, newMessage?: LLM.ChatMessage): LLM.ChatCompletionRequest | null {
|
||||
const tools: LLM.Tool[] = [
|
||||
{
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'test',
|
||||
description: 'A simple test function',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
message: {
|
||||
type: 'string',
|
||||
description: 'The test message',
|
||||
},
|
||||
},
|
||||
required: ['message'],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
export function compilePrompt(state: AppState, newMessages: LLM.ChatMessage[] = []): LLM.ChatCompletionRequest | null {
|
||||
const { currentStory, model } = state;
|
||||
|
||||
if (!currentStory || !model) {
|
||||
|
|
@ -15,13 +35,12 @@ namespace Prompt {
|
|||
...currentStory.chatMessages,
|
||||
];
|
||||
|
||||
if (newMessage) {
|
||||
messages.push(newMessage);
|
||||
}
|
||||
messages.push(...newMessages);
|
||||
|
||||
return {
|
||||
model,
|
||||
messages,
|
||||
tools,
|
||||
// TODO banned_tokens
|
||||
};
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue