diff --git a/Cargo.lock b/Cargo.lock index ca991ee0..10460691 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -706,8 +706,12 @@ dependencies = [ "async-trait", "coldvox-foundation", "dirs", + "hound", "parking_lot", "serde", + "serde_json", + "sha2", + "tempfile", "thiserror 2.0.17", "tokio", "tracing", diff --git a/crates/app/src/runtime.rs b/crates/app/src/runtime.rs index 3af5adc1..f85de539 100644 --- a/crates/app/src/runtime.rs +++ b/crates/app/src/runtime.rs @@ -512,7 +512,7 @@ pub async fn start( // Text injection channel #[cfg(feature = "text-injection")] - let (text_injection_tx, text_injection_rx) = mpsc::channel::(100); + let (_text_injection_tx, text_injection_rx) = mpsc::channel::(100); #[cfg(not(feature = "text-injection"))] let (_text_injection_tx, _text_injection_rx) = mpsc::channel::(100); @@ -521,10 +521,10 @@ pub async fn start( let mut stt_forward_handle: Option> = None; #[allow(unused_variables)] let (stt_handle, vad_fanout_handle) = if let Some(pm) = plugin_manager.clone() { - // This is the single, unified path for STT processing. - #[cfg(feature = "whisper")] - let (session_tx, session_rx) = mpsc::channel::(100); - let stt_audio_rx = audio_tx.subscribe(); + // This is the single, unified path for STT processing. + #[cfg(feature = "whisper")] + let (session_tx, session_rx) = mpsc::channel::(100); + let stt_audio_rx = audio_tx.subscribe(); #[cfg(feature = "whisper")] let (stt_pipeline_tx, stt_pipeline_rx) = mpsc::channel::(100); @@ -550,8 +550,8 @@ pub async fn start( Settings::default(), // Use default settings for now ); - let vad_bcast_tx_clone = vad_bcast_tx.clone(); - let activation_mode = opts.activation_mode; + let vad_bcast_tx_clone = vad_bcast_tx.clone(); + let activation_mode = opts.activation_mode; // This task is the new "translator" from VAD/Hotkey events to generic SessionEvents. let vad_fanout_handle = tokio::spawn(async move { @@ -770,11 +770,9 @@ pub async fn start( #[cfg(test)] mod tests { use super::*; - - + use coldvox_stt::plugin::{FailoverConfig, GcPolicy, PluginSelectionConfig}; use coldvox_stt::TranscriptionEvent; - /// Helper to create default runtime options for testing. fn test_opts(activation_mode: ActivationMode) -> AppRuntimeOptions { diff --git a/crates/app/tests/golden_master.rs b/crates/app/tests/golden_master.rs index e402313c..ab3551d9 100644 --- a/crates/app/tests/golden_master.rs +++ b/crates/app/tests/golden_master.rs @@ -117,20 +117,36 @@ pub mod harness { (Value::Object(ao), Value::Object(bo)) => { let kind_a = ao.get("kind").and_then(|v| v.as_str()).unwrap_or(""); let kind_b = bo.get("kind").and_then(|v| v.as_str()).unwrap_or(""); - if kind_a != kind_b { all_ok = false; break; } + if kind_a != kind_b { + all_ok = false; + break; + } if kind_a == "SpeechEnd" { - let da = ao.get("duration_ms").and_then(|v| v.as_u64()).unwrap_or(0); - let db = bo.get("duration_ms").and_then(|v| v.as_u64()).unwrap_or(0); + let da = + ao.get("duration_ms").and_then(|v| v.as_u64()).unwrap_or(0); + let db = + bo.get("duration_ms").and_then(|v| v.as_u64()).unwrap_or(0); let diff = da.abs_diff(db); - if diff > 128 { all_ok = false; break; } + if diff > 128 { + all_ok = false; + break; + } } else if kind_a == "SpeechStart" { // SpeechStart has no duration, ignore } else { // Unknown kind fallback to strict equality - if av != bv { all_ok = false; break; } + if av != bv { + all_ok = false; + break; + } + } + } + _ => { + if av != bv { + all_ok = false; + break; } } - _ => { if av != bv { all_ok = false; break; } } } } all_ok diff --git a/crates/coldvox-foundation/src/error.rs b/crates/coldvox-foundation/src/error.rs index 398fc14e..305f9cbe 100644 --- a/crates/coldvox-foundation/src/error.rs +++ b/crates/coldvox-foundation/src/error.rs @@ -100,6 +100,9 @@ pub enum SttError { #[error("Invalid configuration: {0}")] InvalidConfig(String), + + #[error("Checksum validation failed: {0}")] + ChecksumFailed(String), } #[derive(Debug, thiserror::Error)] diff --git a/crates/coldvox-stt/Cargo.toml b/crates/coldvox-stt/Cargo.toml index 8fde54a6..185fcffa 100644 --- a/crates/coldvox-stt/Cargo.toml +++ b/crates/coldvox-stt/Cargo.toml @@ -13,13 +13,19 @@ thiserror = "2.0" dirs = "5.0" serde = { version = "1.0", features = ["derive"] } coldvox-foundation = { path = "../coldvox-foundation" } -## Removed Python-dependent faster-whisper backend; will replace with pure Rust implementation +sha2 = "0.10" +serde_json = "1.0" +pyo3 = { version = "0.20", features = ["auto-initialize"], optional = true } +faster-whisper-rs = { git = "https://github.com/gmt-happy/faster-whisper-rs", rev = "319e719", optional = true } +[dev-dependencies] +tempfile = "3.8" +hound = "3.5" [features] default = [] parakeet = [] -whisper = [] # Placeholder until new backend is implemented +whisper = ["dep:pyo3", "dep:faster-whisper-rs"] coqui = [] leopard = [] silero-stt = [] diff --git a/crates/coldvox-stt/src/lib.rs b/crates/coldvox-stt/src/lib.rs index bd649189..e2e35293 100644 --- a/crates/coldvox-stt/src/lib.rs +++ b/crates/coldvox-stt/src/lib.rs @@ -19,6 +19,7 @@ pub use coldvox_foundation::error::ColdVoxError; pub use plugin::SttPlugin; pub use plugin_adapter::PluginAdapter; // adapter for plugin → StreamingStt pub use types::{TranscriptionConfig, TranscriptionEvent, WordInfo}; +pub mod validation; /// Generates unique utterance IDs static UTTERANCE_ID_COUNTER: AtomicU64 = AtomicU64::new(1); diff --git a/crates/coldvox-stt/src/plugins/whisper_plugin.rs b/crates/coldvox-stt/src/plugins/whisper_plugin.rs index a811dac5..8b1573a5 100644 --- a/crates/coldvox-stt/src/plugins/whisper_plugin.rs +++ b/crates/coldvox-stt/src/plugins/whisper_plugin.rs @@ -148,6 +148,34 @@ impl WhisperPlugin { Ok(self.model_size.model_identifier()) } + #[cfg(feature = "whisper")] + fn validate_model(&self, model_path: &Path) -> Result<(), ColdVoxError> { + // Find the checksums file in the same directory as the model + let checksum_path = model_path + .parent() + .ok_or_else(|| { + SttError::ChecksumFailed("Could not determine parent directory of model".to_string()) + })? + .join("models.sha256.json"); + + if checksum_path.exists() { + let checksums = Checksums::load(&checksum_path)?; + checksums.verify(&model_path)?; + info!( + target: "coldvox::stt::whisper", + model = %model_path.display(), + "Model checksum verified successfully" + ); + } else { + warn!( + target: "coldvox::stt::whisper", + path = %checksum_path.display(), + "Checksum file not found, skipping model validation" + ); + } + Ok(()) + } + #[cfg(feature = "whisper")] fn build_whisper_config(&self, config: &TranscriptionConfig) -> WhisperConfig { WhisperConfig { @@ -276,6 +304,13 @@ impl SttPlugin for WhisperPlugin { #[cfg(feature = "whisper")] { let model_id = self.resolve_model_identifier(&config)?; + let model_path = PathBuf::from(&model_id); + + // If the model is a file path, validate its checksum + if model_path.is_file() { + self.validate_model(&model_path)?; + } + let mut whisper_config = self.build_whisper_config(&config); if whisper_config.language.is_none() { whisper_config.language = self.language.clone(); @@ -747,6 +782,263 @@ fn check_whisper_available() -> bool { false } +#[cfg(test)] +mod tests { + use super::*; + use crate::validation::Checksums; + use std::env; + + #[test] + fn model_size_identifier_mapping() { + assert_eq!(WhisperModelSize::Tiny.model_identifier(), "tiny"); + assert_eq!(WhisperModelSize::Base.model_identifier(), "base.en"); + assert_eq!(WhisperModelSize::LargeV3.model_identifier(), "large-v3"); + } + + #[test] + fn parse_model_size() { + assert_eq!( + WhisperPluginFactory::parse_model_size("tiny").unwrap(), + WhisperModelSize::Tiny + ); + assert_eq!( + WhisperPluginFactory::parse_model_size("large-v3").unwrap(), + WhisperModelSize::LargeV3 + ); + assert!(WhisperPluginFactory::parse_model_size("invalid").is_err()); + assert!(WhisperPluginFactory::parse_model_size("").is_err()); + } + + #[test] + fn environment_detection() { + // Test CI detection + env::set_var("CI", "true"); + assert_eq!(detect_environment(), Environment::CI); + env::remove_var("CI"); + + // Test development detection + env::set_var("DEBUG", "1"); + assert_eq!(detect_environment(), Environment::Development); + env::remove_var("DEBUG"); + + // Default to production when no indicators are present + assert_eq!(detect_environment(), Environment::Production); + } + + #[test] + fn model_size_for_memory() { + // Test memory-based model selection + assert_eq!( + WhisperPluginFactory::get_model_size_for_memory(300), + WhisperModelSize::Tiny + ); + assert_eq!( + WhisperPluginFactory::get_model_size_for_memory(750), + WhisperModelSize::Base + ); + assert_eq!( + WhisperPluginFactory::get_model_size_for_memory(1500), + WhisperModelSize::Small + ); + assert_eq!( + WhisperPluginFactory::get_model_size_for_memory(3000), + WhisperModelSize::Medium + ); + assert_eq!( + WhisperPluginFactory::get_model_size_for_memory(8000), + WhisperModelSize::Base + ); + } + + #[test] + fn environment_default_model_sizes() { + // Test default model sizes for each environment + assert_eq!( + default_model_size_for_environment(Environment::CI), + WhisperModelSize::Tiny + ); + + // Development and production depend on memory, so we can't test exact values + // without mocking memory detection + } + + #[test] + fn development_env_prefers_large_on_beefy_machine() { + // Simulate development environment + env::set_var("DEBUG", "1"); + // Simulate a beefy machine with lots of available memory + env::set_var("WHISPER_AVAILABLE_MEM_MB", "16384"); + + assert_eq!(detect_environment(), Environment::Development); + let chosen = default_model_size_for_environment(Environment::Development); + assert_eq!(chosen, WhisperModelSize::LargeV3); + + env::remove_var("WHISPER_AVAILABLE_MEM_MB"); + env::remove_var("DEBUG"); + } + + #[test] + fn production_env_does_not_escalate_to_large_by_default() { + // Ensure no CI or dev markers are present + for var in [ + "CI", + "CONTINUOUS_INTEGRATION", + "GITHUB_ACTIONS", + "GITLAB_CI", + "TRAVIS", + "CIRCLECI", + "JENKINS_URL", + "BUILDKITE", + "RUST_BACKTRACE", + "DEBUG", + "DEV", + ] { + env::remove_var(var); + } + + // Simulate lots of memory + env::set_var("WHISPER_AVAILABLE_MEM_MB", "16384"); + assert_eq!(detect_environment(), Environment::Production); + let chosen = default_model_size_for_environment(Environment::Production); + assert_ne!(chosen, WhisperModelSize::LargeV3); + env::remove_var("WHISPER_AVAILABLE_MEM_MB"); + } + + #[test] + fn whisper_model_size_env_var() { + // Test that WHISPER_MODEL_SIZE environment variable is respected + env::set_var("WHISPER_MODEL_SIZE", "large-v2"); + let factory = WhisperPluginFactory::new(); + assert_eq!(factory.model_size, WhisperModelSize::LargeV2); + env::remove_var("WHISPER_MODEL_SIZE"); + + // Test with invalid value - should fall back to environment default + env::set_var("WHISPER_MODEL_SIZE", "invalid-size"); + let factory = WhisperPluginFactory::new(); + // Should not panic and should use a valid default based on environment + assert!(matches!( + factory.model_size, + WhisperModelSize::Tiny | WhisperModelSize::Base | WhisperModelSize::Small + )); + env::remove_var("WHISPER_MODEL_SIZE"); + } + + #[test] + fn gpu_detection_caching() { + // Ensure WHISPER_DEVICE is not set to test detection + env::remove_var("WHISPER_DEVICE"); + + // First call should trigger detection + let device1 = WhisperPluginFactory::detect_device(); + + // Second call should return cached result without re-running detection + let device2 = WhisperPluginFactory::detect_device(); + + // Both calls should return the same result + assert_eq!(device1, device2); + + // Verify the device is either "cuda" or "cpu" + assert!(device1 == "cuda" || device1 == "cpu"); + } + + #[test] + fn whisper_device_env_var_overrides_cache() { + // Set WHISPER_DEVICE to override detection + env::set_var("WHISPER_DEVICE", "cuda:1"); + + let factory = WhisperPluginFactory::new(); + assert_eq!(factory.device, "cuda:1"); + + env::remove_var("WHISPER_DEVICE"); + } + + #[test] + fn gpu_detection_thread_safety() { + use std::thread; + + // Ensure WHISPER_DEVICE is not set to test detection + env::remove_var("WHISPER_DEVICE"); + + let handles: Vec<_> = (0..10) + .map(|_| thread::spawn(WhisperPluginFactory::detect_device)) + .collect(); + + // All threads should get the same result + let results: Vec = handles + .into_iter() + .map(|handle| handle.join().unwrap()) + .collect(); + + // All results should be identical + let first_result = &results[0]; + assert!(results.iter().all(|r| r == first_result)); + + // Verify the device is either "cuda" or "cpu" + assert!(first_result == "cuda" || first_result == "cpu"); + } + + #[cfg(feature = "whisper")] + mod validation_tests { + use super::*; + use std::fs; + use tempfile::tempdir; + + #[test] + fn test_checksum_validation_success() { + let dir = tempdir().unwrap(); + let model_path = dir.path().join("model.bin"); + let checksum_path = dir.path().join("models.sha256.json"); + + fs::write(&model_path, "dummy model data").unwrap(); + let checksum = crate::validation::compute_sha256(&model_path).unwrap(); + fs::write( + &checksum_path, + format!(r#"{{"model.bin": "{}"}}"#, checksum), + ) + .unwrap(); + + let plugin = WhisperPlugin::new(); + let result = plugin.validate_model(&model_path); + assert!(result.is_ok()); + } + + #[test] + fn test_checksum_validation_failure() { + let dir = tempdir().unwrap(); + let model_path = dir.path().join("model.bin"); + let checksum_path = dir.path().join("models.sha256.json"); + + fs::write(&model_path, "dummy model data").unwrap(); + fs::write( + &checksum_path, + r#"{"model.bin": "invalid_checksum"}"#, + ) + .unwrap(); + + let plugin = WhisperPlugin::new(); + let result = plugin.validate_model(&model_path); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!( + err.downcast_ref::(), + Some(SttError::ChecksumFailed(_)) + )); + } + + #[test] + fn test_missing_checksum_file() { + let dir = tempdir().unwrap(); + let model_path = dir.path().join("model.bin"); + + fs::write(&model_path, "dummy model data").unwrap(); + + let plugin = WhisperPlugin::new(); + let result = plugin.validate_model(&model_path); + assert!(result.is_ok()); + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/coldvox-stt/src/validation.rs b/crates/coldvox-stt/src/validation.rs new file mode 100644 index 00000000..2f00e498 --- /dev/null +++ b/crates/coldvox-stt/src/validation.rs @@ -0,0 +1,121 @@ +//! Types and functions for validating file integrity using SHA256 checksums. + +use serde::Deserialize; +use std::collections::HashMap; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; + +use coldvox_foundation::error::{ColdVoxError, SttError}; + +/// Represents a collection of SHA256 checksums for model files. +/// The keys are filenames, and the values are the corresponding checksums. +#[derive(Debug, Deserialize, Clone)] +pub struct Checksums { + #[serde(flatten)] + pub files: HashMap, +} + +impl Checksums { + /// Loads checksums from a JSON file. + /// + /// # Arguments + /// + /// * `path` - The path to the checksums file. + /// + /// # Returns + /// + /// A `Result` containing the `Checksums` struct or an error if the file + /// cannot be read or parsed. + pub fn load>(path: P) -> Result { + let content = fs::read_to_string(path.as_ref()).map_err(|err| { + SttError::ChecksumFailed(format!( + "Failed to read checksum file at {}: {}", + path.as_ref().display(), + err + )) + })?; + serde_json::from_str(&content).map_err(|err| { + { + SttError::ChecksumFailed(format!( + "Failed to parse checksum file at {}: {}", + path.as_ref().display(), + err + )) + } + .into() + }) + } + + /// Verifies the checksum of a file. + /// + /// # Arguments + /// + /// * `file_path` - The path to the file to verify. + /// + /// # Returns + /// + /// A `Result` indicating whether the checksum is valid or an error if + /// the checksum is missing or does not match. + pub fn verify>(&self, file_path: P) -> Result<(), ColdVoxError> { + let file_path = file_path.as_ref(); + let file_name = file_path.file_name().ok_or_else(|| { + SttError::ChecksumFailed(format!( + "Could not get file name from path: {}", + file_path.display() + )) + })?; + + let expected_checksum = self.files.get(Path::new(file_name)).ok_or_else(|| { + SttError::ChecksumFailed(format!( + "No checksum found for model: {}", + file_name.to_string_lossy() + )) + })?; + + let actual_checksum = compute_sha256(file_path)?; + + if &actual_checksum == expected_checksum { + Ok(()) + } else { + Err(SttError::ChecksumFailed(format!( + "Checksum mismatch for model: {}\n Expected: {}\n Actual: {}", + file_name.to_string_lossy(), + expected_checksum, + actual_checksum + )) + .into()) + } + } +} + +/// Computes the SHA256 checksum of a file. +/// +/// # Arguments +/// +/// * `path` - The path to the file. +/// +/// # Returns +/// +/// A `Result` containing the hex-encoded SHA256 checksum or an error if +/// the file cannot be read. +pub fn compute_sha256>(path: P) -> Result { + use sha2::{Digest, Sha256}; + let mut file = fs::File::open(path.as_ref()).map_err(|err| { + SttError::ChecksumFailed(format!( + "Failed to open file for hashing at {}: {}", + path.as_ref().display(), + err + )) + })?; + let mut hasher = Sha256::new(); + io::copy(&mut file, &mut hasher).map_err(|err| { + SttError::ChecksumFailed(format!( + "Failed to read file for hashing at {}: {}", + path.as_ref().display(), + err + )) + })?; + let hash = hasher.finalize(); + Ok(format!("{:x}", hash)) +} diff --git a/crates/coldvox-text-injection/src/confirm.rs b/crates/coldvox-text-injection/src/confirm.rs index cacff8a6..d99f8cef 100644 --- a/crates/coldvox-text-injection/src/confirm.rs +++ b/crates/coldvox-text-injection/src/confirm.rs @@ -54,11 +54,11 @@ //! - Future: Could extend to clipboard/enigo fallbacks for cross-method validation use crate::types::{InjectionConfig, InjectionResult}; +use coldvox_foundation::error::InjectionError; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::Mutex; -use tracing::{info, warn, debug, trace, error}; -use coldvox_foundation::error::InjectionError; +use tracing::{debug, error, info, trace, warn}; use unicode_segmentation::UnicodeSegmentation; /// Confirmation result for text injection @@ -226,16 +226,12 @@ pub async fn text_changed( .map_err(|e| InjectionError::Other(format!("TextProxy path failed: {e}")))? .build(); - if let Ok(text_proxy) = time::timeout(Duration::from_millis(25), text_fut).await { - if let Ok(text_proxy) = text_proxy { - let get_text_fut = text_proxy.get_text(0, -1); - if let Ok(current_text) = - time::timeout(Duration::from_millis(25), get_text_fut).await - { - if let Ok(current_text) = current_text { - last_text = current_text; - } - } + if let Ok(Ok(text_proxy)) = time::timeout(Duration::from_millis(25), text_fut).await { + let get_text_fut = text_proxy.get_text(0, -1); + if let Ok(Ok(current_text)) = + time::timeout(Duration::from_millis(25), get_text_fut).await + { + last_text = current_text; } } } @@ -271,47 +267,44 @@ pub async fn text_changed( .map_err(|e| InjectionError::Other(format!("TextProxy path failed: {e}")))? .build(); - if let Ok(text_proxy) = time::timeout(poll_interval, text_fut).await { - if let Ok(text_proxy) = text_proxy { - let get_text_fut = text_proxy.get_text(0, -1); - if let Ok(current_text) = time::timeout(poll_interval, get_text_fut).await { - if let Ok(current_text) = current_text { - // Check if text has changed and matches our prefix - if current_text != last_text { - trace!( - old_text = %last_text, - new_text = %current_text, - "Text content changed during polling" + if let Ok(Ok(text_proxy)) = time::timeout(poll_interval, text_fut).await { + let get_text_fut = text_proxy.get_text(0, -1); + if let Ok(Ok(current_text)) = time::timeout(poll_interval, get_text_fut).await + { + // Check if text has changed and matches our prefix + if current_text != last_text { + trace!( + old_text = %last_text, + new_text = %current_text, + "Text content changed during polling" + ); + + // Extract the new portion (last few characters) + if current_text.len() > last_text.len() { + let new_chars = ¤t_text[last_text.len()..]; + + debug!( + new_chars = %new_chars, + expected_prefix = %expected_prefix, + "Checking if new text matches expected prefix" + ); + + if TextChangeListener::matches_prefix( + new_chars, + &expected_prefix, + ) { + info!( + new_chars = %new_chars, + expected_prefix = %expected_prefix, + elapsed_ms = %start_time.elapsed().as_millis(), + poll_count = %poll_count, + "Text change confirmed via AT-SPI polling" ); - - // Extract the new portion (last few characters) - if current_text.len() > last_text.len() { - let new_chars = ¤t_text[last_text.len()..]; - - debug!( - new_chars = %new_chars, - expected_prefix = %expected_prefix, - "Checking if new text matches expected prefix" - ); - - if TextChangeListener::matches_prefix( - new_chars, - &expected_prefix, - ) { - info!( - new_chars = %new_chars, - expected_prefix = %expected_prefix, - elapsed_ms = %start_time.elapsed().as_millis(), - poll_count = %poll_count, - "Text change confirmed via AT-SPI polling" - ); - return Ok(ConfirmationResult::Success); - } - } - - last_text = current_text; + return Ok(ConfirmationResult::Success); } } + + last_text = current_text; } } } diff --git a/crates/coldvox-text-injection/src/injectors/atspi.rs b/crates/coldvox-text-injection/src/injectors/atspi.rs index cf85d11b..c371a5a0 100644 --- a/crates/coldvox-text-injection/src/injectors/atspi.rs +++ b/crates/coldvox-text-injection/src/injectors/atspi.rs @@ -5,6 +5,8 @@ //! while providing the new TextInjector trait interface. use crate::confirm::{create_confirmation_context, ConfirmationContext}; +use crate::log_throttle::log_atspi_connection_failure; +use crate::logging::utils; use crate::types::{ InjectionConfig, InjectionContext, InjectionMethod, InjectionMode, InjectionResult, }; @@ -13,8 +15,6 @@ use async_trait::async_trait; use coldvox_foundation::error::InjectionError; use std::time::Instant; use tracing::{debug, trace, warn}; -use crate::log_throttle::log_atspi_connection_failure; -use crate::logging::utils; // Re-export the old Context type for backwards compatibility #[deprecated( diff --git a/crates/coldvox-text-injection/src/injectors/unified_clipboard.rs b/crates/coldvox-text-injection/src/injectors/unified_clipboard.rs index b3d8f858..bbcf5782 100644 --- a/crates/coldvox-text-injection/src/injectors/unified_clipboard.rs +++ b/crates/coldvox-text-injection/src/injectors/unified_clipboard.rs @@ -590,66 +590,6 @@ impl UnifiedClipboardInjector { } } - /// Helper to restore clipboard content without borrowing &self - /// Uses wl-copy if available (feature-enabled path handled earlier), otherwise xclip. - async fn restore_clipboard_direct(content: Vec) -> InjectionResult<()> { - // Try wl-copy first if present at runtime - let wl_copy_ok = tokio::process::Command::new("which") - .arg("wl-copy") - .output() - .await - .map(|o| o.status.success()) - .unwrap_or(false); - - if wl_copy_ok { - let mut child = tokio::process::Command::new("wl-copy") - .stdin(Stdio::piped()) - .spawn() - .map_err(|e| InjectionError::Process(format!("Failed to spawn wl-copy: {}", e)))?; - if let Some(mut stdin) = child.stdin.take() { - timeout(Duration::from_millis(1000), stdin.write_all(&content)) - .await - .map_err(|_| InjectionError::Timeout(1000)) - .and_then(|r| { - r.map_err(|e| InjectionError::Process(format!("wl-copy stdin: {}", e))) - })?; - } - let status = child - .wait() - .await - .map_err(|e| InjectionError::Process(format!("wl-copy wait: {}", e)))?; - return if status.success() { - Ok(()) - } else { - Err(InjectionError::Process("wl-copy failed".into())) - }; - } - - // Fallback to xclip - let mut child = tokio::process::Command::new("xclip") - .args(["-selection", "clipboard"]) - .stdin(Stdio::piped()) - .spawn() - .map_err(|e| InjectionError::Process(format!("Failed to spawn xclip: {}", e)))?; - if let Some(mut stdin) = child.stdin.take() { - timeout(Duration::from_millis(1000), stdin.write_all(&content)) - .await - .map_err(|_| InjectionError::Timeout(1000)) - .and_then(|r| { - r.map_err(|e| InjectionError::Process(format!("xclip stdin: {}", e))) - })?; - } - let status = child - .wait() - .await - .map_err(|e| InjectionError::Process(format!("xclip wait: {}", e)))?; - if status.success() { - Ok(()) - } else { - Err(InjectionError::Process("xclip failed".into())) - } - } - // ...existing code... /// Main injection method with configurable behavior diff --git a/crates/coldvox-text-injection/src/prewarm.rs b/crates/coldvox-text-injection/src/prewarm.rs index 8065bbb8..7c635526 100644 --- a/crates/coldvox-text-injection/src/prewarm.rs +++ b/crates/coldvox-text-injection/src/prewarm.rs @@ -9,7 +9,7 @@ use crate::types::{InjectionConfig, InjectionMethod, InjectionResult}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{Mutex, RwLock}; -use tracing::{debug, info, warn, trace}; +use tracing::{debug, info, trace, warn}; /// TTL for cached pre-warmed data (3 seconds) const CACHE_TTL: Duration = Duration::from_secs(3); @@ -161,7 +161,7 @@ impl PrewarmController { /// Pre-warm AT-SPI connection and snapshot focused element async fn prewarm_atspi(&self) -> Result { - let start_time = Instant::now(); + let start_time = Instant::now(); debug!("Starting AT-SPI pre-warming"); #[cfg(feature = "atspi")] @@ -246,7 +246,7 @@ impl PrewarmController { /// Arm the event listener for text change confirmation async fn arm_event_listener(&self) -> Result { - let start_time = Instant::now(); + let start_time = Instant::now(); debug!("Arming event listener for text change confirmation"); #[cfg(feature = "atspi")] diff --git a/crates/coldvox-text-injection/src/tests/wl_copy_stdin_test.rs b/crates/coldvox-text-injection/src/tests/wl_copy_stdin_test.rs index 8e67d069..c95ff706 100644 --- a/crates/coldvox-text-injection/src/tests/wl_copy_stdin_test.rs +++ b/crates/coldvox-text-injection/src/tests/wl_copy_stdin_test.rs @@ -11,9 +11,7 @@ use crate::types::{InjectionConfig, InjectionContext}; use std::process::Command; use std::time::Duration; -use super::test_utils::{ - command_exists, is_wayland_environment, read_clipboard_with_wl_paste, -}; +use super::test_utils::{command_exists, is_wayland_environment, read_clipboard_with_wl_paste}; /// Test that wl-copy properly receives content via stdin /// This is the core test for the stdin piping fix @@ -152,7 +150,7 @@ async fn test_wl_copy_timeout_handling() { // Create config with very short timeout to force timeout let config = InjectionConfig { - per_method_timeout_ms: 10, // Very short timeout + per_method_timeout_ms: 10, // Very short timeout paste_action_timeout_ms: 10, // Very short timeout ..Default::default() }; diff --git a/crates/coldvox-text-injection/src/ydotool_injector.rs b/crates/coldvox-text-injection/src/ydotool_injector.rs index 8b4204c0..0806bf07 100644 --- a/crates/coldvox-text-injection/src/ydotool_injector.rs +++ b/crates/coldvox-text-injection/src/ydotool_injector.rs @@ -53,12 +53,9 @@ fn candidate_socket_paths() -> Vec { } fn locate_existing_socket() -> Option { - for candidate in candidate_socket_paths() { - if Path::new(&candidate).exists() { - return Some(candidate); - } - } - None + candidate_socket_paths() + .into_iter() + .find(|candidate| Path::new(&candidate).exists()) } fn preferred_socket_path() -> Option {