diff --git a/src/games/ai/components/header/header.tsx b/src/games/ai/components/header/header.tsx index 59364da..445e174 100644 --- a/src/games/ai/components/header/header.tsx +++ b/src/games/ai/components/header/header.tsx @@ -13,12 +13,13 @@ import { Ace } from "../ace"; export const Header = () => { const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext); const { - messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, - setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct + messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt, + setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt, } = useContext(StateContext); const loreOpen = useBool(); const promptsOpen = useBool(); + const genparamsOpen = useBool(); const assistantOpen = useBool(); const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]); @@ -83,6 +84,9 @@ export const Header = () => { + @@ -100,6 +104,19 @@ export const Header = () => { placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers." /> + +

Generation Parameters

+
+

Banned phrases

+ +
+

Prompts Editor

@@ -109,17 +126,11 @@ export const Header = () => {

User prompt template


+

Summary template

+ +

Instruct template

-
-

Banned phrases

-
{ - const { editMessage, deleteMessage, setCurrentSwipe, addSwipe } = useContext(StateContext); + const { messages, editMessage, deleteMessage, setCurrentSwipe, setMessages } = useContext(StateContext); const [editing, setEditing] = useState(false); const [savedMessage, setSavedMessage] = useState(''); const textRef = useRef(null); @@ -41,6 +41,13 @@ export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps) } }, [deleteMessage, index]); + const handleStopHere = useCallback(() => { + if (confirm('Delete all messages after that?')) { + setMessages(messages.filter((_, i) => i <= index)); + setEditing(false); + } + }, [messages, setMessages, index]); + const handleEdit = useCallback((e: InputEvent) => { if (e.target instanceof HTMLTextAreaElement) { const newContent = e.target.value; @@ -71,6 +78,7 @@ export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps) ? <> + : <> diff --git a/src/games/ai/contexts/llm.tsx b/src/games/ai/contexts/llm.tsx index eb9c92e..4c41200 100644 --- a/src/games/ai/contexts/llm.tsx +++ b/src/games/ai/contexts/llm.tsx @@ -10,6 +10,7 @@ import { Huggingface } from "../huggingface"; interface ICompileArgs { keepUsers?: number; + raw?: boolean; } interface ICompiledPrompt { @@ -47,6 +48,7 @@ type IGenerationSettings = Partial; interface IActions { compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator; + summarize: (content: string) => Promise; countTokens: (prompt: string) => Promise; } export type ILLMContext = IContext & IActions; @@ -81,7 +83,7 @@ export const LLMContext = createContext({} as ILLMContext); export const LLMContextProvider = ({ children }: { children?: any }) => { const { - connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, + connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt, setTriggerNext, addMessage, editMessage, setInstruct, } = useContext(StateContext); @@ -285,6 +287,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { generating.setFalse(); } }, + summarize: async (message) => { + const content = Huggingface.applyTemplate(summarizePrompt, { message }); + const prompt = Huggingface.applyChatTemplate(instruct, [{ role: 'user', content }]); + + const tokens = await Array.fromAsync(actions.generate(prompt)); + + return tokens.join(''); + }, countTokens: async (prompt) => { if (!connectionUrl) { return 0; @@ -305,7 +315,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { return 0; }, - }), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]); + }), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct, summarizePrompt]); useEffect(() => void (async () => { if (triggerNext && !generating.value) { @@ -326,11 +336,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { for await (const chunk of actions.generate(prompt)) { text += chunk; setPromptTokens(tokens + Math.round(text.length * 0.25)); - editMessage(messageId, text); + editMessage(messageId, text.trim()); } text = MessageTools.trimSentence(text); - editMessage(messageId, text); + editMessage(messageId, text.trim()); setPromptTokens(0); // trigger calculation diff --git a/src/games/ai/contexts/state.tsx b/src/games/ai/contexts/state.tsx index dd5235a..b69c96d 100644 --- a/src/games/ai/contexts/state.tsx +++ b/src/games/ai/contexts/state.tsx @@ -10,6 +10,7 @@ interface IContext { systemPrompt: string; lore: string; userPrompt: string; + summarizePrompt: string; bannedWords: string[]; messages: IMessage[]; triggerNext: boolean; @@ -22,6 +23,7 @@ interface IActions { setLore: (lore: string | Event) => void; setSystemPrompt: (prompt: string | Event) => void; setUserPrompt: (prompt: string | Event) => void; + setSummarizePrompt: (prompt: string | Event) => void; setBannedWords: (words: string[]) => void; setTriggerNext: (triggerNext: boolean) => void; @@ -63,6 +65,7 @@ export const loadContext = (): IContext => { lore: '', userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }} Remember that this story should be infinite and go forever. Avoid cliffhangers and pauses, be creative.{% elif isStart %}Write a novel using information above as a reference. Make sure to follow the lore exactly and avoid cliffhangers.{% else %}Continue the story forward. Avoid cliffhangers and pauses.{% endif %}`, + summarizePrompt: 'Make the following text shorter, keeping all important details:\n\n{{ message }}\n\nYour answer should only contain the shortened text.', bannedWords: [], messages: [], triggerNext: false, @@ -92,6 +95,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { const [lore, setLore] = useInputState(loadedContext.lore); const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt); const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt); + const [summarizePrompt, setSummarizePrompt] = useInputState(loadedContext.summarizePrompt); const [bannedWords, setBannedWords] = useState(loadedContext.bannedWords); const [messages, setMessages] = useState(loadedContext.messages); @@ -103,6 +107,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { setInstruct, setSystemPrompt, setUserPrompt, + setSummarizePrompt, setLore, setTriggerNext, setBannedWords: (words) => setBannedWords([...words]), @@ -181,6 +186,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { systemPrompt, lore, userPrompt, + summarizePrompt, bannedWords, messages, triggerNext, diff --git a/src/games/ai/huggingface.ts b/src/games/ai/huggingface.ts index b1d71bd..d99cb1a 100644 --- a/src/games/ai/huggingface.ts +++ b/src/games/ai/huggingface.ts @@ -79,6 +79,7 @@ export namespace Huggingface { }; const templateCache: Record = loadCache(); + const compiledTemplates = new Map(); const hasField = (obj: unknown, field: T): obj is Record => ( obj != null && typeof obj === 'object' && (field in obj) @@ -256,15 +257,29 @@ export namespace Huggingface { return template; } - export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => { - const template = new Template(templateString); - - const prompt = template.render({ + export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => ( + applyTemplate(templateString, { messages, add_generation_prompt: true, tools: functions?.map(convertFunctionToTool), - }); + }) + ); - return prompt; - }; + export const applyTemplate = (templateString: string, args: Record): string => { + try { + let template = compiledTemplates.get(templateString); + if (!template) { + template = new Template(templateString); + compiledTemplates.set(templateString, template); + } + + const result = template.render(args); + + return result; + } catch (e) { + console.error('[applyTemplate] error:', e); + } + + return ''; + } } \ No newline at end of file