AI Story: add instruct format selector
This commit is contained in:
parent
eed4f492cc
commit
a4ce47d9d8
|
|
@ -22,13 +22,13 @@
|
|||
}
|
||||
|
||||
textarea,
|
||||
input {
|
||||
input,
|
||||
select {
|
||||
color: var(--color);
|
||||
border: var(--border);
|
||||
background-color: var(--backgroundColorDark);
|
||||
font-size: 1em;
|
||||
font-family: sans-serif;
|
||||
appearance: none;
|
||||
outline: none;
|
||||
}
|
||||
|
||||
|
|
@ -106,7 +106,8 @@ body {
|
|||
flex-direction: row;
|
||||
height: auto;
|
||||
width: 100%;
|
||||
> textarea {
|
||||
|
||||
>textarea {
|
||||
min-height: 48px;
|
||||
resize: none;
|
||||
background-color: var(--backgroundColor);
|
||||
|
|
|
|||
|
|
@ -6,13 +6,20 @@
|
|||
width: 100%;
|
||||
border: var(--border);
|
||||
|
||||
>input {
|
||||
&.valid {
|
||||
background-color: var(--green);
|
||||
}
|
||||
.valid {
|
||||
background-color: var(--green);
|
||||
}
|
||||
|
||||
&.invalid {
|
||||
background-color: var(--red);
|
||||
.invalid {
|
||||
background-color: var(--red);
|
||||
}
|
||||
|
||||
.inputs {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
|
||||
select {
|
||||
text-transform: capitalize;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -25,7 +32,7 @@
|
|||
}
|
||||
|
||||
.modalTitle {
|
||||
margin: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.scrollPane {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
|
|||
import { useBool } from "@common/hooks/useBool";
|
||||
import { Modal } from "@common/components/modal/modal";
|
||||
|
||||
import { StateContext } from "../../contexts/state";
|
||||
import { Instruct, StateContext } from "../../contexts/state";
|
||||
import { LLMContext } from "../../contexts/llm";
|
||||
import { MiniChat } from "../minichat/minichat";
|
||||
import { AutoTextarea } from "../autoTextarea";
|
||||
|
|
@ -13,8 +13,8 @@ import { Ace } from "../ace";
|
|||
export const Header = () => {
|
||||
const { getContextLength } = useContext(LLMContext);
|
||||
const {
|
||||
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords,
|
||||
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords,
|
||||
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
|
||||
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
|
||||
} = useContext(StateContext);
|
||||
const [urlValid, setUrlValid] = useState(false);
|
||||
const [urlEditing, setUrlEditing] = useState(false);
|
||||
|
|
@ -66,11 +66,21 @@ export const Header = () => {
|
|||
|
||||
return (
|
||||
<div class={styles.header}>
|
||||
<input value={connectionUrl}
|
||||
onInput={setConnectionUrl}
|
||||
onFocus={handleFocusUrl}
|
||||
onBlur={handleBlurUrl}
|
||||
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid} />
|
||||
<div class={styles.inputs}>
|
||||
<input value={connectionUrl}
|
||||
onInput={setConnectionUrl}
|
||||
onFocus={handleFocusUrl}
|
||||
onBlur={handleBlurUrl}
|
||||
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid}
|
||||
/>
|
||||
<select value={instruct} onChange={setInstruct}>
|
||||
{Object.entries(Instruct).map(([label, value]) => (
|
||||
<option value={value} key={value}>
|
||||
{label.toLowerCase()}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</div>
|
||||
<div class={styles.buttons}>
|
||||
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
|
||||
🌍
|
||||
|
|
@ -86,7 +96,11 @@ export const Header = () => {
|
|||
</div>
|
||||
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
|
||||
<h3 class={styles.modalTitle}>Lore Editor</h3>
|
||||
<AutoTextarea value={lore} onInput={setLore} />
|
||||
<AutoTextarea
|
||||
value={lore}
|
||||
onInput={setLore}
|
||||
placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers."
|
||||
/>
|
||||
</Modal>
|
||||
<Modal open={promptsOpen.value} onClose={promptsOpen.setFalse}>
|
||||
<h3 class={styles.modalTitle}>Prompts Editor</h3>
|
||||
|
|
@ -99,6 +113,7 @@ export const Header = () => {
|
|||
<hr />
|
||||
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
||||
<AutoTextarea
|
||||
placeholder="Each phrase on separate line"
|
||||
value={bannedWordsInput}
|
||||
onInput={handleSetBannedWords}
|
||||
onBlur={handleBlurBannedWords}
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
export const LLAMA_TEMPLATE = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`;
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
import Lock from "@common/lock";
|
||||
import SSE from "@common/sse";
|
||||
import { LLAMA_TEMPLATE } from "../const";
|
||||
import { createContext } from "preact";
|
||||
import { useContext, useEffect, useMemo } from "preact/hooks";
|
||||
import { MessageTools, type IMessage } from "../messages";
|
||||
|
|
@ -58,7 +57,7 @@ export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
|||
|
||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||
const {
|
||||
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords,
|
||||
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
|
||||
setTriggerNext, addMessage, editMessage,
|
||||
} = useContext(StateContext);
|
||||
const generating = useBool(false);
|
||||
|
|
@ -74,13 +73,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
}, [userPrompt]);
|
||||
|
||||
const actions: IActions = useMemo(() => ({
|
||||
applyChatTemplate: (messages: ITemplateMessage[], templateString: string, eosToken = '</s>') => {
|
||||
applyChatTemplate: (messages: ITemplateMessage[], templateString: string) => {
|
||||
const template = new Template(templateString);
|
||||
|
||||
const prompt = template.render({
|
||||
messages,
|
||||
bos_token: '',
|
||||
eos_token: eosToken,
|
||||
add_generation_prompt: true,
|
||||
});
|
||||
|
||||
|
|
@ -160,7 +157,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
}
|
||||
}
|
||||
|
||||
const prompt = actions.applyChatTemplate(templateMessages, LLAMA_TEMPLATE);
|
||||
const prompt = actions.applyChatTemplate(templateMessages, instruct);
|
||||
return {
|
||||
prompt,
|
||||
isContinue,
|
||||
|
|
@ -269,7 +266,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
|
||||
return 0;
|
||||
},
|
||||
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords]);
|
||||
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
|
||||
|
||||
useEffect(() => void (async () => {
|
||||
if (triggerNext && !generating.value) {
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ import { useInputState } from "@common/hooks/useInputState";
|
|||
interface IContext {
|
||||
connectionUrl: string;
|
||||
input: string;
|
||||
lore: string;
|
||||
instruct: string;
|
||||
systemPrompt: string;
|
||||
lore: string;
|
||||
userPrompt: string;
|
||||
bannedWords: string[];
|
||||
messages: IMessage[];
|
||||
|
|
@ -17,6 +18,7 @@ interface IContext {
|
|||
interface IActions {
|
||||
setConnectionUrl: (url: string | Event) => void;
|
||||
setInput: (url: string | Event) => void;
|
||||
setInstruct: (template: string | Event) => void;
|
||||
setLore: (lore: string | Event) => void;
|
||||
setSystemPrompt: (prompt: string | Event) => void;
|
||||
setUserPrompt: (prompt: string | Event) => void;
|
||||
|
|
@ -35,6 +37,16 @@ interface IActions {
|
|||
|
||||
const SAVE_KEY = 'ai_game_save_state';
|
||||
|
||||
export enum Instruct {
|
||||
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`,
|
||||
|
||||
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
||||
|
||||
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
|
||||
|
||||
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`,
|
||||
};
|
||||
|
||||
export const saveContext = (context: IContext) => {
|
||||
const contextToSave: Partial<IContext> = { ...context };
|
||||
delete contextToSave.triggerNext;
|
||||
|
|
@ -44,8 +56,9 @@ export const saveContext = (context: IContext) => {
|
|||
|
||||
export const loadContext = (): IContext => {
|
||||
const defaultContext: IContext = {
|
||||
connectionUrl: 'http://192.168.10.102:5001',
|
||||
connectionUrl: 'http://localhost:5001',
|
||||
input: '',
|
||||
instruct: Instruct.LLAMA,
|
||||
systemPrompt: 'You are creative writer. Write a story based on the world description below.',
|
||||
lore: '',
|
||||
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
|
||||
|
|
@ -75,6 +88,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
|||
const loadedContext = useMemo(() => loadContext(), []);
|
||||
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl);
|
||||
const [input, setInput] = useInputState(loadedContext.input);
|
||||
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
|
||||
const [lore, setLore] = useInputState(loadedContext.lore);
|
||||
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
|
||||
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
|
||||
|
|
@ -86,6 +100,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
|||
const actions: IActions = useMemo(() => ({
|
||||
setConnectionUrl,
|
||||
setInput,
|
||||
setInstruct,
|
||||
setSystemPrompt,
|
||||
setUserPrompt,
|
||||
setLore,
|
||||
|
|
@ -162,6 +177,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
|||
const rawContext: IContext = {
|
||||
connectionUrl,
|
||||
input,
|
||||
instruct,
|
||||
systemPrompt,
|
||||
lore,
|
||||
userPrompt,
|
||||
|
|
|
|||
Loading…
Reference in New Issue