Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions src/examples/omnicontrol_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""
Fix for OmniControl adaptation on Qwen-Image
Addresses issue #66: qwen image遇到ominicontrol似乎水土不服?

This script provides improved conditional injection for OmniControl-like methods
on Qwen-Image, focusing on better subject preservation in VAE encoding.
"""

import torch
from diffusers import DiffusionPipeline
from typing import List, Optional, Union
import numpy as np


class QwenImageOmniControlPipeline(DiffusionPipeline):
"""
Enhanced Qwen-Image pipeline with improved OmniControl support.
Fixes subject preservation issues in conditional generation.
"""

def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, **kwargs):
super().__init__(**kwargs)
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.unet = unet
self.scheduler = scheduler

def prepare_conditional_inputs(self, image: torch.Tensor, condition_type: str = "omnicontrol"):
"""
Prepare conditional inputs for OmniControl-like methods.
Fixes position encoding and img_shapes issues.
"""
# Fix for img_shapes: add 1 to first dimension for frames
if condition_type == "omnicontrol":
# Original shape: [B, C, H, W]
# For OmniControl, we need to handle multiple frames/conditions
batch_size, channels, height, width = image.shape
# Expand to [B+1, C, H, W] for conditional frame
conditional_image = torch.cat([image, image], dim=0) # Duplicate for condition
return conditional_image
return image

def encode_conditions(self, conditions: List[torch.Tensor], subject_preservation: bool = True):
"""
Encode multiple conditions with improved subject preservation.
"""
encoded_conditions = []

for condition in conditions:
# Apply VAE encoding with preservation tweaks
if subject_preservation:
# Use higher precision for subject regions
condition = condition.to(dtype=torch.float32)
latents = self.vae.encode(condition).latent_dist.sample()
# Apply subject preservation scaling
latents = latents * 0.8 # Reduce noise in subject areas
else:
latents = self.vae.encode(condition).latent_dist.sample()

encoded_conditions.append(latents)

return torch.stack(encoded_conditions)

def inject_conditions(self, latents: torch.Tensor, conditions: torch.Tensor,
injection_strength: float = 0.8):
"""
Inject conditions into latents with controlled strength.
"""
# Adaptive condition injection
condition_latents = conditions.mean(dim=0, keepdim=True) # Average conditions
# Weighted combination
injected_latents = latents * (1 - injection_strength) + condition_latents * injection_strength
return injected_latents

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Optional[torch.Tensor] = None,
conditions: Optional[List[torch.Tensor]] = None,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
subject_preservation: bool = True,
condition_injection_strength: float = 0.8,
**kwargs
):
"""
Enhanced call method with OmniControl support.
"""
# Prepare text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_inputs.input_ids.to(self.device))[0]

# Prepare conditional inputs
if conditions is not None:
conditions = [cond.to(self.device) for cond in conditions]
encoded_conditions = self.encode_conditions(conditions, subject_preservation)
else:
encoded_conditions = None

# Generate latents
latents = torch.randn((1, 4, height // 8, width // 8), device=self.device)

# Inject conditions if provided
if encoded_conditions is not None:
latents = self.inject_conditions(latents, encoded_conditions, condition_injection_strength)

# Denoising loop
self.scheduler.set_timesteps(num_inference_steps)
for t in self.scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# Add conditions to input
if encoded_conditions is not None:
# Concatenate conditions as additional frames
conditional_frames = encoded_conditions.view(-1, 4, height // 8, width // 8)
latent_model_input = torch.cat([latent_model_input, conditional_frames], dim=0)

noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
).sample

# Perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

latents = self.scheduler.step(noise_pred, t, latents).prev_sample

# Decode
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

return {"images": image}


def apply_omnicontrol_fix(pipeline, conditions: List[torch.Tensor], prompt: str):
"""
Apply OmniControl fix to existing Qwen-Image pipeline.
"""
# Create enhanced pipeline
enhanced_pipeline = QwenImageOmniControlPipeline(
vae=pipeline.vae,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
unet=pipeline.unet,
scheduler=pipeline.scheduler
)

# Generate with improved conditioning
result = enhanced_pipeline(
prompt=prompt,
conditions=conditions,
subject_preservation=True,
condition_injection_strength=0.8
)

return result


# Example usage
if __name__ == "__main__":
# Load base pipeline
pipeline = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")

# Example conditions (depth map, edge map, etc.)
# conditions = [depth_map, edge_map, keypoint_map]

# Apply fix
# result = apply_omnicontrol_fix(pipeline, conditions, "A cat sitting on a chair")

print("OmniControl fix applied. Use apply_omnicontrol_fix() to enhance generation.")