Working generation
This commit is contained in:
parent
a2faacf130
commit
9df2993b38
|
|
@ -8,6 +8,7 @@
|
|||
"bake": "bun build/build.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@huggingface/jinja": "0.3.1",
|
||||
"@inquirer/select": "2.3.10",
|
||||
"classnames": "2.5.1",
|
||||
"preact": "10.22.0"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
export default class Lock {
|
||||
private doUnlock: (() => void) | null = null;
|
||||
private lockPromise: Promise<void> | null = null;
|
||||
|
||||
get locked() {
|
||||
return this.lockPromise !== null;
|
||||
}
|
||||
|
||||
async wait(): Promise<void> {
|
||||
if (!this.lockPromise) {
|
||||
this.lockPromise = new Promise(resolve => this.doUnlock = resolve);
|
||||
}
|
||||
|
||||
return this.lockPromise;
|
||||
}
|
||||
|
||||
release(): void {
|
||||
this.doUnlock?.();
|
||||
this.doUnlock = null;
|
||||
this.lockPromise = null;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,270 @@
|
|||
export interface ISSEOptions {
|
||||
headers?: Record<string, string>;
|
||||
payload?: string;
|
||||
method?: string;
|
||||
withCredentials?: boolean;
|
||||
debug?: boolean;
|
||||
start?: boolean;
|
||||
}
|
||||
|
||||
interface SSEEvent extends Event {
|
||||
id?: string;
|
||||
source?: SSE;
|
||||
readyState?: number;
|
||||
data?: string;
|
||||
}
|
||||
|
||||
interface SSEEventListener {
|
||||
(e: SSEEvent): void;
|
||||
}
|
||||
|
||||
type OnEvent = `on${string}`;
|
||||
type EventType = 'message' | 'error' | 'readystatechange' | 'abort' | 'open';
|
||||
|
||||
const FIELD_SEPARATOR = ':';
|
||||
|
||||
export default class SSE {
|
||||
public static INITIALIZING = -1;
|
||||
public static CONNECTING = 0;
|
||||
public static OPEN = 1;
|
||||
public static CLOSED = 2;
|
||||
|
||||
private headers: Record<string, string>;
|
||||
private payload: string;
|
||||
private method: string;
|
||||
private withCredentials: boolean;
|
||||
private debug: boolean;
|
||||
private listeners: Record<string, SSEEventListener[]> = {};
|
||||
private xhr: XMLHttpRequest | null = null;
|
||||
private readyState: number = SSE.INITIALIZING;
|
||||
private progress = 0;
|
||||
private chunk = '';
|
||||
|
||||
[key: OnEvent]: SSEEventListener | undefined;
|
||||
|
||||
constructor(private url: string, options: ISSEOptions = {}) {
|
||||
this.headers = options.headers || {};
|
||||
this.payload = options.payload !== undefined ? options.payload : '';
|
||||
this.method = options.method || (this.payload ? 'POST' : 'GET');
|
||||
this.withCredentials = !!options.withCredentials;
|
||||
this.debug = !!options.debug;
|
||||
|
||||
|
||||
if (options.start === undefined || options.start) {
|
||||
this.stream();
|
||||
}
|
||||
}
|
||||
|
||||
addEventListener(type: EventType, listener: SSEEventListener) {
|
||||
if (this.listeners[type] === undefined) {
|
||||
this.listeners[type] = [];
|
||||
}
|
||||
|
||||
if (this.listeners[type].indexOf(listener) === -1) {
|
||||
this.listeners[type].push(listener);
|
||||
}
|
||||
}
|
||||
|
||||
removeEventListener(type: EventType, listener: SSEEventListener) {
|
||||
if (this.listeners[type] === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
const filtered: SSEEventListener[] = [];
|
||||
this.listeners[type].forEach((element) => {
|
||||
if (element !== listener) {
|
||||
filtered.push(element);
|
||||
}
|
||||
});
|
||||
if (filtered.length === 0) {
|
||||
delete this.listeners[type];
|
||||
} else {
|
||||
this.listeners[type] = filtered;
|
||||
}
|
||||
}
|
||||
|
||||
dispatchEvent(e: SSEEvent | null) {
|
||||
if (!e) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (this.debug) {
|
||||
console.debug(e);
|
||||
}
|
||||
|
||||
e.source = this;
|
||||
|
||||
const onHandler: OnEvent = `on${e.type}`;
|
||||
if (this.hasOwnProperty(onHandler) && this[onHandler]) {
|
||||
this[onHandler].call(this, e);
|
||||
if (e.defaultPrevented) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.listeners[e.type]) {
|
||||
return this.listeners[e.type].every((callback) => {
|
||||
callback(e);
|
||||
return !e.defaultPrevented;
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private _setReadyState(state: number) {
|
||||
const event: SSEEvent = new CustomEvent('readystatechange');
|
||||
event.readyState = state;
|
||||
this.readyState = state;
|
||||
this.dispatchEvent(event);
|
||||
}
|
||||
|
||||
private _onStreamFailure = (e: ProgressEvent<XMLHttpRequestEventTarget>) => {
|
||||
const event: SSEEvent = new CustomEvent('error');
|
||||
if (e.currentTarget instanceof XMLHttpRequest) {
|
||||
event.data = e.currentTarget.response;
|
||||
}
|
||||
this.dispatchEvent(event);
|
||||
this.close();
|
||||
}
|
||||
|
||||
private _onStreamAbort = () => {
|
||||
this.dispatchEvent(new CustomEvent('abort'));
|
||||
this.close();
|
||||
}
|
||||
|
||||
private _onStreamProgress = (e: ProgressEvent<XMLHttpRequestEventTarget>) => {
|
||||
if (!this.xhr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.xhr.status !== 200) {
|
||||
this._onStreamFailure(e);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.readyState === SSE.CONNECTING) {
|
||||
this.dispatchEvent(new CustomEvent('open'));
|
||||
this._setReadyState(SSE.OPEN);
|
||||
}
|
||||
|
||||
const data = this.xhr.responseText.substring(this.progress);
|
||||
|
||||
this.progress += data.length;
|
||||
const parts = (this.chunk + data).split(/(\r\n|\r|\n){2}/g);
|
||||
|
||||
// we assume that the last chunk can be incomplete because of buffering or other network effects
|
||||
// so we always save the last part to merge it with the next incoming packet
|
||||
const lastPart = parts.pop();
|
||||
parts.forEach((part) => {
|
||||
if (part.trim().length > 0) {
|
||||
this.dispatchEvent(this._parseEventChunk(part));
|
||||
}
|
||||
});
|
||||
this.chunk = lastPart ?? '';
|
||||
}
|
||||
|
||||
private _onStreamLoaded = (e: ProgressEvent<XMLHttpRequestEventTarget>) => {
|
||||
this._onStreamProgress(e);
|
||||
|
||||
// Parse the last chunk.
|
||||
this.dispatchEvent(this._parseEventChunk(this.chunk));
|
||||
this.chunk = '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a received SSE event chunk into a constructed event object.
|
||||
*
|
||||
* Reference: https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage
|
||||
*/
|
||||
private _parseEventChunk(chunk: string) {
|
||||
if (!chunk || chunk.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (this.debug) {
|
||||
console.debug(chunk);
|
||||
}
|
||||
|
||||
const e: Record<string, string> = { 'id': '', 'retry': '', 'data': '', 'event': 'message' };
|
||||
chunk.split(/\n|\r\n|\r/).forEach((line) => {
|
||||
line = line.trimEnd();
|
||||
const index = line.indexOf(FIELD_SEPARATOR);
|
||||
if (index <= 0) {
|
||||
// Line was either empty, or started with a separator and is a comment.
|
||||
// Either way, ignore.
|
||||
return;
|
||||
}
|
||||
|
||||
const field = line.substring(0, index);
|
||||
if (!(field in e)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// only first whitespace should be trimmed
|
||||
const skip = (line[index + 1] === ' ') ? 2 : 1;
|
||||
const value = line.substring(index + skip);
|
||||
|
||||
// consecutive 'data' is concatenated with newlines
|
||||
if (field === 'data' && e[field] !== null) {
|
||||
e['data'] += "\n" + value;
|
||||
} else {
|
||||
e[field] = value;
|
||||
}
|
||||
});
|
||||
|
||||
const event: SSEEvent = new CustomEvent(e.event);
|
||||
event.data = e.data || '';
|
||||
event.id = e.id;
|
||||
return event;
|
||||
};
|
||||
|
||||
private _checkStreamClosed = () => {
|
||||
if (!this.xhr) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.xhr.readyState === XMLHttpRequest.DONE) {
|
||||
this._setReadyState(SSE.CLOSED);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* starts the streaming
|
||||
*/
|
||||
stream() {
|
||||
if (this.xhr) {
|
||||
// Already connected.
|
||||
return;
|
||||
}
|
||||
|
||||
this._setReadyState(SSE.CONNECTING);
|
||||
|
||||
this.xhr = new XMLHttpRequest();
|
||||
this.xhr.addEventListener('progress', this._onStreamProgress);
|
||||
this.xhr.addEventListener('load', this._onStreamLoaded);
|
||||
this.xhr.addEventListener('readystatechange', this._checkStreamClosed);
|
||||
this.xhr.addEventListener('error', this._onStreamFailure);
|
||||
this.xhr.addEventListener('abort', this._onStreamAbort);
|
||||
this.xhr.open(this.method, this.url);
|
||||
for (const header in this.headers) {
|
||||
this.xhr.setRequestHeader(header, this.headers[header]);
|
||||
}
|
||||
this.xhr.withCredentials = this.withCredentials;
|
||||
this.xhr.send(this.payload);
|
||||
};
|
||||
|
||||
/**
|
||||
* closes the stream
|
||||
* @type Close
|
||||
*/
|
||||
close() {
|
||||
if (this.readyState === SSE.CLOSED) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.xhr?.abort();
|
||||
this.xhr = null;
|
||||
this._setReadyState(SSE.CLOSED);
|
||||
};
|
||||
}
|
||||
|
|
@ -1,7 +1,27 @@
|
|||
import { useContext, useEffect, useRef } from "preact/hooks";
|
||||
import { GlobalContext } from "../context";
|
||||
import { Message } from "./message";
|
||||
|
||||
export const Chat = () => {
|
||||
const { messages } = useContext(GlobalContext);
|
||||
const chatRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const lastMessageContent = messages.at(-1)?.displayContent ?? messages.at(-1)?.content;
|
||||
|
||||
useEffect(() => {
|
||||
if (chatRef.current) {
|
||||
chatRef.current.scrollTo({
|
||||
top: chatRef.current.scrollHeight,
|
||||
behavior: 'smooth',
|
||||
});
|
||||
}
|
||||
}, [messages.length, lastMessageContent]);
|
||||
|
||||
return (
|
||||
<div class="chat">
|
||||
Chat
|
||||
<div class="chat" ref={chatRef}>
|
||||
{messages.map((m, i) => (
|
||||
<Message message={m} key={i} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ import { useCallback, useContext } from "preact/hooks";
|
|||
import { GlobalContext } from "../context";
|
||||
|
||||
export const Input = () => {
|
||||
const { input, setInput } = useContext(GlobalContext);
|
||||
const { input, setInput, addMessage } = useContext(GlobalContext);
|
||||
|
||||
console.log({input});
|
||||
|
||||
const handleChange = useCallback((e: Event) => {
|
||||
if (e.target instanceof HTMLTextAreaElement) {
|
||||
|
|
@ -10,9 +12,12 @@ export const Input = () => {
|
|||
}
|
||||
}, []);
|
||||
|
||||
const handleSend = useCallback(() => {
|
||||
console.log('Send:', input);
|
||||
setInput('');
|
||||
const handleSend = useCallback(async () => {
|
||||
const newInput = input.trim();
|
||||
if (newInput) {
|
||||
addMessage(newInput, 'user', true);
|
||||
setInput('');
|
||||
}
|
||||
}, [input]);
|
||||
|
||||
const handleKeyDown = useCallback((e: KeyboardEvent) => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
import { useContext, useMemo } from "preact/hooks";
|
||||
import type { IMessage } from "../messages";
|
||||
import { GlobalContext } from "../context";
|
||||
|
||||
interface IProps {
|
||||
message: IMessage;
|
||||
}
|
||||
|
||||
export const Message = ({ message }: IProps) => {
|
||||
const { name } = useContext(GlobalContext);
|
||||
|
||||
return (
|
||||
<div class="message">
|
||||
<div class="header">
|
||||
<div class="name">
|
||||
{message.role === 'user' ? name : '---'}
|
||||
</div>
|
||||
<div class="buttons">
|
||||
</div>
|
||||
</div>
|
||||
<div class="content">
|
||||
{message.displayContent ?? message.content}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,32 +1,114 @@
|
|||
import { createContext } from "preact";
|
||||
import { useMemo, useState } from "preact/hooks";
|
||||
import { useEffect, useMemo, useState } from "preact/hooks";
|
||||
import { compilePrompt, type IMessage } from "./messages";
|
||||
import { generate } from "./generation";
|
||||
|
||||
export interface ISettings {
|
||||
export interface IContext {
|
||||
connectionUrl: string;
|
||||
input: string;
|
||||
name: string;
|
||||
messages: IMessage[];
|
||||
}
|
||||
|
||||
export interface IActions {
|
||||
setConnectionUrl: (url: string) => void;
|
||||
setInput: (url: string) => void;
|
||||
setName: (name: string) => void;
|
||||
|
||||
setMessages: (messages: IMessage[]) => void;
|
||||
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
|
||||
editMessage: (index: number, content: string) => void;
|
||||
deleteMessage: (index: number) => void;
|
||||
}
|
||||
|
||||
export type IGlobalContext = ISettings & IActions;
|
||||
export type IGlobalContext = IContext & IActions;
|
||||
|
||||
const SAVE_KEY = 'ai_game_save_state';
|
||||
|
||||
const saveContext = (ctx: IContext) => {
|
||||
localStorage.setItem(SAVE_KEY, JSON.stringify(ctx));
|
||||
}
|
||||
|
||||
const loadContext = (): IContext => {
|
||||
const defaultContext: IContext = {
|
||||
connectionUrl: 'http://192.168.10.102:5001',
|
||||
input: '',
|
||||
name: 'Maya',
|
||||
messages: [],
|
||||
};
|
||||
|
||||
let loadedContext: Partial<IContext> = {};
|
||||
|
||||
try {
|
||||
const json = localStorage.getItem(SAVE_KEY);
|
||||
if (json) {
|
||||
loadedContext = JSON.parse(json);
|
||||
}
|
||||
} catch { }
|
||||
|
||||
return { ...defaultContext, ...loadedContext };
|
||||
}
|
||||
|
||||
export const GlobalContext = createContext<IGlobalContext>({} as IGlobalContext);
|
||||
|
||||
export const GlobalContextProvider = ({ children }: { children?: any }) => {
|
||||
const [settings, setSettings] = useState<ISettings>({
|
||||
connectionUrl: 'http://192.168.10.102:5001',
|
||||
input: '',
|
||||
});
|
||||
const loadedContext = useMemo(() => loadContext(), []);
|
||||
const [connectionUrl, setConnectionUrl] = useState(loadedContext.connectionUrl);
|
||||
const [input, setInput] = useState(loadedContext.input);
|
||||
const [name, setName] = useState(loadedContext.name);
|
||||
const [messages, setMessages] = useState(loadedContext.messages);
|
||||
const [triggerNext, setTriggerNext] = useState(false);
|
||||
|
||||
const actions: IActions = useMemo(() => ({
|
||||
setConnectionUrl: (connectionUrl) => setSettings(s => ({ ...s, connectionUrl })),
|
||||
setInput: (input) => setSettings(s => ({ ...s, input })),
|
||||
setConnectionUrl,
|
||||
setInput,
|
||||
setName,
|
||||
|
||||
setMessages: (newMessages) => setMessages(newMessages.slice()),
|
||||
addMessage: (content, role, triggerNext = false) => {
|
||||
setMessages(messages => [
|
||||
...messages,
|
||||
{ role, content }
|
||||
]);
|
||||
setTriggerNext(triggerNext);
|
||||
},
|
||||
editMessage: (index, content) => setMessages(messages => (
|
||||
messages.map((m, i) => ({ ...m, content: i === index ? content : m.content }))
|
||||
)),
|
||||
deleteMessage: (index) => setMessages(messages =>
|
||||
messages.filter((_, i) => i !== index)
|
||||
),
|
||||
}), []);
|
||||
|
||||
const value = useMemo(() => ({ ...settings, ...actions }), [settings, actions])
|
||||
useEffect(() => void (async () => {
|
||||
if (triggerNext) {
|
||||
setTriggerNext(false);
|
||||
|
||||
const prompt = await compilePrompt(messages);
|
||||
const messageId = messages.length;
|
||||
let text = '';
|
||||
actions.addMessage('', 'assistant');
|
||||
for await (const chunk of generate(connectionUrl, prompt, { temperature: 1.0 })) {
|
||||
text += chunk;
|
||||
actions.editMessage(messageId, text);
|
||||
}
|
||||
}
|
||||
})(), [triggerNext, messages]);
|
||||
|
||||
const rawContext: IContext = {
|
||||
connectionUrl,
|
||||
input,
|
||||
name,
|
||||
messages,
|
||||
};
|
||||
|
||||
const context = useMemo(() => rawContext, Object.values(rawContext));
|
||||
|
||||
useEffect(() => {
|
||||
saveContext(context);
|
||||
}, [context]);
|
||||
|
||||
const value = useMemo(() => ({ ...context, ...actions }), [context, actions])
|
||||
|
||||
return (
|
||||
<GlobalContext.Provider value={value}>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
import Lock from "@common/lock";
|
||||
import SSE from "@common/sse";
|
||||
|
||||
interface IGenerationSettings {
|
||||
temperature: number;
|
||||
// TODO
|
||||
}
|
||||
|
||||
export async function* generate(host: string, prompt: string, generationSetings: IGenerationSettings) {
|
||||
const sse = new SSE(`${host}/api/extra/generate/stream`, {
|
||||
payload: JSON.stringify({
|
||||
...generationSetings,
|
||||
prompt,
|
||||
stop_sequence: ['\n'],
|
||||
}),
|
||||
});
|
||||
|
||||
const messages: string[] = [];
|
||||
const messageLock = new Lock();
|
||||
let end = false;
|
||||
|
||||
sse.addEventListener('message', (e) => {
|
||||
if (e.data) {
|
||||
{
|
||||
const { token, finish_reason } = JSON.parse(e.data);
|
||||
messages.push(token);
|
||||
|
||||
if (finish_reason && finish_reason !== 'null') {
|
||||
end = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
messageLock.release();
|
||||
});
|
||||
|
||||
const handleEnd = () => {
|
||||
end = true;
|
||||
messageLock.release();
|
||||
};
|
||||
|
||||
sse.addEventListener('error', handleEnd);
|
||||
sse.addEventListener('abort', handleEnd);
|
||||
sse.addEventListener('readystatechange', (e) => {
|
||||
if (e.readyState === SSE.CLOSED) handleEnd();
|
||||
});
|
||||
|
||||
while (!end || messages.length) {
|
||||
while (messages.length > 0) {
|
||||
const message = messages.shift();
|
||||
if (message != null) {
|
||||
try {
|
||||
yield message;
|
||||
} catch { }
|
||||
}
|
||||
}
|
||||
if (!end) {
|
||||
await messageLock.wait();
|
||||
}
|
||||
}
|
||||
|
||||
sse.close();
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import { Template } from "@huggingface/jinja";
|
||||
import type { IContext } from "./context";
|
||||
|
||||
export interface IMessage {
|
||||
role: 'user' | 'assistant' | 'system';
|
||||
content: string;
|
||||
displayContent?: string;
|
||||
}
|
||||
|
||||
export const applyChatTemplate = (messages: IMessage[], templateString: string, eosToken = '</s>') => {
|
||||
const template = new Template(templateString);
|
||||
|
||||
const prompt = template.render({
|
||||
messages,
|
||||
bos_token: '',
|
||||
eos_token: eosToken,
|
||||
add_generation_prompt: true,
|
||||
});
|
||||
|
||||
return prompt;
|
||||
}
|
||||
|
||||
export const compilePrompt = async (messages: IMessage[]): Promise<string> => {
|
||||
// TODO chat template
|
||||
// TODO tokenize
|
||||
|
||||
return applyChatTemplate(messages, "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + '\\n' }}{% endfor %}")
|
||||
}
|
||||
|
|
@ -3,8 +3,10 @@
|
|||
}
|
||||
|
||||
:root {
|
||||
--color: #DCDCD2;
|
||||
--backgroundColor: #333333;
|
||||
--color: #DCDCD2;
|
||||
--italicColor: #AFAFAF;
|
||||
--quoteColor: #D4E5FF;
|
||||
}
|
||||
|
||||
body {
|
||||
|
|
@ -25,6 +27,7 @@ body {
|
|||
width: 100%;
|
||||
max-width: 1200px;
|
||||
height: 100%;
|
||||
max-height: 100dvh;
|
||||
|
||||
>.header {
|
||||
display: flex;
|
||||
|
|
@ -36,13 +39,33 @@ body {
|
|||
|
||||
>.chat {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
background-color: gray;
|
||||
flex-grow: 1;
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
overflow-x: hidden;
|
||||
overflow-y: auto;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: var(--color) transparent;
|
||||
border: 1px solid var(--color);
|
||||
|
||||
>.message {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
|
||||
>.header {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: space-between;
|
||||
|
||||
>.name {
|
||||
font-weight: bold;
|
||||
}
|
||||
}
|
||||
>.content {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
>.chat-input {
|
||||
|
|
@ -50,10 +73,10 @@ body {
|
|||
flex-direction: row;
|
||||
height: auto;
|
||||
min-height: 48px;
|
||||
background-color: green;
|
||||
width: 100%;
|
||||
border: 1px solid var(--color);
|
||||
|
||||
textarea {
|
||||
>textarea {
|
||||
color: var(--color);
|
||||
background-color: var(--backgroundColor);
|
||||
font-size: 1em;
|
||||
|
|
|
|||
Loading…
Reference in New Issue