From 9df2993b389bd32f509df9f569481cc4935f8d06 Mon Sep 17 00:00:00 2001 From: Pabloader Date: Tue, 29 Oct 2024 18:19:21 +0000 Subject: [PATCH] Working generation --- bun.lockb | Bin 29555 -> 29916 bytes package.json | 1 + src/common/lock.ts | 22 +++ src/common/sse.ts | 270 ++++++++++++++++++++++++++++ src/games/ai/components/chat.tsx | 26 ++- src/games/ai/components/input.tsx | 13 +- src/games/ai/components/message.tsx | 26 +++ src/games/ai/context.tsx | 102 +++++++++-- src/games/ai/generation.ts | 62 +++++++ src/games/ai/messages.ts | 28 +++ src/games/ai/style.css | 35 +++- 11 files changed, 562 insertions(+), 23 deletions(-) create mode 100644 src/common/lock.ts create mode 100644 src/common/sse.ts create mode 100644 src/games/ai/components/message.tsx create mode 100644 src/games/ai/generation.ts create mode 100644 src/games/ai/messages.ts diff --git a/bun.lockb b/bun.lockb index a12c8c06747e325dfc5e87447328772d6257ed40..1fd6ee67d367e95abf892cf6798d7afc6987c22a 100755 GIT binary patch delta 4622 zcmeHKdr*|u760xBD2uG}hL07&7l_EhF3Tb!`zgMHD=H`~3sSSLM3#d1qKHP_(R6CY z`pR*tNL3`s#2Aek(NvviV$)2T(Z(8+4r`K0ZJkVBVt&OP_sd+t4Z?{~iWSbXrQ*k#W4_l$~OfA)UF)Y84vTkbk`*E>49kKB3pnZ?I; zz5ATE^5r>eB(2Qd?MmPDi;5s4NlSdoJ*#|e(u!4$9xqh7(Gev{{%yb*;7%b)!+=|W z(ZDgXBpHD1;D-Wx!4Cm01s@Bn3CUjx3+kN&vc0PCAVvYBH$m`cG?Ib+KsIn6ij0>d zB*_Rg4GIo;E+jt~GWZmb7qS`13+xE>?-?AFH-e9Yyb}&Et^o3a9}79=25#aB7D37# zRj7D{%iD6jp4E*~EEwjS8lj0&9@3if=@4W9!+^tqpFvvzUIDV=X<#bwAdq*8fV@MKfcW`OV84@qsc=(U zz-q6j#j~`r7jl^N?*JwMHv?I}1PkLM^R`j5oN1Vs7-T!BOHR-3fhAt`t`Pb|h%uV+ z7q9|2?EjK($R9=Si1B>9KJDe+zSqL~g?N^_B2@7{1yCoGGDsB%$&LCZb)n9s0O}@E z2CL#Yxlw;kU8r*?fcgzmB2|$@?nu?J2%#_;b3~9QvPkTs0AzoItOl|$@(eCA%tvrI zz(q1Ac2Ji=H3T3lgG|;&jU^>Y6&`X&sS$?|LX)T=szd}R5TzOvgveCgmZ&555LNV0 z7wQ`n7@`_8BPB^iKl_aNqH<_gf`OE1Rh%Jrv}*Vk{XA2+HhUul&~u0sqiT2`TVvPO zaCp!}5k;v{4ucD4G66C{BjL>!QiiJH0=b8(vOI+L4|NzEI81g|b4@%!$}rXN8e~>Y zb|bRL@HMzcz)6~kDhdo!BX;6w-SpwGl88^xTCBBZ;aI&|>nK`2trd@dtw(FEL91SC zT|p~2!i=+S&}7@ta#2IfTD+aSj|!AWk*pBD#2$g16V;12-|a=0Qc` zFm=WAF$dyR*&a`+2@bJ=+zG1qHFYJZa!dkUPjDEPBM9>`CW7FT@&-63INUW4mD98` zF-wj~r2UBw(LjMjRdkb*q>A^*outYUNwh!7Ar~dl^(2RQiIikjd`<3TRkkJ5{$z*v z2?ZcJMM{b)Zj(EuM8Q^Yyldq_jG*%=>6N&rB>V%k@*rYUrBNVUBhpw9>t%wtFtR+G z8I%W6P=%6=f0`uaaD}2Foj0Xh*ha1<`+h`RO14lxqoCtTGA>Cis5Ci(OIt~i#z_(* zZeFQ4#6O5E=l?q5Pnr?x{|```|9cJKFR9&fycKgo6Z{b5ZKw$8Gx9>`h4=@N<&_e} zrNw!8$ik2cBTwl9vAh<SM`*wBhq`C)`(n z&tr~DOd!hy7oN~3QJd)o>O3mWb%}i1joLz=qPEghvr80E7it@Qg}RW+^IUX0&q&Af zoT7;CfIE_Jq=tN_D5mH0T~uo^lF{N6CFHiaDAH=A^WaL!V1-xUTCGl@(phj#1@NlC zDW;OI0AAVP6}V}n*x(ho4K}BkL2rRuTL`ZTonjVs7Q!n#yaHEF*>-pZZoAznoD=}J zwa7@NMNU!C_hiwJWqLFsfp$fN8`_t(_#W$QUftLyCVh_Wz()%nK>wfcFeL!_KdyiK zcw2nt=GDRf{fWNB$xpNWTSio~`2BMk_p; zb*9V>@@wZ&W<|{mBRIZX_yVd1;d?~8x8?(R2y3%BzJw=%azSQL9taPgWC2;}VnzH& zyq?K#$OO6vU6&&&W952^x{gJy#YLHtOK z0*waogOmjt2O1Ab0ZjlM30c7S3-UEb`4( zvvnfJqeI!(uLwdP%KP#J{uENZYFvlb-49&a2y=dAhV1R;>AEi_KiTnJ&VLHls<; zkDjT1=4jWYIxlOPi+RNdY1jfqXfvAhtm;CQ?{)d!wO!D%nJv8Tll0Moq!9>NJ%g*h z_QjZWZ82`B*s+~Z(VP@o^<#X7;yDj3S6{DtKf7p;>((}GL@yqvT{`Jrb^cP|S zVPBgl9F2a;U8tD!eDGMC5goL$R4lNzeD*wyo~}_WGXS z5n5RJHuY5(*!BGGpA0#dEVFG+uAA(^y^urlKakrBj>u$4fLo+`T~p6^Uw`f?_cX!d8KJUp@Y(0N}--E z7Chx$ckw_E{B4hYIHCN!#N&C(I@h4NV>1^O+k2_irHB*sGgpC}8b`0VQWE}ob7I7p zIQqn86*Xk4EiviYXWWv%?5TZp)^Sgr8omRSNZ-^uT8Cw_{HnMGWb> zde-_);MJBl17Y(9dTx`^t(f#&{pJ_vQeVrsdlf6rH`_U=g0s9p_0akf_rC!#{Pd-L zorjv=@yTN2L^F;dA6h%zp>T^D$KRDGpVZT3cNxXh<%?r9wJu%UAx~ZU$S^!ncqsf~ zr^f7ybv8}v@T_jN;)$C2$Bxp+Z~D$0q{DUTsn0_z8d{5QUFvC_*58}LF9FK*HFUkM yC$_15*)reqWsiBhjd{(!<;|YH&id}bG^o*(G38(S@^XVZxrivUr7yjbI-ZwKF&RN&-b3I0rAhz#M2h#+t;7DHSpDw2cCQX+M>&6zB2oAzgE8So!*%z zjdWg&PgAdVN?NJv_hf9_yEMp1(!>78S~~q}rFEU{Ev*|QsSh17lGL*am;l@=BxwS0 zJunWKXpkf$uoZkf@M-X|zH-z7|Jd66)|Rez z>3E1=-5xxFd(lC?FGF|`%B+6_JflJ;&}=f{x&=kc1Z}Lp$Ju+r2y&(QX*CH z6!{}nE;<}lX1Ja}qlqrDfs{$A@hYwbpSa*NG@277 zsYKh~`bej!rNAT=o>r1n!;M5*ljIV4G?b)@^%Q{U?Maj~*=0<}RVu`o2)0Sy0L~3= z1K*u-rHxKZ&N956M5B{kB94?~RV*Pt>c?p)Sv9;inMRXchR-HbPKrx3lRrfjgEW+) z8m=ePC`7YJF{`4L{HPDokhx63sqoF#%6P=3kYI*Z7zKT6&e+B*wx*4k4#a<6nhC;n zm1cpsFtR+G8I){@1g%N)%_Lx@2Bg;!cxPa0Etye>B{qX{=Yn;Ll=EI?h%c~t29IloBk652ASQ;9@ z$UChL@#B%@wIJ573t@c-mjSsjvd0=gtmgr7VdU{%5X*fb{vrK{!E0Ax{Ke5_KVBfd zNIzbn@4P_6N7LJkv@y#~tFugWF3T;_NzV3=Iom{Av)v+t&VoAwE?;qr>9k4l(9aYT zT?03RlpGIb=a^_mjvEo}GPq0NDstT-i~4ds)SGLfTi_I$ljot5JQMBDbBkOGhqTwAkVig*1rTO5dQiQC)#Y*l7rL5s5;NaFDmqLyd(dI#%cwPBL0O z6m2z8ht)0S&{1%2fJ?WzMH%^R9$IaKRp2T}w!N@u~H0heFo7L~NA2v!xr zDsb~jalk4Eta7-;19TbOC2$qRZn2R1ieXhTtOB=~<~U)M6IMCh!c75ipF2&%RVB>^ zI$jzll7`Qfejyb7ZPSlW7Gib}J3@cwv_9&$tmz6~&sD>pls}#j^KmGeq*c*CRrx~x z8%y{98!X)BsKh})|F7#jPn!(l!c(x$qAKPjX-hRtt!~KY_`xw|IcNol<5LZY)mf9L z;y=qAP#z6dr#N^Ke4F{j%TFqP(Q47B3CK~2BN&I<2S5*kc)O)^t2%29M{_5Lw`2z4 zwZ|_%ehu;qQJM2pKLRxMY&u=D$H5cuq%%M~DKE?$;0;X!@nZbB3QXM3N6p91ZKhiOHlUnjAECsQ zH`$6;ZX{=2I@Q)~iZTSnG*Z_~o%PA0kNWEEVgVhims9mrV@>VfFUfnk!G-0{y(DRh!UEa8^Uv790BZ@58pCo<%E_TyO z>~(x`N@pq#>}Ho zD8}}nFgHB(_!rMzdGPBO)}hCN9;+m2%cSaQ*UdKHifiS^7ofi+IPxO3U>k4J<`r_P zo`hX&JT$WS?h7{zqT32%VE%S`1zPR&E3cfU=Smf^dp^iXdhZzKvIpn-GMetMFdObg zlexi$w8PyXi&$!Js7lrIu=7P-m+t;$*KLg9nX$ivbgsdU7%DunArMPx9&@Umes#aG zvFl99M>b8_X0A7G2>Z}Owj)*^xDa2AnqwLUg z#rD;gNBchg>xX0g)``Jv(*B#)pP$%2u+<=DLW0=SQ$Zb#a%w2+>-mHE^wp7(`=7#G z8~d(ET1A76c8i`RcJ(c~|Js3>TeRcQe5=RVwa;9b(A~RqPlEVhA{o3kBzX>RS*o5l zCO-V{U5!6kc-SEHN3iBdp*fY_@=kH+Ipn$iw~B20ZjX*>>RIOg?X7oby>bW_OP`~L zOg9 jaP#dy4Rp-E^rFN89OX(U{iA7M_@#&TMGbFlYfktd4~-ir diff --git a/package.json b/package.json index 4b73220..90f7aa3 100644 --- a/package.json +++ b/package.json @@ -8,6 +8,7 @@ "bake": "bun build/build.ts" }, "dependencies": { + "@huggingface/jinja": "0.3.1", "@inquirer/select": "2.3.10", "classnames": "2.5.1", "preact": "10.22.0" diff --git a/src/common/lock.ts b/src/common/lock.ts new file mode 100644 index 0000000..fa64e58 --- /dev/null +++ b/src/common/lock.ts @@ -0,0 +1,22 @@ +export default class Lock { + private doUnlock: (() => void) | null = null; + private lockPromise: Promise | null = null; + + get locked() { + return this.lockPromise !== null; + } + + async wait(): Promise { + if (!this.lockPromise) { + this.lockPromise = new Promise(resolve => this.doUnlock = resolve); + } + + return this.lockPromise; + } + + release(): void { + this.doUnlock?.(); + this.doUnlock = null; + this.lockPromise = null; + } +} \ No newline at end of file diff --git a/src/common/sse.ts b/src/common/sse.ts new file mode 100644 index 0000000..bf59b51 --- /dev/null +++ b/src/common/sse.ts @@ -0,0 +1,270 @@ +export interface ISSEOptions { + headers?: Record; + payload?: string; + method?: string; + withCredentials?: boolean; + debug?: boolean; + start?: boolean; +} + +interface SSEEvent extends Event { + id?: string; + source?: SSE; + readyState?: number; + data?: string; +} + +interface SSEEventListener { + (e: SSEEvent): void; +} + +type OnEvent = `on${string}`; +type EventType = 'message' | 'error' | 'readystatechange' | 'abort' | 'open'; + +const FIELD_SEPARATOR = ':'; + +export default class SSE { + public static INITIALIZING = -1; + public static CONNECTING = 0; + public static OPEN = 1; + public static CLOSED = 2; + + private headers: Record; + private payload: string; + private method: string; + private withCredentials: boolean; + private debug: boolean; + private listeners: Record = {}; + private xhr: XMLHttpRequest | null = null; + private readyState: number = SSE.INITIALIZING; + private progress = 0; + private chunk = ''; + + [key: OnEvent]: SSEEventListener | undefined; + + constructor(private url: string, options: ISSEOptions = {}) { + this.headers = options.headers || {}; + this.payload = options.payload !== undefined ? options.payload : ''; + this.method = options.method || (this.payload ? 'POST' : 'GET'); + this.withCredentials = !!options.withCredentials; + this.debug = !!options.debug; + + + if (options.start === undefined || options.start) { + this.stream(); + } + } + + addEventListener(type: EventType, listener: SSEEventListener) { + if (this.listeners[type] === undefined) { + this.listeners[type] = []; + } + + if (this.listeners[type].indexOf(listener) === -1) { + this.listeners[type].push(listener); + } + } + + removeEventListener(type: EventType, listener: SSEEventListener) { + if (this.listeners[type] === undefined) { + return; + } + + const filtered: SSEEventListener[] = []; + this.listeners[type].forEach((element) => { + if (element !== listener) { + filtered.push(element); + } + }); + if (filtered.length === 0) { + delete this.listeners[type]; + } else { + this.listeners[type] = filtered; + } + } + + dispatchEvent(e: SSEEvent | null) { + if (!e) { + return true; + } + + if (this.debug) { + console.debug(e); + } + + e.source = this; + + const onHandler: OnEvent = `on${e.type}`; + if (this.hasOwnProperty(onHandler) && this[onHandler]) { + this[onHandler].call(this, e); + if (e.defaultPrevented) { + return false; + } + } + + if (this.listeners[e.type]) { + return this.listeners[e.type].every((callback) => { + callback(e); + return !e.defaultPrevented; + }); + } + + return true; + } + + private _setReadyState(state: number) { + const event: SSEEvent = new CustomEvent('readystatechange'); + event.readyState = state; + this.readyState = state; + this.dispatchEvent(event); + } + + private _onStreamFailure = (e: ProgressEvent) => { + const event: SSEEvent = new CustomEvent('error'); + if (e.currentTarget instanceof XMLHttpRequest) { + event.data = e.currentTarget.response; + } + this.dispatchEvent(event); + this.close(); + } + + private _onStreamAbort = () => { + this.dispatchEvent(new CustomEvent('abort')); + this.close(); + } + + private _onStreamProgress = (e: ProgressEvent) => { + if (!this.xhr) { + return; + } + + if (this.xhr.status !== 200) { + this._onStreamFailure(e); + return; + } + + if (this.readyState === SSE.CONNECTING) { + this.dispatchEvent(new CustomEvent('open')); + this._setReadyState(SSE.OPEN); + } + + const data = this.xhr.responseText.substring(this.progress); + + this.progress += data.length; + const parts = (this.chunk + data).split(/(\r\n|\r|\n){2}/g); + + // we assume that the last chunk can be incomplete because of buffering or other network effects + // so we always save the last part to merge it with the next incoming packet + const lastPart = parts.pop(); + parts.forEach((part) => { + if (part.trim().length > 0) { + this.dispatchEvent(this._parseEventChunk(part)); + } + }); + this.chunk = lastPart ?? ''; + } + + private _onStreamLoaded = (e: ProgressEvent) => { + this._onStreamProgress(e); + + // Parse the last chunk. + this.dispatchEvent(this._parseEventChunk(this.chunk)); + this.chunk = ''; + } + + /** + * Parse a received SSE event chunk into a constructed event object. + * + * Reference: https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage + */ + private _parseEventChunk(chunk: string) { + if (!chunk || chunk.length === 0) { + return null; + } + + if (this.debug) { + console.debug(chunk); + } + + const e: Record = { 'id': '', 'retry': '', 'data': '', 'event': 'message' }; + chunk.split(/\n|\r\n|\r/).forEach((line) => { + line = line.trimEnd(); + const index = line.indexOf(FIELD_SEPARATOR); + if (index <= 0) { + // Line was either empty, or started with a separator and is a comment. + // Either way, ignore. + return; + } + + const field = line.substring(0, index); + if (!(field in e)) { + return; + } + + // only first whitespace should be trimmed + const skip = (line[index + 1] === ' ') ? 2 : 1; + const value = line.substring(index + skip); + + // consecutive 'data' is concatenated with newlines + if (field === 'data' && e[field] !== null) { + e['data'] += "\n" + value; + } else { + e[field] = value; + } + }); + + const event: SSEEvent = new CustomEvent(e.event); + event.data = e.data || ''; + event.id = e.id; + return event; + }; + + private _checkStreamClosed = () => { + if (!this.xhr) { + return; + } + + if (this.xhr.readyState === XMLHttpRequest.DONE) { + this._setReadyState(SSE.CLOSED); + } + }; + + /** + * starts the streaming + */ + stream() { + if (this.xhr) { + // Already connected. + return; + } + + this._setReadyState(SSE.CONNECTING); + + this.xhr = new XMLHttpRequest(); + this.xhr.addEventListener('progress', this._onStreamProgress); + this.xhr.addEventListener('load', this._onStreamLoaded); + this.xhr.addEventListener('readystatechange', this._checkStreamClosed); + this.xhr.addEventListener('error', this._onStreamFailure); + this.xhr.addEventListener('abort', this._onStreamAbort); + this.xhr.open(this.method, this.url); + for (const header in this.headers) { + this.xhr.setRequestHeader(header, this.headers[header]); + } + this.xhr.withCredentials = this.withCredentials; + this.xhr.send(this.payload); + }; + + /** + * closes the stream + * @type Close + */ + close() { + if (this.readyState === SSE.CLOSED) { + return; + } + + this.xhr?.abort(); + this.xhr = null; + this._setReadyState(SSE.CLOSED); + }; +} \ No newline at end of file diff --git a/src/games/ai/components/chat.tsx b/src/games/ai/components/chat.tsx index a63b983..3e644e3 100644 --- a/src/games/ai/components/chat.tsx +++ b/src/games/ai/components/chat.tsx @@ -1,7 +1,27 @@ +import { useContext, useEffect, useRef } from "preact/hooks"; +import { GlobalContext } from "../context"; +import { Message } from "./message"; + export const Chat = () => { + const { messages } = useContext(GlobalContext); + const chatRef = useRef(null); + + const lastMessageContent = messages.at(-1)?.displayContent ?? messages.at(-1)?.content; + + useEffect(() => { + if (chatRef.current) { + chatRef.current.scrollTo({ + top: chatRef.current.scrollHeight, + behavior: 'smooth', + }); + } + }, [messages.length, lastMessageContent]); + return ( -
- Chat +
+ {messages.map((m, i) => ( + + ))}
); -} \ No newline at end of file +} diff --git a/src/games/ai/components/input.tsx b/src/games/ai/components/input.tsx index e95f53f..7a6e210 100644 --- a/src/games/ai/components/input.tsx +++ b/src/games/ai/components/input.tsx @@ -2,7 +2,9 @@ import { useCallback, useContext } from "preact/hooks"; import { GlobalContext } from "../context"; export const Input = () => { - const { input, setInput } = useContext(GlobalContext); + const { input, setInput, addMessage } = useContext(GlobalContext); + + console.log({input}); const handleChange = useCallback((e: Event) => { if (e.target instanceof HTMLTextAreaElement) { @@ -10,9 +12,12 @@ export const Input = () => { } }, []); - const handleSend = useCallback(() => { - console.log('Send:', input); - setInput(''); + const handleSend = useCallback(async () => { + const newInput = input.trim(); + if (newInput) { + addMessage(newInput, 'user', true); + setInput(''); + } }, [input]); const handleKeyDown = useCallback((e: KeyboardEvent) => { diff --git a/src/games/ai/components/message.tsx b/src/games/ai/components/message.tsx new file mode 100644 index 0000000..9407e61 --- /dev/null +++ b/src/games/ai/components/message.tsx @@ -0,0 +1,26 @@ +import { useContext, useMemo } from "preact/hooks"; +import type { IMessage } from "../messages"; +import { GlobalContext } from "../context"; + +interface IProps { + message: IMessage; +} + +export const Message = ({ message }: IProps) => { + const { name } = useContext(GlobalContext); + + return ( +
+
+
+ {message.role === 'user' ? name : '---'} +
+
+
+
+
+ {message.displayContent ?? message.content} +
+
+ ); +}; \ No newline at end of file diff --git a/src/games/ai/context.tsx b/src/games/ai/context.tsx index 52dfe3b..21984f8 100644 --- a/src/games/ai/context.tsx +++ b/src/games/ai/context.tsx @@ -1,32 +1,114 @@ import { createContext } from "preact"; -import { useMemo, useState } from "preact/hooks"; +import { useEffect, useMemo, useState } from "preact/hooks"; +import { compilePrompt, type IMessage } from "./messages"; +import { generate } from "./generation"; -export interface ISettings { +export interface IContext { connectionUrl: string; input: string; + name: string; + messages: IMessage[]; } export interface IActions { setConnectionUrl: (url: string) => void; setInput: (url: string) => void; + setName: (name: string) => void; + + setMessages: (messages: IMessage[]) => void; + addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void; + editMessage: (index: number, content: string) => void; + deleteMessage: (index: number) => void; } -export type IGlobalContext = ISettings & IActions; +export type IGlobalContext = IContext & IActions; + +const SAVE_KEY = 'ai_game_save_state'; + +const saveContext = (ctx: IContext) => { + localStorage.setItem(SAVE_KEY, JSON.stringify(ctx)); +} + +const loadContext = (): IContext => { + const defaultContext: IContext = { + connectionUrl: 'http://192.168.10.102:5001', + input: '', + name: 'Maya', + messages: [], + }; + + let loadedContext: Partial = {}; + + try { + const json = localStorage.getItem(SAVE_KEY); + if (json) { + loadedContext = JSON.parse(json); + } + } catch { } + + return { ...defaultContext, ...loadedContext }; +} export const GlobalContext = createContext({} as IGlobalContext); export const GlobalContextProvider = ({ children }: { children?: any }) => { - const [settings, setSettings] = useState({ - connectionUrl: 'http://192.168.10.102:5001', - input: '', - }); + const loadedContext = useMemo(() => loadContext(), []); + const [connectionUrl, setConnectionUrl] = useState(loadedContext.connectionUrl); + const [input, setInput] = useState(loadedContext.input); + const [name, setName] = useState(loadedContext.name); + const [messages, setMessages] = useState(loadedContext.messages); + const [triggerNext, setTriggerNext] = useState(false); const actions: IActions = useMemo(() => ({ - setConnectionUrl: (connectionUrl) => setSettings(s => ({ ...s, connectionUrl })), - setInput: (input) => setSettings(s => ({ ...s, input })), + setConnectionUrl, + setInput, + setName, + + setMessages: (newMessages) => setMessages(newMessages.slice()), + addMessage: (content, role, triggerNext = false) => { + setMessages(messages => [ + ...messages, + { role, content } + ]); + setTriggerNext(triggerNext); + }, + editMessage: (index, content) => setMessages(messages => ( + messages.map((m, i) => ({ ...m, content: i === index ? content : m.content })) + )), + deleteMessage: (index) => setMessages(messages => + messages.filter((_, i) => i !== index) + ), }), []); - const value = useMemo(() => ({ ...settings, ...actions }), [settings, actions]) + useEffect(() => void (async () => { + if (triggerNext) { + setTriggerNext(false); + + const prompt = await compilePrompt(messages); + const messageId = messages.length; + let text = ''; + actions.addMessage('', 'assistant'); + for await (const chunk of generate(connectionUrl, prompt, { temperature: 1.0 })) { + text += chunk; + actions.editMessage(messageId, text); + } + } + })(), [triggerNext, messages]); + + const rawContext: IContext = { + connectionUrl, + input, + name, + messages, + }; + + const context = useMemo(() => rawContext, Object.values(rawContext)); + + useEffect(() => { + saveContext(context); + }, [context]); + + const value = useMemo(() => ({ ...context, ...actions }), [context, actions]) return ( diff --git a/src/games/ai/generation.ts b/src/games/ai/generation.ts new file mode 100644 index 0000000..41d3112 --- /dev/null +++ b/src/games/ai/generation.ts @@ -0,0 +1,62 @@ +import Lock from "@common/lock"; +import SSE from "@common/sse"; + +interface IGenerationSettings { + temperature: number; + // TODO +} + +export async function* generate(host: string, prompt: string, generationSetings: IGenerationSettings) { + const sse = new SSE(`${host}/api/extra/generate/stream`, { + payload: JSON.stringify({ + ...generationSetings, + prompt, + stop_sequence: ['\n'], + }), + }); + + const messages: string[] = []; + const messageLock = new Lock(); + let end = false; + + sse.addEventListener('message', (e) => { + if (e.data) { + { + const { token, finish_reason } = JSON.parse(e.data); + messages.push(token); + + if (finish_reason && finish_reason !== 'null') { + end = true; + } + } + } + messageLock.release(); + }); + + const handleEnd = () => { + end = true; + messageLock.release(); + }; + + sse.addEventListener('error', handleEnd); + sse.addEventListener('abort', handleEnd); + sse.addEventListener('readystatechange', (e) => { + if (e.readyState === SSE.CLOSED) handleEnd(); + }); + + while (!end || messages.length) { + while (messages.length > 0) { + const message = messages.shift(); + if (message != null) { + try { + yield message; + } catch { } + } + } + if (!end) { + await messageLock.wait(); + } + } + + sse.close(); +} \ No newline at end of file diff --git a/src/games/ai/messages.ts b/src/games/ai/messages.ts new file mode 100644 index 0000000..e94b895 --- /dev/null +++ b/src/games/ai/messages.ts @@ -0,0 +1,28 @@ +import { Template } from "@huggingface/jinja"; +import type { IContext } from "./context"; + +export interface IMessage { + role: 'user' | 'assistant' | 'system'; + content: string; + displayContent?: string; +} + +export const applyChatTemplate = (messages: IMessage[], templateString: string, eosToken = '') => { + const template = new Template(templateString); + + const prompt = template.render({ + messages, + bos_token: '', + eos_token: eosToken, + add_generation_prompt: true, + }); + + return prompt; +} + +export const compilePrompt = async (messages: IMessage[]): Promise => { + // TODO chat template + // TODO tokenize + + return applyChatTemplate(messages, "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\\n' }}{% endfor %}") +} \ No newline at end of file diff --git a/src/games/ai/style.css b/src/games/ai/style.css index c7faeaa..7caee46 100644 --- a/src/games/ai/style.css +++ b/src/games/ai/style.css @@ -3,8 +3,10 @@ } :root { - --color: #DCDCD2; --backgroundColor: #333333; + --color: #DCDCD2; + --italicColor: #AFAFAF; + --quoteColor: #D4E5FF; } body { @@ -25,6 +27,7 @@ body { width: 100%; max-width: 1200px; height: 100%; + max-height: 100dvh; >.header { display: flex; @@ -36,13 +39,33 @@ body { >.chat { display: flex; - flex-direction: row; + flex-direction: column; height: 100%; - background-color: gray; flex-grow: 1; - width: 100%; + max-width: 100%; + overflow-x: hidden; + overflow-y: auto; scrollbar-width: thin; scrollbar-color: var(--color) transparent; + border: 1px solid var(--color); + + >.message { + width: 100%; + padding: 12px; + + >.header { + display: flex; + flex-direction: row; + justify-content: space-between; + + >.name { + font-weight: bold; + } + } + >.content { + white-space: pre-wrap; + } + } } >.chat-input { @@ -50,10 +73,10 @@ body { flex-direction: row; height: auto; min-height: 48px; - background-color: green; width: 100%; + border: 1px solid var(--color); - textarea { + >textarea { color: var(--color); background-color: var(--backgroundColor); font-size: 1em;