diff --git a/Cargo.toml b/Cargo.toml index 130b6728..36ea9981 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ wasmtime-wasi-config = "38.0.4" [dependencies] anyhow = { workspace = true } axum = "0.8" -clap = { version = "4.5", features = ["derive"] } +clap = { version = "4.5", features = ["derive", "env"] } clap_complete = "4.5" etcetera = { workspace = true } figment = { version = "0.10", features = ["env", "toml"] } diff --git a/crates/mcp-server/src/components.rs b/crates/mcp-server/src/components.rs index 7aa226ee..987ac3e5 100644 --- a/crates/mcp-server/src/components.rs +++ b/crates/mcp-server/src/components.rs @@ -11,7 +11,7 @@ use rmcp::{Peer, RoleServer}; use serde_json::{json, Value}; use tracing::{debug, error, info, instrument}; use wassette::schema::{canonicalize_output_schema, ensure_structured_result}; -use wassette::{ComponentLoadOutcome, LifecycleManager, LoadResult}; +use wassette::{ComponentLoadOutcome, LifecycleManager, LoadResult, OciCredentials}; #[instrument(skip(lifecycle_manager))] pub(crate) async fn get_component_tools(lifecycle_manager: &LifecycleManager) -> Result> { @@ -388,6 +388,16 @@ async fn handle_tool_list_notification( pub async fn handle_load_component_cli( req: &CallToolRequestParam, lifecycle_manager: &LifecycleManager, +) -> Result { + handle_load_component_cli_with_credentials(req, lifecycle_manager, None).await +} + +/// CLI-specific version of handle_load_component with optional credentials +#[instrument(skip(lifecycle_manager))] +pub async fn handle_load_component_cli_with_credentials( + req: &CallToolRequestParam, + lifecycle_manager: &LifecycleManager, + credentials: Option, ) -> Result { let args = extract_args_from_request(req)?; let path = args @@ -397,7 +407,10 @@ pub async fn handle_load_component_cli( info!(path, "Loading component (CLI mode)"); - match lifecycle_manager.load_component(path).await { + match lifecycle_manager + .load_component_with_credentials(path, credentials) + .await + { Ok(outcome) => { handle_tool_list_notification(None, &outcome.component_id, "load").await; create_load_component_success_result(&outcome) diff --git a/crates/wassette/src/lib.rs b/crates/wassette/src/lib.rs index 963f23cd..18abe5ab 100644 --- a/crates/wassette/src/lib.rs +++ b/crates/wassette/src/lib.rs @@ -42,6 +42,7 @@ use component_storage::ComponentStorage; pub use config::{LifecycleBuilder, LifecycleConfig}; pub use http::WassetteWasiState; use loader::{ComponentResource, DownloadedResource}; +pub use oci_auth::OciCredentials; use policy_internal::PolicyManager; pub use policy_internal::{PermissionGrantRequest, PermissionRule, PolicyInfo}; use runtime_context::RuntimeContext; @@ -424,15 +425,26 @@ impl LifecycleManager { self.policy_manager.restore_from_disk(component_id).await } + #[allow(dead_code)] async fn resolve_component_resource(&self, uri: &str) -> Result<(String, DownloadedResource)> { + self.resolve_component_resource_with_credentials(uri, None) + .await + } + + async fn resolve_component_resource_with_credentials( + &self, + uri: &str, + credentials: Option, + ) -> Result<(String, DownloadedResource)> { // Show progress when running in CLI mode (stderr is a TTY) let show_progress = std::io::stderr().is_terminal(); - let resource = loader::load_resource_with_progress::( + let resource = loader::load_resource_with_progress_and_credentials::( uri, &self.oci_client, &self.http_client, show_progress, + credentials, ) .await?; let id = resource.id()?; @@ -527,8 +539,29 @@ impl LifecycleManager { /// component and whether it replaced an existing instance. #[instrument(skip(self))] pub async fn load_component(&self, uri: &str) -> Result { + self.load_component_with_credentials(uri, None).await + } + + /// Loads a WebAssembly component from the specified URI with optional explicit credentials + /// + /// If a component with the given id already exists, it will be updated with the new component. + /// Returns rich [`ComponentLoadOutcome`] information describing the loaded + /// component and whether it replaced an existing instance. + /// + /// # Arguments + /// + /// * `uri` - The URI to load the component from (file://, oci://, or https://) + /// * `credentials` - Optional explicit OCI registry credentials (takes priority over Docker config) + #[instrument(skip(self))] + pub async fn load_component_with_credentials( + &self, + uri: &str, + credentials: Option, + ) -> Result { debug!(uri, "Loading component"); - let (component_id, resource) = self.resolve_component_resource(uri).await?; + let (component_id, resource) = self + .resolve_component_resource_with_credentials(uri, credentials) + .await?; let staged_path = self .stage_component_artifact(&component_id, resource) .await?; diff --git a/crates/wassette/src/loader.rs b/crates/wassette/src/loader.rs index 71a241ee..b14b2334 100644 --- a/crates/wassette/src/loader.rs +++ b/crates/wassette/src/loader.rs @@ -153,11 +153,18 @@ pub trait Loadable: Sized { const RESOURCE_TYPE: &'static str; async fn from_local_file(path: &Path) -> Result; + #[allow(dead_code)] async fn from_oci_reference_with_progress( reference: &str, oci_client: &oci_client::Client, show_progress: bool, ) -> Result; + async fn from_oci_reference_with_progress_and_credentials( + reference: &str, + oci_client: &oci_client::Client, + show_progress: bool, + credentials: Option, + ) -> Result; async fn from_url(url: &str, http_client: &reqwest::Client) -> Result; } @@ -192,6 +199,21 @@ impl Loadable for ComponentResource { reference: &str, oci_client: &oci_client::Client, show_progress: bool, + ) -> Result { + Self::from_oci_reference_with_progress_and_credentials( + reference, + oci_client, + show_progress, + None, + ) + .await + } + + async fn from_oci_reference_with_progress_and_credentials( + reference: &str, + oci_client: &oci_client::Client, + show_progress: bool, + credentials: Option, ) -> Result { let reference: oci_client::Reference = reference.parse().context("Failed to parse OCI reference")?; @@ -201,7 +223,7 @@ impl Loadable for ComponentResource { } // Get authentication credentials for this registry - let auth = crate::oci_auth::get_registry_auth(&reference) + let auth = crate::oci_auth::get_registry_auth_with_credentials(&reference, credentials) .context("Failed to get registry authentication")?; // First try oci-wasm for backwards compatibility with single-layer artifacts @@ -344,6 +366,15 @@ impl Loadable for PolicyResource { bail!("OCI references are not supported for policy resources. Use 'file://' or 'https://' schemes instead.") } + async fn from_oci_reference_with_progress_and_credentials( + _reference: &str, + _oci_client: &oci_client::Client, + _show_progress: bool, + _credentials: Option, + ) -> Result { + bail!("OCI references are not supported for policy resources. Use 'file://' or 'https://' schemes instead.") + } + async fn from_url(url: &str, http_client: &reqwest::Client) -> Result { let url_obj = reqwest::Url::parse(url)?; let filename = url_obj @@ -392,6 +423,24 @@ pub(crate) async fn load_resource_with_progress( oci_client: &oci_wasm::WasmClient, http_client: &reqwest::Client, show_progress: bool, +) -> Result { + load_resource_with_progress_and_credentials::( + uri, + oci_client, + http_client, + show_progress, + None, + ) + .await +} + +/// Generic resource loading function with optional progress reporting and credentials +pub(crate) async fn load_resource_with_progress_and_credentials( + uri: &str, + oci_client: &oci_wasm::WasmClient, + http_client: &reqwest::Client, + show_progress: bool, + credentials: Option, ) -> Result { let uri = uri.trim(); let error_message = format!( @@ -402,7 +451,15 @@ pub(crate) async fn load_resource_with_progress( match scheme { "file" => T::from_local_file(Path::new(reference)).await, - "oci" => T::from_oci_reference_with_progress(reference, oci_client, show_progress).await, + "oci" => { + T::from_oci_reference_with_progress_and_credentials( + reference, + oci_client, + show_progress, + credentials, + ) + .await + } "https" => T::from_url(uri, http_client).await, _ => bail!("Unsupported {} scheme: {}", T::RESOURCE_TYPE, scheme), } diff --git a/crates/wassette/src/oci_auth.rs b/crates/wassette/src/oci_auth.rs index fc26181e..03e21a21 100644 --- a/crates/wassette/src/oci_auth.rs +++ b/crates/wassette/src/oci_auth.rs @@ -13,6 +13,15 @@ use oci_client::secrets::RegistryAuth; use oci_client::Reference; use tracing::{debug, warn}; +/// OCI registry credentials provided explicitly via CLI or environment variables +#[derive(Debug, Clone)] +pub struct OciCredentials { + /// Registry username + pub username: String, + /// Registry password or token + pub password: String, +} + /// Get authentication credentials for an OCI registry reference /// /// This function attempts to read credentials from the Docker config file @@ -39,6 +48,46 @@ use tracing::{debug, warn}; /// Returns an error if the Docker config file exists but cannot be parsed /// or if credential retrieval fails for reasons other than missing config. pub fn get_registry_auth(reference: &Reference) -> Result { + get_registry_auth_with_credentials(reference, None) +} + +/// Get authentication credentials for an OCI registry reference with optional explicit credentials +/// +/// This function implements a priority-based credential resolution: +/// +/// 1. Use explicit credentials if provided (CLI flags or environment variables) +/// 2. Fall back to Docker config file credentials +/// 3. Fall back to Anonymous if no credentials are found +/// +/// # Arguments +/// +/// * `reference` - The OCI reference to get credentials for +/// * `explicit_credentials` - Optional explicit credentials from CLI flags +/// +/// # Returns +/// +/// Returns a `RegistryAuth` enum that can be one of: +/// - `Anonymous` - No credentials found +/// - `Basic(username, password)` - Username/password credentials +/// +/// # Errors +/// +/// Returns an error if the Docker config file exists but cannot be parsed +/// or if credential retrieval fails for reasons other than missing config. +pub fn get_registry_auth_with_credentials( + reference: &Reference, + explicit_credentials: Option, +) -> Result { + // Priority 1: Use explicit credentials if provided + if let Some(creds) = explicit_credentials { + debug!( + "Using explicit credentials for registry: {}", + reference.resolve_registry() + ); + return Ok(RegistryAuth::Basic(creds.username, creds.password)); + } + + // Priority 2: Try Docker config file // Get the registry server address from the reference // Strip trailing slash if present for consistent matching let server = reference @@ -214,4 +263,133 @@ mod tests { // Should not have trailing slash assert!(!server.ends_with('/'), "Server should not end with slash"); } + + #[test] + fn test_explicit_credentials_take_precedence() { + use temp_env; + + let temp_dir = TempDir::new().unwrap(); + + // Create a test Docker config with basic auth + let config_content = r#"{ + "auths": { + "ghcr.io": { + "auth": "ZG9ja2VydXNlcjpkb2NrZXJwYXNz" + } + } + }"#; + + let config_path = create_test_docker_config(&temp_dir, config_content); + let docker_config_dir = config_path.parent().unwrap(); + + temp_env::with_var("DOCKER_CONFIG", Some(docker_config_dir), || { + let reference: Reference = "ghcr.io/test/image:latest".parse().unwrap(); + + // Provide explicit credentials that should override Docker config + let explicit_creds = OciCredentials { + username: "explicituser".to_string(), + password: "explicitpass".to_string(), + }; + + let auth = + get_registry_auth_with_credentials(&reference, Some(explicit_creds)).unwrap(); + + match auth { + RegistryAuth::Basic(username, password) => { + assert_eq!(username, "explicituser"); + assert_eq!(password, "explicitpass"); + } + _ => panic!( + "Expected Basic auth with explicit credentials, got: {:?}", + auth + ), + } + }); + } + + #[test] + fn test_fallback_to_docker_config_when_no_explicit_credentials() { + use temp_env; + + let temp_dir = TempDir::new().unwrap(); + + // Create a test Docker config with basic auth + let config_content = r#"{ + "auths": { + "ghcr.io": { + "auth": "dGVzdHVzZXI6dGVzdHBhc3M=" + } + } + }"#; + + let config_path = create_test_docker_config(&temp_dir, config_content); + let docker_config_dir = config_path.parent().unwrap(); + + temp_env::with_var("DOCKER_CONFIG", Some(docker_config_dir), || { + let reference: Reference = "ghcr.io/test/image:latest".parse().unwrap(); + + // Call with None for explicit credentials - should use Docker config + let auth = get_registry_auth_with_credentials(&reference, None).unwrap(); + + match auth { + RegistryAuth::Basic(username, password) => { + assert_eq!(username, "testuser"); + assert_eq!(password, "testpass"); + } + _ => panic!("Expected Basic auth from Docker config, got: {:?}", auth), + } + }); + } + + #[test] + fn test_explicit_credentials_without_docker_config() { + use temp_env; + + let temp_dir = TempDir::new().unwrap(); + + // Set DOCKER_CONFIG to empty temp dir (no config.json) + temp_env::with_var("DOCKER_CONFIG", Some(temp_dir.path()), || { + let reference: Reference = "ghcr.io/test/image:latest".parse().unwrap(); + + // Provide explicit credentials + let explicit_creds = OciCredentials { + username: "explicituser".to_string(), + password: "explicitpass".to_string(), + }; + + let auth = + get_registry_auth_with_credentials(&reference, Some(explicit_creds)).unwrap(); + + match auth { + RegistryAuth::Basic(username, password) => { + assert_eq!(username, "explicituser"); + assert_eq!(password, "explicitpass"); + } + _ => panic!( + "Expected Basic auth with explicit credentials, got: {:?}", + auth + ), + } + }); + } + + #[test] + fn test_anonymous_when_no_credentials_available() { + use temp_env; + + let temp_dir = TempDir::new().unwrap(); + + // Set DOCKER_CONFIG to empty temp dir (no config.json) + temp_env::with_var("DOCKER_CONFIG", Some(temp_dir.path()), || { + let reference: Reference = "docker.io/library/nginx:latest".parse().unwrap(); + + // Call with None for explicit credentials and no Docker config + let auth = get_registry_auth_with_credentials(&reference, None).unwrap(); + + assert!( + matches!(auth, RegistryAuth::Anonymous), + "Expected Anonymous auth when no credentials available" + ); + }); + } } diff --git a/src/commands.rs b/src/commands.rs index 65240889..fe7198f2 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -177,6 +177,15 @@ pub enum ComponentCommands { /// Directory where components are stored. Defaults to $XDG_DATA_HOME/wassette/components #[arg(long)] component_dir: Option, + /// Registry username for OCI authentication + #[arg(long, env = "OCI_REGISTRY_USER")] + registry_user: Option, + /// Registry password for OCI authentication (use --registry-password-stdin for better security) + #[arg(long, env = "OCI_REGISTRY_PASSWORD")] + registry_password: Option, + /// Read registry password from stdin + #[arg(long, conflicts_with = "registry_password")] + registry_password_stdin: bool, }, /// Unload a WebAssembly component. Unload { diff --git a/src/main.rs b/src/main.rs index 5601161b..43165680 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,6 +16,7 @@ use rmcp::transport::{stdio as stdio_transport, SseServer}; use serde_json::{json, Map}; use tracing_subscriber::layer::SubscriberExt as _; use tracing_subscriber::util::SubscriberInitExt as _; +use wassette::OciCredentials; mod cli_handlers; mod commands; @@ -274,18 +275,62 @@ async fn main() -> Result<()> { ComponentCommands::Load { path, component_dir, + registry_user, + registry_password, + registry_password_stdin, } => { + // Validate credentials pairing + if registry_user.is_some() != registry_password.is_some() + && !registry_password_stdin + { + bail!("Both --registry-user and --registry-password (or --registry-password-stdin) must be provided together"); + } + + // Read password from stdin if requested + let password = if *registry_password_stdin { + if registry_user.is_none() { + bail!("--registry-user must be provided when using --registry-password-stdin"); + } + use std::io::Read; + let mut buffer = String::new(); + std::io::stdin() + .read_to_string(&mut buffer) + .context("Failed to read password from stdin")?; + Some(buffer.trim().to_string()) + } else { + registry_password.clone() + }; + + // Create credentials if username and password are provided + let credentials = + if let (Some(username), Some(pwd)) = (registry_user, &password) { + Some(OciCredentials { + username: username.clone(), + password: pwd.clone(), + }) + } else { + None + }; + let component_dir = component_dir.clone().or_else(|| cli.component_dir.clone()); let lifecycle_manager = create_lifecycle_manager(component_dir).await?; let mut args = Map::new(); args.insert("path".to_string(), json!(path)); - handle_tool_cli_command( + + // Use the new handler with credentials + use mcp_server::components::handle_load_component_cli_with_credentials; + use rmcp::model::CallToolRequestParam; + let req = CallToolRequestParam { + name: "load-component".to_string().into(), + arguments: Some(args), + }; + let result = handle_load_component_cli_with_credentials( + &req, &lifecycle_manager, - "load-component", - args, - OutputFormat::Json, + credentials, ) .await?; + print_result(&result, OutputFormat::Json)?; } ComponentCommands::Unload { id, component_dir } => { let component_dir = component_dir.clone().or_else(|| cli.component_dir.clone()); diff --git a/tests/cli_integration_test.rs b/tests/cli_integration_test.rs index ea8e8ed9..8cd7391a 100644 --- a/tests/cli_integration_test.rs +++ b/tests/cli_integration_test.rs @@ -900,3 +900,85 @@ async fn test_cli_autocomplete_includes_all_commands() -> Result<()> { Ok(()) } + +#[test(tokio::test)] +async fn test_cli_component_load_with_incomplete_credentials() -> Result<()> { + let ctx = CliTestContext::new().await?; + + // Test with only username (should fail validation) + let (_stdout, stderr, exit_code) = ctx + .run_command(&[ + "component", + "load", + "oci://ghcr.io/test/image:latest", + "--registry-user", + "testuser", + ]) + .await?; + + assert_ne!( + exit_code, 0, + "Command should fail with incomplete credentials" + ); + assert!( + stderr.contains("Both --registry-user and --registry-password") + || stderr.contains("must be provided together"), + "Error message should mention incomplete credentials. Got: {stderr}" + ); + + // Test with only password (should fail validation) + let (_stdout2, stderr2, exit_code2) = ctx + .run_command(&[ + "component", + "load", + "oci://ghcr.io/test/image:latest", + "--registry-password", + "testpass", + ]) + .await?; + + assert_ne!( + exit_code2, 0, + "Command should fail with incomplete credentials" + ); + assert!( + stderr2.contains("Both --registry-user and --registry-password") + || stderr2.contains("must be provided together"), + "Error message should mention incomplete credentials. Got: {stderr2}" + ); + + Ok(()) +} + +#[test(tokio::test)] +async fn test_cli_component_load_help_shows_registry_flags() -> Result<()> { + let ctx = CliTestContext::new().await?; + + let (stdout, _stderr, exit_code) = ctx.run_command(&["component", "load", "--help"]).await?; + + assert_eq!(exit_code, 0, "Help command should succeed"); + + // Verify all new flags are documented + assert!( + stdout.contains("--registry-user"), + "Help should document --registry-user flag" + ); + assert!( + stdout.contains("--registry-password"), + "Help should document --registry-password flag" + ); + assert!( + stdout.contains("--registry-password-stdin"), + "Help should document --registry-password-stdin flag" + ); + assert!( + stdout.contains("OCI_REGISTRY_USER"), + "Help should mention OCI_REGISTRY_USER environment variable" + ); + assert!( + stdout.contains("OCI_REGISTRY_PASSWORD"), + "Help should mention OCI_REGISTRY_PASSWORD environment variable" + ); + + Ok(()) +}