From a4ce47d9d865db15be4c25fc4ae7afb389718344 Mon Sep 17 00:00:00 2001 From: Pabloader Date: Sat, 2 Nov 2024 18:25:19 +0000 Subject: [PATCH] AI Story: add instruct format selector --- src/games/ai/assets/style.css | 7 ++-- .../ai/components/header/header.module.css | 21 ++++++++---- src/games/ai/components/header/header.tsx | 33 ++++++++++++++----- src/games/ai/const.ts | 1 - src/games/ai/contexts/llm.tsx | 11 +++---- src/games/ai/contexts/state.tsx | 20 +++++++++-- 6 files changed, 64 insertions(+), 29 deletions(-) delete mode 100644 src/games/ai/const.ts diff --git a/src/games/ai/assets/style.css b/src/games/ai/assets/style.css index f9477b2..254532e 100644 --- a/src/games/ai/assets/style.css +++ b/src/games/ai/assets/style.css @@ -22,13 +22,13 @@ } textarea, -input { +input, +select { color: var(--color); border: var(--border); background-color: var(--backgroundColorDark); font-size: 1em; font-family: sans-serif; - appearance: none; outline: none; } @@ -106,7 +106,8 @@ body { flex-direction: row; height: auto; width: 100%; - > textarea { + + >textarea { min-height: 48px; resize: none; background-color: var(--backgroundColor); diff --git a/src/games/ai/components/header/header.module.css b/src/games/ai/components/header/header.module.css index d972408..ee7f15e 100644 --- a/src/games/ai/components/header/header.module.css +++ b/src/games/ai/components/header/header.module.css @@ -6,13 +6,20 @@ width: 100%; border: var(--border); - >input { - &.valid { - background-color: var(--green); - } + .valid { + background-color: var(--green); + } - &.invalid { - background-color: var(--red); + .invalid { + background-color: var(--red); + } + + .inputs { + display: flex; + flex-direction: row; + + select { + text-transform: capitalize; } } @@ -25,7 +32,7 @@ } .modalTitle { - margin: 0; + margin: 0; } .scrollPane { diff --git a/src/games/ai/components/header/header.tsx b/src/games/ai/components/header/header.tsx index af373cc..d0c4ab1 100644 --- a/src/games/ai/components/header/header.tsx +++ b/src/games/ai/components/header/header.tsx @@ -2,7 +2,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from "preact/ho import { useBool } from "@common/hooks/useBool"; import { Modal } from "@common/components/modal/modal"; -import { StateContext } from "../../contexts/state"; +import { Instruct, StateContext } from "../../contexts/state"; import { LLMContext } from "../../contexts/llm"; import { MiniChat } from "../minichat/minichat"; import { AutoTextarea } from "../autoTextarea"; @@ -13,8 +13,8 @@ import { Ace } from "../ace"; export const Header = () => { const { getContextLength } = useContext(LLMContext); const { - messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, - setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, + messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, + setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct } = useContext(StateContext); const [urlValid, setUrlValid] = useState(false); const [urlEditing, setUrlEditing] = useState(false); @@ -66,11 +66,21 @@ export const Header = () => { return (
- +
+ + +

Lore Editor

- +

Prompts Editor

@@ -99,6 +113,7 @@ export const Header = () => {

Banned phrases

\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`; \ No newline at end of file diff --git a/src/games/ai/contexts/llm.tsx b/src/games/ai/contexts/llm.tsx index dd1d378..a8a9e04 100644 --- a/src/games/ai/contexts/llm.tsx +++ b/src/games/ai/contexts/llm.tsx @@ -1,6 +1,5 @@ import Lock from "@common/lock"; import SSE from "@common/sse"; -import { LLAMA_TEMPLATE } from "../const"; import { createContext } from "preact"; import { useContext, useEffect, useMemo } from "preact/hooks"; import { MessageTools, type IMessage } from "../messages"; @@ -58,7 +57,7 @@ export const LLMContext = createContext({} as ILLMContext); export const LLMContextProvider = ({ children }: { children?: any }) => { const { - connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, + connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, setTriggerNext, addMessage, editMessage, } = useContext(StateContext); const generating = useBool(false); @@ -74,13 +73,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }, [userPrompt]); const actions: IActions = useMemo(() => ({ - applyChatTemplate: (messages: ITemplateMessage[], templateString: string, eosToken = '') => { + applyChatTemplate: (messages: ITemplateMessage[], templateString: string) => { const template = new Template(templateString); const prompt = template.render({ messages, - bos_token: '', - eos_token: eosToken, add_generation_prompt: true, }); @@ -160,7 +157,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { } } - const prompt = actions.applyChatTemplate(templateMessages, LLAMA_TEMPLATE); + const prompt = actions.applyChatTemplate(templateMessages, instruct); return { prompt, isContinue, @@ -269,7 +266,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { return 0; }, - }), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords]); + }), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]); useEffect(() => void (async () => { if (triggerNext && !generating.value) { diff --git a/src/games/ai/contexts/state.tsx b/src/games/ai/contexts/state.tsx index 7d5d3b8..98d7809 100644 --- a/src/games/ai/contexts/state.tsx +++ b/src/games/ai/contexts/state.tsx @@ -6,8 +6,9 @@ import { useInputState } from "@common/hooks/useInputState"; interface IContext { connectionUrl: string; input: string; - lore: string; + instruct: string; systemPrompt: string; + lore: string; userPrompt: string; bannedWords: string[]; messages: IMessage[]; @@ -17,6 +18,7 @@ interface IContext { interface IActions { setConnectionUrl: (url: string | Event) => void; setInput: (url: string | Event) => void; + setInstruct: (template: string | Event) => void; setLore: (lore: string | Event) => void; setSystemPrompt: (prompt: string | Event) => void; setUserPrompt: (prompt: string | Event) => void; @@ -35,6 +37,16 @@ interface IActions { const SAVE_KEY = 'ai_game_save_state'; +export enum Instruct { + LLAMA = `{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}`, + + MISTRAL = `{%- if messages[0]['role'] == 'system' %}{%- set system_message = messages[0]['content'] %}{%- set loop_messages = messages[1:] %}{%- else %}{%- set loop_messages = messages %}{%- endif %}{%- for message in loop_messages %}{%- if message['role'] == 'user' %}{%- if loop.first and system_message is defined %}{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}{%- else %}{{- ' [INST] ' + message['content'] + ' [/INST]' }}{%- endif %}{%- elif message['role'] == 'assistant' %}{{- ' ' + message['content'] + ''}}{%- endif %}{%- endfor %}`, + + CHATML = `{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n\n' }}{% endif %}`, + + ALPACA = `{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{ message['content'] + '\n\n'}}{% elif message['role'] == 'user' %}{{'### Instruction:\n\n' + message['content'] + '\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Response:\n\n' + message['content'] + '\n\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Response:\n\n' }}{% endif %}`, +}; + export const saveContext = (context: IContext) => { const contextToSave: Partial = { ...context }; delete contextToSave.triggerNext; @@ -44,8 +56,9 @@ export const saveContext = (context: IContext) => { export const loadContext = (): IContext => { const defaultContext: IContext = { - connectionUrl: 'http://192.168.10.102:5001', + connectionUrl: 'http://localhost:5001', input: '', + instruct: Instruct.LLAMA, systemPrompt: 'You are creative writer. Write a story based on the world description below.', lore: '', userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }} @@ -75,6 +88,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { const loadedContext = useMemo(() => loadContext(), []); const [connectionUrl, setConnectionUrl] = useInputState(loadedContext.connectionUrl); const [input, setInput] = useInputState(loadedContext.input); + const [instruct, setInstruct] = useInputState(loadedContext.instruct); const [lore, setLore] = useInputState(loadedContext.lore); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); @@ -86,6 +100,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { const actions: IActions = useMemo(() => ({ setConnectionUrl, setInput, + setInstruct, setSystemPrompt, setUserPrompt, setLore, @@ -162,6 +177,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { const rawContext: IContext = { connectionUrl, input, + instruct, systemPrompt, lore, userPrompt,