diff --git a/bun.lockb b/bun.lockb index a12c8c0..1fd6ee6 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/package.json b/package.json index 4b73220..90f7aa3 100644 --- a/package.json +++ b/package.json @@ -8,6 +8,7 @@ "bake": "bun build/build.ts" }, "dependencies": { + "@huggingface/jinja": "0.3.1", "@inquirer/select": "2.3.10", "classnames": "2.5.1", "preact": "10.22.0" diff --git a/src/common/lock.ts b/src/common/lock.ts new file mode 100644 index 0000000..fa64e58 --- /dev/null +++ b/src/common/lock.ts @@ -0,0 +1,22 @@ +export default class Lock { + private doUnlock: (() => void) | null = null; + private lockPromise: Promise | null = null; + + get locked() { + return this.lockPromise !== null; + } + + async wait(): Promise { + if (!this.lockPromise) { + this.lockPromise = new Promise(resolve => this.doUnlock = resolve); + } + + return this.lockPromise; + } + + release(): void { + this.doUnlock?.(); + this.doUnlock = null; + this.lockPromise = null; + } +} \ No newline at end of file diff --git a/src/common/sse.ts b/src/common/sse.ts new file mode 100644 index 0000000..bf59b51 --- /dev/null +++ b/src/common/sse.ts @@ -0,0 +1,270 @@ +export interface ISSEOptions { + headers?: Record; + payload?: string; + method?: string; + withCredentials?: boolean; + debug?: boolean; + start?: boolean; +} + +interface SSEEvent extends Event { + id?: string; + source?: SSE; + readyState?: number; + data?: string; +} + +interface SSEEventListener { + (e: SSEEvent): void; +} + +type OnEvent = `on${string}`; +type EventType = 'message' | 'error' | 'readystatechange' | 'abort' | 'open'; + +const FIELD_SEPARATOR = ':'; + +export default class SSE { + public static INITIALIZING = -1; + public static CONNECTING = 0; + public static OPEN = 1; + public static CLOSED = 2; + + private headers: Record; + private payload: string; + private method: string; + private withCredentials: boolean; + private debug: boolean; + private listeners: Record = {}; + private xhr: XMLHttpRequest | null = null; + private readyState: number = SSE.INITIALIZING; + private progress = 0; + private chunk = ''; + + [key: OnEvent]: SSEEventListener | undefined; + + constructor(private url: string, options: ISSEOptions = {}) { + this.headers = options.headers || {}; + this.payload = options.payload !== undefined ? options.payload : ''; + this.method = options.method || (this.payload ? 'POST' : 'GET'); + this.withCredentials = !!options.withCredentials; + this.debug = !!options.debug; + + + if (options.start === undefined || options.start) { + this.stream(); + } + } + + addEventListener(type: EventType, listener: SSEEventListener) { + if (this.listeners[type] === undefined) { + this.listeners[type] = []; + } + + if (this.listeners[type].indexOf(listener) === -1) { + this.listeners[type].push(listener); + } + } + + removeEventListener(type: EventType, listener: SSEEventListener) { + if (this.listeners[type] === undefined) { + return; + } + + const filtered: SSEEventListener[] = []; + this.listeners[type].forEach((element) => { + if (element !== listener) { + filtered.push(element); + } + }); + if (filtered.length === 0) { + delete this.listeners[type]; + } else { + this.listeners[type] = filtered; + } + } + + dispatchEvent(e: SSEEvent | null) { + if (!e) { + return true; + } + + if (this.debug) { + console.debug(e); + } + + e.source = this; + + const onHandler: OnEvent = `on${e.type}`; + if (this.hasOwnProperty(onHandler) && this[onHandler]) { + this[onHandler].call(this, e); + if (e.defaultPrevented) { + return false; + } + } + + if (this.listeners[e.type]) { + return this.listeners[e.type].every((callback) => { + callback(e); + return !e.defaultPrevented; + }); + } + + return true; + } + + private _setReadyState(state: number) { + const event: SSEEvent = new CustomEvent('readystatechange'); + event.readyState = state; + this.readyState = state; + this.dispatchEvent(event); + } + + private _onStreamFailure = (e: ProgressEvent) => { + const event: SSEEvent = new CustomEvent('error'); + if (e.currentTarget instanceof XMLHttpRequest) { + event.data = e.currentTarget.response; + } + this.dispatchEvent(event); + this.close(); + } + + private _onStreamAbort = () => { + this.dispatchEvent(new CustomEvent('abort')); + this.close(); + } + + private _onStreamProgress = (e: ProgressEvent) => { + if (!this.xhr) { + return; + } + + if (this.xhr.status !== 200) { + this._onStreamFailure(e); + return; + } + + if (this.readyState === SSE.CONNECTING) { + this.dispatchEvent(new CustomEvent('open')); + this._setReadyState(SSE.OPEN); + } + + const data = this.xhr.responseText.substring(this.progress); + + this.progress += data.length; + const parts = (this.chunk + data).split(/(\r\n|\r|\n){2}/g); + + // we assume that the last chunk can be incomplete because of buffering or other network effects + // so we always save the last part to merge it with the next incoming packet + const lastPart = parts.pop(); + parts.forEach((part) => { + if (part.trim().length > 0) { + this.dispatchEvent(this._parseEventChunk(part)); + } + }); + this.chunk = lastPart ?? ''; + } + + private _onStreamLoaded = (e: ProgressEvent) => { + this._onStreamProgress(e); + + // Parse the last chunk. + this.dispatchEvent(this._parseEventChunk(this.chunk)); + this.chunk = ''; + } + + /** + * Parse a received SSE event chunk into a constructed event object. + * + * Reference: https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage + */ + private _parseEventChunk(chunk: string) { + if (!chunk || chunk.length === 0) { + return null; + } + + if (this.debug) { + console.debug(chunk); + } + + const e: Record = { 'id': '', 'retry': '', 'data': '', 'event': 'message' }; + chunk.split(/\n|\r\n|\r/).forEach((line) => { + line = line.trimEnd(); + const index = line.indexOf(FIELD_SEPARATOR); + if (index <= 0) { + // Line was either empty, or started with a separator and is a comment. + // Either way, ignore. + return; + } + + const field = line.substring(0, index); + if (!(field in e)) { + return; + } + + // only first whitespace should be trimmed + const skip = (line[index + 1] === ' ') ? 2 : 1; + const value = line.substring(index + skip); + + // consecutive 'data' is concatenated with newlines + if (field === 'data' && e[field] !== null) { + e['data'] += "\n" + value; + } else { + e[field] = value; + } + }); + + const event: SSEEvent = new CustomEvent(e.event); + event.data = e.data || ''; + event.id = e.id; + return event; + }; + + private _checkStreamClosed = () => { + if (!this.xhr) { + return; + } + + if (this.xhr.readyState === XMLHttpRequest.DONE) { + this._setReadyState(SSE.CLOSED); + } + }; + + /** + * starts the streaming + */ + stream() { + if (this.xhr) { + // Already connected. + return; + } + + this._setReadyState(SSE.CONNECTING); + + this.xhr = new XMLHttpRequest(); + this.xhr.addEventListener('progress', this._onStreamProgress); + this.xhr.addEventListener('load', this._onStreamLoaded); + this.xhr.addEventListener('readystatechange', this._checkStreamClosed); + this.xhr.addEventListener('error', this._onStreamFailure); + this.xhr.addEventListener('abort', this._onStreamAbort); + this.xhr.open(this.method, this.url); + for (const header in this.headers) { + this.xhr.setRequestHeader(header, this.headers[header]); + } + this.xhr.withCredentials = this.withCredentials; + this.xhr.send(this.payload); + }; + + /** + * closes the stream + * @type Close + */ + close() { + if (this.readyState === SSE.CLOSED) { + return; + } + + this.xhr?.abort(); + this.xhr = null; + this._setReadyState(SSE.CLOSED); + }; +} \ No newline at end of file diff --git a/src/games/ai/components/chat.tsx b/src/games/ai/components/chat.tsx index a63b983..3e644e3 100644 --- a/src/games/ai/components/chat.tsx +++ b/src/games/ai/components/chat.tsx @@ -1,7 +1,27 @@ +import { useContext, useEffect, useRef } from "preact/hooks"; +import { GlobalContext } from "../context"; +import { Message } from "./message"; + export const Chat = () => { + const { messages } = useContext(GlobalContext); + const chatRef = useRef(null); + + const lastMessageContent = messages.at(-1)?.displayContent ?? messages.at(-1)?.content; + + useEffect(() => { + if (chatRef.current) { + chatRef.current.scrollTo({ + top: chatRef.current.scrollHeight, + behavior: 'smooth', + }); + } + }, [messages.length, lastMessageContent]); + return ( -
- Chat +
+ {messages.map((m, i) => ( + + ))}
); -} \ No newline at end of file +} diff --git a/src/games/ai/components/input.tsx b/src/games/ai/components/input.tsx index e95f53f..7a6e210 100644 --- a/src/games/ai/components/input.tsx +++ b/src/games/ai/components/input.tsx @@ -2,7 +2,9 @@ import { useCallback, useContext } from "preact/hooks"; import { GlobalContext } from "../context"; export const Input = () => { - const { input, setInput } = useContext(GlobalContext); + const { input, setInput, addMessage } = useContext(GlobalContext); + + console.log({input}); const handleChange = useCallback((e: Event) => { if (e.target instanceof HTMLTextAreaElement) { @@ -10,9 +12,12 @@ export const Input = () => { } }, []); - const handleSend = useCallback(() => { - console.log('Send:', input); - setInput(''); + const handleSend = useCallback(async () => { + const newInput = input.trim(); + if (newInput) { + addMessage(newInput, 'user', true); + setInput(''); + } }, [input]); const handleKeyDown = useCallback((e: KeyboardEvent) => { diff --git a/src/games/ai/components/message.tsx b/src/games/ai/components/message.tsx new file mode 100644 index 0000000..9407e61 --- /dev/null +++ b/src/games/ai/components/message.tsx @@ -0,0 +1,26 @@ +import { useContext, useMemo } from "preact/hooks"; +import type { IMessage } from "../messages"; +import { GlobalContext } from "../context"; + +interface IProps { + message: IMessage; +} + +export const Message = ({ message }: IProps) => { + const { name } = useContext(GlobalContext); + + return ( +
+
+
+ {message.role === 'user' ? name : '---'} +
+
+
+
+
+ {message.displayContent ?? message.content} +
+
+ ); +}; \ No newline at end of file diff --git a/src/games/ai/context.tsx b/src/games/ai/context.tsx index 52dfe3b..21984f8 100644 --- a/src/games/ai/context.tsx +++ b/src/games/ai/context.tsx @@ -1,32 +1,114 @@ import { createContext } from "preact"; -import { useMemo, useState } from "preact/hooks"; +import { useEffect, useMemo, useState } from "preact/hooks"; +import { compilePrompt, type IMessage } from "./messages"; +import { generate } from "./generation"; -export interface ISettings { +export interface IContext { connectionUrl: string; input: string; + name: string; + messages: IMessage[]; } export interface IActions { setConnectionUrl: (url: string) => void; setInput: (url: string) => void; + setName: (name: string) => void; + + setMessages: (messages: IMessage[]) => void; + addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void; + editMessage: (index: number, content: string) => void; + deleteMessage: (index: number) => void; } -export type IGlobalContext = ISettings & IActions; +export type IGlobalContext = IContext & IActions; + +const SAVE_KEY = 'ai_game_save_state'; + +const saveContext = (ctx: IContext) => { + localStorage.setItem(SAVE_KEY, JSON.stringify(ctx)); +} + +const loadContext = (): IContext => { + const defaultContext: IContext = { + connectionUrl: 'http://192.168.10.102:5001', + input: '', + name: 'Maya', + messages: [], + }; + + let loadedContext: Partial = {}; + + try { + const json = localStorage.getItem(SAVE_KEY); + if (json) { + loadedContext = JSON.parse(json); + } + } catch { } + + return { ...defaultContext, ...loadedContext }; +} export const GlobalContext = createContext({} as IGlobalContext); export const GlobalContextProvider = ({ children }: { children?: any }) => { - const [settings, setSettings] = useState({ - connectionUrl: 'http://192.168.10.102:5001', - input: '', - }); + const loadedContext = useMemo(() => loadContext(), []); + const [connectionUrl, setConnectionUrl] = useState(loadedContext.connectionUrl); + const [input, setInput] = useState(loadedContext.input); + const [name, setName] = useState(loadedContext.name); + const [messages, setMessages] = useState(loadedContext.messages); + const [triggerNext, setTriggerNext] = useState(false); const actions: IActions = useMemo(() => ({ - setConnectionUrl: (connectionUrl) => setSettings(s => ({ ...s, connectionUrl })), - setInput: (input) => setSettings(s => ({ ...s, input })), + setConnectionUrl, + setInput, + setName, + + setMessages: (newMessages) => setMessages(newMessages.slice()), + addMessage: (content, role, triggerNext = false) => { + setMessages(messages => [ + ...messages, + { role, content } + ]); + setTriggerNext(triggerNext); + }, + editMessage: (index, content) => setMessages(messages => ( + messages.map((m, i) => ({ ...m, content: i === index ? content : m.content })) + )), + deleteMessage: (index) => setMessages(messages => + messages.filter((_, i) => i !== index) + ), }), []); - const value = useMemo(() => ({ ...settings, ...actions }), [settings, actions]) + useEffect(() => void (async () => { + if (triggerNext) { + setTriggerNext(false); + + const prompt = await compilePrompt(messages); + const messageId = messages.length; + let text = ''; + actions.addMessage('', 'assistant'); + for await (const chunk of generate(connectionUrl, prompt, { temperature: 1.0 })) { + text += chunk; + actions.editMessage(messageId, text); + } + } + })(), [triggerNext, messages]); + + const rawContext: IContext = { + connectionUrl, + input, + name, + messages, + }; + + const context = useMemo(() => rawContext, Object.values(rawContext)); + + useEffect(() => { + saveContext(context); + }, [context]); + + const value = useMemo(() => ({ ...context, ...actions }), [context, actions]) return ( diff --git a/src/games/ai/generation.ts b/src/games/ai/generation.ts new file mode 100644 index 0000000..41d3112 --- /dev/null +++ b/src/games/ai/generation.ts @@ -0,0 +1,62 @@ +import Lock from "@common/lock"; +import SSE from "@common/sse"; + +interface IGenerationSettings { + temperature: number; + // TODO +} + +export async function* generate(host: string, prompt: string, generationSetings: IGenerationSettings) { + const sse = new SSE(`${host}/api/extra/generate/stream`, { + payload: JSON.stringify({ + ...generationSetings, + prompt, + stop_sequence: ['\n'], + }), + }); + + const messages: string[] = []; + const messageLock = new Lock(); + let end = false; + + sse.addEventListener('message', (e) => { + if (e.data) { + { + const { token, finish_reason } = JSON.parse(e.data); + messages.push(token); + + if (finish_reason && finish_reason !== 'null') { + end = true; + } + } + } + messageLock.release(); + }); + + const handleEnd = () => { + end = true; + messageLock.release(); + }; + + sse.addEventListener('error', handleEnd); + sse.addEventListener('abort', handleEnd); + sse.addEventListener('readystatechange', (e) => { + if (e.readyState === SSE.CLOSED) handleEnd(); + }); + + while (!end || messages.length) { + while (messages.length > 0) { + const message = messages.shift(); + if (message != null) { + try { + yield message; + } catch { } + } + } + if (!end) { + await messageLock.wait(); + } + } + + sse.close(); +} \ No newline at end of file diff --git a/src/games/ai/messages.ts b/src/games/ai/messages.ts new file mode 100644 index 0000000..e94b895 --- /dev/null +++ b/src/games/ai/messages.ts @@ -0,0 +1,28 @@ +import { Template } from "@huggingface/jinja"; +import type { IContext } from "./context"; + +export interface IMessage { + role: 'user' | 'assistant' | 'system'; + content: string; + displayContent?: string; +} + +export const applyChatTemplate = (messages: IMessage[], templateString: string, eosToken = '') => { + const template = new Template(templateString); + + const prompt = template.render({ + messages, + bos_token: '', + eos_token: eosToken, + add_generation_prompt: true, + }); + + return prompt; +} + +export const compilePrompt = async (messages: IMessage[]): Promise => { + // TODO chat template + // TODO tokenize + + return applyChatTemplate(messages, "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\\n' }}{% endfor %}") +} \ No newline at end of file diff --git a/src/games/ai/style.css b/src/games/ai/style.css index c7faeaa..7caee46 100644 --- a/src/games/ai/style.css +++ b/src/games/ai/style.css @@ -3,8 +3,10 @@ } :root { - --color: #DCDCD2; --backgroundColor: #333333; + --color: #DCDCD2; + --italicColor: #AFAFAF; + --quoteColor: #D4E5FF; } body { @@ -25,6 +27,7 @@ body { width: 100%; max-width: 1200px; height: 100%; + max-height: 100dvh; >.header { display: flex; @@ -36,13 +39,33 @@ body { >.chat { display: flex; - flex-direction: row; + flex-direction: column; height: 100%; - background-color: gray; flex-grow: 1; - width: 100%; + max-width: 100%; + overflow-x: hidden; + overflow-y: auto; scrollbar-width: thin; scrollbar-color: var(--color) transparent; + border: 1px solid var(--color); + + >.message { + width: 100%; + padding: 12px; + + >.header { + display: flex; + flex-direction: row; + justify-content: space-between; + + >.name { + font-weight: bold; + } + } + >.content { + white-space: pre-wrap; + } + } } >.chat-input { @@ -50,10 +73,10 @@ body { flex-direction: row; height: auto; min-height: 48px; - background-color: green; width: 100%; + border: 1px solid var(--color); - textarea { + >textarea { color: var(--color); background-color: var(--backgroundColor); font-size: 1em;