AI story: add summary api, add "stop here" button
This commit is contained in:
parent
b95506a095
commit
cc43e035fe
|
|
@ -13,12 +13,13 @@ import { Ace } from "../ace";
|
|||
export const Header = () => {
|
||||
const { modelName, modelTemplate, contextLength, promptTokens, blockConnection } = useContext(LLMContext);
|
||||
const {
|
||||
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct,
|
||||
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct
|
||||
messages, connectionUrl, systemPrompt, lore, userPrompt, bannedWords, instruct, summarizePrompt,
|
||||
setConnectionUrl, setSystemPrompt, setLore, setUserPrompt, addSwipe, setBannedWords, setInstruct, setSummarizePrompt,
|
||||
} = useContext(StateContext);
|
||||
|
||||
const loreOpen = useBool();
|
||||
const promptsOpen = useBool();
|
||||
const genparamsOpen = useBool();
|
||||
const assistantOpen = useBool();
|
||||
|
||||
const bannedWordsInput = useMemo(() => bannedWords.join('\n'), [bannedWords]);
|
||||
|
|
@ -83,6 +84,9 @@ export const Header = () => {
|
|||
<button class='icon color' title='Edit lore' onClick={loreOpen.setTrue}>
|
||||
🌍
|
||||
</button>
|
||||
<button class='icon color' title='Generation parameters' onClick={genparamsOpen.setTrue}>
|
||||
⚙
|
||||
</button>
|
||||
<button class='icon color' title='Edit prompts' onClick={promptsOpen.setTrue}>
|
||||
📃
|
||||
</button>
|
||||
|
|
@ -100,6 +104,19 @@ export const Header = () => {
|
|||
placeholder="Describe your world, for example: World of Awoo has big mountains and wide rivers."
|
||||
/>
|
||||
</Modal>
|
||||
<Modal open={genparamsOpen.value} onClose={genparamsOpen.setFalse}>
|
||||
<h3 class={styles.modalTitle}>Generation Parameters</h3>
|
||||
<div className={styles.scrollPane}>
|
||||
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
||||
<AutoTextarea
|
||||
placeholder="Each phrase on separate line"
|
||||
value={bannedWordsInput}
|
||||
onInput={handleSetBannedWords}
|
||||
onBlur={handleBlurBannedWords}
|
||||
class={styles.template}
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
<Modal open={promptsOpen.value} onClose={promptsOpen.setFalse}>
|
||||
<h3 class={styles.modalTitle}>Prompts Editor</h3>
|
||||
<div className={styles.scrollPane}>
|
||||
|
|
@ -109,17 +126,11 @@ export const Header = () => {
|
|||
<h4 class={styles.modalTitle}>User prompt template</h4>
|
||||
<Ace value={userPrompt} onInput={setUserPrompt} />
|
||||
<hr />
|
||||
<h4 class={styles.modalTitle}>Summary template</h4>
|
||||
<Ace value={summarizePrompt} onInput={setSummarizePrompt} />
|
||||
<hr />
|
||||
<h4 class={styles.modalTitle}>Instruct template</h4>
|
||||
<Ace value={instruct} onInput={setInstruct} />
|
||||
<hr />
|
||||
<h4 class={styles.modalTitle}>Banned phrases</h4>
|
||||
<AutoTextarea
|
||||
placeholder="Each phrase on separate line"
|
||||
value={bannedWordsInput}
|
||||
onInput={handleSetBannedWords}
|
||||
onBlur={handleBlurBannedWords}
|
||||
class={styles.template}
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
<MiniChat
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ interface IProps {
|
|||
}
|
||||
|
||||
export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps) => {
|
||||
const { editMessage, deleteMessage, setCurrentSwipe, addSwipe } = useContext(StateContext);
|
||||
const { messages, editMessage, deleteMessage, setCurrentSwipe, setMessages } = useContext(StateContext);
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [savedMessage, setSavedMessage] = useState('');
|
||||
const textRef = useRef<HTMLDivElement>(null);
|
||||
|
|
@ -41,6 +41,13 @@ export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps)
|
|||
}
|
||||
}, [deleteMessage, index]);
|
||||
|
||||
const handleStopHere = useCallback(() => {
|
||||
if (confirm('Delete all messages after that?')) {
|
||||
setMessages(messages.filter((_, i) => i <= index));
|
||||
setEditing(false);
|
||||
}
|
||||
}, [messages, setMessages, index]);
|
||||
|
||||
const handleEdit = useCallback((e: InputEvent) => {
|
||||
if (e.target instanceof HTMLTextAreaElement) {
|
||||
const newContent = e.target.value;
|
||||
|
|
@ -71,6 +78,7 @@ export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps)
|
|||
? <>
|
||||
<button class='icon' onClick={handleToggleEdit}>✔</button>
|
||||
<button class='icon' onClick={handleDeleteMessage}>🗑️</button>
|
||||
<button class='icon' onClick={handleStopHere} title='Stop here'>⛔</button>
|
||||
<button class='icon' onClick={handleCancelEdit}>❌</button>
|
||||
</>
|
||||
: <>
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import { Huggingface } from "../huggingface";
|
|||
|
||||
interface ICompileArgs {
|
||||
keepUsers?: number;
|
||||
raw?: boolean;
|
||||
}
|
||||
|
||||
interface ICompiledPrompt {
|
||||
|
|
@ -47,6 +48,7 @@ type IGenerationSettings = Partial<typeof DEFAULT_GENERATION_SETTINGS>;
|
|||
interface IActions {
|
||||
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise<ICompiledPrompt>;
|
||||
generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator<string>;
|
||||
summarize: (content: string) => Promise<string>;
|
||||
countTokens: (prompt: string) => Promise<number>;
|
||||
}
|
||||
export type ILLMContext = IContext & IActions;
|
||||
|
|
@ -81,7 +83,7 @@ export const LLMContext = createContext<ILLMContext>({} as ILLMContext);
|
|||
|
||||
export const LLMContextProvider = ({ children }: { children?: any }) => {
|
||||
const {
|
||||
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct,
|
||||
connectionUrl, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, instruct, summarizePrompt,
|
||||
setTriggerNext, addMessage, editMessage, setInstruct,
|
||||
} = useContext(StateContext);
|
||||
|
||||
|
|
@ -285,6 +287,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
generating.setFalse();
|
||||
}
|
||||
},
|
||||
summarize: async (message) => {
|
||||
const content = Huggingface.applyTemplate(summarizePrompt, { message });
|
||||
const prompt = Huggingface.applyChatTemplate(instruct, [{ role: 'user', content }]);
|
||||
|
||||
const tokens = await Array.fromAsync(actions.generate(prompt));
|
||||
|
||||
return tokens.join('');
|
||||
},
|
||||
countTokens: async (prompt) => {
|
||||
if (!connectionUrl) {
|
||||
return 0;
|
||||
|
|
@ -305,7 +315,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
|
||||
return 0;
|
||||
},
|
||||
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct]);
|
||||
}), [connectionUrl, lore, userPromptTemplate, systemPrompt, bannedWords, instruct, summarizePrompt]);
|
||||
|
||||
useEffect(() => void (async () => {
|
||||
if (triggerNext && !generating.value) {
|
||||
|
|
@ -326,11 +336,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
|
|||
for await (const chunk of actions.generate(prompt)) {
|
||||
text += chunk;
|
||||
setPromptTokens(tokens + Math.round(text.length * 0.25));
|
||||
editMessage(messageId, text);
|
||||
editMessage(messageId, text.trim());
|
||||
}
|
||||
|
||||
text = MessageTools.trimSentence(text);
|
||||
editMessage(messageId, text);
|
||||
editMessage(messageId, text.trim());
|
||||
|
||||
setPromptTokens(0); // trigger calculation
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ interface IContext {
|
|||
systemPrompt: string;
|
||||
lore: string;
|
||||
userPrompt: string;
|
||||
summarizePrompt: string;
|
||||
bannedWords: string[];
|
||||
messages: IMessage[];
|
||||
triggerNext: boolean;
|
||||
|
|
@ -22,6 +23,7 @@ interface IActions {
|
|||
setLore: (lore: string | Event) => void;
|
||||
setSystemPrompt: (prompt: string | Event) => void;
|
||||
setUserPrompt: (prompt: string | Event) => void;
|
||||
setSummarizePrompt: (prompt: string | Event) => void;
|
||||
setBannedWords: (words: string[]) => void;
|
||||
setTriggerNext: (triggerNext: boolean) => void;
|
||||
|
||||
|
|
@ -63,6 +65,7 @@ export const loadContext = (): IContext => {
|
|||
lore: '',
|
||||
userPrompt: `{% if prompt %}{% if isStart %}Start{% else %}Continue{% endif %} this story, taking information into account: {{ prompt | trim }}
|
||||
Remember that this story should be infinite and go forever. Avoid cliffhangers and pauses, be creative.{% elif isStart %}Write a novel using information above as a reference. Make sure to follow the lore exactly and avoid cliffhangers.{% else %}Continue the story forward. Avoid cliffhangers and pauses.{% endif %}`,
|
||||
summarizePrompt: 'Make the following text shorter, keeping all important details:\n\n{{ message }}\n\nYour answer should only contain the shortened text.',
|
||||
bannedWords: [],
|
||||
messages: [],
|
||||
triggerNext: false,
|
||||
|
|
@ -92,6 +95,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
|||
const [lore, setLore] = useInputState(loadedContext.lore);
|
||||
const [systemPrompt, setSystemPrompt] = useInputState(loadedContext.systemPrompt);
|
||||
const [userPrompt, setUserPrompt] = useInputState(loadedContext.userPrompt);
|
||||
const [summarizePrompt, setSummarizePrompt] = useInputState(loadedContext.summarizePrompt);
|
||||
const [bannedWords, setBannedWords] = useState<string[]>(loadedContext.bannedWords);
|
||||
const [messages, setMessages] = useState(loadedContext.messages);
|
||||
|
||||
|
|
@ -103,6 +107,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
|||
setInstruct,
|
||||
setSystemPrompt,
|
||||
setUserPrompt,
|
||||
setSummarizePrompt,
|
||||
setLore,
|
||||
setTriggerNext,
|
||||
setBannedWords: (words) => setBannedWords([...words]),
|
||||
|
|
@ -181,6 +186,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
|
|||
systemPrompt,
|
||||
lore,
|
||||
userPrompt,
|
||||
summarizePrompt,
|
||||
bannedWords,
|
||||
messages,
|
||||
triggerNext,
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ export namespace Huggingface {
|
|||
};
|
||||
|
||||
const templateCache: Record<string, string> = loadCache();
|
||||
const compiledTemplates = new Map<string, Template>();
|
||||
|
||||
const hasField = <T extends string>(obj: unknown, field: T): obj is Record<T, unknown> => (
|
||||
obj != null && typeof obj === 'object' && (field in obj)
|
||||
|
|
@ -256,15 +257,29 @@ export namespace Huggingface {
|
|||
return template;
|
||||
}
|
||||
|
||||
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => {
|
||||
const template = new Template(templateString);
|
||||
|
||||
const prompt = template.render({
|
||||
export const applyChatTemplate = (templateString: string, messages: ITemplateMessage[], functions?: IFunction[]) => (
|
||||
applyTemplate(templateString, {
|
||||
messages,
|
||||
add_generation_prompt: true,
|
||||
tools: functions?.map(convertFunctionToTool),
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
return prompt;
|
||||
};
|
||||
export const applyTemplate = (templateString: string, args: Record<string, any>): string => {
|
||||
try {
|
||||
let template = compiledTemplates.get(templateString);
|
||||
if (!template) {
|
||||
template = new Template(templateString);
|
||||
compiledTemplates.set(templateString, template);
|
||||
}
|
||||
|
||||
const result = template.render(args);
|
||||
|
||||
return result;
|
||||
} catch (e) {
|
||||
console.error('[applyTemplate] error:', e);
|
||||
}
|
||||
|
||||
return '';
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue