diff --git a/src/nodeBridge.ts b/src/nodeBridge.ts index d4882618..084b3133 100644 --- a/src/nodeBridge.ts +++ b/src/nodeBridge.ts @@ -436,6 +436,62 @@ class NodeHandlerRegistry { }; }); + // test a specific model + this.messageBus.registerHandler('models.test', async (data) => { + const { cwd, modelId } = data; + const context = await this.getContext(cwd); + const startTime = Date.now(); + + try { + const { model, error } = await resolveModelWithContext( + modelId, + context, + ); + + if (error || !model) { + return { + success: false, + error: + error instanceof Error + ? error.message + : error || 'Model not found', + responseTime: Date.now() - startTime, + }; + } + + const m = await model._mCreator(); + const result = await m.doGenerate({ + prompt: [ + { + role: 'user', + content: [{ type: 'text', text: 'Hi' }], + }, + ], + }); + + const hasResponse = result.content?.some( + (c) => c.type === 'text' && (c as { text: string }).text, + ); + + return { + success: true, + data: { + modelId, + providerName: model.provider.name, + modelName: model.model.name, + responseTime: Date.now() - startTime, + hasResponse, + }, + }; + } catch (err) { + return { + success: false, + error: err instanceof Error ? err.message : 'Unknown error', + responseTime: Date.now() - startTime, + }; + } + }); + ////////////////////////////////////////////// // outputStyles this.messageBus.registerHandler('outputStyles.list', async (data) => { @@ -2356,6 +2412,17 @@ function buildSignalKey(cwd: string, sessionId: string) { return `${cwd}/${sessionId}`; } +// Default API endpoints for providers that use SDK defaults +const DEFAULT_PROVIDER_ENDPOINTS: Record = { + openai: 'https://api.openai.com', + google: 'https://generativelanguage.googleapis.com', + xai: 'https://api.x.ai', + anthropic: 'https://api.anthropic.com', + openrouter: 'https://openrouter.ai', + cerebras: 'https://api.cerebras.ai', + antigravity: 'https://antigravity.google', +}; + function normalizeProviders(providers: ProvidersMap, context: Context) { return Object.values(providers as Record).map( (provider) => { @@ -2382,12 +2449,16 @@ function normalizeProviders(providers: ProvidersMap, context: Context) { provider.options?.apiKey || context.config.provider?.[provider.id]?.options?.apiKey ); + // Use provider.api or fallback to default endpoint + const api = + provider.api || DEFAULT_PROVIDER_ENDPOINTS[provider.id] || undefined; return { id: provider.id, name: provider.name, doc: provider.doc, env: provider.env, apiEnv: provider.apiEnv, + api, validEnvs, hasApiKey, }; diff --git a/src/nodeBridge.types.ts b/src/nodeBridge.types.ts index f6ce90fd..941654d1 100644 --- a/src/nodeBridge.types.ts +++ b/src/nodeBridge.types.ts @@ -230,6 +230,23 @@ type ModelsListOutput = { }; }; +type ModelsTestInput = { + cwd: string; + modelId: string; +}; +type ModelsTestOutput = { + success: boolean; + error?: string; + responseTime: number; + data?: { + modelId: string; + providerName: string; + modelName: string; + responseTime: number; + hasResponse: boolean; + }; +}; + // ============================================================================ // Output Styles Handlers // ============================================================================ @@ -435,6 +452,7 @@ type ProvidersListOutput = { doc?: string; env?: string[]; apiEnv?: string[]; + api?: string; validEnvs: string[]; hasApiKey: boolean; }>; @@ -863,6 +881,7 @@ export type HandlerMap = { // Models handlers 'models.list': { input: ModelsListInput; output: ModelsListOutput }; + 'models.test': { input: ModelsTestInput; output: ModelsTestOutput }; // Output styles handlers 'outputStyles.list': { diff --git a/src/slash-commands/builtin/index.ts b/src/slash-commands/builtin/index.ts index a9dadba1..d7459da5 100644 --- a/src/slash-commands/builtin/index.ts +++ b/src/slash-commands/builtin/index.ts @@ -12,6 +12,7 @@ import { createLogoutCommand } from './logout'; import { createMcpCommand } from './mcp'; import { createModelCommand } from './model'; import { createOutputStyleCommand } from './output-style'; +import { pingCommand } from './ping'; import { createResumeCommand } from './resume'; import { createReviewCommand } from './review'; import { brainstormCommand } from './spec/brainstorm'; @@ -38,6 +39,7 @@ export function createBuiltinCommands(opts: { createMcpCommand(opts), createModelCommand(opts), createOutputStyleCommand(), + pingCommand, createResumeCommand(), createReviewCommand(opts.language), createTerminalSetupCommand(), diff --git a/src/slash-commands/builtin/ping.tsx b/src/slash-commands/builtin/ping.tsx new file mode 100644 index 00000000..516c37bf --- /dev/null +++ b/src/slash-commands/builtin/ping.tsx @@ -0,0 +1,533 @@ +import { Box, Text } from 'ink'; +import Spinner from 'ink-spinner'; +import React, { useEffect, useState } from 'react'; +import type { AssistantMessage } from '../../message'; +import { GradientText } from '../../ui/GradientText'; +import { useAppStore } from '../../ui/store'; +import { useTextGradientAnimation } from '../../ui/useTextGradientAnimation'; +import type { LocalJSXCommand } from '../types'; + +interface PingResult { + providerId: string; + providerName: string; + status: 'pending' | 'testing' | 'success' | 'failed'; + responseTime?: number; + error?: string; +} + +interface ModelTestResult { + modelId: string; + modelName: string; + providerName: string; + status: 'pending' | 'testing' | 'success' | 'failed'; + responseTime?: number; + error?: string; +} + +async function pingEndpoint(endpoint: string): Promise<{ + status: 'success' | 'failed'; + responseTime: number; + error?: string; +}> { + const startTime = Date.now(); + try { + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 10000); + await fetch(endpoint, { method: 'HEAD', signal: controller.signal }); + clearTimeout(timeoutId); + return { status: 'success', responseTime: Date.now() - startTime }; + } catch (error) { + return { + status: 'failed', + responseTime: Date.now() - startTime, + error: error instanceof Error ? error.message : 'Network error', + }; + } +} + +function getLatencyColor(ms: number): string { + if (ms < 200) return 'green'; + if (ms < 500) return 'yellow'; + if (ms < 1000) return 'magenta'; + return 'red'; +} + +function getLatencyBar(ms: number, maxWidth = 20): string { + const normalized = Math.min(ms / 1000, 1); + const filled = Math.round(normalized * maxWidth); + return '█'.repeat(filled) + '░'.repeat(maxWidth - filled); +} + +function formatResultsText(results: PingResult[]): string { + const sorted = [...results].sort((a, b) => { + if (a.status === 'success' && b.status === 'failed') return -1; + if (a.status === 'failed' && b.status === 'success') return 1; + return (a.responseTime || 0) - (b.responseTime || 0); + }); + const successResults = sorted.filter((r) => r.status === 'success'); + const fastestResult = successResults[0]; + const maxNameLen = Math.max(...sorted.map((r) => r.providerName.length)); + const lines: string[] = ['📡 Ping', '']; + + for (const result of sorted) { + const paddedName = result.providerName.padEnd(maxNameLen); + if (result.status === 'success' && result.responseTime !== undefined) { + const ms = result.responseTime; + const icon = ms < 200 ? '🟢' : ms < 500 ? '🟡' : ms < 1000 ? '🟠' : '🔴'; + const badge = + fastestResult?.providerId === result.providerId ? ' ⚡' : ''; + lines.push(`${icon} ${paddedName} ${ms}ms${badge}`); + } else { + lines.push(`❌ ${paddedName} ${result.error || 'Failed'}`); + } + } + lines.push('', '🟢 <200ms 🟡 200-500ms 🟠 500-1000ms 🔴 >1000ms'); + return lines.join('\n'); +} + +function formatModelResultsText(results: ModelTestResult[]): string { + // Group by provider + const byProvider = new Map(); + for (const result of results) { + const list = byProvider.get(result.providerName) || []; + list.push(result); + byProvider.set(result.providerName, list); + } + + // Sort each provider's models by response time + for (const [, models] of byProvider) { + models.sort((a, b) => { + if (a.status === 'success' && b.status === 'failed') return -1; + if (a.status === 'failed' && b.status === 'success') return 1; + return (a.responseTime || 0) - (b.responseTime || 0); + }); + } + + const successResults = results.filter((r) => r.status === 'success'); + const fastestResult = successResults.sort( + (a, b) => (a.responseTime || 0) - (b.responseTime || 0), + )[0]; + + const maxNameLen = Math.max(...results.map((r) => r.modelName.length)); + const lines: string[] = ['🤖 Model Ping', '']; + + for (const [providerName, models] of byProvider) { + lines.push(`🏷️ ${providerName}`); + for (const result of models) { + const paddedName = result.modelName.padEnd(maxNameLen); + if (result.status === 'success' && result.responseTime !== undefined) { + const ms = result.responseTime; + const icon = + ms < 2000 ? '🟢' : ms < 5000 ? '🟡' : ms < 10000 ? '🟠' : '🔴'; + const badge = fastestResult?.modelId === result.modelId ? ' ⚡' : ''; + lines.push(` ${icon} ${paddedName} ${ms}ms${badge}`); + } else { + lines.push( + ` ❌ ${paddedName} ${result.error?.slice(0, 35) || 'Failed'}`, + ); + } + } + lines.push(''); + } + + lines.push('🟢 <2s 🟡 2-5s 🟠 5-10s 🔴 >10s'); + return lines.join('\n'); +} + +// Model ping component +function ModelPingComponent({ + onDone, +}: { + onDone: (result: string | null) => void; +}) { + const { bridge, cwd, addMessage } = useAppStore(); + const [results, setResults] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [completed, setCompleted] = useState(false); + + const titleText = 'Testing Models'; + const highlightIndex = useTextGradientAnimation(titleText, loading); + + useEffect(() => { + const runModelTests = async () => { + try { + const modelsResult = await bridge.request('models.list', { cwd }); + if (!modelsResult?.success) { + setError('Could not retrieve models'); + setLoading(false); + return; + } + + const providersResult = await bridge.request('providers.list', { cwd }); + if (!providersResult?.success) { + setError('Could not retrieve providers'); + setLoading(false); + return; + } + + const validProviderIds = new Set( + providersResult.data.providers + .filter((p) => p.validEnvs.length > 0 || p.hasApiKey) + .map((p) => p.id), + ); + + const modelsToTest: Array<{ + modelId: string; + modelName: string; + providerId: string; + providerName: string; + }> = []; + + for (const group of modelsResult.data.groupedModels) { + if (validProviderIds.has(group.providerId)) { + for (const model of group.models) { + modelsToTest.push({ + modelId: model.value, + modelName: model.name, + providerId: group.providerId, + providerName: group.provider, + }); + } + } + } + + if (modelsToTest.length === 0) { + setError('No configured models found to test'); + setLoading(false); + return; + } + + setResults( + modelsToTest.map((m) => ({ + modelId: m.modelId, + modelName: m.modelName, + providerName: m.providerName, + status: 'testing' as const, + })), + ); + + const testPromises = modelsToTest.map(async (model, idx) => { + const testResult = await bridge.request('models.test', { + cwd, + modelId: model.modelId, + }); + setResults((prev) => + prev.map((r, i) => + i === idx + ? { + ...r, + status: testResult.success ? 'success' : 'failed', + responseTime: + testResult.data?.responseTime || testResult.responseTime, + error: testResult.error, + } + : r, + ), + ); + return testResult; + }); + + await Promise.all(testPromises); + setLoading(false); + setCompleted(true); + } catch (err) { + setError(err instanceof Error ? err.message : 'Unknown error'); + setLoading(false); + } + }; + runModelTests(); + }, [bridge, cwd]); + + useEffect(() => { + if (completed && results.length > 0) { + const resultText = formatModelResultsText(results); + const assistantMessage: AssistantMessage = { + role: 'assistant', + content: resultText, + text: resultText, + model: 'system', + usage: { input_tokens: 0, output_tokens: 0 }, + }; + addMessage(assistantMessage); + const timer = setTimeout(() => onDone(''), 100); + return () => clearTimeout(timer); + } + }, [completed, results, onDone, addMessage]); + + if (error) { + return ( + + ❌ Error: {error} + + ); + } + + const completedCount = results.filter( + (r) => r.status === 'success' || r.status === 'failed', + ).length; + + if (completed) { + return ( + + ✓ Model test completed + + ); + } + + return ( + + + + {' '} + + + {results.length > 0 && ( + + {' '} + ({completedCount}/{results.length}) + + )} + + + {results.map((result) => ( + + + {result.status === 'testing' ? ( + + + + ) : result.status === 'success' ? ( + + ) : ( + + )} + + + + {result.modelName.slice(0, 22)} + + + + {result.status === 'success' && + result.responseTime !== undefined && ( + + {result.responseTime}ms + + )} + {result.status === 'failed' && ( + + {result.error?.slice(0, 30) || 'Failed'} + + )} + {result.status === 'testing' && ( + + testing... + + )} + + + ))} + + + ); +} + +// Network ping component +function NetworkPingComponent({ + onDone, +}: { + onDone: (result: string | null) => void; +}) { + const { bridge, cwd, addMessage } = useAppStore(); + const [results, setResults] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [completed, setCompleted] = useState(false); + + const titleText = 'Testing Network Latency'; + const highlightIndex = useTextGradientAnimation(titleText, loading); + + useEffect(() => { + const runPingTests = async () => { + try { + const providersResult = await bridge.request('providers.list', { cwd }); + if (!providersResult?.success) { + setError('Could not retrieve configured providers'); + setLoading(false); + return; + } + + const providersWithApi = providersResult.data.providers.filter( + (p) => p.api, + ); + if (providersWithApi.length === 0) { + setError('No providers with API endpoints found'); + setLoading(false); + return; + } + + setResults( + providersWithApi.map((p) => ({ + providerId: p.id, + providerName: p.name, + status: 'testing' as const, + })), + ); + + const pingPromises = providersWithApi.map(async (provider, idx) => { + const pingResult = await pingEndpoint(provider.api || ''); + setResults((prev) => + prev.map((r, i) => + i === idx + ? { + ...r, + status: pingResult.status, + responseTime: pingResult.responseTime, + error: pingResult.error, + } + : r, + ), + ); + return pingResult; + }); + + await Promise.all(pingPromises); + setLoading(false); + setCompleted(true); + } catch (err) { + setError(err instanceof Error ? err.message : 'Unknown error'); + setLoading(false); + } + }; + runPingTests(); + }, [bridge, cwd]); + + useEffect(() => { + if (completed && results.length > 0) { + const resultText = formatResultsText(results); + const assistantMessage: AssistantMessage = { + role: 'assistant', + content: resultText, + text: resultText, + model: 'system', + usage: { input_tokens: 0, output_tokens: 0 }, + }; + addMessage(assistantMessage); + const timer = setTimeout(() => onDone(''), 100); + return () => clearTimeout(timer); + } + }, [completed, results, onDone, addMessage]); + + if (error) { + return ( + + ❌ Error: {error} + + ); + } + + const completedCount = results.filter( + (r) => r.status === 'success' || r.status === 'failed', + ).length; + + if (completed) { + return ( + + ✓ Test completed + + ); + } + + return ( + + + + {' '} + + + {results.length > 0 && ( + + {' '} + ({completedCount}/{results.length}) + + )} + + + {results.map((result) => ( + + + {result.status === 'testing' ? ( + + + + ) : result.status === 'success' ? ( + + ) : ( + + )} + + + + {result.providerName} + + + + {result.status === 'success' && + result.responseTime !== undefined && ( + <> + + {getLatencyBar(result.responseTime, 12)} + + + {' '} + {result.responseTime}ms + + + )} + {result.status === 'failed' && ( + + {result.error?.slice(0, 25) || 'Failed'} + + )} + {result.status === 'testing' && ( + + testing... + + )} + + + ))} + + + ); +} + +export const pingCommand: LocalJSXCommand = { + type: 'local-jsx', + name: 'ping', + description: + 'Test network latency to providers. "/ping model" tests all configured models', + async call(onDone, _context, args) { + const isModelMode = args?.trim().toLowerCase() === 'model'; + + if (isModelMode) { + return React.createElement(ModelPingComponent, { onDone }); + } + return React.createElement(NetworkPingComponent, { onDone }); + }, +};