diff --git a/src/common/utils.ts b/src/common/utils.ts index 353316f..8ae716b 100644 --- a/src/common/utils.ts +++ b/src/common/utils.ts @@ -161,6 +161,16 @@ export const formatTime = (seconds: number): string => { return parts.join(' '); }; +export const fuzzyMatch = (target: string, query: string): boolean => { + const t = target.toLowerCase(); + const q = query.toLowerCase(); + let qi = 0; + for (let ti = 0; ti < t.length && qi < q.length; ti++) { + if (t[ti] === q[qi]) qi++; + } + return qi === q.length; +}; + export const extractString = (e: Event | string): string => { if (typeof e === 'string') { return e; diff --git a/src/games/hordeseer/utils/calculations.ts b/src/games/hordeseer/utils/calculations.ts index f9d6253..9100736 100644 --- a/src/games/hordeseer/utils/calculations.ts +++ b/src/games/hordeseer/utils/calculations.ts @@ -1,4 +1,4 @@ -import { WorkerData } from "../utils/api"; +import { type WorkerData } from "../utils/api"; /** * Calculate total kudos/hour across a set of workers. diff --git a/src/games/storywriter/components/settings/connection.tsx b/src/games/storywriter/components/settings/connection.tsx index fe659fc..eedab87 100644 --- a/src/games/storywriter/components/settings/connection.tsx +++ b/src/games/storywriter/components/settings/connection.tsx @@ -2,6 +2,7 @@ import { useBool } from "@common/hooks/useBool"; import { useQuery } from "@common/hooks/useAsyncState"; import { useInputState } from "@common/hooks/useInputState"; import { useUpdate } from "@common/hooks/useUpdate"; +import { fuzzyMatch } from "@common/utils"; import clsx from "clsx"; import { useMemo, useRef } from "preact/hooks"; import styles from "../../assets/settings-modal.module.css"; @@ -9,10 +10,11 @@ import { useAppState } from "../../contexts/state"; import LLM from "../../utils/llm"; export const ConnectionSettings = () => { - const { connection, model, dispatch } = useAppState(); + const { connection, model, imageModel, dispatch } = useAppState(); const [url, setUrl] = useInputState(connection?.url ?? ""); const [apiKey, setApiKey] = useInputState(connection?.apiKey ?? ""); const [selectedModel, setSelectedModel] = useInputState(model?.id ?? ""); + const [selectedImageModel, setSelectedImageModel] = useInputState(imageModel?.id ?? ""); const [update, triggerFetch] = useUpdate(); const showPassword = useBool(false); @@ -34,7 +36,14 @@ export const ConnectionSettings = () => { return r.data; }, []); + const fetchImageModels = useMemo(() => async (conn: LLM.Connection | null) => { + if (!conn) return []; + const r = await LLM.getImageModels(conn); + return r.data; + }, []); + const modelsData = useQuery(fetchModels, connectionToFetch); + const imageModelsData = useQuery(fetchImageModels, connectionToFetch); const isLoadingModels = connectionToFetch != null && modelsData == undefined; const [modelFilter, setModelFilter] = useInputState(""); @@ -55,23 +64,20 @@ export const ConnectionSettings = () => { const filteredGroupedModels = useMemo(() => { if (!modelFilter) return groupedModels; - const query = modelFilter.toLowerCase(); - const fuzzyMatch = (target: string) => { - const t = target.toLowerCase(); - let qi = 0; - for (let ti = 0; ti < t.length && qi < query.length; ti++) { - if (t[ti] === query[qi]) qi++; - } - return qi === query.length; - }; return groupedModels .map(({ context, models }) => ({ context, - models: models.filter(m => m.id === selectedModel || fuzzyMatch(m.id)), + models: models.filter(m => m.id === selectedModel || fuzzyMatch(m.id, modelFilter)), })) .filter(({ models }) => models.length > 0); }, [groupedModels, modelFilter, selectedModel]); + const filteredImageModels = useMemo(() => { + const sorted = [...(imageModelsData ?? [])].sort((a, b) => a.id.localeCompare(b.id)); + if (!modelFilter) return sorted; + return sorted.filter(m => m.id === selectedImageModel || fuzzyMatch(m.id, modelFilter)); + }, [imageModelsData, modelFilter, selectedImageModel]); + const handleBlur = () => { if (url && apiKey) { dispatch({ type: "SET_CONNECTION", connection: { url, apiKey } }); @@ -93,6 +99,13 @@ export const ConnectionSettings = () => { dispatch({ type: "SET_MODEL", model: selectedModelInfo }); }; + const handleImageModelChange = (e: Event) => { + setSelectedImageModel(e); + const target = e.target as HTMLSelectElement; + const selectedModelInfo = imageModelsData?.find(m => m.id === target.value) ?? null; + dispatch({ type: "SET_IMAGE_MODEL", model: selectedModelInfo }); + }; + const connectionToTest = url && apiKey ? { url, apiKey } : null; return ( @@ -173,6 +186,30 @@ export const ConnectionSettings = () => {
Enter connection details to load models
)} +Loading models...
+ ) : imageModelsData.length > 0 ? ( + + ) : ( +No image models available
+ ) + ) : ( +Enter connection details to load models
+ )} +