diff --git a/src/common/components/modal/modal.module.css b/src/common/components/modal/modal.module.css index 0dc6ad7..1e40825 100644 --- a/src/common/components/modal/modal.module.css +++ b/src/common/components/modal/modal.module.css @@ -18,7 +18,7 @@ border-radius: var(--border-radius, 0); &::backdrop { - background-color: var(--shadeColor, rgba(0, 0, 0, 0.2)); + backdrop-filter: blur(5px); } >.content { diff --git a/src/games/ai/assets/bg.jpg b/src/games/ai/assets/bg.jpg new file mode 100644 index 0000000..cec9429 Binary files /dev/null and b/src/games/ai/assets/bg.jpg differ diff --git a/src/games/ai/assets/style.css b/src/games/ai/assets/style.css index 254532e..c3b4611 100644 --- a/src/games/ai/assets/style.css +++ b/src/games/ai/assets/style.css @@ -1,6 +1,6 @@ :root { - --backgroundColorDark: #111111; - --backgroundColor: #333333; + --backgroundColorDark: rgba(0, 0, 0, 0.3); + --backgroundColor: rgba(51, 51, 51, 0.9); --color: #DCDCD2; --italicColor: #AFAFAF; --quoteColor: #D4E5FF; @@ -67,19 +67,27 @@ button { body { color: var(--color); - background-color: var(--backgroundColor); width: 100dvw; height: 100dvh; - - display: flex; - flex-direction: row; - justify-content: center; font-size: 16px; line-height: 1.5; + .root { + background-size: cover; + background-position: center; + background-repeat: no-repeat; + width: 100%; + height: 100%; + + display: flex; + flex-direction: row; + justify-content: center; + } + .app { display: flex; flex-direction: column; + background-color: var(--backgroundColor); width: 100%; max-width: 1200px; diff --git a/src/games/ai/components/app.tsx b/src/games/ai/components/app.tsx index 97b1265..4d4aa76 100644 --- a/src/games/ai/components/app.tsx +++ b/src/games/ai/components/app.tsx @@ -2,12 +2,16 @@ import { Header } from "./header/header"; import { Chat } from "./chat"; import { Input } from "./input"; +import bgImage from '../assets/bg.jpg'; + export const App = () => { return ( -
-
- - +
+
+
+ + +
); }; diff --git a/src/games/ai/components/header/header.tsx b/src/games/ai/components/header/header.tsx index 2b21610..2f2bfae 100644 --- a/src/games/ai/components/header/header.tsx +++ b/src/games/ai/components/header/header.tsx @@ -13,8 +13,8 @@ import { Ace } from "../ace"; export const Header = () => { const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext); const { - messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, - setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, + messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, + setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, setSummaryEnabled, } = useContext(StateContext); const loreOpen = useBool(); @@ -52,6 +52,12 @@ export const Header = () => { } }, [setBannedWords]); + const handleSetSummaryEnabled = useCallback((e: Event) => { + if (e.target instanceof HTMLInputElement) { + setSummaryEnabled(e.target.checked); + } + }, [setSummaryEnabled]); + return (
@@ -128,6 +134,10 @@ export const Header = () => {

Summary template

+

Instruct template

diff --git a/src/games/ai/components/minichat/minichat.tsx b/src/games/ai/components/minichat/minichat.tsx index 3ba990d..d096057 100644 --- a/src/games/ai/components/minichat/minichat.tsx +++ b/src/games/ai/components/minichat/minichat.tsx @@ -32,7 +32,7 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) }, []); useEffect(() => { - DOMTools.scrollDown(ref.current, false); + setTimeout(() => DOMTools.scrollDown(ref.current, false), 100); }, [generating, open]); useEffect(() => { diff --git a/src/games/ai/contexts/llm.tsx b/src/games/ai/contexts/llm.tsx index 24760e6..9eb7294 100644 --- a/src/games/ai/contexts/llm.tsx +++ b/src/games/ai/contexts/llm.tsx @@ -38,11 +38,18 @@ const DEFAULT_GENERATION_SETTINGS = { top_k: 100, top_p: 0.92, banned_tokens: [], - max_length: 512, + max_length: 300, trim_stop: true, - stop_sequence: ['[INST]', '[/INST]', '', '<|'] + stop_sequence: ['[INST]', '[/INST]', '', '<|'], + dry_allowed_length: 5, + dry_multiplier: 0.8, + dry_base: 1, + dry_sequence_breakers: ["\n", ":", "\"", "*"], + dry_penalty_last_n: 0 } +const MESSAGES_TO_KEEP = 10; + type IGenerationSettings = Partial; interface IActions { @@ -81,14 +88,18 @@ export const normalizeModel = (model: string) => { export const LLMContext = createContext({} as ILLMContext); +const processing = { + tokenizing: false, + summarizing: false, +} + export const LLMContextProvider = ({ children }: { children?: any }) => { const { - connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, + connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, summaryEnabled, setTriggerNext, addMessage, editMessage, editSummary, setInstruct, } = useContext(StateContext); const generating = useBool(false); - const summarizing = useBool(false); const blockConnection = useBool(false); const [promptTokens, setPromptTokens] = useState(0); const [contextLength, setContextLength] = useState(0); @@ -156,10 +167,8 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { const lastUserMessage = userMessages.at(-1); const firstUserMessage = userMessages.at(0); - const system = `${systemPrompt}\n\n${lore}`.trim(); - const templateMessages: Huggingface.ITemplateMessage[] = [ - { role: 'system', content: system }, + { role: 'system', content: systemPrompt.trim() }, ]; if (keepUsers) { @@ -168,7 +177,8 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { for (const message of messages) { const { role } = message; - const content = MessageTools.getSwipe(message)?.content ?? ''; + const swipe = MessageTools.getSwipe(message); + let content = swipe?.content ?? ''; if (role === 'user' && usersRemaining > keepUsers) { usersRemaining--; } else if (role === 'assistant' && templateMessages.at(-1).role === 'assistant') { @@ -188,7 +198,16 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { } } else { const story = promptMessages.filter(m => m.role === 'assistant') - .map(m => MessageTools.getSwipe(m)?.content.trim()).join('\n\n'); + .map((m, i, msgs) => { + const swipe = MessageTools.getSwipe(m); + if (!swipe) return ''; + + let { content, summary } = swipe; + if (summary && i < msgs.length - MESSAGES_TO_KEEP) { + content = summary; + } + return content; + }).join('\n\n'); if (story.length > 0) { const prompt = MessageTools.getSwipe(firstUserMessage)?.content; @@ -215,6 +234,8 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }); } + templateMessages[1].content = `${lore}\n\n${templateMessages[1].content}`; + const prompt = Huggingface.applyChatTemplate(instruct, templateMessages); return { prompt, @@ -326,8 +347,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { let text: string = ''; const { prompt, isRegen } = await actions.compilePrompt(messages); - const tokens = await actions.countTokens(prompt); - setPromptTokens(tokens); if (!isRegen) { addMessage('', 'assistant'); @@ -337,7 +356,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { editSummary(messageId, 'Generating...'); for await (const chunk of actions.generate(prompt)) { text += chunk; - setPromptTokens(tokens + Math.round(text.length * 0.25)); + setPromptTokens(promptTokens + Math.round(text.length * 0.25)); editMessage(messageId, text.trim()); } @@ -345,39 +364,37 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { editMessage(messageId, text); editSummary(messageId, ''); - setPromptTokens(0); // trigger calculation - MessageTools.playReady(); } - })(), [actions, triggerNext, messages, generating.value]); + })(), [triggerNext]); useEffect(() => void (async () => { - if (!generating.value && !summarizing.value) { - summarizing.setTrue(); - for (let id = 0; id < messages.length; id++) { - const message = messages[id]; - const swipe = MessageTools.getSwipe(message); - if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) { - const summary = await actions.summarize(swipe.content); - editSummary(id, summary); + if (summaryEnabled && !generating.value && !processing.summarizing) { + try { + processing.summarizing = true; + for (let id = 0; id < messages.length; id++) { + const message = messages[id]; + const swipe = MessageTools.getSwipe(message); + if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) { + const summary = await actions.summarize(swipe.content); + editSummary(id, summary); + } } + } catch (e) { + console.error(`Could not summarize`, e) + } finally { + processing.summarizing = false; } - summarizing.setFalse(); } - })(), [messages, generating.value, summarizing.value]); + })(), [messages]); useEffect(() => { if (!blockConnection.value) { setPromptTokens(0); setContextLength(0); + setModelName(''); getContextLength().then(setContextLength); - } - }, [connectionUrl, instruct, blockConnection.value]); - - useEffect(() => { - if (!blockConnection.value) { - setModelName(''); getModelName().then(normalizeModel).then(setModelName); } }, [connectionUrl, blockConnection.value]); @@ -397,14 +414,24 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { } }, [modelName]); - useEffect(() => { - if (promptTokens === 0 && !blockConnection.value) { - actions.compilePrompt(messages) - .then(({ prompt }) => actions.countTokens(prompt)) - .then(setPromptTokens) - .catch(e => console.error(`Could not count tokens`, e)); + const calculateTokens = useCallback(async () => { + if (!processing.tokenizing && !blockConnection.value && !generating.value) { + try { + processing.tokenizing = true; + const { prompt } = await actions.compilePrompt(messages); + const tokens = await actions.countTokens(prompt); + setPromptTokens(tokens); + } catch (e) { + console.error(`Could not count tokens`, e) + } finally { + processing.tokenizing = false; + } } - }, [actions, promptTokens, messages, blockConnection.value]); + }, [actions, messages, blockConnection.value]); + + useEffect(() => { + calculateTokens(); + }, [messages, connectionUrl, blockConnection.value, instruct, /* systemPrompt, lore, userPrompt TODO debounce*/]); useEffect(() => { try { diff --git a/src/games/ai/contexts/state.tsx b/src/games/ai/contexts/state.tsx index ac80911..958a8b7 100644 --- a/src/games/ai/contexts/state.tsx +++ b/src/games/ai/contexts/state.tsx @@ -11,6 +11,7 @@ interface IContext { lore: string; userPrompt: string; summarizePrompt: string; + summaryEnabled: boolean; bannedWords: string[]; messages: IMessage[]; triggerNext: boolean; @@ -26,6 +27,7 @@ interface IActions { setSummarizePrompt: (prompt: string | Event) => void; setBannedWords: (words: string[]) => void; setTriggerNext: (triggerNext: boolean) => void; + setSummaryEnabled: (summaryEnabled: boolean) => void; setMessages: (messages: IMessage[]) => void; addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void; @@ -64,11 +66,21 @@ export const loadContext = (): IContext => { connectionUrl: 'http://localhost:5001', input: '', instruct: Instruct.CHATML, - systemPrompt: 'You are creative writer. Write a story based on the world description below.', + systemPrompt: 'You are a creative writer. Write a story based on the world description below. Story should be adult and mature; and could include swearing, violence and unfairness. Portray characters realistically and stay in the lore.', lore: '', - userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }} -Remember that this story should be infinite and go forever. Avoid cliffhangers and pauses, be creative.{% elif isStart %}Write a novel using information above as a reference. Make sure to follow the lore exactly and avoid cliffhangers.{% else %}Continue the story forward. Avoid cliffhangers and pauses.{% endif %}`, - summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details, your answer should contain only the shortened text:\n\n{{ message }}', + userPrompt: `{% if isStart -%} + Write a novel using information above as a reference. +{%- else -%} + Continue the story forward. +{%- endif %} + +{% if prompt -%} + This is the description of what I want to happen next: {{ prompt | trim }} +{% endif %} +Remember that this story should be infinite and go forever. +Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, + summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', + summaryEnabled: false, bannedWords: [], messages: [], triggerNext: false, @@ -101,6 +113,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { const [summarizePrompt, setSummarizePrompt] = useInputState(loadedContext.summarizePrompt); const [bannedWords, setBannedWords] = useState(loadedContext.bannedWords); const [messages, setMessages] = useState(loadedContext.messages); + const [summaryEnabled, setSummaryEnabled] = useState(loadedContext.summaryEnabled); const [triggerNext, setTriggerNext] = useState(false); @@ -113,6 +126,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { setSummarizePrompt, setLore, setTriggerNext, + setSummaryEnabled, setBannedWords: (words) => setBannedWords([...words]), setMessages: (newMessages) => setMessages(newMessages.slice()), @@ -192,6 +206,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { lore, userPrompt, summarizePrompt, + summaryEnabled, bannedWords, messages, triggerNext,