diff --git a/src/common/utils.ts b/src/common/utils.ts index 353316f..8ae716b 100644 --- a/src/common/utils.ts +++ b/src/common/utils.ts @@ -161,6 +161,16 @@ export const formatTime = (seconds: number): string => { return parts.join(' '); }; +export const fuzzyMatch = (target: string, query: string): boolean => { + const t = target.toLowerCase(); + const q = query.toLowerCase(); + let qi = 0; + for (let ti = 0; ti < t.length && qi < q.length; ti++) { + if (t[ti] === q[qi]) qi++; + } + return qi === q.length; +}; + export const extractString = (e: Event | string): string => { if (typeof e === 'string') { return e; diff --git a/src/games/hordeseer/utils/calculations.ts b/src/games/hordeseer/utils/calculations.ts index f9d6253..9100736 100644 --- a/src/games/hordeseer/utils/calculations.ts +++ b/src/games/hordeseer/utils/calculations.ts @@ -1,4 +1,4 @@ -import { WorkerData } from "../utils/api"; +import { type WorkerData } from "../utils/api"; /** * Calculate total kudos/hour across a set of workers. diff --git a/src/games/storywriter/components/settings/connection.tsx b/src/games/storywriter/components/settings/connection.tsx index fe659fc..eedab87 100644 --- a/src/games/storywriter/components/settings/connection.tsx +++ b/src/games/storywriter/components/settings/connection.tsx @@ -2,6 +2,7 @@ import { useBool } from "@common/hooks/useBool"; import { useQuery } from "@common/hooks/useAsyncState"; import { useInputState } from "@common/hooks/useInputState"; import { useUpdate } from "@common/hooks/useUpdate"; +import { fuzzyMatch } from "@common/utils"; import clsx from "clsx"; import { useMemo, useRef } from "preact/hooks"; import styles from "../../assets/settings-modal.module.css"; @@ -9,10 +10,11 @@ import { useAppState } from "../../contexts/state"; import LLM from "../../utils/llm"; export const ConnectionSettings = () => { - const { connection, model, dispatch } = useAppState(); + const { connection, model, imageModel, dispatch } = useAppState(); const [url, setUrl] = useInputState(connection?.url ?? ""); const [apiKey, setApiKey] = useInputState(connection?.apiKey ?? ""); const [selectedModel, setSelectedModel] = useInputState(model?.id ?? ""); + const [selectedImageModel, setSelectedImageModel] = useInputState(imageModel?.id ?? ""); const [update, triggerFetch] = useUpdate(); const showPassword = useBool(false); @@ -34,7 +36,14 @@ export const ConnectionSettings = () => { return r.data; }, []); + const fetchImageModels = useMemo(() => async (conn: LLM.Connection | null) => { + if (!conn) return []; + const r = await LLM.getImageModels(conn); + return r.data; + }, []); + const modelsData = useQuery(fetchModels, connectionToFetch); + const imageModelsData = useQuery(fetchImageModels, connectionToFetch); const isLoadingModels = connectionToFetch != null && modelsData == undefined; const [modelFilter, setModelFilter] = useInputState(""); @@ -55,23 +64,20 @@ export const ConnectionSettings = () => { const filteredGroupedModels = useMemo(() => { if (!modelFilter) return groupedModels; - const query = modelFilter.toLowerCase(); - const fuzzyMatch = (target: string) => { - const t = target.toLowerCase(); - let qi = 0; - for (let ti = 0; ti < t.length && qi < query.length; ti++) { - if (t[ti] === query[qi]) qi++; - } - return qi === query.length; - }; return groupedModels .map(({ context, models }) => ({ context, - models: models.filter(m => m.id === selectedModel || fuzzyMatch(m.id)), + models: models.filter(m => m.id === selectedModel || fuzzyMatch(m.id, modelFilter)), })) .filter(({ models }) => models.length > 0); }, [groupedModels, modelFilter, selectedModel]); + const filteredImageModels = useMemo(() => { + const sorted = [...(imageModelsData ?? [])].sort((a, b) => a.id.localeCompare(b.id)); + if (!modelFilter) return sorted; + return sorted.filter(m => m.id === selectedImageModel || fuzzyMatch(m.id, modelFilter)); + }, [imageModelsData, modelFilter, selectedImageModel]); + const handleBlur = () => { if (url && apiKey) { dispatch({ type: "SET_CONNECTION", connection: { url, apiKey } }); @@ -93,6 +99,13 @@ export const ConnectionSettings = () => { dispatch({ type: "SET_MODEL", model: selectedModelInfo }); }; + const handleImageModelChange = (e: Event) => { + setSelectedImageModel(e); + const target = e.target as HTMLSelectElement; + const selectedModelInfo = imageModelsData?.find(m => m.id === target.value) ?? null; + dispatch({ type: "SET_IMAGE_MODEL", model: selectedModelInfo }); + }; + const connectionToTest = url && apiKey ? { url, apiKey } : null; return ( @@ -173,6 +186,30 @@ export const ConnectionSettings = () => {

Enter connection details to load models

)} +
+ + {connectionToTest ? ( + imageModelsData == undefined ? ( +

Loading models...

+ ) : imageModelsData.length > 0 ? ( + + ) : ( +

No image models available

+ ) + ) : ( +

Enter connection details to load models

+ )} +
); }; diff --git a/src/games/storywriter/contexts/state.tsx b/src/games/storywriter/contexts/state.tsx index 31144ca..91d63b2 100644 --- a/src/games/storywriter/contexts/state.tsx +++ b/src/games/storywriter/contexts/state.tsx @@ -145,7 +145,8 @@ interface IState { currentTab: Tab; chatOpen: boolean; connection: LLM.Connection | null; - model: LLM.ModelInfo | null; + model: LLM.ModelInfoText | null; + imageModel: LLM.ModelInfoImage | null; enableThinking: boolean; bannedTokens: string[]; systemInstruction: string; @@ -193,7 +194,8 @@ type Action = | { type: 'EDIT_CHAT_MESSAGE'; worldId: string; storyId: string; messageId: string; content: string } // Connection | { type: 'SET_CONNECTION'; connection: LLM.Connection | null } - | { type: 'SET_MODEL'; model: LLM.ModelInfo | null } + | { type: 'SET_MODEL'; model: LLM.ModelInfoText | null } + | { type: 'SET_IMAGE_MODEL'; model: LLM.ModelInfoImage | null } | { type: 'SET_ENABLE_THINKING'; enable: boolean } | { type: 'SET_BANNED_TOKENS'; tokens: string[] } // Characters @@ -253,6 +255,7 @@ const DEFAULT_STATE: IState = { chatOpen: false, connection: null, model: null, + imageModel: null, enableThinking: false, bannedTokens: [], userName: 'User', @@ -500,6 +503,9 @@ function reducer(state: IState, action: Action): IState { case 'SET_MODEL': { return { ...state, model: action.model }; } + case 'SET_IMAGE_MODEL': { + return { ...state, imageModel: action.model }; + } case 'SET_ENABLE_THINKING': { return { ...state, enableThinking: action.enable }; } @@ -647,7 +653,8 @@ export interface AppState { currentTab: Tab; chatOpen: boolean; connection: LLM.Connection | null; - model: LLM.ModelInfo | null; + model: LLM.ModelInfoText | null; + imageModel: LLM.ModelInfoImage | null; enableThinking: boolean; bannedTokens: string[]; systemInstruction: string; @@ -702,6 +709,7 @@ export const StateContextProvider = ({ children }: { children?: any }) => { chatOpen: state.chatOpen, connection: state.connection, model: state.model, + imageModel: state.imageModel ?? null, enableThinking: state.enableThinking, bannedTokens: state.bannedTokens ?? [], systemInstruction, diff --git a/src/games/storywriter/utils/llm.ts b/src/games/storywriter/utils/llm.ts index ddefe00..855dae3 100644 --- a/src/games/storywriter/utils/llm.ts +++ b/src/games/storywriter/utils/llm.ts @@ -132,7 +132,7 @@ namespace LLM { }; } - interface ModelInfoText extends BaseModelInfo { + export interface ModelInfoText extends BaseModelInfo { context_length: number; top_provider: { context_length: number; @@ -141,7 +141,7 @@ namespace LLM { }; } - interface ModelInfoImage extends BaseModelInfo { + export interface ModelInfoImage extends BaseModelInfo { } export type ModelInfo = ModelInfoText | ModelInfoImage; diff --git a/test/common/utils.test.ts b/test/common/utils.test.ts index ec3ffdd..72fe77a 100644 --- a/test/common/utils.test.ts +++ b/test/common/utils.test.ts @@ -18,6 +18,7 @@ import { callUpdater, formatTime, formatNumber, + fuzzyMatch, } from '@common/utils'; describe('utils', () => { @@ -365,6 +366,45 @@ describe('utils', () => { }); }); + describe('fuzzyMatch', () => { + it('returns true for exact match', () => { + expect(fuzzyMatch('hello', 'hello')).toBe(true); + }); + + it('returns true when query chars appear in order', () => { + expect(fuzzyMatch('Deliberate', 'dlt')).toBe(true); + }); + + it('returns true for non-contiguous subsequence', () => { + expect(fuzzyMatch('Analog Diffusion', 'andif')).toBe(true); + }); + + it('returns false when query chars are out of order', () => { + expect(fuzzyMatch('abc', 'ca')).toBe(false); + }); + + it('returns false when a query char is missing from target', () => { + expect(fuzzyMatch('Deliberate', 'dltz')).toBe(false); + }); + + it('returns true for empty query', () => { + expect(fuzzyMatch('anything', '')).toBe(true); + }); + + it('returns false for empty target with non-empty query', () => { + expect(fuzzyMatch('', 'a')).toBe(false); + }); + + it('returns false when query is longer than target', () => { + expect(fuzzyMatch('ab', 'abc')).toBe(false); + }); + + it('is case-insensitive', () => { + expect(fuzzyMatch('DreamShaper', 'dreamshaper')).toBe(true); + expect(fuzzyMatch('dreamshaper', 'DREAM')).toBe(true); + }); + }); + describe('formatTime', () => { it('should return 0:00 for zero seconds', () => { expect(formatTime(0)).toBe('0:00');