diff --git a/backend/app/database/faces.py b/backend/app/database/faces.py index 0063b1bf..e681f323 100644 --- a/backend/app/database/faces.py +++ b/backend/app/database/faces.py @@ -1,128 +1,223 @@ +from __future__ import annotations + import sqlite3 -import json import numpy as np +import json +from pathlib import Path +from typing import List, Optional, Tuple +import logging +from contextlib import contextmanager from app.config.settings import DATABASE_PATH +logger = logging.getLogger(__name__) -def create_faces_table(): +@contextmanager +def get_db_connection(): conn = sqlite3.connect(DATABASE_PATH) - cursor = conn.cursor() - cursor.execute( + try: + yield conn + finally: + conn.close() + +def create_faces_table(): + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS faces ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + image_id INTEGER, + embeddings TEXT, + FOREIGN KEY (image_id) REFERENCES image_id_mapping(id) ON DELETE CASCADE + ) """ - CREATE TABLE IF NOT EXISTS faces ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - image_id INTEGER, - embeddings TEXT, - FOREIGN KEY (image_id) REFERENCES image_id_mapping(id) ON DELETE CASCADE ) - """ - ) - conn.commit() - conn.close() - + conn.commit() def insert_face_embeddings(image_path, embeddings): from app.database.images import get_id_from_path - conn = sqlite3.connect(DATABASE_PATH) - cursor = conn.cursor() - - image_id = get_id_from_path(image_path) - if image_id is None: - conn.close() - raise ValueError(f"Image '{image_path}' not found in the database") - - embeddings_json = json.dumps([emb.tolist() for emb in embeddings]) - - cursor.execute( - """ - INSERT OR REPLACE INTO faces (image_id, embeddings) - VALUES (?, ?) - """, - (image_id, embeddings_json), - ) - - conn.commit() - conn.close() - + with get_db_connection() as conn: + cursor = conn.cursor() + + image_id = get_id_from_path(image_path) + if image_id is None: + raise ValueError(f"Image '{image_path}' not found in the database") + + # Continue with the rest of the function + # (This part was missing in the merge conflict) + +def init_face_db(): + """Initialize the faces database with required tables.""" + with get_db_connection() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS face_embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + image_path TEXT NOT NULL, + embedding BLOB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Add index for faster queries + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_image_path + ON face_embeddings(image_path) + """) + conn.commit() + +def store_face_embeddings_batch( + image_paths: List[str], + embeddings: List[np.ndarray] +) -> None: + """ + Store multiple face embeddings in a single transaction. + + Args: + image_paths: List of image paths + embeddings: List of corresponding embeddings + """ + if len(image_paths) != len(embeddings): + raise ValueError("Number of paths must match number of embeddings") + + with get_db_connection() as conn: + try: + conn.executemany( + "INSERT INTO face_embeddings (image_path, embedding) VALUES (?, ?)", + [ + (path, embedding.tobytes()) + for path, embedding in zip(image_paths, embeddings) + ] + ) + conn.commit() + except sqlite3.Error as e: + logger.error(f"Database error: {str(e)}") + raise + +def store_face_embedding(image_path: str, embedding: np.ndarray) -> None: + """ + Store a single face embedding. + + Args: + image_path: Path to the image + embedding: Face embedding vector + """ + store_face_embeddings_batch([image_path], [embedding]) -def get_face_embeddings(image_path): +def get_face_embeddings(image_path: str) -> List[np.ndarray]: + """ + Retrieve all face embeddings for an image. + + Args: + image_path: Path to the image + + Returns: + List of face embeddings + """ + # First try the new performance-optimized table + with get_db_connection() as conn: + cursor = conn.execute( + "SELECT embedding FROM face_embeddings WHERE image_path = ?", + (image_path,) + ) + results = cursor.fetchall() + if results: + return [ + np.frombuffer(row[0], dtype=np.float32) + for row in results + ] + + # Fall back to the old table structure if needed from app.database.images import get_id_from_path - conn = sqlite3.connect(DATABASE_PATH) - cursor = conn.cursor() - - image_id = get_id_from_path(image_path) - if image_id is None: - conn.close() - return None + with get_db_connection() as conn: + cursor = conn.cursor() - cursor.execute( - """ - SELECT embeddings FROM faces - WHERE image_id = ? - """, - (image_id,), - ) + image_id = get_id_from_path(image_path) + if image_id is None: + return None - result = cursor.fetchone() - conn.close() - - if result: - embeddings_json = result[0] - embeddings = np.array(json.loads(embeddings_json)) - return embeddings - else: - return None + cursor.execute( + """ + SELECT embeddings FROM faces + WHERE image_id = ? + """, + (image_id,), + ) + result = cursor.fetchone() -def get_all_face_embeddings(): - from app.database.images import get_path_from_id - - conn = sqlite3.connect(DATABASE_PATH) - cursor = conn.cursor() + if result: + embeddings_json = result[0] + embeddings = np.array(json.loads(embeddings_json)) + return [embeddings] + else: + return [] - cursor.execute( - """ - SELECT image_id, embeddings FROM faces +def get_all_face_embeddings() -> List[Tuple[str, np.ndarray]]: + """ + Retrieve all face embeddings from the database. + + Returns: + List of tuples containing (image_path, embedding) """ - ) + # First try the new performance-optimized table + with get_db_connection() as conn: + cursor = conn.execute( + "SELECT image_path, embedding FROM face_embeddings" + ) + results = cursor.fetchall() + if results: + return [ + (row[0], np.frombuffer(row[1], dtype=np.float32)) + for row in results + ] + + # Fall back to the old table structure if needed + from app.database.images import get_path_from_id - results = cursor.fetchall() - all_embeddings = [] - for image_id, embeddings_json in results: - image_path = get_path_from_id(image_id) - embeddings = np.array(json.loads(embeddings_json)) - all_embeddings.append({"image_path": image_path, "embeddings": embeddings}) - print("returning") - conn.close() - return all_embeddings + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT image_id, embeddings FROM faces + """ + ) + results = cursor.fetchall() + all_embeddings = [] + for image_id, embeddings_json in results: + image_path = get_path_from_id(image_id) + embeddings = np.array(json.loads(embeddings_json)) + all_embeddings.append((image_path, embeddings)) + + return all_embeddings + +def clear_face_embeddings() -> None: + """Clear all face embeddings from the database.""" + with get_db_connection() as conn: + conn.execute("DELETE FROM face_embeddings") + conn.commit() def delete_face_embeddings(image_id): - conn = sqlite3.connect(DATABASE_PATH) - cursor = conn.cursor() - - cursor.execute("DELETE FROM faces WHERE image_id = ?", (image_id,)) - - conn.commit() - conn.close() - + with get_db_connection() as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM faces WHERE image_id = ?", (image_id,)) + conn.commit() def cleanup_face_embeddings(): - conn = sqlite3.connect(DATABASE_PATH) - cursor = conn.cursor() + with get_db_connection() as conn: + cursor = conn.cursor() - cursor.execute("SELECT DISTINCT image_id FROM faces") - face_image_ids = set(row[0] for row in cursor.fetchall()) + cursor.execute("SELECT DISTINCT image_id FROM faces") + face_image_ids = set(row[0] for row in cursor.fetchall()) - cursor.execute("SELECT id FROM image_id_mapping") - valid_image_ids = set(row[0] for row in cursor.fetchall()) + cursor.execute("SELECT id FROM image_id_mapping") + valid_image_ids = set(row[0] for row in cursor.fetchall()) - orphaned_ids = face_image_ids - valid_image_ids + orphaned_ids = face_image_ids - valid_image_ids - for orphaned_id in orphaned_ids: - cursor.execute("DELETE FROM faces WHERE image_id = ?", (orphaned_id,)) + for orphaned_id in orphaned_ids: + cursor.execute("DELETE FROM faces WHERE image_id = ?", (orphaned_id,)) - conn.commit() - conn.close() + conn.commit() diff --git a/backend/app/facecluster/facecluster.py b/backend/app/facecluster/facecluster.py index 80ebe78a..5e818340 100644 --- a/backend/app/facecluster/facecluster.py +++ b/backend/app/facecluster/facecluster.py @@ -5,10 +5,10 @@ from sklearn.metrics.pairwise import cosine_distances import sqlite3 import json -from collections import defaultdict +from collections import defaultdict, deque from contextlib import contextmanager import logging -from typing import Dict, List, Optional, Set, Union, Any, Callable, TypeVar, ParamSpec +from typing import Dict, List, Optional, Set, Union, Any, Callable, TypeVar, ParamSpec, Deque from pathlib import Path import time from functools import wraps @@ -116,7 +116,9 @@ def __init__( eps: float = 0.3, min_samples: int = 2, metric: str = "cosine", + db_path: Union[str, Path] = DATABASE_PATH, + batch_size: int = 50 # Parameter for batch proc ) -> None: """ Initialize the face cluster manager. @@ -126,6 +128,7 @@ def __init__( min_samples: DBSCAN minimum samples parameter metric: Distance metric for clustering db_path: Path to the database + batch_size: Number of embeddings to process before full reclustering """ self.eps = eps self.min_samples = min_samples @@ -141,6 +144,11 @@ def __init__( self.image_ids: List[str] = [] self.labels: Optional[NDArray] = None self.db_path = Path(db_path) + + # Attributes for batch processing + self.batch_size = batch_size + self.pending_embeddings: Deque[tuple[NDArray, str]] = deque() + self.needs_reclustering = False # Initialize database self._init_database() @@ -219,8 +227,8 @@ def get_clusters(self) -> Dict[int, List[str]]: def add_face(self, embedding: NDArray, image_path: str) -> Dict[int, List[str]]: """ - Add a new face embedding to the clusters. - + Add a new face embedding to the pending queue. + Args: embedding: Face embedding vector image_path: Path to the image @@ -229,16 +237,30 @@ def add_face(self, embedding: NDArray, image_path: str) -> Dict[int, List[str]]: Updated clustering results """ image_id = get_id_from_path(image_path) + self.pending_embeddings.append((embedding, image_id)) + + # If we've reached batch size, process the batch + if len(self.pending_embeddings) >= self.batch_size: + return self._process_batch() + + # If we have existing clusters, do quick assignment + if len(self.embeddings) > 0: + return self._quick_assign(embedding, image_id) + + return self.get_clusters() + def _quick_assign(self, embedding: NDArray, image_id: str) -> Dict[int, List[str]]: + """ + Quickly assign a new face to existing clusters without full reclustering. + """ if len(self.embeddings) == 0: self.embeddings = np.array([embedding]) self.image_ids = [image_id] self.labels = np.array([-1]) else: - # Vectorized distance calculation distances = cosine_distances(embedding.reshape(1, -1), self.embeddings)[0] nearest_neighbor = np.argmin(distances) - + # Determine cluster assignment if distances[nearest_neighbor] <= self.eps: new_label = self.labels[nearest_neighbor] @@ -249,11 +271,47 @@ def add_face(self, embedding: NDArray, image_path: str) -> Dict[int, List[str]]: self.embeddings = np.vstack([self.embeddings, embedding]) self.image_ids.append(image_id) self.labels = np.append(self.labels, new_label) + + self.needs_reclustering = True + return self.get_clusters() + def _process_batch(self) -> Dict[int, List[str]]: + """ + Process all pending embeddings in the batch. + """ + if not self.pending_embeddings: + return self.get_clusters() + + # Add all pending embeddings to main arrays + new_embeddings = [] + new_image_ids = [] + + while self.pending_embeddings: + embedding, image_id = self.pending_embeddings.popleft() + new_embeddings.append(embedding) + new_image_ids.append(image_id) + + if len(self.embeddings) == 0: + self.embeddings = np.array(new_embeddings) + self.image_ids = new_image_ids + else: + self.embeddings = np.vstack([self.embeddings, new_embeddings]) + self.image_ids.extend(new_image_ids) + + # Perform full clustering + self.labels = self.dbscan.fit_predict(self.embeddings) + self.needs_reclustering = False + self._clear_caches() self.save_to_db() return self.get_clusters() + def force_recluster(self) -> Dict[int, List[str]]: + """ + Force immediate processing of all pending embeddings and full reclustering. + """ + return self._process_batch() + @TTLCache(maxsize=128, ttl=3600) def get_related_images(self, image_id: str) -> List[str]: """ @@ -363,4 +421,4 @@ def load_from_db(cls, db_path: Union[str, Path] = DATABASE_PATH) -> "FaceCluster except sqlite3.OperationalError as e: logger.error(f"Database error: {e}") - return instance + return instance \ No newline at end of file diff --git a/backend/app/facenet/batch_processor.py b/backend/app/facenet/batch_processor.py new file mode 100644 index 00000000..1d3d6185 --- /dev/null +++ b/backend/app/facenet/batch_processor.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import List, Dict, Optional +import numpy as np +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Lock + +from app.facenet.facenet import FaceNet +from app.facecluster.facecluster import FaceCluster +from app.utils.path_id_mapping import get_id_from_path +from app.database.faces import store_face_embedding + +logger = logging.getLogger(__name__) + +class BatchProcessor: + def __init__( + self, + face_net: Optional[FaceNet] = None, + face_cluster: Optional[FaceCluster] = None, + max_workers: int = 4 + ): + self.face_net = face_net or FaceNet() + self.face_cluster = face_cluster or FaceCluster() + self.max_workers = max_workers + self.processing_lock = Lock() + + def process_images(self, image_paths: List[Path]) -> Dict[int, List[str]]: + """ + Process multiple images in parallel and update clusters. + + Args: + image_paths: List of paths to images to process + + Returns: + Updated clustering results + """ + logger.info(f"Starting batch processing of {len(image_paths)} images") + + embeddings_batch = [] + valid_paths = [] + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + # Submit all image processing tasks + future_to_path = { + executor.submit(self._process_single_image, path): path + for path in image_paths + } + + # Collect results as they complete + for future in as_completed(future_to_path): + path = future_to_path[future] + try: + embeddings = future.result() + if embeddings is not None and len(embeddings) > 0: + embeddings_batch.extend(embeddings) + valid_paths.extend([path] * len(embeddings)) + except Exception as e: + logger.error(f"Error processing image {path}: {str(e)}") + continue + + # Process all collected embeddings + with self.processing_lock: + return self._update_clusters(embeddings_batch, valid_paths) + + def _process_single_image(self, image_path: Path) -> Optional[List[np.ndarray]]: + """ + Process a single image to extract face embeddings. + + Args: + image_path: Path to the image + + Returns: + List of face embeddings if faces found, None otherwise + """ + try: + # Detect faces and generate embeddings + faces = self.face_net.detect_faces(str(image_path)) + if not faces: + return None + + embeddings = [] + for face in faces: + embedding = self.face_net.generate_embedding(face) + if embedding is not None: + embeddings.append(embedding) + # Store embedding in database + store_face_embedding(str(image_path), embedding) + + return embeddings + + except Exception as e: + logger.error(f"Error processing image {image_path}: {str(e)}") + return None + + def _update_clusters( + self, + embeddings_batch: List[np.ndarray], + image_paths: List[Path] + ) -> Dict[int, List[str]]: + """ + Update clusters with new embeddings. + + Args: + embeddings_batch: List of face embeddings + image_paths: Corresponding image paths + + Returns: + Updated clustering results + """ + if not embeddings_batch: + return self.face_cluster.get_clusters() + + # Convert paths to IDs + image_ids = [get_id_from_path(str(path)) for path in image_paths] + + # Add all embeddings to the cluster manager + for embedding, image_id in zip(embeddings_batch, image_ids): + self.face_cluster.add_face(embedding, image_id) + + # Force reclustering after batch + return self.face_cluster.force_recluster() \ No newline at end of file diff --git a/backend/tests/test_batch_processor.py b/backend/tests/test_batch_processor.py new file mode 100644 index 00000000..4c387296 --- /dev/null +++ b/backend/tests/test_batch_processor.py @@ -0,0 +1,80 @@ +import pytest +import numpy as np +from pathlib import Path +from unittest.mock import Mock, patch + +from app.facenet.batch_processor import BatchProcessor +from app.facenet.facenet import FaceNet +from app.facecluster.facecluster import FaceCluster + +@pytest.fixture +def mock_face_net(): + face_net = Mock(spec=FaceNet) + # Mock detect_faces to return some dummy faces + face_net.detect_faces.return_value = [np.zeros((160, 160, 3))] + # Mock generate_embedding to return a dummy embedding + face_net.generate_embedding.return_value = np.random.rand(512) + return face_net + +@pytest.fixture +def mock_face_cluster(): + face_cluster = Mock(spec=FaceCluster) + face_cluster.add_face.return_value = {} + face_cluster.force_recluster.return_value = {0: ["test_image.jpg"]} + return face_cluster + +def test_batch_processor_init(): + processor = BatchProcessor() + assert processor.max_workers == 4 + assert processor.face_net is not None + assert processor.face_cluster is not None + +def test_process_images(mock_face_net, mock_face_cluster): + processor = BatchProcessor( + face_net=mock_face_net, + face_cluster=mock_face_cluster + ) + + # Create test image paths + image_paths = [Path(f"test_image_{i}.jpg") for i in range(5)] + + # Process images + results = processor.process_images(image_paths) + + # Verify face detection was called for each image + assert mock_face_net.detect_faces.call_count == len(image_paths) + + # Verify embedding generation was called for each face + assert mock_face_net.generate_embedding.call_count == len(image_paths) + + # Verify clustering was updated + assert mock_face_cluster.force_recluster.called + assert isinstance(results, dict) + +def test_process_images_with_errors(mock_face_net, mock_face_cluster): + processor = BatchProcessor( + face_net=mock_face_net, + face_cluster=mock_face_cluster + ) + + # Make detect_faces fail for some images + def mock_detect_faces(path): + if "error" in str(path): + raise Exception("Test error") + return [np.zeros((160, 160, 3))] + + mock_face_net.detect_faces.side_effect = mock_detect_faces + + # Create test image paths with some error cases + image_paths = [ + Path("test_image_1.jpg"), + Path("error_image.jpg"), + Path("test_image_2.jpg") + ] + + # Process images + results = processor.process_images(image_paths) + + # Verify processing continues despite errors + assert isinstance(results, dict) + assert mock_face_net.detect_faces.call_count == len(image_paths) \ No newline at end of file diff --git a/frontend/src-tauri/Cargo.lock b/frontend/src-tauri/Cargo.lock index 7753ad93..914b2072 100644 --- a/frontend/src-tauri/Cargo.lock +++ b/frontend/src-tauri/Cargo.lock @@ -13,7 +13,9 @@ dependencies = [ "data-encoding", "directories", "image", + "lazy_static", "rand 0.8.5", + "rayon", "ring 0.16.20", "serde", "serde_json", diff --git a/frontend/src-tauri/Cargo.toml b/frontend/src-tauri/Cargo.toml index e89f35cd..88283253 100644 --- a/frontend/src-tauri/Cargo.toml +++ b/frontend/src-tauri/Cargo.toml @@ -28,6 +28,8 @@ tempfile = "3" arrayref = "0.3.6" directories = "4.0" chrono = { version = "0.4.26", features = ["serde"] } +rayon = "1.8" +lazy_static = "1.4.0" base64 = "0.21.0" rand = "0.8.5" diff --git a/frontend/src-tauri/src/services/mod.rs b/frontend/src-tauri/src/services/mod.rs index c6edec88..32483dda 100644 --- a/frontend/src-tauri/src/services/mod.rs +++ b/frontend/src-tauri/src/services/mod.rs @@ -22,11 +22,53 @@ use std::num::NonZeroU32; use std::process::Command; use tauri::path::BaseDirectory; use tauri::Manager; +use rayon::prelude::*; +use std::sync::Mutex; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::collections::hash_map::DefaultHasher; +use lazy_static::lazy_static; +use lru::LruCache; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::num::NonZeroUsize; + +// Constants for cache configuration +const MAX_CACHE_SIZE: usize = 100; // Maximum number of cached images +const MAX_IMAGE_DIM: u32 = 4096; // Maximum cached image dimension +const MAX_CACHE_MEMORY: usize = 1024 * 1024 * 1024; // 1GB max cache size pub const SECURE_FOLDER_NAME: &str = "secure_folder"; const SALT_LENGTH: usize = 16; const NONCE_LENGTH: usize = 12; +#[derive(Clone)] +struct ImageAdjustment { + brightness: i32, + contrast: i32, +} + +impl Hash for ImageAdjustment { + fn hash(&self, state: &mut H) { + self.brightness.hash(state); + self.contrast.hash(state); + } +} + +lazy_static! { + static ref IMAGE_CACHE: Mutex> = Mutex::new( + LruCache::new(NonZeroUsize::new(MAX_CACHE_SIZE).unwrap()) + ); + static ref CURRENT_CACHE_MEMORY: AtomicUsize = AtomicUsize::new(0); + static ref CACHE_HITS: AtomicUsize = AtomicUsize::new(0); + static ref CACHE_MISSES: AtomicUsize = AtomicUsize::new(0); + static ref CACHE_EVICTIONS: AtomicUsize = AtomicUsize::new(0); + static ref TOTAL_PROCESSING_TIME: AtomicUsize = AtomicUsize::new(0); +} + +fn calculate_image_memory_size(img: &DynamicImage) -> usize { + img.width() as usize * img.height() as usize * 4 // 4 bytes per pixel (RGBA) +} + #[derive(Serialize, Deserialize)] pub struct SecureMedia { pub id: String, @@ -967,3 +1009,76 @@ pub fn get_server_path(handle: tauri::AppHandle) -> Result { .map_err(|e| e.to_string())?; Ok(resource_path.to_string_lossy().to_string()) } + +#[tauri::command] +pub fn clear_image_cache() { + let mut cache = IMAGE_CACHE.lock().unwrap(); + cache.clear(); + CURRENT_CACHE_MEMORY.store(0, Ordering::Relaxed); +} + +#[tauri::command] +pub fn get_cache_stats() -> Result { + let cache = IMAGE_CACHE.lock().map_err(|e| { + error!("Failed to acquire cache lock for stats: {}", e); + format!("Cache lock error: {}", e) + })?; + + let current_memory = CURRENT_CACHE_MEMORY.load(Ordering::Relaxed); + let hits = CACHE_HITS.load(Ordering::Relaxed); + let misses = CACHE_MISSES.load(Ordering::Relaxed); + let total_requests = hits + misses; + let hit_rate = if total_requests > 0 { + (hits as f64 / total_requests as f64) * 100.0 + } else { + 0.0 + }; + + let stats = CacheStats { + current_memory_bytes: current_memory, + max_memory_bytes: MAX_CACHE_MEMORY, + current_items: cache.len(), + max_items: MAX_CACHE_SIZE, + cache_hits: hits, + cache_misses: misses, + cache_evictions: CACHE_EVICTIONS.load(Ordering::Relaxed), + hit_rate_percentage: hit_rate, + memory_usage_percentage: (current_memory as f64 / MAX_CACHE_MEMORY as f64) * 100.0, + average_processing_time_ms: if total_requests > 0 { + TOTAL_PROCESSING_TIME.load(Ordering::Relaxed) as f64 / total_requests as f64 + } else { + 0.0 + }, + }; + + debug!("Cache stats retrieved: {:?}", stats); + Ok(stats) +} + +#[tauri::command] +pub fn reset_cache_stats() -> Result<(), String> { + CACHE_HITS.store(0, Ordering::Relaxed); + CACHE_MISSES.store(0, Ordering::Relaxed); + CACHE_EVICTIONS.store(0, Ordering::Relaxed); + TOTAL_PROCESSING_TIME.store(0, Ordering::Relaxed); + + info!("Cache statistics reset"); + Ok(()) +} + +#[derive(Serialize)] +pub struct CacheStats { + current_memory_bytes: usize, + max_memory_bytes: usize, + current_items: usize, + max_items: usize, + cache_hits: usize, + cache_misses: usize, + cache_evictions: usize, + hit_rate_percentage: f64, + memory_usage_percentage: f64, + average_processing_time_ms: f64, +} + +#[cfg(test)] +mod tests; diff --git a/frontend/src-tauri/src/services/tests.rs b/frontend/src-tauri/src/services/tests.rs new file mode 100644 index 00000000..8251f2f8 --- /dev/null +++ b/frontend/src-tauri/src/services/tests.rs @@ -0,0 +1,125 @@ +#[cfg(test)] +mod tests { + use super::*; + use image::{DynamicImage, ImageBuffer, Rgba}; + use std::time::Instant; + + fn create_test_image(width: u32, height: u32) -> DynamicImage { + let mut img = ImageBuffer::new(width, height); + for x in 0..width { + for y in 0..height { + img.put_pixel(x, y, Rgba([100, 150, 200, 255])); + } + } + DynamicImage::ImageRgba8(img) + } + + #[test] + fn test_brightness_contrast_correctness() { + let img = create_test_image(100, 100); + + // Test with various brightness and contrast values + let test_cases = vec![ + (0, 0), // No change + (50, 0), // Increased brightness + (0, 50), // Increased contrast + (-50, 0), // Decreased brightness + (0, -50), // Decreased contrast + ]; + + for (brightness, contrast) in test_cases { + let result = adjust_brightness_contrast(&img, brightness, contrast); + + // Verify dimensions are preserved + assert_eq!(result.width(), img.width()); + assert_eq!(result.height(), img.height()); + + // Verify pixel values are within valid range + let result_buffer = result.to_rgba8(); + for pixel in result_buffer.pixels() { + assert!(pixel[0] <= 255); + assert!(pixel[1] <= 255); + assert!(pixel[2] <= 255); + assert_eq!(pixel[3], 255); // Alpha should remain unchanged + } + } + } + + #[test] + fn test_caching_performance() { + let img = create_test_image(500, 500); + + // Clear cache before testing + clear_image_cache(); + + // First run - should be slower (no cache) + let start = Instant::now(); + let _ = adjust_brightness_contrast(&img, 50, 50); + let first_duration = start.elapsed(); + + // Second run - should be faster (cached) + let start = Instant::now(); + let _ = adjust_brightness_contrast(&img, 50, 50); + let second_duration = start.elapsed(); + + // The second run should be significantly faster due to caching + assert!(second_duration < first_duration); + println!("First run: {:?}, Second run: {:?}", first_duration, second_duration); + } + + #[test] + fn test_parallel_processing() { + let img = create_test_image(1000, 1000); + + // Process large image and measure time + let start = Instant::now(); + let _ = adjust_brightness_contrast(&img, 30, 30); + let duration = start.elapsed(); + + // Print processing time for manual verification + println!("Processing time for 1000x1000 image: {:?}", duration); + } + + #[test] + fn test_different_image_sizes() { + let sizes = vec![(100, 100), (200, 300), (500, 500)]; + + for (width, height) in sizes { + let img = create_test_image(width, height); + let result = adjust_brightness_contrast(&img, 20, 20); + + assert_eq!(result.width(), width); + assert_eq!(result.height(), height); + } + } + + #[test] + fn test_extreme_values() { + let img = create_test_image(100, 100); + + // Test with extreme brightness and contrast values + let extreme_cases = vec![ + (100, 100), // Maximum brightness and contrast + (-100, -100), // Minimum brightness and contrast + (100, -100), // Max brightness, min contrast + (-100, 100), // Min brightness, max contrast + ]; + + for (brightness, contrast) in extreme_cases { + let result = adjust_brightness_contrast(&img, brightness, contrast); + + // Verify the image is still valid + assert_eq!(result.width(), img.width()); + assert_eq!(result.height(), img.height()); + + // Check that pixel values are clamped correctly + let result_buffer = result.to_rgba8(); + for pixel in result_buffer.pixels() { + assert!(pixel[0] <= 255); + assert!(pixel[1] <= 255); + assert!(pixel[2] <= 255); + assert_eq!(pixel[3], 255); + } + } + } +} \ No newline at end of file