1
0
Fork 0

AI Story: add instruct format selector

This commit is contained in:
Pabloader 2024-11-02 18:25:19 +00:00
parent eed4f492cc
commit a4ce47d9d8
6 changed files with 64 additions and 29 deletions

View File

@ -22,13 +22,13 @@
} }
textarea, textarea,
input { input,
select {
color: var(--color); color: var(--color);
border: var(--border); border: var(--border);
background-color: var(--backgroundColorDark); background-color: var(--backgroundColorDark);
font-size: 1em; font-size: 1em;
font-family: sans-serif; font-family: sans-serif;
appearance: none;
outline: none; outline: none;
} }
@ -106,6 +106,7 @@ body {
flex-direction: row; flex-direction: row;
height: auto; height: auto;
width: 100%; width: 100%;
>textarea { >textarea {
min-height: 48px; min-height: 48px;
resize: none; resize: none;

View File

@ -6,14 +6,21 @@
width: 100%; width: 100%;
border: var(--border); border: var(--border);
>input { .valid {
&.valid {
background-color: var(--green); background-color: var(--green);
} }
&.invalid { .invalid {
background-color: var(--red); background-color: var(--red);
} }
.inputs {
display: flex;
flex-direction: row;
select {
text-transform: capitalize;
}
} }
.buttons { .buttons {

View File

@ -2,7 +2,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
import { useBool } from "@common/hooks/useBool"; import { useBool } from "@common/hooks/useBool";
import { Modal } from "@common/components/modal/modal"; import { Modal } from "@common/components/modal/modal";
import { StateContext } from "../../contexts/state"; import { Instruct, StateContext } from "../../contexts/state";
import { LLMContext } from "../../contexts/llm"; import { LLMContext } from "../../contexts/llm";
import { MiniChat } from "../minichat/minichat"; import { MiniChat } from "../minichat/minichat";
import { AutoTextarea } from "../autoTextarea"; import { AutoTextarea } from "../autoTextarea";
@ -13,8 +13,8 @@ import { Ace } from "../ace";
export const Header = () => { export const Header = () => {
const { getContextLength } = useContext(LLMContext); const { getContextLength } = useContext(LLMContext);
const { const {
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
} = useContext(StateContext); } = useContext(StateContext);
const [urlValid, setUrlValid] = useState(false); const [urlValid, setUrlValid] = useState(false);
const [urlEditing, setUrlEditing] = useState(false); const [urlEditing, setUrlEditing] = useState(false);
@ -66,11 +66,21 @@ export const Header = () => {
return ( return (
<div class={styles.header}> <div class={styles.header}>
<div class={styles.inputs}>
<input value={connectionUrl} <input value={connectionUrl}
onInput={setConnectionUrl} onInput={setConnectionUrl}
onFocus={handleFocusUrl} onFocus={handleFocusUrl}
onBlur={handleBlurUrl} onBlur={handleBlurUrl}
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid} /> class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid}
/>
<select value={instruct} onChange={setInstruct}>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</select>
</div>
<div class={styles.buttons}> <div class={styles.buttons}>
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}> <button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
🌍 🌍
@ -86,7 +96,11 @@ export const Header = () => {
</div> </div>
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}> <Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
<h3 class={styles.modalTitle}>Lore Editor</h3> <h3 class={styles.modalTitle}>Lore Editor</h3>
<AutoTextarea value={lore} onInput={setLore} /> <AutoTextarea
value={lore}
onInput={setLore}
placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers."
/>
</Modal> </Modal>
<Modal open={promptsOpen.value} onClose={promptsOpen.setFalse}> <Modal open={promptsOpen.value} onClose={promptsOpen.setFalse}>
<h3 class={styles.modalTitle}>Prompts Editor</h3> <h3 class={styles.modalTitle}>Prompts Editor</h3>
@ -99,6 +113,7 @@ export const Header = () => {
<hr /> <hr />
<h4 class={styles.modalTitle}>Banned phrases</h4> <h4 class={styles.modalTitle}>Banned phrases</h4>
<AutoTextarea <AutoTextarea
placeholder="Each phrase on separate line"
value={bannedWordsInput} value={bannedWordsInput}
onInput={handleSetBannedWords} onInput={handleSetBannedWords}
onBlur={handleBlurBannedWords} onBlur={handleBlurBannedWords}

View File

@ -1 +0,0 @@
export const LLAMA_TEMPLATE = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`;

View File

@ -1,6 +1,5 @@
import Lock from "@common/lock"; import Lock from "@common/lock";
import SSE from "@common/sse"; import SSE from "@common/sse";
import { LLAMA_TEMPLATE } from "../const";
import { createContext } from "preact"; import { createContext } from "preact";
import { useContext, useEffect, useMemo } from "preact/hooks"; import { useContext, useEffect, useMemo } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages"; import { MessageTools, type IMessage } from "../messages";
@ -58,7 +57,7 @@ export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
export const LLMContextProvider = ({ children }: { children?: any }) => { export const LLMContextProvider = ({ children }: { children?: any }) => {
const { const {
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
setTriggerNext, addMessage, editMessage, setTriggerNext, addMessage, editMessage,
} = useContext(StateContext); } = useContext(StateContext);
const generating = useBool(false); const generating = useBool(false);
@ -74,13 +73,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}, [userPrompt]); }, [userPrompt]);
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
applyChatTemplate: (messages: ITemplateMessage[], templateString: string, eosToken = '</s>') => { applyChatTemplate: (messages: ITemplateMessage[], templateString: string) => {
const template = new Template(templateString); const template = new Template(templateString);
const prompt = template.render({ const prompt = template.render({
messages, messages,
bos_token: '',
eos_token: eosToken,
add_generation_prompt: true, add_generation_prompt: true,
}); });
@ -160,7 +157,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
} }
} }
const prompt = actions.applyChatTemplate(templateMessages, LLAMA_TEMPLATE); const prompt = actions.applyChatTemplate(templateMessages, instruct);
return { return {
prompt, prompt,
isContinue, isContinue,
@ -269,7 +266,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
return 0; return 0;
}, },
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords]); }), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
useEffect(() => void (async () => { useEffect(() => void (async () => {
if (triggerNext && !generating.value) { if (triggerNext && !generating.value) {

View File

@ -6,8 +6,9 @@ import { useInputState } from "@common/hooks/useInputState";
interface IContext { interface IContext {
connectionUrl: string; connectionUrl: string;
input: string; input: string;
lore: string; instruct: string;
systemPrompt: string; systemPrompt: string;
lore: string;
userPrompt: string; userPrompt: string;
bannedWords: string[]; bannedWords: string[];
messages: IMessage[]; messages: IMessage[];
@ -17,6 +18,7 @@ interface IContext {
interface IActions { interface IActions {
setConnectionUrl: (url: string | Event) => void; setConnectionUrl: (url: string | Event) => void;
setInput: (url: string | Event) => void; setInput: (url: string | Event) => void;
setInstruct: (template: string | Event) => void;
setLore: (lore: string | Event) => void; setLore: (lore: string | Event) => void;
setSystemPrompt: (prompt: string | Event) => void; setSystemPrompt: (prompt: string | Event) => void;
setUserPrompt: (prompt: string | Event) => void; setUserPrompt: (prompt: string | Event) => void;
@ -35,6 +37,16 @@ interface IActions {
const SAVE_KEY = 'ai_game_save_state'; const SAVE_KEY = 'ai_game_save_state';
export enum Instruct {
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`,
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`,
};
export const saveContext = (context: IContext) => { export const saveContext = (context: IContext) => {
const contextToSave: Partial<IContext> = { ...context }; const contextToSave: Partial<IContext> = { ...context };
delete contextToSave.triggerNext; delete contextToSave.triggerNext;
@ -44,8 +56,9 @@ export const saveContext = (context: IContext) => {
export const loadContext = (): IContext => { export const loadContext = (): IContext => {
const defaultContext: IContext = { const defaultContext: IContext = {
connectionUrl: 'http://192.168.10.102:5001', connectionUrl: 'http://localhost:5001',
input: '', input: '',
instruct: Instruct.LLAMA,
systemPrompt: 'You are creative writer. Write a story based on the world description below.', systemPrompt: 'You are creative writer. Write a story based on the world description below.',
lore: '', lore: '',
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }} userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
@ -75,6 +88,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const loadedContext = useMemo(() => loadContext(), []); const loadedContext = useMemo(() => loadContext(), []);
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl); const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl);
const [input, setInput] = useInputState(loadedContext.input); const [input, setInput] = useInputState(loadedContext.input);
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
const [lore, setLore] = useInputState(loadedContext.lore); const [lore, setLore] = useInputState(loadedContext.lore);
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
@ -86,6 +100,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const actions: IActions = useMemo(() => ({ const actions: IActions = useMemo(() => ({
setConnectionUrl, setConnectionUrl,
setInput, setInput,
setInstruct,
setSystemPrompt, setSystemPrompt,
setUserPrompt, setUserPrompt,
setLore, setLore,
@ -162,6 +177,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const rawContext: IContext = { const rawContext: IContext = {
connectionUrl, connectionUrl,
input, input,
instruct,
systemPrompt, systemPrompt,
lore, lore,
userPrompt, userPrompt,