diff --git a/src/games/storywriter/components/settings/connection.tsx b/src/games/storywriter/components/settings/connection.tsx index 4eecefa..fe659fc 100644 --- a/src/games/storywriter/components/settings/connection.tsx +++ b/src/games/storywriter/components/settings/connection.tsx @@ -30,7 +30,7 @@ export const ConnectionSettings = () => { const fetchModels = useMemo(() => async (conn: LLM.Connection | null) => { if (!conn) return []; - const r = await LLM.getModels(conn); + const r = await LLM.getTextModels(conn); return r.data; }, []); diff --git a/src/games/storywriter/utils/llm.ts b/src/games/storywriter/utils/llm.ts index ac6b349..ddefe00 100644 --- a/src/games/storywriter/utils/llm.ts +++ b/src/games/storywriter/utils/llm.ts @@ -118,13 +118,22 @@ namespace LLM { error: string; } - export interface ModelInfo { + type Modality = 'text' | 'image'; + + interface BaseModelInfo { id: string; object: 'model'; created: number; owned_by: string; - context_length: number; supported_parameters: string[]; + architecture?: { + input_modalities: Modality[]; + output_modalities: Modality[]; + }; + } + + interface ModelInfoText extends BaseModelInfo { + context_length: number; top_provider: { context_length: number; max_completion_tokens: number; @@ -132,11 +141,22 @@ namespace LLM { }; } - export interface ModelsResponse { - object: 'list'; - data: ModelInfo[]; + interface ModelInfoImage extends BaseModelInfo { } + export type ModelInfo = ModelInfoText | ModelInfoImage; + + const isTextModel = (model: ModelInfo): model is ModelInfoText => ('context_length' in model); + const isImageModel = (model: ModelInfo): model is ModelInfoImage => Boolean( + !isTextModel(model) && + model.architecture && + (model.architecture.output_modalities).includes('image') + ); + + export interface ModelsResponse { + object: 'list'; + data: T[]; + } interface CountTokensRequestString { model: string; @@ -246,8 +266,18 @@ namespace LLM { return e != null && typeof e === 'object' && 'data' in e && typeof e.data === 'string'; } - export async function getModels(connection: Connection): Promise { - return request(connection, '/v1/models'); + export async function getTextModels(connection: Connection): Promise> { + const response = await request(connection, '/v1/models'); + + response.data = response.data.filter(isTextModel); + return response as ModelsResponse; + } + + export async function getImageModels(connection: Connection): Promise> { + const response = await request(connection, '/v1/models'); + + response.data = response.data.filter(isImageModel); + return response as ModelsResponse; } export async function countTokens(connection: Connection, body: CountTokensRequest) { @@ -262,7 +292,7 @@ namespace LLM { } export async function generate(connection: Connection, config: ChatCompletionRequest) { - return request(connection, '/v1/chat/completions', 'POST', { + return request(connection, '/v1/chat/completions', 'POST', { ...config, stream: false, });