diff --git a/src/examples/color_palette_extraction.py b/src/examples/color_palette_extraction.py new file mode 100644 index 0000000..f26864e --- /dev/null +++ b/src/examples/color_palette_extraction.py @@ -0,0 +1,626 @@ +""" +🌈 Color Palette Extraction & Application - New Feature for Qwen-Image +Extract color schemes from reference images and apply to generated content + +This feature provides advanced color palette analysis and application capabilities that allow users to +extract color schemes from reference images and apply them to new generations or existing images, +maintaining color harmony and aesthetic consistency. +""" + +import torch +import numpy as np +from PIL import Image, ImageDraw +from typing import Dict, List, Optional, Union, Tuple, Any +from dataclasses import dataclass +from collections import Counter +import colorsys +import json +import logging +from sklearn.cluster import KMeans + + +@dataclass +class ColorInfo: + """Information about a color in the palette.""" + rgb: Tuple[int, int, int] + hex: str + hsl: Tuple[float, float, float] + name: str + frequency: float + cluster_id: int + + +@dataclass +class ColorPalette: + """A complete color palette extracted from an image.""" + dominant_colors: List[ColorInfo] + color_harmony: Dict[str, Any] + mood: str + temperature: str + saturation_level: str + contrast_ratio: float + palette_type: str # "warm", "cool", "neutral", "monochromatic", etc. + + +@dataclass +class PaletteApplicationConfig: + """Configuration for palette application.""" + strength: float = 0.8 + preserve_luminance: bool = True + maintain_contrast: bool = True + color_temperature_adjustment: bool = True + saturation_matching: bool = True + + +class ColorPaletteExtractor: + """ + Advanced color palette extraction and analysis system. + Provides comprehensive color analysis and palette generation. + """ + + def __init__(self): + # Color name database (simplified) + self.color_names = { + (255, 0, 0): "Red", + (0, 255, 0): "Green", + (0, 0, 255): "Blue", + (255, 255, 0): "Yellow", + (255, 0, 255): "Magenta", + (0, 255, 255): "Cyan", + (255, 255, 255): "White", + (0, 0, 0): "Black", + (128, 128, 128): "Gray", + # Add more colors as needed + } + + # Setup logging + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger("ColorPaletteExtractor") + + def extract_palette(self, image: Image.Image, num_colors: int = 8) -> ColorPalette: + """ + Extract color palette from an image using advanced clustering. + + Args: + image: Input PIL image + num_colors: Number of dominant colors to extract + + Returns: + ColorPalette object with analysis + """ + # Convert to RGB if necessary + if image.mode != 'RGB': + image = image.convert('RGB') + + # Get image data + img_array = np.array(image) + height, width = img_array.shape[:2] + + # Reshape for clustering (take sample for performance) + pixels = img_array.reshape(-1, 3) + + # Sample pixels for better performance (take every 10th pixel) + sample_indices = np.arange(0, len(pixels), 10) + pixels_sample = pixels[sample_indices] + + # Perform K-means clustering + kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10) + labels = kmeans.fit_predict(pixels_sample) + centers = kmeans.cluster_centers_ + + # Get color frequencies + label_counts = Counter(labels) + total_samples = len(labels) + + # Create ColorInfo objects + dominant_colors = [] + for i, center in enumerate(centers): + rgb = tuple(int(c) for c in center) + frequency = label_counts[i] / total_samples + + color_info = ColorInfo( + rgb=rgb, + hex=self._rgb_to_hex(rgb), + hsl=self._rgb_to_hsl(rgb), + name=self._get_color_name(rgb), + frequency=frequency, + cluster_id=i + ) + dominant_colors.append(color_info) + + # Sort by frequency + dominant_colors.sort(key=lambda x: x.frequency, reverse=True) + + # Analyze color harmony and properties + color_harmony = self._analyze_color_harmony(dominant_colors) + mood = self._determine_mood(dominant_colors) + temperature = self._determine_temperature(dominant_colors) + saturation_level = self._determine_saturation_level(dominant_colors) + contrast_ratio = self._calculate_contrast_ratio(dominant_colors) + palette_type = self._classify_palette_type(dominant_colors) + + return ColorPalette( + dominant_colors=dominant_colors, + color_harmony=color_harmony, + mood=mood, + temperature=temperature, + saturation_level=saturation_level, + contrast_ratio=contrast_ratio, + palette_type=palette_type + ) + + def _rgb_to_hex(self, rgb: Tuple[int, int, int]) -> str: + """Convert RGB tuple to hex string.""" + return f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}" + + def _rgb_to_hsl(self, rgb: Tuple[int, int, int]) -> Tuple[float, float, float]: + """Convert RGB to HSL.""" + r, g, b = [x / 255.0 for x in rgb] + h, l, s = colorsys.rgb_to_hls(r, g, b) + return (h * 360, s * 100, l * 100) # Convert to degrees and percentages + + def _get_color_name(self, rgb: Tuple[int, int, int]) -> str: + """Get approximate color name (simplified).""" + # Find closest named color + min_distance = float('inf') + closest_name = "Unknown" + + for named_rgb, name in self.color_names.items(): + distance = sum((a - b) ** 2 for a, b in zip(rgb, named_rgb)) + if distance < min_distance: + min_distance = distance + closest_name = name + + # If very close to a named color, use it + if min_distance < 1000: # Threshold for color similarity + return closest_name + + # Otherwise, generate descriptive name based on HSL + h, s, l = self._rgb_to_hsl(rgb) + + # Hue-based naming + if 0 <= h < 30 or 330 <= h <= 360: + hue_name = "Red" + elif 30 <= h < 90: + hue_name = "Yellow" + elif 90 <= h < 150: + hue_name = "Green" + elif 150 <= h < 210: + hue_name = "Cyan" + elif 210 <= h < 270: + hue_name = "Blue" + elif 270 <= h < 330: + hue_name = "Magenta" + + saturation_desc = "Vivid" if s > 70 else "Muted" if s < 30 else "Moderate" + lightness_desc = "Light" if l > 70 else "Dark" if l < 30 else "" + + parts = [lightness_desc, saturation_desc, hue_name] + return " ".join(filter(None, parts)) + + def _analyze_color_harmony(self, colors: List[ColorInfo]) -> Dict[str, Any]: + """Analyze color harmony and relationships.""" + if len(colors) < 2: + return {"harmony_type": "single_color", "complementary": False} + + # Extract hues + hues = [color.hsl[0] for color in colors] + + # Check for complementary colors (180 degrees apart) + complementary_pairs = [] + for i, h1 in enumerate(hues): + for j, h2 in enumerate(hues[i+1:], i+1): + diff = min(abs(h1 - h2), 360 - abs(h1 - h2)) + if 170 <= diff <= 190: # Allow some tolerance + complementary_pairs.append((i, j)) + + # Check for triadic harmony (120 degrees apart) + triadic_groups = [] + for i in range(len(hues)): + group = [i] + for j in range(len(hues)): + if i != j: + diff = min(abs(hues[i] - hues[j]), 360 - abs(hues[i] - hues[j])) + if 110 <= diff <= 130: + group.append(j) + if len(group) >= 3: + triadic_groups.append(group) + + # Determine harmony type + if complementary_pairs: + harmony_type = "complementary" + elif triadic_groups: + harmony_type = "triadic" + elif len(set(round(h/30) for h in hues)) == 1: # Similar hues + harmony_type = "monochromatic" + elif len(set(round(h/90) for h in hues)) <= 2: # Adjacent hues + harmony_type = "analogous" + else: + harmony_type = "mixed" + + return { + "harmony_type": harmony_type, + "complementary_pairs": complementary_pairs, + "triadic_groups": triadic_groups, + "hue_variation": np.std(hues) + } + + def _determine_mood(self, colors: List[ColorInfo]) -> str: + """Determine the overall mood of the color palette.""" + avg_saturation = np.mean([c.hsl[1] for c in colors]) + avg_lightness = np.mean([c.hsl[2] for c in colors]) + + if avg_lightness > 70 and avg_saturation > 60: + return "vibrant_cheerful" + elif avg_lightness < 30 and avg_saturation > 60: + return "dramatic_intense" + elif avg_saturation < 30: + return "calm_subdued" + elif avg_lightness > 70: + return "bright_fresh" + elif avg_lightness < 30: + return "dark_mysterious" + else: + return "balanced_neutral" + + def _determine_temperature(self, colors: List[ColorInfo]) -> str: + """Determine color temperature (warm/cool).""" + warm_hues = 0 + cool_hues = 0 + + for color in colors: + h = color.hsl[0] + # Warm: red-orange-yellow range + if (0 <= h <= 90) or (330 <= h <= 360): + warm_hues += color.frequency + # Cool: green-blue-purple range + elif 90 < h < 330: + cool_hues += color.frequency + + if warm_hues > cool_hues * 1.2: + return "warm" + elif cool_hues > warm_hues * 1.2: + return "cool" + else: + return "neutral" + + def _determine_saturation_level(self, colors: List[ColorInfo]) -> str: + """Determine overall saturation level.""" + avg_saturation = np.mean([c.hsl[1] for c in colors]) + + if avg_saturation > 70: + return "high" + elif avg_saturation < 30: + return "low" + else: + return "medium" + + def _calculate_contrast_ratio(self, colors: List[ColorInfo]) -> float: + """Calculate contrast ratio between lightest and darkest colors.""" + lightness_values = [c.hsl[2] for c in colors] + if lightness_values: + return max(lightness_values) / max(min(lightness_values), 1) + return 1.0 + + def _classify_palette_type(self, colors: List[ColorInfo]) -> str: + """Classify the palette type.""" + temperature = self._determine_temperature(colors) + saturation = self._determine_saturation_level(colors) + + if temperature == "neutral" and saturation == "low": + return "monochromatic" + elif temperature == "warm": + return "warm" + elif temperature == "cool": + return "cool" + else: + return "balanced" + + +class ColorPaletteApplicator: + """ + System for applying extracted color palettes to images and generations. + """ + + def __init__(self, model_path: str = "Qwen/Qwen-Image"): + self.model_path = model_path + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.pipeline = None + self.config = PaletteApplicationConfig() + + def load_pipeline(self): + """Load the Qwen-Image pipeline.""" + try: + from diffusers import DiffusionPipeline + + self.pipeline = DiffusionPipeline.from_pretrained( + self.model_path, + torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, + device_map="auto" if self.device == "cuda" else None, + low_cpu_mem_usage=True + ) + + if self.device == "cuda": + self.pipeline.to(self.device) + + return True + except Exception as e: + logging.error(f"Failed to load pipeline: {e}") + return False + + def apply_palette_to_generation(self, prompt: str, palette: ColorPalette, + **generation_kwargs) -> Tuple[Image.Image, Dict]: + """ + Generate an image using a specific color palette. + + Args: + prompt: Text prompt for generation + palette: ColorPalette to apply + **generation_kwargs: Additional generation parameters + + Returns: + Tuple of (generated_image, metadata) + """ + if not self.pipeline: + raise RuntimeError("Pipeline not loaded. Call load_pipeline() first.") + + # Enhance prompt with color palette information + enhanced_prompt = self._enhance_prompt_with_palette(prompt, palette) + + # Set generation parameters based on palette + gen_params = { + "prompt": enhanced_prompt, + "guidance_scale": generation_kwargs.get("guidance_scale", 7.5), + "num_inference_steps": generation_kwargs.get("num_inference_steps", 50), + **generation_kwargs + } + + # Generate image + with torch.inference_mode(): + result = self.pipeline(**gen_params) + + generated_image = result.images[0] if hasattr(result, 'images') else result + + metadata = { + "operation": "palette_guided_generation", + "original_prompt": prompt, + "enhanced_prompt": enhanced_prompt, + "palette_info": { + "type": palette.palette_type, + "temperature": palette.temperature, + "mood": palette.mood, + "dominant_colors": [c.hex for c in palette.dominant_colors[:3]] + }, + "generation_params": gen_params + } + + return generated_image, metadata + + def apply_palette_to_existing_image(self, image: Image.Image, palette: ColorPalette, + strength: float = 0.8) -> Tuple[Image.Image, Dict]: + """ + Apply color palette transformation to an existing image. + + Args: + image: Input PIL image + palette: Target color palette + strength: Transformation strength (0-1) + + Returns: + Tuple of (transformed_image, metadata) + """ + # This is a simplified implementation + # In production, you'd use more sophisticated color transfer algorithms + + # Convert to LAB color space for better color manipulation + img_array = np.array(image.convert('RGB')) + + # Simple color adjustment based on palette + # This is a placeholder - real implementation would use color transfer algorithms + adjusted_image = self._simple_color_adjustment(img_array, palette, strength) + + result_image = Image.fromarray(adjusted_image.astype(np.uint8)) + + metadata = { + "operation": "palette_application", + "strength": strength, + "palette_type": palette.palette_type, + "dominant_colors_applied": len(palette.dominant_colors) + } + + return result_image, metadata + + def _enhance_prompt_with_palette(self, prompt: str, palette: ColorPalette) -> str: + """Enhance generation prompt with palette information.""" + palette_descriptions = [] + + # Add temperature and mood + if palette.temperature != "neutral": + palette_descriptions.append(f"{palette.temperature} color scheme") + + if palette.mood: + mood_desc = palette.mood.replace("_", " ") + palette_descriptions.append(f"{mood_desc} mood") + + # Add dominant colors + color_names = [c.name for c in palette.dominant_colors[:3]] + if color_names: + palette_descriptions.append(f"featuring {', '.join(color_names)} colors") + + # Add palette type + if palette.palette_type not in ["balanced", "neutral"]: + palette_descriptions.append(f"{palette.palette_type} palette") + + palette_text = ", ".join(palette_descriptions) + + return f"{prompt}, {palette_text}" + + def _simple_color_adjustment(self, img_array: np.ndarray, palette: ColorPalette, + strength: float) -> np.ndarray: + """Simple color adjustment based on palette (placeholder implementation).""" + # This is a very basic implementation + # Real color transfer would use algorithms like Reinhard's method + + # Get average color of palette + avg_color = np.mean([c.rgb for c in palette.dominant_colors], axis=0) + + # Simple color shift towards palette average + adjusted = img_array.astype(np.float32) + color_shift = (avg_color - np.mean(img_array.reshape(-1, 3), axis=0)) * strength + + for i in range(3): # RGB channels + adjusted[:, :, i] += color_shift[i] + + # Clip to valid range + adjusted = np.clip(adjusted, 0, 255) + + return adjusted + + def create_palette_from_description(self, description: str) -> ColorPalette: + """ + Create a color palette from a text description. + + Args: + description: Text description of desired colors + + Returns: + Generated ColorPalette + """ + # This is a simplified implementation + # In production, you'd use NLP and color theory + + description_lower = description.lower() + + # Simple keyword matching for color generation + if "ocean" in description_lower or "sea" in description_lower: + base_colors = [(0, 100, 200), (0, 150, 255), (100, 200, 255)] + elif "sunset" in description_lower: + base_colors = [(255, 100, 0), (255, 150, 50), (200, 50, 100)] + elif "forest" in description_lower: + base_colors = [(50, 150, 50), (100, 200, 100), (30, 100, 30)] + elif "warm" in description_lower: + base_colors = [(255, 150, 100), (255, 200, 150), (200, 100, 50)] + elif "cool" in description_lower: + base_colors = [(100, 150, 255), (150, 200, 255), (50, 100, 200)] + else: + # Default palette + base_colors = [(200, 200, 200), (150, 150, 150), (100, 100, 100)] + + # Create ColorInfo objects + dominant_colors = [] + for i, rgb in enumerate(base_colors): + color_info = ColorInfo( + rgb=rgb, + hex=self._rgb_to_hex_static(rgb), + hsl=self._rgb_to_hsl_static(rgb), + name=f"Generated Color {i+1}", + frequency=1.0 / len(base_colors), + cluster_id=i + ) + dominant_colors.append(color_info) + + return ColorPalette( + dominant_colors=dominant_colors, + color_harmony={"harmony_type": "generated"}, + mood="custom", + temperature="neutral", + saturation_level="medium", + contrast_ratio=1.5, + palette_type="custom" + ) + + @staticmethod + def _rgb_to_hex_static(rgb: Tuple[int, int, int]) -> str: + """Static version of RGB to hex conversion.""" + return f"#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}" + + @staticmethod + def _rgb_to_hsl_static(rgb: Tuple[int, int, int]) -> Tuple[float, float, float]: + """Static version of RGB to HSL conversion.""" + r, g, b = [x / 255.0 for x in rgb] + h, l, s = colorsys.rgb_to_hls(r, g, b) + return (h * 360, s * 100, l * 100) + + +def demonstrate_color_palette_extraction(): + """ + Demonstrate the Color Palette Extraction & Application functionality. + """ + print("=" * 60) + print("QWEN-IMAGE COLOR PALETTE EXTRACTION & APPLICATION") + print("=" * 60) + print() + + print("NEW FEATURE: Extract and apply color palettes") + print() + + print("CAPABILITIES:") + print("* Advanced color palette extraction") + print("* Color harmony analysis") + print("* Palette-guided image generation") + print("* Color transfer to existing images") + print("* Custom palette creation from descriptions") + print() + + print("USAGE EXAMPLES:") + print(""" +# Initialize color palette system +from color_palette_extraction import ColorPaletteExtractor, ColorPaletteApplicator + +extractor = ColorPaletteExtractor() +applicator = ColorPaletteApplicator() +applicator.load_pipeline() + +# Extract palette from reference image +reference_image = Image.open("sunset_photo.jpg") +palette = extractor.extract_palette(reference_image, num_colors=6) + +print(f"Palette type: {palette.palette_type}") +print(f"Mood: {palette.mood}") +print(f"Dominant colors: {[c.hex for c in palette.dominant_colors[:3]]}") + +# Generate new image using extracted palette +generated_image, metadata = applicator.apply_palette_to_generation( + "A beautiful landscape", + palette +) + +# Apply palette to existing image +existing_image = Image.open("old_photo.jpg") +transformed_image, _ = applicator.apply_palette_to_existing_image( + existing_image, palette, strength=0.7 +) + +# Create custom palette from description +custom_palette = applicator.create_palette_from_description( + "warm sunset colors with golden and orange tones" +) + +# Use custom palette for generation +styled_image, _ = applicator.apply_palette_to_generation( + "A cozy living room", + custom_palette +) +""") + + print("KEY FEATURES:") + print("* K-means clustering for accurate color extraction") + print("* Color harmony and mood analysis") + print("* Temperature and saturation detection") + print("* Palette-guided content generation") + print("* Color transfer algorithms") + print("* Custom palette creation") + print("* Comprehensive color metadata") + print() + + print("PERFECT FOR:") + print("* Brand color consistency") + print("* Artistic style matching") + print("* Color scheme design") + print("* Photo enhancement") + print("* Creative color exploration") + print() + + +if __name__ == "__main__": + demonstrate_color_palette_extraction() \ No newline at end of file