AI Story: add instruct format selector
This commit is contained in:
parent
eed4f492cc
commit
a4ce47d9d8
|
|
@ -22,13 +22,13 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
textarea,
|
textarea,
|
||||||
input {
|
input,
|
||||||
|
select {
|
||||||
color: var(--color);
|
color: var(--color);
|
||||||
border: var(--border);
|
border: var(--border);
|
||||||
background-color: var(--backgroundColorDark);
|
background-color: var(--backgroundColorDark);
|
||||||
font-size: 1em;
|
font-size: 1em;
|
||||||
font-family: sans-serif;
|
font-family: sans-serif;
|
||||||
appearance: none;
|
|
||||||
outline: none;
|
outline: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -106,7 +106,8 @@ body {
|
||||||
flex-direction: row;
|
flex-direction: row;
|
||||||
height: auto;
|
height: auto;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
> textarea {
|
|
||||||
|
>textarea {
|
||||||
min-height: 48px;
|
min-height: 48px;
|
||||||
resize: none;
|
resize: none;
|
||||||
background-color: var(--backgroundColor);
|
background-color: var(--backgroundColor);
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,20 @@
|
||||||
width: 100%;
|
width: 100%;
|
||||||
border: var(--border);
|
border: var(--border);
|
||||||
|
|
||||||
>input {
|
.valid {
|
||||||
&.valid {
|
background-color: var(--green);
|
||||||
background-color: var(--green);
|
}
|
||||||
}
|
|
||||||
|
|
||||||
&.invalid {
|
.invalid {
|
||||||
background-color: var(--red);
|
background-color: var(--red);
|
||||||
|
}
|
||||||
|
|
||||||
|
.inputs {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
|
||||||
|
select {
|
||||||
|
text-transform: capitalize;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -25,7 +32,7 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
.modalTitle {
|
.modalTitle {
|
||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.scrollPane {
|
.scrollPane {
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho
|
||||||
import { useBool } from "@common/hooks/useBool";
|
import { useBool } from "@common/hooks/useBool";
|
||||||
import { Modal } from "@common/components/modal/modal";
|
import { Modal } from "@common/components/modal/modal";
|
||||||
|
|
||||||
import { StateContext } from "../../contexts/state";
|
import { Instruct, StateContext } from "../../contexts/state";
|
||||||
import { LLMContext } from "../../contexts/llm";
|
import { LLMContext } from "../../contexts/llm";
|
||||||
import { MiniChat } from "../minichat/minichat";
|
import { MiniChat } from "../minichat/minichat";
|
||||||
import { AutoTextarea } from "../autoTextarea";
|
import { AutoTextarea } from "../autoTextarea";
|
||||||
|
|
@ -13,8 +13,8 @@ import { Ace } from "../ace";
|
||||||
export const Header = () => {
|
export const Header = () => {
|
||||||
const { getContextLength } = useContext(LLMContext);
|
const { getContextLength } = useContext(LLMContext);
|
||||||
const {
|
const {
|
||||||
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords,
|
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
|
||||||
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords,
|
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
const [urlValid, setUrlValid] = useState(false);
|
const [urlValid, setUrlValid] = useState(false);
|
||||||
const [urlEditing, setUrlEditing] = useState(false);
|
const [urlEditing, setUrlEditing] = useState(false);
|
||||||
|
|
@ -66,11 +66,21 @@ export const Header = () => {
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div class={styles.header}>
|
<div class={styles.header}>
|
||||||
<input value={connectionUrl}
|
<div class={styles.inputs}>
|
||||||
onInput={setConnectionUrl}
|
<input value={connectionUrl}
|
||||||
onFocus={handleFocusUrl}
|
onInput={setConnectionUrl}
|
||||||
onBlur={handleBlurUrl}
|
onFocus={handleFocusUrl}
|
||||||
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid} />
|
onBlur={handleBlurUrl}
|
||||||
|
class={urlEditing ? '' : urlValid ? styles.valid : styles.invalid}
|
||||||
|
/>
|
||||||
|
<select value={instruct} onChange={setInstruct}>
|
||||||
|
{Object.entries(Instruct).map(([label, value]) => (
|
||||||
|
<option value={value} key={value}>
|
||||||
|
{label.toLowerCase()}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
<div class={styles.buttons}>
|
<div class={styles.buttons}>
|
||||||
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
|
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
|
||||||
🌍
|
🌍
|
||||||
|
|
@ -86,7 +96,11 @@ export const Header = () => {
|
||||||
</div>
|
</div>
|
||||||
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
|
<Modal open={loreOpen.value} onClose={loreOpen.setFalse}>
|
||||||
<h3 class={styles.modalTitle}>Lore Editor</h3>
|
<h3 class={styles.modalTitle}>Lore Editor</h3>
|
||||||
<AutoTextarea value={lore} onInput={setLore} />
|
<AutoTextarea
|
||||||
|
value={lore}
|
||||||
|
onInput={setLore}
|
||||||
|
placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers."
|
||||||
|
/>
|
||||||
</Modal>
|
</Modal>
|
||||||
<Modal open={promptsOpen.value} onClose={promptsOpen.setFalse}>
|
<Modal open={promptsOpen.value} onClose={promptsOpen.setFalse}>
|
||||||
<h3 class={styles.modalTitle}>Prompts Editor</h3>
|
<h3 class={styles.modalTitle}>Prompts Editor</h3>
|
||||||
|
|
@ -99,6 +113,7 @@ export const Header = () => {
|
||||||
<hr />
|
<hr />
|
||||||
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
||||||
<AutoTextarea
|
<AutoTextarea
|
||||||
|
placeholder="Each phrase on separate line"
|
||||||
value={bannedWordsInput}
|
value={bannedWordsInput}
|
||||||
onInput={handleSetBannedWords}
|
onInput={handleSetBannedWords}
|
||||||
onBlur={handleBlurBannedWords}
|
onBlur={handleBlurBannedWords}
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
export const LLAMA_TEMPLATE = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`;
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import Lock from "@common/lock";
|
import Lock from "@common/lock";
|
||||||
import SSE from "@common/sse";
|
import SSE from "@common/sse";
|
||||||
import { LLAMA_TEMPLATE } from "../const";
|
|
||||||
import { createContext } from "preact";
|
import { createContext } from "preact";
|
||||||
import { useContext, useEffect, useMemo } from "preact/hooks";
|
import { useContext, useEffect, useMemo } from "preact/hooks";
|
||||||
import { MessageTools, type IMessage } from "../messages";
|
import { MessageTools, type IMessage } from "../messages";
|
||||||
|
|
@ -58,7 +57,7 @@ export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
||||||
|
|
||||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
const {
|
const {
|
||||||
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords,
|
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
|
||||||
setTriggerNext, addMessage, editMessage,
|
setTriggerNext, addMessage, editMessage,
|
||||||
} = useContext(StateContext);
|
} = useContext(StateContext);
|
||||||
const generating = useBool(false);
|
const generating = useBool(false);
|
||||||
|
|
@ -74,13 +73,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}, [userPrompt]);
|
}, [userPrompt]);
|
||||||
|
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
applyChatTemplate: (messages: ITemplateMessage[], templateString: string, eosToken = '</s>') => {
|
applyChatTemplate: (messages: ITemplateMessage[], templateString: string) => {
|
||||||
const template = new Template(templateString);
|
const template = new Template(templateString);
|
||||||
|
|
||||||
const prompt = template.render({
|
const prompt = template.render({
|
||||||
messages,
|
messages,
|
||||||
bos_token: '',
|
|
||||||
eos_token: eosToken,
|
|
||||||
add_generation_prompt: true,
|
add_generation_prompt: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -160,7 +157,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const prompt = actions.applyChatTemplate(templateMessages, LLAMA_TEMPLATE);
|
const prompt = actions.applyChatTemplate(templateMessages, instruct);
|
||||||
return {
|
return {
|
||||||
prompt,
|
prompt,
|
||||||
isContinue,
|
isContinue,
|
||||||
|
|
@ -269,7 +266,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
},
|
},
|
||||||
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords]);
|
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
|
||||||
|
|
||||||
useEffect(() => void (async () => {
|
useEffect(() => void (async () => {
|
||||||
if (triggerNext && !generating.value) {
|
if (triggerNext && !generating.value) {
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,9 @@ import { useInputState } from "@common/hooks/useInputState";
|
||||||
interface IContext {
|
interface IContext {
|
||||||
connectionUrl: string;
|
connectionUrl: string;
|
||||||
input: string;
|
input: string;
|
||||||
lore: string;
|
instruct: string;
|
||||||
systemPrompt: string;
|
systemPrompt: string;
|
||||||
|
lore: string;
|
||||||
userPrompt: string;
|
userPrompt: string;
|
||||||
bannedWords: string[];
|
bannedWords: string[];
|
||||||
messages: IMessage[];
|
messages: IMessage[];
|
||||||
|
|
@ -17,6 +18,7 @@ interface IContext {
|
||||||
interface IActions {
|
interface IActions {
|
||||||
setConnectionUrl: (url: string | Event) => void;
|
setConnectionUrl: (url: string | Event) => void;
|
||||||
setInput: (url: string | Event) => void;
|
setInput: (url: string | Event) => void;
|
||||||
|
setInstruct: (template: string | Event) => void;
|
||||||
setLore: (lore: string | Event) => void;
|
setLore: (lore: string | Event) => void;
|
||||||
setSystemPrompt: (prompt: string | Event) => void;
|
setSystemPrompt: (prompt: string | Event) => void;
|
||||||
setUserPrompt: (prompt: string | Event) => void;
|
setUserPrompt: (prompt: string | Event) => void;
|
||||||
|
|
@ -35,6 +37,16 @@ interface IActions {
|
||||||
|
|
||||||
const SAVE_KEY = 'ai_game_save_state';
|
const SAVE_KEY = 'ai_game_save_state';
|
||||||
|
|
||||||
|
export enum Instruct {
|
||||||
|
LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`,
|
||||||
|
|
||||||
|
MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + '</s>'}}{%- endif %}{%- endfor %}`,
|
||||||
|
|
||||||
|
CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`,
|
||||||
|
|
||||||
|
ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`,
|
||||||
|
};
|
||||||
|
|
||||||
export const saveContext = (context: IContext) => {
|
export const saveContext = (context: IContext) => {
|
||||||
const contextToSave: Partial<IContext> = { ...context };
|
const contextToSave: Partial<IContext> = { ...context };
|
||||||
delete contextToSave.triggerNext;
|
delete contextToSave.triggerNext;
|
||||||
|
|
@ -44,8 +56,9 @@ export const saveContext = (context: IContext) => {
|
||||||
|
|
||||||
export const loadContext = (): IContext => {
|
export const loadContext = (): IContext => {
|
||||||
const defaultContext: IContext = {
|
const defaultContext: IContext = {
|
||||||
connectionUrl: 'http://192.168.10.102:5001',
|
connectionUrl: 'http://localhost:5001',
|
||||||
input: '',
|
input: '',
|
||||||
|
instruct: Instruct.LLAMA,
|
||||||
systemPrompt: 'You are creative writer. Write a story based on the world description below.',
|
systemPrompt: 'You are creative writer. Write a story based on the world description below.',
|
||||||
lore: '',
|
lore: '',
|
||||||
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
|
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
|
||||||
|
|
@ -75,6 +88,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
const loadedContext = useMemo(() => loadContext(), []);
|
const loadedContext = useMemo(() => loadContext(), []);
|
||||||
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl);
|
const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl);
|
||||||
const [input, setInput] = useInputState(loadedContext.input);
|
const [input, setInput] = useInputState(loadedContext.input);
|
||||||
|
const [instruct, setInstruct] = useInputState(loadedContext.instruct);
|
||||||
const [lore, setLore] = useInputState(loadedContext.lore);
|
const [lore, setLore] = useInputState(loadedContext.lore);
|
||||||
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
|
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
|
||||||
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
|
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
|
||||||
|
|
@ -86,6 +100,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
const actions: IActions = useMemo(() => ({
|
const actions: IActions = useMemo(() => ({
|
||||||
setConnectionUrl,
|
setConnectionUrl,
|
||||||
setInput,
|
setInput,
|
||||||
|
setInstruct,
|
||||||
setSystemPrompt,
|
setSystemPrompt,
|
||||||
setUserPrompt,
|
setUserPrompt,
|
||||||
setLore,
|
setLore,
|
||||||
|
|
@ -162,6 +177,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
||||||
const rawContext: IContext = {
|
const rawContext: IContext = {
|
||||||
connectionUrl,
|
connectionUrl,
|
||||||
input,
|
input,
|
||||||
|
instruct,
|
||||||
systemPrompt,
|
systemPrompt,
|
||||||
lore,
|
lore,
|
||||||
userPrompt,
|
userPrompt,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue