1
0
Fork 0

Working generation

This commit is contained in:
Pabloader 2024-10-29 18:19:21 +00:00
parent a2faacf130
commit 9df2993b38
11 changed files with 562 additions and 23 deletions

BIN
bun.lockb

Binary file not shown.

View File

@ -8,6 +8,7 @@
"bake": "bun build/build.ts"
},
"dependencies": {
"@huggingface/jinja": "0.3.1",
"@inquirer/select": "2.3.10",
"classnames": "2.5.1",
"preact": "10.22.0"

22
src/common/lock.ts Normal file
View File

@ -0,0 +1,22 @@
export default class Lock {
private doUnlock: (() => void) | null = null;
private lockPromise: Promise<void> | null = null;
get locked() {
return this.lockPromise !== null;
}
async wait(): Promise<void> {
if (!this.lockPromise) {
this.lockPromise = new Promise(resolve => this.doUnlock = resolve);
}
return this.lockPromise;
}
release(): void {
this.doUnlock?.();
this.doUnlock = null;
this.lockPromise = null;
}
}

270
src/common/sse.ts Normal file
View File

@ -0,0 +1,270 @@
export interface ISSEOptions {
headers?: Record<string, string>;
payload?: string;
method?: string;
withCredentials?: boolean;
debug?: boolean;
start?: boolean;
}
interface SSEEvent extends Event {
id?: string;
source?: SSE;
readyState?: number;
data?: string;
}
interface SSEEventListener {
(e: SSEEvent): void;
}
type OnEvent = `on${string}`;
type EventType = 'message' | 'error' | 'readystatechange' | 'abort' | 'open';
const FIELD_SEPARATOR = ':';
export default class SSE {
public static INITIALIZING = -1;
public static CONNECTING = 0;
public static OPEN = 1;
public static CLOSED = 2;
private headers: Record<string, string>;
private payload: string;
private method: string;
private withCredentials: boolean;
private debug: boolean;
private listeners: Record<string, SSEEventListener[]> = {};
private xhr: XMLHttpRequest | null = null;
private readyState: number = SSE.INITIALIZING;
private progress = 0;
private chunk = '';
[key: OnEvent]: SSEEventListener | undefined;
constructor(private url: string, options: ISSEOptions = {}) {
this.headers = options.headers || {};
this.payload = options.payload !== undefined ? options.payload : '';
this.method = options.method || (this.payload ? 'POST' : 'GET');
this.withCredentials = !!options.withCredentials;
this.debug = !!options.debug;
if (options.start === undefined || options.start) {
this.stream();
}
}
addEventListener(type: EventType, listener: SSEEventListener) {
if (this.listeners[type] === undefined) {
this.listeners[type] = [];
}
if (this.listeners[type].indexOf(listener) === -1) {
this.listeners[type].push(listener);
}
}
removeEventListener(type: EventType, listener: SSEEventListener) {
if (this.listeners[type] === undefined) {
return;
}
const filtered: SSEEventListener[] = [];
this.listeners[type].forEach((element) => {
if (element !== listener) {
filtered.push(element);
}
});
if (filtered.length === 0) {
delete this.listeners[type];
} else {
this.listeners[type] = filtered;
}
}
dispatchEvent(e: SSEEvent | null) {
if (!e) {
return true;
}
if (this.debug) {
console.debug(e);
}
e.source = this;
const onHandler: OnEvent = `on${e.type}`;
if (this.hasOwnProperty(onHandler) && this[onHandler]) {
this[onHandler].call(this, e);
if (e.defaultPrevented) {
return false;
}
}
if (this.listeners[e.type]) {
return this.listeners[e.type].every((callback) => {
callback(e);
return !e.defaultPrevented;
});
}
return true;
}
private _setReadyState(state: number) {
const event: SSEEvent = new CustomEvent('readystatechange');
event.readyState = state;
this.readyState = state;
this.dispatchEvent(event);
}
private _onStreamFailure = (e: ProgressEvent<XMLHttpRequestEventTarget>) => {
const event: SSEEvent = new CustomEvent('error');
if (e.currentTarget instanceof XMLHttpRequest) {
event.data = e.currentTarget.response;
}
this.dispatchEvent(event);
this.close();
}
private _onStreamAbort = () => {
this.dispatchEvent(new CustomEvent('abort'));
this.close();
}
private _onStreamProgress = (e: ProgressEvent<XMLHttpRequestEventTarget>) => {
if (!this.xhr) {
return;
}
if (this.xhr.status !== 200) {
this._onStreamFailure(e);
return;
}
if (this.readyState === SSE.CONNECTING) {
this.dispatchEvent(new CustomEvent('open'));
this._setReadyState(SSE.OPEN);
}
const data = this.xhr.responseText.substring(this.progress);
this.progress += data.length;
const parts = (this.chunk + data).split(/(\r\n|\r|\n){2}/g);
// we assume that the last chunk can be incomplete because of buffering or other network effects
// so we always save the last part to merge it with the next incoming packet
const lastPart = parts.pop();
parts.forEach((part) => {
if (part.trim().length > 0) {
this.dispatchEvent(this._parseEventChunk(part));
}
});
this.chunk = lastPart ?? '';
}
private _onStreamLoaded = (e: ProgressEvent<XMLHttpRequestEventTarget>) => {
this._onStreamProgress(e);
// Parse the last chunk.
this.dispatchEvent(this._parseEventChunk(this.chunk));
this.chunk = '';
}
/**
* Parse a received SSE event chunk into a constructed event object.
*
* Reference: https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage
*/
private _parseEventChunk(chunk: string) {
if (!chunk || chunk.length === 0) {
return null;
}
if (this.debug) {
console.debug(chunk);
}
const e: Record<string, string> = { 'id': '', 'retry': '', 'data': '', 'event': 'message' };
chunk.split(/\n|\r\n|\r/).forEach((line) => {
line = line.trimEnd();
const index = line.indexOf(FIELD_SEPARATOR);
if (index <= 0) {
// Line was either empty, or started with a separator and is a comment.
// Either way, ignore.
return;
}
const field = line.substring(0, index);
if (!(field in e)) {
return;
}
// only first whitespace should be trimmed
const skip = (line[index + 1] === ' ') ? 2 : 1;
const value = line.substring(index + skip);
// consecutive 'data' is concatenated with newlines
if (field === 'data' && e[field] !== null) {
e['data'] += "\n" + value;
} else {
e[field] = value;
}
});
const event: SSEEvent = new CustomEvent(e.event);
event.data = e.data || '';
event.id = e.id;
return event;
};
private _checkStreamClosed = () => {
if (!this.xhr) {
return;
}
if (this.xhr.readyState === XMLHttpRequest.DONE) {
this._setReadyState(SSE.CLOSED);
}
};
/**
* starts the streaming
*/
stream() {
if (this.xhr) {
// Already connected.
return;
}
this._setReadyState(SSE.CONNECTING);
this.xhr = new XMLHttpRequest();
this.xhr.addEventListener('progress', this._onStreamProgress);
this.xhr.addEventListener('load', this._onStreamLoaded);
this.xhr.addEventListener('readystatechange', this._checkStreamClosed);
this.xhr.addEventListener('error', this._onStreamFailure);
this.xhr.addEventListener('abort', this._onStreamAbort);
this.xhr.open(this.method, this.url);
for (const header in this.headers) {
this.xhr.setRequestHeader(header, this.headers[header]);
}
this.xhr.withCredentials = this.withCredentials;
this.xhr.send(this.payload);
};
/**
* closes the stream
* @type Close
*/
close() {
if (this.readyState === SSE.CLOSED) {
return;
}
this.xhr?.abort();
this.xhr = null;
this._setReadyState(SSE.CLOSED);
};
}

View File

@ -1,7 +1,27 @@
import { useContext, useEffect, useRef } from "preact/hooks";
import { GlobalContext } from "../context";
import { Message } from "./message";
export const Chat = () => {
const { messages } = useContext(GlobalContext);
const chatRef = useRef<HTMLDivElement>(null);
const lastMessageContent = messages.at(-1)?.displayContent ?? messages.at(-1)?.content;
useEffect(() => {
if (chatRef.current) {
chatRef.current.scrollTo({
top: chatRef.current.scrollHeight,
behavior: 'smooth',
});
}
}, [messages.length, lastMessageContent]);
return (
<div class="chat">
Chat
<div class="chat" ref={chatRef}>
{messages.map((m, i) => (
<Message message={m} key={i} />
))}
</div>
);
}
}

