Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions apps/desktop/src/components/settings/ai/llm/select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import { listOllamaModels } from "../shared/list-ollama";
import { listGenericModels, listOpenAIModels } from "../shared/list-openai";
import { listOpenRouterModels } from "../shared/list-openrouter";
import { ModelCombobox } from "../shared/model-combobox";
import { useLocalProviderStatus } from "../shared/use-local-provider-status";
import { HealthStatusIndicator, useConnectionHealth } from "./health";
import { PROVIDERS } from "./shared";

Expand All @@ -44,6 +45,10 @@ export function SelectProviderAndModel() {
const isConfigured = !!(current_llm_provider && current_llm_model);
const hasError = isConfigured && health.status === "error";

// Get local provider statuses
const { status: ollamaStatus } = useLocalProviderStatus("ollama");
const { status: lmStudioStatus } = useLocalProviderStatus("lmstudio");

const handleSelectProvider = settings.UI.useSetValueCallback(
"current_llm_provider",
(provider: string) => provider,
Expand Down Expand Up @@ -116,16 +121,35 @@ export function SelectProviderAndModel() {
{PROVIDERS.map((provider) => {
const status = configuredProviders[provider.id];

// Get local provider status
let localStatus = null;
if (provider.id === "ollama") {
localStatus = ollamaStatus;
} else if (provider.id === "lmstudio") {
localStatus = lmStudioStatus;
}

// For local providers, only enable if connected
// For other providers, enable if they have listModels
const isLocalProvider =
provider.id === "ollama" || provider.id === "lmstudio";
const isEnabled = isLocalProvider
? localStatus === "connected" && !!status?.listModels
: !!status?.listModels;

return (
<SelectItem
key={provider.id}
value={provider.id}
disabled={!status?.listModels}
disabled={!isEnabled}
>
<div className="flex flex-col gap-0.5">
<div className="flex items-center gap-2">
{provider.icon}
<span>{provider.displayName}</span>
{localStatus === "connected" && (
<span className="size-1.5 rounded-full bg-green-500" />
)}
</div>
</div>
</SelectItem>
Expand Down Expand Up @@ -194,6 +218,10 @@ function useConfiguredMapping(): Record<string, ProviderStatus> {
settings.STORE_ID,
);

// Get local provider statuses here as well for the mapping
const { status: ollamaStatus } = useLocalProviderStatus("ollama");
const { status: lmStudioStatus } = useLocalProviderStatus("lmstudio");

const mapping = useMemo(() => {
return Object.fromEntries(
PROVIDERS.map((provider) => {
Expand All @@ -206,14 +234,25 @@ function useConfiguredMapping(): Record<string, ProviderStatus> {
const proLocked =
requiresEntitlement(provider.requirements, "pro") && !billing.isPro;

// Check if it's a connected local provider
const isLocalProvider =
provider.id === "ollama" || provider.id === "lmstudio";
const localStatus =
provider.id === "ollama"
? ollamaStatus
: provider.id === "lmstudio"
? lmStudioStatus
: null;
const isConnectedLocal = isLocalProvider && localStatus === "connected";

const eligible =
getProviderSelectionBlockers(provider.requirements, {
isAuthenticated: !!auth?.session,
isPro: billing.isPro,
config: { base_url: baseUrl, api_key: apiKey },
}).length === 0;
}).length === 0 || isConnectedLocal;

if (!eligible) {
if (!eligible && !isConnectedLocal) {
return [provider.id, { listModels: undefined, proLocked }];
}

Expand Down Expand Up @@ -261,7 +300,7 @@ function useConfiguredMapping(): Record<string, ProviderStatus> {
return [provider.id, { listModels: listModelsFunc, proLocked }];
}),
) as Record<string, ProviderStatus>;
}, [configuredProviders, auth, billing]);
}, [configuredProviders, auth, billing, ollamaStatus, lmStudioStatus]);

return mapping;
}
18 changes: 18 additions & 0 deletions apps/desktop/src/components/settings/ai/llm/shared.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ type Provider = {
icon: ReactNode;
baseUrl?: string;
requirements: ProviderRequirement[];
links?: {
download?: { label: string; url: string };
models?: { label: string; url: string };
};
};

const _PROVIDERS = [
Expand All @@ -44,6 +48,13 @@ const _PROVIDERS = [
icon: <LmStudio size={16} />,
baseUrl: "http://127.0.0.1:1234/v1",
requirements: [],
links: {
download: {
label: "Download LM Studio",
url: "https://lmstudio.ai/download",
},
models: { label: "Available models", url: "https://lmstudio.ai/models" },
},
},
{
id: "ollama",
Expand All @@ -52,6 +63,13 @@ const _PROVIDERS = [
icon: <Ollama size={16} />,
baseUrl: "http://127.0.0.1:11434/v1",
requirements: [],
links: {
download: {
label: "Download Ollama",
url: "https://ollama.com/download",
},
models: { label: "Available models", url: "https://ollama.com/library" },
},
},
{
id: "openrouter",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { LMStudioClient } from "@lmstudio/sdk";
import { fetch as tauriFetch } from "@tauri-apps/plugin-http";
import { Effect, pipe } from "effect";

export type LocalProviderStatus = "connected" | "disconnected" | "checking";

const CHECK_TIMEOUT = "2 seconds";

export async function checkOllamaConnection(baseUrl: string): Promise<boolean> {
if (!baseUrl) {
return false;
}

return pipe(
Effect.tryPromise(async () => {
const ollamaHost = baseUrl.replace(/\/v1\/?$/, "");
const response = await tauriFetch(`${ollamaHost}/api/tags`, {
method: "GET",
headers: {
Origin: new URL(ollamaHost).origin,
},
});
return response.ok;
}),
Effect.timeout(CHECK_TIMEOUT),
Effect.catchAll(() => Effect.succeed(false)),
Effect.runPromise,
);
}

export async function checkLMStudioConnection(
baseUrl: string,
): Promise<boolean> {
if (!baseUrl) {
return false;
}

return pipe(
Effect.tryPromise(async () => {
const url = new URL(baseUrl);
const port = url.port || "1234";
const formattedUrl = `ws:127.0.0.1:${port}`;
const client = new LMStudioClient({ baseUrl: formattedUrl });
await client.system.listDownloadedModels();
return true;
}),
Effect.timeout(CHECK_TIMEOUT),
Effect.catchAll(() => Effect.succeed(false)),
Effect.runPromise,
);
}
95 changes: 88 additions & 7 deletions apps/desktop/src/components/settings/ai/shared/index.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Icon } from "@iconify-icon/react";
import { type AnyFieldApi, useForm } from "@tanstack/react-form";
import { MoveUpRight } from "lucide-react";
import type { ReactNode } from "react";
import { Streamdown } from "streamdown";

Expand All @@ -11,10 +12,12 @@ import {
AccordionItem,
AccordionTrigger,
} from "@hypr/ui/components/ui/accordion";
import { Button } from "@hypr/ui/components/ui/button";
import {
InputGroup,
InputGroupInput,
} from "@hypr/ui/components/ui/input-group";
import { Spinner } from "@hypr/ui/components/ui/spinner";
import { cn } from "@hypr/utils";

import { useBillingAccess } from "../../../../billing";
Expand All @@ -25,6 +28,7 @@ import {
type ProviderRequirement,
requiresEntitlement,
} from "./eligibility";
import { useLocalProviderStatus } from "./use-local-provider-status";

export * from "./model-combobox";

Expand All @@ -38,6 +42,10 @@ type ProviderConfig = {
baseUrl?: string;
disabled?: boolean;
requirements: ProviderRequirement[];
links?: {
download?: { label: string; url: string };
models?: { label: string; url: string };
};
};

function useIsProviderConfigured(
Expand Down Expand Up @@ -94,6 +102,8 @@ export function NonHyprProviderCard({
providerType,
providers,
);
const { status: localProviderStatus, refetch: refetchStatus } =
useLocalProviderStatus(config.id);

const requiredFields = getRequiredConfigFields(config.requirements);
const showApiKey = requiredFields.includes("api_key");
Expand Down Expand Up @@ -139,13 +149,32 @@ export function NonHyprProviderCard({
(config.disabled || locked) && "cursor-not-allowed opacity-30",
])}
>
<div className="flex items-center gap-2">
{config.icon}
<span>{config.displayName}</span>
{config.badge && (
<span className="text-xs text-neutral-500 font-light border border-neutral-300 rounded-full px-2">
{config.badge}
</span>
<div className="flex items-center justify-between w-full">
<div className="flex items-center gap-2">
{config.icon}
<span>{config.displayName}</span>
{config.badge && (
<span className="text-xs text-neutral-500 font-light border border-neutral-300 rounded-full px-2">
{config.badge}
</span>
)}
{localProviderStatus && (
<LocalProviderStatusBadge status={localProviderStatus} />
)}
</div>
{localProviderStatus && localProviderStatus !== "connected" && (
<Button
variant="outline"
size="sm"
onClick={(e) => {
e.stopPropagation();
refetchStatus();
}}
disabled={localProviderStatus === "checking"}
className="mr-2"
>
Connect
</Button>
)}
</div>
</AccordionTrigger>
Expand Down Expand Up @@ -178,6 +207,32 @@ export function NonHyprProviderCard({
)}
</form.Field>
)}
{config.links && (config.links.download || config.links.models) && (
<div className="flex items-center gap-4 text-xs">
{config.links.download && (
<a
href={config.links.download.url}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center gap-0.5 text-neutral-600 hover:text-neutral-900 hover:underline"
>
{config.links.download.label}
<MoveUpRight size={12} />
</a>
)}
{config.links.models && (
<a
href={config.links.models.url}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center gap-0.5 text-neutral-600 hover:text-neutral-900 hover:underline"
>
{config.links.models.label}
<MoveUpRight size={12} />
</a>
)}
</div>
)}
{!showBaseUrl && config.baseUrl && (
<details className="space-y-4 pt-2">
<summary className="text-xs cursor-pointer text-neutral-600 hover:text-neutral-900 hover:underline">
Expand Down Expand Up @@ -237,6 +292,32 @@ export function StyledStreamdown({
);
}

function LocalProviderStatusBadge({
status,
}: {
status: "connected" | "disconnected" | "checking";
}) {
if (status === "checking") {
return <Spinner size={12} className="shrink-0 text-neutral-400" />;
}

if (status === "connected") {
return (
<span className="flex items-center gap-1 text-xs text-green-600 font-light">
<span className="size-1.5 rounded-full bg-green-500" />
Connected
</span>
);
}

return (
<span className="flex items-center gap-1 text-xs text-neutral-500 font-light">
<span className="size-1.5 rounded-full bg-neutral-400" />
Not Running
</span>
);
}

function useProvider(id: string) {
const providerRow = settings.UI.useRow("ai_providers", id, settings.STORE_ID);
const setProvider = settings.UI.useSetPartialRowCallback(
Expand Down
Loading
Loading