diff --git a/src/games/ai-story/components/message/message.tsx b/src/games/ai-story/components/message/message.tsx
index dc6d146..1e4f957 100644
--- a/src/games/ai-story/components/message/message.tsx
+++ b/src/games/ai-story/components/message/message.tsx
@@ -24,6 +24,7 @@ export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps)
const content = swipe?.content;
const summary = swipe?.summary;
+ const cost = swipe?.cost ?? 0;
const htmlContent = useMemo(() => MessageTools.format(content ?? ''), [content]);
const handleEnableEdit = useCallback(() => {
@@ -32,10 +33,10 @@ export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps)
}, [content]);
const handleSaveEdit = useCallback(() => {
- editMessage(index, editedMessage.trim());
- editSummary(index, '');
+ editMessage(index, editedMessage.trim(), cost);
+ editSummary(index, '', 0);
setEditing(false);
- }, [editMessage, editSummary, index, editedMessage]);
+ }, [editMessage, editSummary, index, editedMessage, cost]);
const handleCancelEdit = useCallback(() => {
setEditing(false);
@@ -77,6 +78,7 @@ export const Message = ({ message, index, isLastUser, isLastAssistant }: IProps)
: <>
{summary &&
{summary}}
+ {cost > 0 &&
💲 {cost}}
>
}
diff --git a/src/games/ai-story/contexts/llm.tsx b/src/games/ai-story/contexts/llm.tsx
index 08576a3..942ef66 100644
--- a/src/games/ai-story/contexts/llm.tsx
+++ b/src/games/ai-story/contexts/llm.tsx
@@ -33,9 +33,9 @@ const MESSAGES_TO_KEEP = 10;
interface IActions {
compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise;
- generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator;
+ generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator;
stopGeneration: () => void;
- summarize: (content: string) => Promise;
+ summarize: (content: string) => Promise;
countTokens: (prompt: string) => Promise;
}
export type ILLMContext = IContext & IActions;
@@ -168,18 +168,17 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
isRegen,
};
},
- generate: async function* (prompt, extraSettings = {}) {
+ generate: async function* (prompt, extraSettings = {}): AsyncGenerator {
try {
console.log('[LLM.generate]', prompt);
- setSpentKudos(0);
for await (const { text, cost } of Connection.generate(connection, prompt, {
...extraSettings,
banned_tokens: bannedWords.filter(w => w.trim()),
})) {
setSpentKudos(sk => sk + cost);
setTotalSpentKudos(sk => sk + cost);
- yield text;
+ yield { text, cost };
}
} catch (e) {
if (e instanceof Error && e.name !== 'AbortError') {
@@ -189,7 +188,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}
}
},
- summarize: async (message) => {
+ summarize: async (message) => {
try {
const content = Huggingface.applyTemplate(summarizePrompt, { message });
const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]);
@@ -204,10 +203,13 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
setSpentKudos(sk => sk + summary.cost);
setTotalSpentKudos(sk => sk + summary.cost);
- return MessageTools.trimSentence(summary.text);
+ return {
+ text: MessageTools.trimSentence(summary.text),
+ cost: summary.cost,
+ };
} catch (e) {
console.error('Error summarizing:', e);
- return '';
+ return { text: '', cost: 0 };
}
},
countTokens: async (prompt) => {
@@ -225,6 +227,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
let messageId = messages.length - 1;
let text = '';
+ let cost = 0;
const { prompt, isRegen } = await actions.compilePrompt(messages, { continueLast });
@@ -236,17 +239,18 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
}
generating.setTrue();
- editSummary(messageId, 'Generating...');
+ editSummary(messageId, 'Generating...', 0);
for await (const chunk of actions.generate(prompt)) {
- text += chunk;
+ text += chunk.text;
+ cost += chunk.cost;
setPromptTokens(promptTokens + approximateTokens(text));
- editMessage(messageId, text.trim());
+ editMessage(messageId, text.trim(), cost);
}
generating.setFalse();
text = MessageTools.trimSentence(text);
- editMessage(messageId, text);
- editSummary(messageId, '');
+ editMessage(messageId, text, cost);
+ editSummary(messageId, '', 0);
MessageTools.playReady();
}
@@ -260,8 +264,8 @@ export const LLMContextProvider = ({ children }: { children?: any }) => {
const message = messages[id];
const swipe = MessageTools.getSwipe(message);
if (message.role === 'assistant' && swipe?.content?.includes('\n') && !swipe.summary) {
- const summary = await actions.summarize(swipe.content);
- editSummary(id, summary);
+ const { text, cost } = await actions.summarize(swipe.content);
+ editSummary(id, text, cost);
}
}
} catch (e) {
diff --git a/src/games/ai-story/contexts/state.tsx b/src/games/ai-story/contexts/state.tsx
index fe4a077..0138484 100644
--- a/src/games/ai-story/contexts/state.tsx
+++ b/src/games/ai-story/contexts/state.tsx
@@ -80,8 +80,8 @@ interface IActions {
setMessages: (messages: IMessage[]) => void;
addMessage: (content: string, role: IMessage['role'], triggerNext?: boolean) => void;
- editMessage: (index: number, content: string) => void;
- editSummary: (index: number, summary: string) => void;
+ editMessage: (index: number, content: string, cost: number) => void;
+ editSummary: (index: number, summary: string, cost: number) => void;
deleteMessage: (index: number) => void;
setCurrentSwipe: (index: number, swipe: number) => void;
addSwipe: (index: number, content: string) => void;
@@ -270,11 +270,11 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
]);
setTriggerNext(triggerNext);
},
- editMessage: (index, content) => {
- setMessages(messages => MessageTools.updateSwipe(messages, index, { content }));
+ editMessage: (index, content, cost) => {
+ setMessages(messages => MessageTools.updateSwipe(messages, index, { content }, cost));
},
- editSummary: (index, summary) => {
- setMessages(messages => MessageTools.updateSwipe(messages, index, { summary }));
+ editSummary: (index, summary, cost) => {
+ setMessages(messages => MessageTools.updateSwipe(messages, index, { summary }, cost));
},
deleteMessage: (index) => setMessages(messages =>
messages.filter((_, i) => i !== index)
@@ -291,7 +291,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
if (currentSwipe >= swipes.length) {
if (latestSwipe.content.length > 0) {
currentSwipe = swipes.length;
- swipes.push({ content: '' });
+ swipes.push({ content: '', cost: 0 });
} else {
currentSwipe = swipes.length - 1;
}
@@ -315,7 +315,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => {
messages.map(
(message, i) => {
if (i === index) {
- const swipes = [...message.swipes, { content }];
+ const swipes = [...message.swipes, { content, cost: 0 }];
return {
...message,
diff --git a/src/games/ai-story/tools/messages.ts b/src/games/ai-story/tools/messages.ts
index 6d15636..1e482d2 100644
--- a/src/games/ai-story/tools/messages.ts
+++ b/src/games/ai-story/tools/messages.ts
@@ -3,6 +3,7 @@ import messageSound from '../assets/message.mp3';
export interface ISwipe {
content: string;
summary?: string;
+ cost: number;
}
export interface IMessage {
@@ -14,8 +15,8 @@ export interface IMessage {
export namespace MessageTools {
export const getSwipe = (message?: IMessage | null) => message?.swipes[message?.currentSwipe];
- export const create = (content: string, role: IMessage['role'] = 'user', technical = false): IMessage => (
- { role, currentSwipe: 0, swipes: [{ content }], technical }
+ export const create = (content: string, role: IMessage['role'] = 'user', technical = false, cost = 0): IMessage => (
+ { role, currentSwipe: 0, swipes: [{ content, cost }], technical }
);
export const playReady = () => {
@@ -97,12 +98,12 @@ export namespace MessageTools {
return text.trim();
}
- export const updateSwipe = (messages: IMessage[], index: number, update: Partial) => (
+ export const updateSwipe = (messages: IMessage[], index: number, update: Partial, cost = 0) => (
messages.map(
(m, i) => ({
...m,
swipes: i === index
- ? m.swipes.map((s, si) => (si === m.currentSwipe ? { ...s, ...update } : s))
+ ? m.swipes.map((s, si) => (si === m.currentSwipe ? { ...s, ...update, cost: s.cost + cost } : s))
: m.swipes
})
)