View File

@ -2,7 +2,9 @@ import { useCallback, useContext } from "preact/hooks";
import { GlobalContext } from "../context";
export const Input = () => {
const { input, setInput } = useContext(GlobalContext);
const { input, setInput, addMessage } = useContext(GlobalContext);
console.log({input});
const handleChange = useCallback((e: Event) => {
if (e.target instanceof HTMLTextAreaElement) {
@ -10,9 +12,12 @@ export const Input = () => {
}
}, []);
const handleSend = useCallback(() => {
console.log('Send:', input);
setInput('');
const handleSend = useCallback(async () => {
const newInput = input.trim();
if (newInput) {
addMessage(newInput, 'user', true);
setInput('');
}
}, [input]);
const handleKeyDown = useCallback((e: KeyboardEvent) => {

View File

@ -0,0 +1,26 @@
import { useContext, useMemo } from "preact/hooks";
import type { IMessage } from "../messages";
import { GlobalContext } from "../context";
interface IProps {
message: IMessage;
}
export const Message = ({ message }: IProps) => {
const { name } = useContext(GlobalContext);
return (
<div class="message">
<div class="header">
<div class="name">
{message.role === 'user' ? name : '---'}
</div>
<div class="buttons">
</div>
</div>
<div class="content">
{message.displayContent ?? message.content}
</div>
</div>
);
};

View File

@ -1,32 +1,114 @@
import { createContext } from "preact";
import { useMemo, useState } from "preact/hooks";
import { useEffect, useMemo, useState } from "preact/hooks";
import { compilePrompt, type IMessage } from "./messages";
import { generate } from "./generation";
export interface ISettings {
export interface IContext {
connectionUrl: string;
input: string;
name: string;
messages: IMessage[];
}
export interface IActions {
setConnectionUrl: (url: string) => void;
setInput: (url: string) => void;
setName: (name: string) => void;
setMessages: (messages: IMessage[]) => void;
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
editMessage: (index: number, content: string) => void;
deleteMessage: (index: number) => void;
}
export type IGlobalContext = ISettings & IActions;
export type IGlobalContext = IContext & IActions;
const SAVE_KEY = 'ai_game_save_state';
const saveContext = (ctx: IContext) => {
localStorage.setItem(SAVE_KEY, JSON.stringify(ctx));
}
const loadContext = (): IContext => {
const defaultContext: IContext = {
connectionUrl: 'http://192.168.10.102:5001',
input: '',
name: 'Maya',
messages: [],
};
let loadedContext: Partial<IContext> = {};
try {
const json = localStorage.getItem(SAVE_KEY);
if (json) {
loadedContext = JSON.parse(json);
}
} catch { }
return { ...defaultContext, ...loadedContext };
}
export const GlobalContext = createContext<IGlobalContext>({} as IGlobalContext);
export const GlobalContextProvider = ({ children }: { children?: any }) => {
const [settings, setSettings] = useState<ISettings>({
connectionUrl: 'http://192.168.10.102:5001',
input: '',
});
const loadedContext = useMemo(() => loadContext(), []);
const [connectionUrl, setConnectionUrl] = useState(loadedContext.connectionUrl);
const [input, setInput] = useState(loadedContext.input);
const [name, setName] = useState(loadedContext.name);
const [messages, setMessages] = useState(loadedContext.messages);
const [triggerNext, setTriggerNext] = useState(false);
const actions: IActions = useMemo(() => ({
setConnectionUrl: (connectionUrl) => setSettings(s => ({ ...s, connectionUrl })),
setInput: (input) => setSettings(s => ({ ...s, input })),
setConnectionUrl,
setInput,
setName,
setMessages: (newMessages) => setMessages(newMessages.slice()),
addMessage: (content, role, triggerNext = false) => {
setMessages(messages => [
...messages,
{ role, content }
]);
setTriggerNext(triggerNext);
},
editMessage: (index, content) => setMessages(messages => (
messages.map((m, i) => ({ ...m, content: i === index ? content : m.content }))
)),
deleteMessage: (index) => setMessages(messages =>
messages.filter((_, i) => i !== index)
),
}), []);
const value = useMemo(() => ({ ...settings, ...actions }), [settings, actions])
useEffect(() => void (async () => {
if (triggerNext) {
setTriggerNext(false);
const prompt = await compilePrompt(messages);
const messageId = messages.length;
let text = '';
actions.addMessage('', 'assistant');
for await (const chunk of generate(connectionUrl, prompt, { temperature: 1.0 })) {
text += chunk;
actions.editMessage(messageId, text);
}
}
})(), [triggerNext, messages]);
const rawContext: IContext = {
connectionUrl,
input,
name,
messages,
};
const context = useMemo(() => rawContext, Object.values(rawContext));
useEffect(() => {
saveContext(context);
}, [context]);
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
return (
<GlobalContext.Provider value={value}>

View File

@ -0,0 +1,62 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
interface IGenerationSettings {
temperature: number;
// TODO
}
export async function* generate(host: string, prompt: string, generationSetings: IGenerationSettings) {
const sse = new SSE(`${host}/api/extra/generate/stream`, {
payload: JSON.stringify({
...generationSetings,
prompt,
stop_sequence: ['\n'],
}),
});
const messages: string[] = [];
const messageLock = new Lock();
let end = false;
sse.addEventListener('message', (e) => {
if (e.data) {
{
const { token, finish_reason } = JSON.parse(e.data);
messages.push(token);
if (finish_reason && finish_reason !== 'null') {
end = true;
}
}
}
messageLock.release();
});
const handleEnd = () => {
end = true;
messageLock.release();
};
sse.addEventListener('error', handleEnd);
sse.addEventListener('abort', handleEnd);
sse.addEventListener('readystatechange', (e) => {
if (e.readyState === SSE.CLOSED) handleEnd();
});
while (!end || messages.length) {
while (messages.length > 0) {
const message = messages.shift();
if (message != null) {
try {
yield message;
} catch { }
}
}
if (!end) {
await messageLock.wait();
}
}
sse.close();
}

28
src/games/ai/messages.ts Normal file
View File

@ -0,0 +1,28 @@
import { Template } from "@huggingface/jinja";
import type { IContext } from "./context";
export interface IMessage {
role: 'user' | 'assistant' | 'system';
content: string;
displayContent?: string;
}
export const applyChatTemplate = (messages: IMessage[], templateString: string, eosToken = '</s>') => {
const template = new Template(templateString);
const prompt = template.render({
messages,
bos_token: '',
eos_token: eosToken,
add_generation_prompt: true,
});
return prompt;
}
export const compilePrompt = async (messages: IMessage[]): Promise<string> => {
// TODO chat template
// TODO tokenize
return applyChatTemplate(messages, "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\\n' }}{% endfor %}")
}

View File

@ -3,8 +3,10 @@
}
:root {
--color: #DCDCD2;
--backgroundColor: #333333;
--color: #DCDCD2;
--italicColor: #AFAFAF;
--quoteColor: #D4E5FF;
}
body {
@ -25,6 +27,7 @@ body {
width: 100%;
max-width: 1200px;
height: 100%;
max-height: 100dvh;
>.header {
display: flex;
@ -36,13 +39,33 @@ body {
>.chat {
display: flex;
flex-direction: row;
flex-direction: column;
height: 100%;
background-color: gray;
flex-grow: 1;
width: 100%;
max-width: 100%;
overflow-x: hidden;
overflow-y: auto;
scrollbar-width: thin;
scrollbar-color: var(--color) transparent;
border: 1px solid var(--color);
>.message {
width: 100%;
padding: 12px;
>.header {
display: flex;
flex-direction: row;
justify-content: space-between;
>.name {
font-weight: bold;
}
}
>.content {
white-space: pre-wrap;
}
}
}
>.chat-input {
@ -50,10 +73,10 @@ body {
flex-direction: row;
height: auto;
min-height: 48px;
background-color: green;
width: 100%;
border: 1px solid var(--color);
textarea {
>textarea {
color: var(--color);
background-color: var(--backgroundColor);
font-size: 1em;