diff --git a/src/api/providers/fetchers/__tests__/openrouter.spec.ts b/src/api/providers/fetchers/__tests__/openrouter.spec.ts index 44892d2024a..e0ab7f5c9a5 100644 --- a/src/api/providers/fetchers/__tests__/openrouter.spec.ts +++ b/src/api/providers/fetchers/__tests__/openrouter.spec.ts @@ -24,6 +24,7 @@ describe("OpenRouter API", () => { const models = await getOpenRouterModels() const openRouterSupportedCaching = Object.entries(models) + .filter(([id, _]) => id.startsWith("anthropic/claude") || id.startsWith("google/gemini")) // only these support cache_control breakpoints (https://openrouter.ai/docs/features/prompt-caching) .filter(([_, model]) => model.supportsPromptCache) .map(([id, _]) => id) @@ -229,7 +230,7 @@ describe("OpenRouter API", () => { const endpoints = await getOpenRouterModelEndpoints("google/gemini-2.5-pro-preview") expect(endpoints).toEqual({ - Google: { + "google-vertex": { maxTokens: 65535, contextWindow: 1048576, supportsImages: true, @@ -243,7 +244,7 @@ describe("OpenRouter API", () => { supportsReasoningEffort: undefined, supportedParameters: undefined, }, - "Google AI Studio": { + "google-ai-studio": { maxTokens: 65536, contextWindow: 1048576, supportsImages: true, diff --git a/src/api/providers/fetchers/openrouter.ts b/src/api/providers/fetchers/openrouter.ts index be8fb26f7a3..89f971a2e66 100644 --- a/src/api/providers/fetchers/openrouter.ts +++ b/src/api/providers/fetchers/openrouter.ts @@ -58,6 +58,7 @@ export type OpenRouterModel = z.infer export const openRouterModelEndpointSchema = modelRouterBaseModelSchema.extend({ provider_name: z.string(), + tag: z.string().optional(), }) export type OpenRouterModelEndpoint = z.infer @@ -149,7 +150,7 @@ export async function getOpenRouterModelEndpoints( const { id, architecture, endpoints } = data for (const endpoint of endpoints) { - models[endpoint.provider_name] = parseOpenRouterModel({ + models[endpoint.tag ?? endpoint.provider_name] = parseOpenRouterModel({ id, model: endpoint, modality: architecture?.modality, @@ -188,7 +189,7 @@ export const parseOpenRouterModel = ({ const cacheReadsPrice = model.pricing?.input_cache_read ? parseApiPrice(model.pricing?.input_cache_read) : undefined - const supportsPromptCache = typeof cacheWritesPrice !== "undefined" && typeof cacheReadsPrice !== "undefined" + const supportsPromptCache = typeof cacheReadsPrice !== "undefined" // some models support caching but don't charge a cacheWritesPrice, e.g. GPT-5 const modelInfo: ModelInfo = { maxTokens: maxTokens || Math.ceil(model.context_length * 0.2), diff --git a/webview-ui/src/components/ui/hooks/useOpenRouterModelProviders.ts b/webview-ui/src/components/ui/hooks/useOpenRouterModelProviders.ts index dc50c0f6a6d..3a2f23dabc5 100644 --- a/webview-ui/src/components/ui/hooks/useOpenRouterModelProviders.ts +++ b/webview-ui/src/components/ui/hooks/useOpenRouterModelProviders.ts @@ -22,12 +22,15 @@ const openRouterEndpointsSchema = z.object({ endpoints: z.array( z.object({ name: z.string(), + tag: z.string().optional(), context_length: z.number(), max_completion_tokens: z.number().nullish(), pricing: z .object({ prompt: z.union([z.string(), z.number()]).optional(), completion: z.union([z.string(), z.number()]).optional(), + input_cache_read: z.union([z.string(), z.number()]).optional(), + input_cache_write: z.union([z.string(), z.number()]).optional(), }) .optional(), }), @@ -51,49 +54,28 @@ async function getOpenRouterProvidersForModel(modelId: string) { return models } - const { id, description, architecture, endpoints } = result.data.data + const { description, architecture, endpoints } = result.data.data for (const endpoint of endpoints) { - const providerName = endpoint.name.split("|")[0].trim() + const providerName = endpoint.tag ?? endpoint.name const inputPrice = parseApiPrice(endpoint.pricing?.prompt) const outputPrice = parseApiPrice(endpoint.pricing?.completion) + const cacheReadsPrice = parseApiPrice(endpoint.pricing?.input_cache_read) + const cacheWritesPrice = parseApiPrice(endpoint.pricing?.input_cache_write) const modelInfo: OpenRouterModelProvider = { maxTokens: endpoint.max_completion_tokens || endpoint.context_length, contextWindow: endpoint.context_length, supportsImages: architecture?.modality?.includes("image"), - supportsPromptCache: false, + supportsPromptCache: typeof cacheReadsPrice !== "undefined", + cacheReadsPrice, + cacheWritesPrice, inputPrice, outputPrice, description, label: providerName, } - // TODO: This is wrong. We need to fetch the model info from - // OpenRouter instead of hardcoding it here. The endpoints payload - // doesn't include this unfortunately, so we need to get it from the - // main models endpoint. - switch (true) { - case modelId.startsWith("anthropic/claude-3.7-sonnet"): - modelInfo.supportsComputerUse = true - modelInfo.supportsPromptCache = true - modelInfo.cacheWritesPrice = 3.75 - modelInfo.cacheReadsPrice = 0.3 - modelInfo.maxTokens = id === "anthropic/claude-3.7-sonnet:thinking" ? 64_000 : 8192 - break - case modelId.startsWith("anthropic/claude-3.5-sonnet-20240620"): - modelInfo.supportsPromptCache = true - modelInfo.cacheWritesPrice = 3.75 - modelInfo.cacheReadsPrice = 0.3 - modelInfo.maxTokens = 8192 - break - default: - modelInfo.supportsPromptCache = true - modelInfo.cacheWritesPrice = 0.3 - modelInfo.cacheReadsPrice = 0.03 - break - } - models[providerName] = modelInfo } } catch (error) {