From 277b3157953be21ca1dcdf0c760292e18965cc96 Mon Sep 17 00:00:00 2001 From: Pabloader Date: Tue, 12 Nov 2024 16:32:52 +0000 Subject: [PATCH] AIStory: stopping --- bun.lockb | Bin 31409 -> 31747 bytes package.json | 1 + src/common/hooks/useAsyncEffect.ts | 4 + src/common/utils.ts | 15 ++- src/games/ai/components/input.tsx | 7 +- src/games/ai/components/minichat/minichat.tsx | 9 +- src/games/ai/connection.ts | 76 ++++++++----- src/games/ai/contexts/llm.tsx | 106 +++++++----------- src/games/ai/contexts/state.tsx | 2 +- src/games/ai/huggingface.ts | 6 +- 10 files changed, 122 insertions(+), 104 deletions(-) create mode 100644 src/common/hooks/useAsyncEffect.ts diff --git a/bun.lockb b/bun.lockb index d9cef444a408f13ab3f110010778c6c78fb73765..40e1e247ab609eaca320fbae79aa4ad1dc8d3299 100755 GIT binary patch delta 5143 zcmeHLdvKK1760xRmSmG`LP$b(vk7?-l1;LkWRq+*&o2>`#0?@rvU!jf0m`Nzk6;BU zBmsq>)ru#rP)k+RN-_mQqm|gA*whg!8mKbHs*LIA=ptBcgY_{$f9Ko%L_2C{I{mBD zo1FW*=RVIl_w4=7y@xJ{ZJ&yJY!-*>ui-zg-%}I5>HUWnNgi+KSnrWtr^VKkzL$*A z`B6z-nz>k1jUpK|quLc+l{+u5l0rEVXb&?bVTnFUwdr;4~ ze5xcFfKLK3?z)~ze}faq3upN!hyzXoAK`De1^Fhv41uPj{%6pQ@eUyOE1|U_t$Gx5 z@+>1j{P*0n+S<|HwM^Q8wV7YB%x97zf4|54*b14seJ?V`37QZw-3K5kD{}@K8-b%h zGw_$dc;I$mBCrz}2V4L&0Ubc@7YB?5>VRyF&oLhx;UWLLGa@Bv2KZheuez+OqkU!j z(q($^=n~v>0UAc3!ArezVfSy)7c+fnl64z8yI7s|T`N6HfS5w<9SB+3B>%vgjPLo8 z`|>i=AtM{9&W~MZXkXcj&t_MToj-NY(A>A5nwVTVeBxF6?)gK(Tl=4M$OW(L$vSp$ zI6L5?tg~a_RQfouP&mmlRT1mSgZvN;Aa{}%`4+MSD&ht51S+x|K!bthVlH`6bvs#t z6!AQHkhjtR@*(mDDf&;a#if{5M;`|j3Ojl9ioPFyQ-`_$UP^x!TsgQP=H&FLlo4F6 z?}7_mg-Tf+_ACtqE8+|C1}nOvKuId5^T8#ei##EUei(eEHkp_tZ-^p_$r7rFP2@p- zlm?JrCNJ`NWHBh>Zt@@>rvZZ^Cj`@3L%F^QrsL(NsLG!pPnaT44WYrXa{X1D6d%uw{Fqa`T^V|1)kdKSFc1rEp1*HRqnA!*O$ zjmMLUUp>5oPU?)9z!IZ~KaeK|z7#`)G3BC^yr^197Na75K^~(b#~5kQSS}jKiz>)& zQp8*2F)6ayM1!Vs_?p*LVu5yiV5sT4P)Cbm6S{HC@lRPzUq$SQG#i9dNV0&qUj~R1 zBiCD*LHa6kKO2ZK8^q~5kp3Q<)XjZzK|IFp$2=gXZ$fyvRIIl8kjIsP7_ozD`X=P| zGJiWGj($mzBz1i*sELsqDna-q@D{4+zlk^pr1>B=Vy(1LogamNkF0^eyNWF6BLBFr zqD;+kG3wqn`db-!xn>a0(gNbd$g{M8xV|04e21S0sNVElo}~4`;l;hXJidP~(lYi&eT;XvoL`+v7>%Z0xL z#8ov1iENkk=OY*A1P_l@{Pu-CPgC66@7}hd^oFd%V~{{dm%=${3!u*u+_XQ@K+TCR zp`(FBH`ON@$dKd`fz*)Xrl4d49Ra5&eX^SlgIksC5+O7KZbgcL;!|A0K%Nvg8D|;j zC2-+nn&qZr;5N^4i3mClZbPbptf?*$NgGq$ls?-)Z-I*@%WOA|gS%_COT^IY;QG=G zRF>uvCfb?iCTF^V&Vw^kQM#Mn2e&ufCF02oZnwoi^DHirKzl51s>(3XMQ}+}o#CcW z!5zqOi4^(_-2P1Hm+2CzG>{4XvY=ms7eK3mFgjRpotQ(P zf!ptZRt}fAh6Wtast{Tgx)4k?6hbQ}v;tR6dMC63x60|lmv{u+iXv!L$^harnsxpBA%r ztn}D8?BqHve)ZMS+LF76e^qg28ugUq2A}kw+o21oe?i4H=zZz3FQF(c0P)X6J3cj@ zXGP?!CZ2%%IQ=L4Z8y>#3ujYlZACX?7WHK2Us)4qF^JE3M2zbBekxr8$kRG=2(H_hBOez1(_g3hiXg^2Sj`%4YAlr6?ITFxqyfA zqH{r4gRTN`$iQJl3Md(r1WE+W1jU1NG_f$U8-XR?GdReXI4D4%EaCd3ez&!HUNjcO z!43Dpok10xzwK-SgqeH?QOB}*5SpoV;eKSZ@iu5?<6YzV64)Y%D6n=M5wYg!pi~fB zheg&b!B*fhTZFB|Rw$+eH>P!0qErFm9pspWt<2WuUUNX#g4n5e7Z4&#)gTtZ1!B8n zm(<z7U66LSsUQWHHrXA`H#^G{?c`o$8eYFBLIjE&TdpnFPCu%< zEyj0kmL%epR#>@<>#gH6ZjtyQNF@0Nw58r?*6!Xj`Wr&KcC0_43YyD9eS%cdix8)V z?r$)OI=WnMD)@E29W*TcSIkablC2MDuRfL2)(QZe3t|a*%7J5$A51y=^ zeBfPKY{CAqooqogv(YHST(%Y?PUZ!Td1mdtUlwxTD;ZHIj$j_Y(H^4*8cn#Mjx^?p z5jv0ZCTPi(%@{tz8$ zGMTkEg)N6~?mFf?ZCBgvxwb;-bNaZ+B_L2l%88` z5@+ZQ^z!2O4NZFT?(MksxxEiUjWvg*Thx{7)1;PZ^(o^W@1(??4T81M=}v zpsSnhR_#rrYiHG^Cmx;kLsd_8hqN!r>fK}ETl#A5nranHn=E{gbBMr3}9opm0JLjH` zJ(_s&6hz6fIn*ulZ`QxG!0~;HnYE{qP3OM%Qs49sl4WcI{@_h@l-p`zXR2w{d@A_; zHdRn*e>?THR`98dM8{flZPQzi#hHm}v#Rpf- m8Q(YF?i=>~O~d!L%@vXAnUihJR?({dRH)aJF#LM^xbdI+Bh6(1 delta 4936 zcmeHLdvH|M89(Rdxd}^1AY^YgAtWT^#cq<_Ja)5TZvrZW3nGuWBq1bAj2c8kL1x^j zNf@<4D~|D-wt|DD>Ntv6AE`A<3zhmPDKOANRji86q*D#b)UgzkiuU*2-3wanjGb}% zSI^Gx{LXh?_k8Dk=R4=@9iPf)&&m6o#jc{)Ypr#`(b|`OeB!FVdVGJkoIG+{W5fK{ zAGmA#g~yJZ&+QdPSUwoY>Hop9h#s3tPP73XTPH#W8e{2|fYl$v6b&N$81oJenrNbjVAm zMh4QgsWZ4fcum*q$VZidm%!7&FMxR~gH5hm^deJJFuAPRG#il^Ml< z!e@BeUWxMaQAR^U)pwx>Pgy>t#$*hY$4rW$+v4lwJPOBavY*Bg$CEcflZz>Y_(v2@ z(3Il}O(gg&=~xOs+DdAOuaPS#l&DFfFk%OdBkm)wMYFt&g{(uZIBKxeNI_xbt%VU> zjl39cX&HvZ{Uoxa{3DKfll+zj*uy+zDn_&IG@hi%5%MN$aS2%YTAG{elMNJ3)+}2f zH<+Et-;#HdCKD-yxRJt$@1=3XN6Bl|q@6;DS5p}AV>E8nl#eV_JK1k3#=7(8Vhvq( zQh2haoJ^t#6ivYjRU0GOl2D^ul}r;D-aW{igG@;hS`E{9il%&$Otn+|vW7xaG|MKe z536mcD8n?4ymQEN8F>dSHI_oG=t4-s=;Ce)r)qI8W9gb{f2uDo8@sUGNZp21r;$31 zRJV~@h&2rwsV9(XHBzzIkC7IgNVOSxFCrB%Ql&5%Ki!__lid`WrYRF?v~QZ3*qFI9-!B&^Y4f$U9wA&QGJ>>3(@Rg{NzlN798TN1Hf; zQbiMxbVzV9UkPK#XFL2NG?H6;a` z*rq9e$)E|FUlx*A)p+cPpQNy=DIeHqLiNMeyqUfV=*j-b2;amoG2U55V3xEiGpCg6wt+y_IcS#xe}U2;ao8 zaN&)XGjsjgD6@s|5y13hX6pY9={q$bo?14@|E(ca@n`>QL!uoyM`-^{UB-|lH$aKGR=P1) zmvIz^Gz=*_PnQYQofn|%^Q?3fl7-Y+0kX}q((SW!nM@;)4ncC{>(WZ0`~cmOZ>6^( zrBHD}fQkyNw7o!=sdNg`8<6JC*5y>%HakFDXItqbNa^G)43MYLO1lbmnL!^wdLPn~ zA{`EVM^S(tD6-N8NOo#04$#74D?L@L%WN8l^eLpa5?$ud6D0w9yu?b@QeEa!b7_E9 zm0D>(q*-K{1O4Vezd5=rpfIFiNZAe@&biwG{T$E_QZcDc=;wrfPFCMv~ugxM~#)xsuEgN>T)iPL;4g_Ta_*~dZG$iRY9w29hZjYYG_qGnGRGhlJm(@ z1FdSHRgKOk8l+)J*&bakpl%Pe@<1y{jih>^l^0rhb*a+`q(hJ#wYpqFp;~BF3#}k6 zrD7je&4*R<>2lf7Dc>=L{uUQYFZvVZ&Y=gwdlkUG%=BTV%w6KLX#{S~bv#k`IL-Y;2Z8X z2xd3D3gAtR+ksdNECK4_KL5F*AQd+dF&E(BW&ty(Vr9lWKBof!&JwW{;1iEeXFjQC z1AG#f0Hpw*!gc_ss(6NnjZoLbN*cbIoAZaS0OkV?03X=gEDy*9W&$}tHo${m@55&b z&!YsAfg~Uiz~xl%xxfcAuFn_qc?c@NmkO?f^P%6kcW@=15i2kW;Kn>#DFEk9jxsMI z?nxJCk!OTEoB{CQSXaK0@H`X(`M@k-CScAB&k3h_erj2TGJ3wLu%8#H9^i%K1?39^ zPdZONPwhP5a^Nzc5m*2$1b7I%aJ&pl0A4))1cS@j-q=Vu&Dg-!gV5I=;iS4&S9RNF z3S6B%wBhP>nZOf@K=%ahq>@#ccJmFnr)gl;o&{xhDRPb1S?+|f9ijGBYL@x-+`Z_v zMY+4`&Z4}+S?Pp6XC(DC+m!Q?e!j|PH{YrU(|qUJ7vFyv_0UU&(=C$dy;Z6-db9uM z-i}gPvs!IFFbB(8w$FZkt{1(y%AFOJPOtd{z4rTqWA=0VH>1Rjab6sU-F&M)+Vtq2 z9fw;wxvA63ip0~i%_=Up7>3<^;T~At{gU$eiThBi(pkaXWYU(F={XfxRoEUZW}xZB zhxxrdQ<_nw+Uepd5mn`Nw4_C~n{Vq2>*X^W@-BPxt39c5A3fNjN~70U^jc5HTGT4@ zA-=b#;uq`Bc0PngQ7z1;_}cH~&#g~R@t_3XbfV=Ql-;UknNMy@Pu;6y%Xa=sk$qK8 zSGBWJ;ITQ-VZ*_8L#rDvg`c*ncJpoh$c3`dfm>eYX{>|_=&X-Uw7Tu)8~vHXd*3WM zEt-svRdD$r<*r^Lub`gQDvpb7tKITx+P_-0o9?mafnzCsTNnQ%(j%Ib-%1~&?jcHU zQ?tycbxm{kwIBb?>5P>XPFEH8J7%T&Hd~eX7T?t;=CwV%d)=i9=4*WS>wWC zGq^^b`tNfyMdB5?#w~wHAFuJrJIT?mDv?D}K1-mr?Jm3dN5Ij+1C_3wXU7!D` any, deps: any[]) => + useEffect(() => void fx(), deps); diff --git a/src/common/utils.ts b/src/common/utils.ts index 3d00c8b..1deed16 100644 --- a/src/common/utils.ts +++ b/src/common/utils.ts @@ -1,4 +1,3 @@ -export const delay = async (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)); export const nextFrame = async (): Promise => new Promise((resolve) => requestAnimationFrame(resolve)); export const randInt = (min: number, max: number) => Math.round(min + (max - min - 1) * Math.random()); @@ -50,20 +49,30 @@ export const intHash = (seed: number, ...parts: number[]) => { return h1; }; export const sinHash = (...data: number[]) => data.reduce((hash, n) => Math.sin((hash * 123.12 + n) * 756.12), 0) / 2 + 0.5; -export const throttle = function R>(func: F, ms: number): F { +export const throttle = function R>(func: F, ms: number, trailing = false): F { let isThrottled = false; let savedResult: R; + let savedThis: T; + let savedArgs: A | undefined; const wrapper: F = function (...args: A) { - if (!isThrottled) { + if (isThrottled) { + savedThis = this; + savedArgs = args; + } else { savedResult = func.apply(this, args); + savedArgs = undefined; isThrottled = true; setTimeout(function () { isThrottled = false; + if (trailing && savedArgs) { + savedResult = wrapper.apply(savedThis, savedArgs); + } }, ms); } + return savedResult; } as F; diff --git a/src/games/ai/components/input.tsx b/src/games/ai/components/input.tsx index df02a3c..70beb82 100644 --- a/src/games/ai/components/input.tsx +++ b/src/games/ai/components/input.tsx @@ -5,7 +5,7 @@ import { AutoTextarea } from "./autoTextarea"; export const Input = () => { const { input, setInput, addMessage, continueMessage } = useContext(StateContext); - const { generating } = useContext(LLMContext); + const { generating, stopGeneration } = useContext(LLMContext); const handleSend = useCallback(async () => { if (!generating) { @@ -29,7 +29,10 @@ export const Input = () => { return (
- + {generating + ? + : + }
); } \ No newline at end of file diff --git a/src/games/ai/components/minichat/minichat.tsx b/src/games/ai/components/minichat/minichat.tsx index d096057..7e5b59f 100644 --- a/src/games/ai/components/minichat/minichat.tsx +++ b/src/games/ai/components/minichat/minichat.tsx @@ -16,7 +16,7 @@ interface IProps { } export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps) => { - const { generating, generate, compilePrompt } = useContext(LLMContext); + const { generating, stopGeneration, generate, compilePrompt } = useContext(LLMContext); const [messages, setMessages] = useState([]); const ref = useRef(null); @@ -105,9 +105,10 @@ export const MiniChat = ({ history = [], buttons = {}, open, onClose }: IProps)
- + {generating + ? + : + } diff --git a/src/games/ai/connection.ts b/src/games/ai/connection.ts index 0f6daa5..dd111c8 100644 --- a/src/games/ai/connection.ts +++ b/src/games/ai/connection.ts @@ -1,6 +1,7 @@ import Lock from "@common/lock"; import SSE from "@common/sse"; -import { delay, throttle } from "@common/utils"; +import { throttle } from "@common/utils"; +import delay, { clearDelay } from "delay"; interface IBaseConnection { instruct: string; @@ -72,7 +73,7 @@ const DEFAULT_GENERATION_SETTINGS = { dry_penalty_last_n: 0 } -const MIN_PERFORMANCE = 5.0; +const MIN_PERFORMANCE = 2.0; const MIN_WORKER_CONTEXT = 8192; const MAX_HORDE_LENGTH = 512; const MAX_HORDE_CONTEXT = 32000; @@ -88,7 +89,7 @@ export const normalizeModel = (model: string) => { currentModel = currentModel .replace(/[ ._-]\d+(k$|-context)/i, '') // remove context length, i.e. -32k - .replace(/[ ._-](gptq|awq|exl2?|imat|i\d)/i, '') // remove quant name + .replace(/[ ._-](gptq|awq|exl2?|imat|i\d|h\d)/i, '') // remove quant name .replace(/([ ._-]?gg(uf|ml)[ ._-]?(v[ ._-]?\d)?)/i, '') // remove gguf-v3/ggml/etc .replace(/[ ._-]i?q([ ._-]?\d[ ._-]?(k?[ ._-]?x*[ ._-]?[lms]?)?)+/i, '') // remove quant size .replace(/[ ._-]\d+(\.\d+)?bpw/i, '') // remove bpw @@ -104,14 +105,15 @@ export const normalizeModel = (model: string) => { .trim(); } -export const approximateTokens = (prompt: string): number => - Math.round(prompt.split(/\s+/).length * 0.75); +export const approximateTokens = (prompt: string): number => prompt.split(/[^a-z0-9]+/i).length; export type IGenerationSettings = Partial; export namespace Connection { const AIHORDE = 'https://aihorde.net'; + let abortController = new AbortController(); + async function* generateKobold(url: string, prompt: string, extraSettings: IGenerationSettings = {}): AsyncGenerator { const sse = new SSE(`${url}/api/extra/generate/stream`, { payload: JSON.stringify({ @@ -144,12 +146,14 @@ export namespace Connection { messageLock.release(); }; + abortController.signal.addEventListener('abort', handleEnd); sse.addEventListener('error', handleEnd); sse.addEventListener('abort', handleEnd); sse.addEventListener('readystatechange', (e) => { if (e.readyState === SSE.CLOSED) handleEnd(); }); + while (!end || messages.length) { while (messages.length > 0) { const message = messages.shift(); @@ -189,6 +193,8 @@ export namespace Connection { workers: model.workers, }; + const { signal } = abortController; + const generateResponse = await fetch(`${AIHORDE}/api/v2/generate/text/async`, { method: 'POST', body: JSON.stringify(requestData), @@ -196,31 +202,44 @@ export namespace Connection { 'Content-Type': 'application/json', apikey: connection.apiKey || HORDE_ANON_KEY, }, + signal, }); - if (!generateResponse.ok || Math.floor(generateResponse.status / 100) !== 2) { + if (!generateResponse.ok || generateResponse.status >= 400) { throw new Error(`Error starting generation: ${generateResponse.statusText}: ${await generateResponse.text()}`); } const { id } = await generateResponse.json() as { id: string }; - const deleteRequest = () => fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method: 'DELETE' }) - .catch(e => console.error('Error deleting request', e)); + const request = async (method = 'GET'): Promise => { + const response = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`, { method }); + if (response.ok && response.status < 400) { + const result: IHordeResult = await response.json(); + if (result.generations?.length === 1) { + const { text } = result.generations[0]; - while (true) { - await delay(2500); - - const retrieveResponse = await fetch(`${AIHORDE}/api/v2/generate/text/status/${id}`); - if (!retrieveResponse.ok || Math.floor(retrieveResponse.status / 100) !== 2) { - deleteRequest(); - throw new Error(`Error retrieving generation: ${retrieveResponse.statusText}: ${await retrieveResponse.text()}`); + return text; + } + } else { + throw new Error(await response.text()); } - const result: IHordeResult = await retrieveResponse.json(); + return null; + }; - if (result.done && result.generations?.length === 1) { - const { text } = result.generations[0]; + const deleteRequest = async () => (await request('DELETE')) ?? ''; - return text; + while (true) { + try { + await delay(2500, { signal }); + + const text = await request(); + + if (text) { + return text; + } + } catch (e) { + console.error('Error in horde generation:', e); + return deleteRequest(); } } } @@ -236,15 +255,20 @@ export namespace Connection { } } + export function stopGeneration() { + abortController.abort(); + abortController = new AbortController(); // refresh + } + async function requestHordeModels(): Promise> { try { const response = await fetch(`${AIHORDE}/api/v2/workers?type=text`); if (response.ok) { const workers: IHordeWorker[] = await response.json(); - const goodWorkers = workers.filter(w => - w.online - && !w.maintenance_mode - && !w.flagged + const goodWorkers = workers.filter(w => + w.online + && !w.maintenance_mode + && !w.flagged && w.max_context_length >= MIN_WORKER_CONTEXT && parseFloat(w.performance) >= MIN_PERFORMANCE ); @@ -299,7 +323,7 @@ export namespace Connection { return result; } } catch (e) { - console.log('Error getting max tokens', e); + console.error('Error getting max tokens', e); } } else if (isHordeConnection(connection)) { return connection.model; @@ -317,7 +341,7 @@ export namespace Connection { return value; } } catch (e) { - console.log('Error getting max tokens', e); + console.error('Error getting max tokens', e); } } else if (isHordeConnection(connection)) { const models = await getHordeModels(); @@ -343,7 +367,7 @@ export namespace Connection { return value; } } catch (e) { - console.log('Error counting tokens', e); + console.error('Error counting tokens', e); } } diff --git a/src/games/ai/contexts/llm.tsx b/src/games/ai/contexts/llm.tsx index 4a7eaae..e9f85a9 100644 --- a/src/games/ai/contexts/llm.tsx +++ b/src/games/ai/contexts/llm.tsx @@ -1,13 +1,13 @@ -import Lock from "@common/lock"; -import SSE from "@common/sse"; import { createContext } from "preact"; import { useCallback, useContext, useEffect, useMemo, useState } from "preact/hooks"; import { MessageTools, type IMessage } from "../messages"; -import { Instruct, StateContext } from "./state"; +import { StateContext } from "./state"; import { useBool } from "@common/hooks/useBool"; import { Template } from "@huggingface/jinja"; import { Huggingface } from "../huggingface"; import { approximateTokens, Connection, normalizeModel, type IGenerationSettings } from "../connection"; +import { throttle } from "@common/utils"; +import { useAsyncEffect } from "@common/hooks/useAsyncEffect"; interface ICompileArgs { keepUsers?: number; @@ -22,9 +22,7 @@ interface ICompiledPrompt { interface IContext { generating: boolean; - blockConnection: ReturnType; modelName: string; - modelTemplate: string; hasToolCalls: boolean; promptTokens: number; contextLength: number; @@ -35,6 +33,7 @@ const MESSAGES_TO_KEEP = 10; interface IActions { compilePrompt: (messages: IMessage[], args?: ICompileArgs) => Promise; generate: (prompt: string, extraSettings?: IGenerationSettings) => AsyncGenerator; + stopGeneration: () => void; summarize: (content: string) => Promise; countTokens: (prompt: string) => Promise; } @@ -50,15 +49,13 @@ const processing = { export const LLMContextProvider = ({ children }: { children?: any }) => { const { connection, messages, triggerNext, lore, userPrompt, systemPrompt, bannedWords, summarizePrompt, summaryEnabled, - setTriggerNext, addMessage, editMessage, editSummary, setInstruct, + setTriggerNext, addMessage, editMessage, editSummary, } = useContext(StateContext); const generating = useBool(false); - const blockConnection = useBool(false); const [promptTokens, setPromptTokens] = useState(0); const [contextLength, setContextLength] = useState(0); const [modelName, setModelName] = useState(''); - const [modelTemplate, setModelTemplate] = useState(''); const [hasToolCalls, setHasToolCalls] = useState(false); const userPromptTemplate = useMemo(() => { @@ -71,20 +68,6 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { } }, [userPrompt]); - const getContextLength = useCallback(async () => { - if (!connection || blockConnection.value) { - return 0; - } - return Connection.getContextLength(connection); - }, [connection, blockConnection.value]); - - const getModelName = useCallback(async () => { - if (!connection || blockConnection.value) { - return ''; - } - return Connection.getModelName(connection); - }, [connection, blockConnection.value]); - const actions: IActions = useMemo(() => ({ compilePrompt: async (messages, { keepUsers } = {}) => { const promptMessages = messages.slice(); @@ -179,31 +162,43 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { }, generate: async function* (prompt, extraSettings = {}) { try { - generating.setTrue(); console.log('[LLM.generate]', prompt); yield* Connection.generate(connection, prompt, { - ...extraSettings, + ...extraSettings, banned_tokens: bannedWords.filter(w => w.trim()), }); - } finally { - generating.setFalse(); + } catch (e) { + if (e instanceof Error && e.name !== 'AbortError') { + alert(e.message); + } else { + console.error('[LLM.generate]', e); + } } }, summarize: async (message) => { - const content = Huggingface.applyTemplate(summarizePrompt, { message }); - const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); + try { + const content = Huggingface.applyTemplate(summarizePrompt, { message }); + const prompt = Huggingface.applyChatTemplate(connection.instruct, [{ role: 'user', content }]); + console.log('[LLM.summarize]', prompt); - const tokens = await Array.fromAsync(actions.generate(prompt)); + const tokens = await Array.fromAsync(Connection.generate(connection, prompt, {})); - return MessageTools.trimSentence(tokens.join('')); + return MessageTools.trimSentence(tokens.join('')); + } catch (e) { + console.error('Error summarizing:', e); + return ''; + } }, countTokens: async (prompt) => { return await Connection.countTokens(connection, prompt); }, + stopGeneration: () => { + Connection.stopGeneration(); + }, }), [connection, lore, userPromptTemplate, systemPrompt, bannedWords, summarizePrompt]); - useEffect(() => void (async () => { + useAsyncEffect(async () => { if (triggerNext && !generating.value) { setTriggerNext(false); @@ -217,12 +212,14 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { messageId++; } + generating.setTrue(); editSummary(messageId, 'Generating...'); for await (const chunk of actions.generate(prompt)) { text += chunk; setPromptTokens(promptTokens + approximateTokens(text)); editMessage(messageId, text.trim()); } + generating.setFalse(); text = MessageTools.trimSentence(text); editMessage(messageId, text); @@ -230,10 +227,10 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { MessageTools.playReady(); } - })(), [triggerNext]); + }, [triggerNext]); - useEffect(() => void (async () => { - if (summaryEnabled && !generating.value && !processing.summarizing) { + useAsyncEffect(async () => { + if (summaryEnabled && !processing.summarizing) { try { processing.summarizing = true; for (let id = 0; id < messages.length; id++) { @@ -250,36 +247,15 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { processing.summarizing = false; } } - })(), [messages]); + }, [messages, summaryEnabled]); - useEffect(() => { - if (!blockConnection.value) { - setPromptTokens(0); - setContextLength(0); - setModelName(''); + useEffect(throttle(() => { + Connection.getContextLength(connection).then(setContextLength); + Connection.getModelName(connection).then(normalizeModel).then(setModelName); + }, 1000, true), [connection]); - getContextLength().then(setContextLength); - getModelName().then(normalizeModel).then(setModelName); - } - }, [connection, blockConnection.value]); - - useEffect(() => { - setModelTemplate(''); - if (modelName) { - Huggingface.findModelTemplate(modelName) - .then((template) => { - if (template) { - setModelTemplate(template); - setInstruct(template); - } else { - setInstruct(Instruct.CHATML); - } - }); - } - }, [modelName]); - - const calculateTokens = useCallback(async () => { - if (!processing.tokenizing && !blockConnection.value && !generating.value) { + const calculateTokens = useCallback(throttle(async () => { + if (!processing.tokenizing && !generating.value) { try { processing.tokenizing = true; const { prompt } = await actions.compilePrompt(messages); @@ -291,11 +267,11 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { processing.tokenizing = false; } } - }, [actions, messages, blockConnection.value]); + }, 1000, true), [actions, messages]); useEffect(() => { calculateTokens(); - }, [messages, connection, blockConnection.value, /* systemPrompt, lore, userPrompt TODO debounce*/]); + }, [messages, connection, systemPrompt, lore, userPrompt]); useEffect(() => { try { @@ -308,9 +284,7 @@ export const LLMContextProvider = ({ children }: { children?: any }) => { const rawContext: IContext = { generating: generating.value, - blockConnection, modelName, - modelTemplate, hasToolCalls, promptTokens, contextLength, diff --git a/src/games/ai/contexts/state.tsx b/src/games/ai/contexts/state.tsx index d7a076b..aaa8334 100644 --- a/src/games/ai/contexts/state.tsx +++ b/src/games/ai/contexts/state.tsx @@ -83,7 +83,7 @@ What should happen next in your answer: {{ prompt | trim }} {% endif %} Remember that this story should be infinite and go forever. Make sure to follow the world description and rules exactly. Avoid cliffhangers and pauses, be creative.`, - summarizePrompt: 'Shrink following text down, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', + summarizePrompt: 'Shrink following text down to one paragraph, keeping all important details:\n\n{{ message }}\n\nAnswer with shortened text only.', summaryEnabled: false, bannedWords: [], messages: [], diff --git a/src/games/ai/huggingface.ts b/src/games/ai/huggingface.ts index 3fd6ddb..630504c 100644 --- a/src/games/ai/huggingface.ts +++ b/src/games/ai/huggingface.ts @@ -1,6 +1,7 @@ import { gguf } from '@huggingface/gguf'; import * as hub from '@huggingface/hub'; import { Template } from '@huggingface/jinja'; +import { normalizeModel } from './connection'; export namespace Huggingface { export interface ITemplateMessage { @@ -92,11 +93,12 @@ export namespace Huggingface { const loadHuggingfaceTokenizerConfig = async (modelName: string): Promise => { console.log(`[huggingface] searching config for '${modelName}'`); + const searchModel = normalizeModel(modelName); - const hubModels = await Array.fromAsync(hub.listModels({ search: { query: modelName }, additionalFields: ['config'] })); + const hubModels = await Array.fromAsync(hub.listModels({ search: { query: searchModel }, additionalFields: ['config'] })); const models = hubModels.filter(m => { if (m.gated) return false; - if (!m.name.toLowerCase().includes(modelName.toLowerCase())) return false; + if (!normalizeModel(m.name).includes(searchModel)) return false; return true; }).sort((a, b) => b.downloads - a.downloads);