diff --git a/apps/desktop/src/components/settings/ai/llm/select.tsx b/apps/desktop/src/components/settings/ai/llm/select.tsx index 01aa81a645..56cb8d1efe 100644 --- a/apps/desktop/src/components/settings/ai/llm/select.tsx +++ b/apps/desktop/src/components/settings/ai/llm/select.tsx @@ -29,6 +29,7 @@ import { listOllamaModels } from "../shared/list-ollama"; import { listGenericModels, listOpenAIModels } from "../shared/list-openai"; import { listOpenRouterModels } from "../shared/list-openrouter"; import { ModelCombobox } from "../shared/model-combobox"; +import { useLocalProviderStatus } from "../shared/use-local-provider-status"; import { HealthStatusIndicator, useConnectionHealth } from "./health"; import { PROVIDERS } from "./shared"; @@ -44,6 +45,10 @@ export function SelectProviderAndModel() { const isConfigured = !!(current_llm_provider && current_llm_model); const hasError = isConfigured && health.status === "error"; + // Get local provider statuses + const { status: ollamaStatus } = useLocalProviderStatus("ollama"); + const { status: lmStudioStatus } = useLocalProviderStatus("lmstudio"); + const handleSelectProvider = settings.UI.useSetValueCallback( "current_llm_provider", (provider: string) => provider, @@ -116,16 +121,35 @@ export function SelectProviderAndModel() { {PROVIDERS.map((provider) => { const status = configuredProviders[provider.id]; + // Get local provider status + let localStatus = null; + if (provider.id === "ollama") { + localStatus = ollamaStatus; + } else if (provider.id === "lmstudio") { + localStatus = lmStudioStatus; + } + + // For local providers, only enable if connected + // For other providers, enable if they have listModels + const isLocalProvider = + provider.id === "ollama" || provider.id === "lmstudio"; + const isEnabled = isLocalProvider + ? localStatus === "connected" && !!status?.listModels + : !!status?.listModels; + return (
{provider.icon} {provider.displayName} + {localStatus === "connected" && ( + + )}
@@ -194,6 +218,10 @@ function useConfiguredMapping(): Record { settings.STORE_ID, ); + // Get local provider statuses here as well for the mapping + const { status: ollamaStatus } = useLocalProviderStatus("ollama"); + const { status: lmStudioStatus } = useLocalProviderStatus("lmstudio"); + const mapping = useMemo(() => { return Object.fromEntries( PROVIDERS.map((provider) => { @@ -206,14 +234,25 @@ function useConfiguredMapping(): Record { const proLocked = requiresEntitlement(provider.requirements, "pro") && !billing.isPro; + // Check if it's a connected local provider + const isLocalProvider = + provider.id === "ollama" || provider.id === "lmstudio"; + const localStatus = + provider.id === "ollama" + ? ollamaStatus + : provider.id === "lmstudio" + ? lmStudioStatus + : null; + const isConnectedLocal = isLocalProvider && localStatus === "connected"; + const eligible = getProviderSelectionBlockers(provider.requirements, { isAuthenticated: !!auth?.session, isPro: billing.isPro, config: { base_url: baseUrl, api_key: apiKey }, - }).length === 0; + }).length === 0 || isConnectedLocal; - if (!eligible) { + if (!eligible && !isConnectedLocal) { return [provider.id, { listModels: undefined, proLocked }]; } @@ -261,7 +300,7 @@ function useConfiguredMapping(): Record { return [provider.id, { listModels: listModelsFunc, proLocked }]; }), ) as Record; - }, [configuredProviders, auth, billing]); + }, [configuredProviders, auth, billing, ollamaStatus, lmStudioStatus]); return mapping; } diff --git a/apps/desktop/src/components/settings/ai/llm/shared.tsx b/apps/desktop/src/components/settings/ai/llm/shared.tsx index cd4abeb097..a05199f2b1 100644 --- a/apps/desktop/src/components/settings/ai/llm/shared.tsx +++ b/apps/desktop/src/components/settings/ai/llm/shared.tsx @@ -23,6 +23,10 @@ type Provider = { icon: ReactNode; baseUrl?: string; requirements: ProviderRequirement[]; + links?: { + download?: { label: string; url: string }; + models?: { label: string; url: string }; + }; }; const _PROVIDERS = [ @@ -44,6 +48,13 @@ const _PROVIDERS = [ icon: , baseUrl: "http://127.0.0.1:1234/v1", requirements: [], + links: { + download: { + label: "Download LM Studio", + url: "https://lmstudio.ai/download", + }, + models: { label: "Available models", url: "https://lmstudio.ai/models" }, + }, }, { id: "ollama", @@ -52,6 +63,13 @@ const _PROVIDERS = [ icon: , baseUrl: "http://127.0.0.1:11434/v1", requirements: [], + links: { + download: { + label: "Download Ollama", + url: "https://ollama.com/download", + }, + models: { label: "Available models", url: "https://ollama.com/library" }, + }, }, { id: "openrouter", diff --git a/apps/desktop/src/components/settings/ai/shared/check-local-provider.ts b/apps/desktop/src/components/settings/ai/shared/check-local-provider.ts new file mode 100644 index 0000000000..82e5974439 --- /dev/null +++ b/apps/desktop/src/components/settings/ai/shared/check-local-provider.ts @@ -0,0 +1,51 @@ +import { LMStudioClient } from "@lmstudio/sdk"; +import { fetch as tauriFetch } from "@tauri-apps/plugin-http"; +import { Effect, pipe } from "effect"; + +export type LocalProviderStatus = "connected" | "disconnected" | "checking"; + +const CHECK_TIMEOUT = "2 seconds"; + +export async function checkOllamaConnection(baseUrl: string): Promise { + if (!baseUrl) { + return false; + } + + return pipe( + Effect.tryPromise(async () => { + const ollamaHost = baseUrl.replace(/\/v1\/?$/, ""); + const response = await tauriFetch(`${ollamaHost}/api/tags`, { + method: "GET", + headers: { + Origin: new URL(ollamaHost).origin, + }, + }); + return response.ok; + }), + Effect.timeout(CHECK_TIMEOUT), + Effect.catchAll(() => Effect.succeed(false)), + Effect.runPromise, + ); +} + +export async function checkLMStudioConnection( + baseUrl: string, +): Promise { + if (!baseUrl) { + return false; + } + + return pipe( + Effect.tryPromise(async () => { + const url = new URL(baseUrl); + const port = url.port || "1234"; + const formattedUrl = `ws:127.0.0.1:${port}`; + const client = new LMStudioClient({ baseUrl: formattedUrl }); + await client.system.listDownloadedModels(); + return true; + }), + Effect.timeout(CHECK_TIMEOUT), + Effect.catchAll(() => Effect.succeed(false)), + Effect.runPromise, + ); +} diff --git a/apps/desktop/src/components/settings/ai/shared/index.tsx b/apps/desktop/src/components/settings/ai/shared/index.tsx index 6da46f8055..3a84fb1bf1 100644 --- a/apps/desktop/src/components/settings/ai/shared/index.tsx +++ b/apps/desktop/src/components/settings/ai/shared/index.tsx @@ -1,5 +1,6 @@ import { Icon } from "@iconify-icon/react"; import { type AnyFieldApi, useForm } from "@tanstack/react-form"; +import { MoveUpRight } from "lucide-react"; import type { ReactNode } from "react"; import { Streamdown } from "streamdown"; @@ -11,10 +12,12 @@ import { AccordionItem, AccordionTrigger, } from "@hypr/ui/components/ui/accordion"; +import { Button } from "@hypr/ui/components/ui/button"; import { InputGroup, InputGroupInput, } from "@hypr/ui/components/ui/input-group"; +import { Spinner } from "@hypr/ui/components/ui/spinner"; import { cn } from "@hypr/utils"; import { useBillingAccess } from "../../../../billing"; @@ -25,6 +28,7 @@ import { type ProviderRequirement, requiresEntitlement, } from "./eligibility"; +import { useLocalProviderStatus } from "./use-local-provider-status"; export * from "./model-combobox"; @@ -38,6 +42,10 @@ type ProviderConfig = { baseUrl?: string; disabled?: boolean; requirements: ProviderRequirement[]; + links?: { + download?: { label: string; url: string }; + models?: { label: string; url: string }; + }; }; function useIsProviderConfigured( @@ -94,6 +102,8 @@ export function NonHyprProviderCard({ providerType, providers, ); + const { status: localProviderStatus, refetch: refetchStatus } = + useLocalProviderStatus(config.id); const requiredFields = getRequiredConfigFields(config.requirements); const showApiKey = requiredFields.includes("api_key"); @@ -139,13 +149,32 @@ export function NonHyprProviderCard({ (config.disabled || locked) && "cursor-not-allowed opacity-30", ])} > -
- {config.icon} - {config.displayName} - {config.badge && ( - - {config.badge} - +
+
+ {config.icon} + {config.displayName} + {config.badge && ( + + {config.badge} + + )} + {localProviderStatus && ( + + )} +
+ {localProviderStatus && localProviderStatus !== "connected" && ( + )}
@@ -178,6 +207,32 @@ export function NonHyprProviderCard({ )} )} + {config.links && (config.links.download || config.links.models) && ( +
+ {config.links.download && ( + + {config.links.download.label} + + + )} + {config.links.models && ( + + {config.links.models.label} + + + )} +
+ )} {!showBaseUrl && config.baseUrl && (
@@ -237,6 +292,32 @@ export function StyledStreamdown({ ); } +function LocalProviderStatusBadge({ + status, +}: { + status: "connected" | "disconnected" | "checking"; +}) { + if (status === "checking") { + return ; + } + + if (status === "connected") { + return ( + + + Connected + + ); + } + + return ( + + + Not Running + + ); +} + function useProvider(id: string) { const providerRow = settings.UI.useRow("ai_providers", id, settings.STORE_ID); const setProvider = settings.UI.useSetPartialRowCallback( diff --git a/apps/desktop/src/components/settings/ai/shared/use-local-provider-status.ts b/apps/desktop/src/components/settings/ai/shared/use-local-provider-status.ts new file mode 100644 index 0000000000..4ee573f6a1 --- /dev/null +++ b/apps/desktop/src/components/settings/ai/shared/use-local-provider-status.ts @@ -0,0 +1,71 @@ +import { useQuery } from "@tanstack/react-query"; + +import * as settings from "../../../../store/tinybase/store/settings"; +import { + checkLMStudioConnection, + checkOllamaConnection, + type LocalProviderStatus, +} from "./check-local-provider"; + +const LOCAL_PROVIDERS = ["ollama", "lmstudio"] as const; +type LocalProviderId = (typeof LOCAL_PROVIDERS)[number]; + +const DEFAULT_OLLAMA_URL = "http://127.0.0.1:11434/v1"; +const DEFAULT_LMSTUDIO_URL = "http://127.0.0.1:1234/v1"; + +function isLocalProvider(providerId: string): providerId is LocalProviderId { + return LOCAL_PROVIDERS.includes(providerId as LocalProviderId); +} + +export function useLocalProviderStatus(providerId: string): { + status: LocalProviderStatus | null; + refetch: () => void; +} { + const configuredProviders = settings.UI.useResultTable( + settings.QUERIES.llmProviders, + settings.STORE_ID, + ); + + const config = configuredProviders[providerId]; + + const defaultUrl = + providerId === "ollama" + ? DEFAULT_OLLAMA_URL + : providerId === "lmstudio" + ? DEFAULT_LMSTUDIO_URL + : ""; + + const baseUrl = String(config?.base_url || defaultUrl).trim(); + + const checkFn = + providerId === "ollama" + ? checkOllamaConnection + : providerId === "lmstudio" + ? checkLMStudioConnection + : null; + + const query = useQuery({ + enabled: isLocalProvider(providerId) && !!checkFn, + queryKey: ["local-provider-status", providerId, baseUrl], + queryFn: async () => { + if (!checkFn) return false; + return checkFn(baseUrl); + }, + staleTime: 10_000, + refetchInterval: 15_000, + retry: false, + }); + + if (!isLocalProvider(providerId)) { + return { status: null, refetch: () => {} }; + } + + const status: LocalProviderStatus = + query.isLoading || query.isFetching + ? "checking" + : query.data + ? "connected" + : "disconnected"; + + return { status, refetch: () => void query.refetch() }; +}