1
0
Fork 0

AI Story: add instruct format selector

This commit is contained in:
Pabloader 2024-11-02 18:25:19 +00:00
parent eed4f492cc
commit a4ce47d9d8
6 changed files with 64 additions and 29 deletions

View File

@ -22,13 +22,13 @@
}
textarea,
input {
input,
select {
color: var(--color);
border: var(--border);
background-color: var(--backgroundColorDark);
font-size: 1em;
font-family: sans-serif;
appearance: none;
outline: none;
}
@ -106,7 +106,8 @@ body {
flex-direction: row;
height: auto;
width: 100%;
> textarea {
>textarea {
min-height: 48px;
resize: none;
background-color: var(--backgroundColor);

View File

@ -6,13 +6,20 @@
width: 100%;
border: var(--border);
>input {
&.valid {
background-color: var(--green);
}
.valid {
background-color: var(--green);
}
&.invalid {
background-color: var(--red);
.invalid {
background-color: var(--red);
}
.inputs {
display: flex;
flex-direction: row;
select {
text-transform: capitalize;
}
}
@ -25,7 +32,7 @@
}
.modalTitle {
margin: 0;
margin: 0;
}
.scrollPane {

View File

@ -2,7 +2,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
import { useBool } from "@common/hooks/useBool";
import { Modal } from "@common/components/modal/modal";
import { StateContext } from "../../contexts/state";
import { Instruct, StateContext } from "../../contexts/state";
import { LLMContext } from "../../contexts/llm";
import { MiniChat } from "../minichat/minichat";
import { AutoTextarea } from "../autoTextarea";
@ -13,8 +13,8 @@ import { Ace } from "../ace";
export const Header = () => {
const { getContextLength } = useContext(LLMContext);
const {
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords,
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords,
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
} = useContext(StateContext);
const [urlValid, setUrlValid] = useState(false);
const [urlEditing, setUrlEditing] = useState(false);
@ -66,11 +66,21 @@ export const Header = () => {
return (
<div class={styles.header}>
<input value={connectionUrl}
onInput={setConnectionUrl}
onFocus={handleFocusUrl}
onBlur={handleBlurUrl}
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid} />
<div class={styles.inputs}>
<input value={connectionUrl}
onInput={setConnectionUrl}
onFocus={handleFocusUrl}
onBlur={handleBlurUrl}
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid}
/>
<select value={instruct} onChange={setInstruct}>
{Object.entries(Instruct).map(([label, value]) => (
<option value={value} key={value}>
{label.toLowerCase()}
</option>
))}
</select>
</div>
<div class={styles.buttons}>
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
🌍
@ -86,7 +96,11 @@ export const Header = () => {
</div>
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
<h3 class={styles.modalTitle}>Lore Editor</h3>
<AutoTextarea value={lore} onInput={setLore} />
<AutoTextarea
value={lore}
onInput={setLore}
placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers."
/>
</Modal>
<Modal open={promptsOpen.value} onClose={promptsOpen.setFalse}>
<h3 class={styles.modalTitle}>Prompts Editor</h3>
@ -99,6 +113,7 @@ export const Header = () => {
<hr />
<h4 class={styles.modalTitle}>Banned phrases</h4>
<AutoTextarea
placeholder="Each phrase on separate line"
value={bannedWordsInput}
onInput={handleSetBannedWords}
onBlur={handleBlurBannedWords}

View File

@ -1 +0,0 @@
export const LLAMA_TEMPLATE = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`;

View File

@ -1,6 +1,5 @@
import Lock from "@common/lock";
import SSE from "@common/sse";
import { LLAMA_TEMPLATE } from "../const";
import { createContext } from "preact";
import { useContext, useEffect, useMemo } from "preact/hooks";
import { MessageTools, type IMessage } from "../messages";
@ -58,7 +57,7 @@ export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
export const LLMContextProvider = ({ children }: { children?: any }) => {
const {
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords,
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
setTriggerNext, addMessage, editMessage,
} = useContext(StateContext);
const generating = useBool(false);
@ -74,13 +73,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}, [userPrompt]);
const actions: IActions = useMemo(() => ({
applyChatTemplate: (messages: ITemplateMessage[], templateString: string, eosToken = '</s>') => {
applyChatTemplate: (messages: ITemplateMessage[], templateString: string) => {
const template = new Template(templateString);
const prompt = template.render({
messages,
bos_token: '',
eos_token: eosToken,
add_generation_prompt: true,
});
@ -160,7 +157,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}
}
const prompt = actions.applyChatTemplate(templateMessages, LLAMA_TEMPLATE);
const prompt = actions.applyChatTemplate(templateMessages, instruct);
return {
prompt,
isContinue,
@ -269,7 +266,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
return 0;
},
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords]);
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
useEffect(() => void (async () => {
if (triggerNext && !generating.value) {

View File

@ -6,8 +6,9 @@ import { useInputState } from "@common/hooks/useInputState";
interface IContext {
connectionUrl: string;
input: string;
lore: string;
instruct: string;
systemPrompt: string;
lore: string;
userPrompt: string;
bannedWords: string[];
messages: IMessage[];
@ -17,6 +18,7 @@ interface IContext {
interface IActions {
setConnectionUrl: (url: string | Event) => void;
setInput: (url: string | Event) => void;
setInstruct: (template: string | Event) => void;
setLore: (lore: string | Event) => void;
setSystemPrompt: (prompt: string | Event) => void;
setUserPrompt: (prompt: string | Event) => void;
@ -35,6 +37,16 @@ interface IActions {
const SAVE_KEY = 'ai_game_save_state';
export enum Instruct {
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`,
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`,
};
export const saveContext = (context: IContext) => {
const contextToSave: Partial<IContext> = { ...context };
delete contextToSave.triggerNext;
@ -44,8 +56,9 @@ export const saveContext = (context: IContext) => {
export const loadContext = (): IContext => {
const defaultContext: IContext = {
connectionUrl: 'http://192.168.10.102:5001',
connectionUrl: 'http://localhost:5001',
input: '',
instruct: Instruct.LLAMA,
systemPrompt: 'You are creative writer. Write a story based on the world description below.',
lore: '',
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
@ -75,6 +88,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const loadedContext = useMemo(() => loadContext(), []);
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl);
const [input, setInput] = useInputState(loadedContext.input);
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
const [lore, setLore] = useInputState(loadedContext.lore);
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
@ -86,6 +100,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const actions: IActions = useMemo(() => ({
setConnectionUrl,
setInput,
setInstruct,
setSystemPrompt,
setUserPrompt,
setLore,
@ -162,6 +177,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
const rawContext: IContext = {
connectionUrl,
input,
instruct,
systemPrompt,
lore,
userPrompt,