Skip to content

Conversation

MaulanaYusufIkhsanRobbani

…rence/Wan2.2-TI2V-5B-logo-guidance.py` to showcase the new classifier guidance feature.

…rence/Wan2.2-TI2V-5B-logo-guidance.py` to showcase the new classifier guidance feature.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @MaulanaYusufIkhsanRobbani, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new classifier guidance feature to the WanVideoPipeline, enabling the generation of videos that are influenced by a pre-trained classifier. This allows for fine-grained control over the generated content, such as ensuring the presence or absence of specific objects or characteristics. The changes include the addition of new modules for classifier and logo detection, integration of these modules into the video generation pipeline, and a demonstration script showcasing logo guidance in image-to-video generation.

Highlights

  • New Classifier Module: A new OpenLogoClassifier module has been added, designed to load pre-trained object detection models (specifically Faster R-CNN with VGG16) and compute gradients of classification probabilities with respect to the video latents. This gradient is crucial for steering the video generation process.
  • New Logo Detector Module: A LogoDetector module has been introduced, which utilizes a Faster R-CNN model (ResNet50 FPN) to detect logos within an image and generate a soft mask. This mask can be used to apply classifier guidance selectively to specific regions of the generated video.
  • Pipeline Integration of Classifier Guidance: The core WanVideoPipeline has been updated to incorporate classifier guidance. This includes new parameters in the __call__ method for controlling the guidance strength, specifying the target class ID, and providing an optional logo mask. The guidance is applied directly to the noise prediction during the denoising steps.
  • Dynamic Guidance Strength Scheduling: A get_guidance_strength method was added to the pipeline, allowing for dynamic adjustment of the classifier guidance strength over the inference steps. This enables strategies like 'late step emphasis', where guidance is applied more strongly towards the end of the generation process.
  • Scheduler Enhancements for Guidance: The FlowMatchScheduler now includes get_sigma_t and get_alpha_t methods, which provide necessary values (sigma and alpha at a given timestep) required for the mathematical formulation of the classifier guidance applied within the pipeline.
  • New Demonstration Script: A new example script, examples/wanvideo/model_inference/Wan2.2-TI2V-5B-logo-guidance.py, has been added. This script demonstrates how to use the new classifier guidance feature for logo guidance in an image-to-video generation scenario, providing a practical usage example.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature: classifier guidance for logo generation. It adds an OpenLogoClassifier for gradient computation and a LogoDetector for creating logo masks, which are then integrated into the WanVideoPipeline. The changes are well-structured and the addition of a demonstration script is appreciated. My review focuses on correctness, performance, and best practices. I've identified a critical issue regarding model evaluation mode in the classifier, a high-severity performance issue with model loading in the pipeline, and several medium-severity issues related to deprecated APIs, code robustness, and unused imports.

Comment on lines +47 to +69
self.model.train()

image_for_model = image_processed
features = self.model.backbone(image_for_model)
proposals, proposal_losses = self.model.rpn(image_for_model, features, None)
box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]])
box_features = self.model.roi_heads.box_head(box_features)
class_logits, box_regression = self.model.roi_heads.box_predictor(box_features)

log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1)
log_prob_target = log_probs[:, class_id].mean()
log_prob_sum = log_prob_sum + log_prob_target

grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0]

# Apply the soft logo mask if provided
if mask is not None:
grad = grad * mask

# Clamp the gradient to avoid artifacts
grad = torch.clamp(grad, -0.1, 0.1)

self.model.eval()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Calling self.model.train() inside the compute_gradient method is incorrect. This will put the model, including its BatchNorm layers, into training mode. During guidance, the model should behave as it does during inference, which means BatchNorm layers should use their learned running statistics, not the statistics of the current mini-batch. Calling .train() will cause them to use batch statistics and update their running statistics, which is not desired and will lead to incorrect gradients and potentially artifacts in the generated video. The model is already set to eval() mode in the __init__ method, so these calls are not only incorrect but also unnecessary.

                # Get predictions from the classifier

                image_for_model = image_processed
                features = self.model.backbone(image_for_model)
                proposals, proposal_losses = self.model.rpn(image_for_model, features, None)
                box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]])
                box_features = self.model.roi_heads.box_head(box_features)
                class_logits, box_regression = self.model.roi_heads.box_predictor(box_features)

                log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1)
                log_prob_target = log_probs[:, class_id].mean()
                log_prob_sum = log_prob_sum + log_prob_target

            grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0]

            # Apply the soft logo mask if provided
            if mask is not None:
                grad = grad * mask

            # Clamp the gradient to avoid artifacts
            grad = torch.clamp(grad, -0.1, 0.1)

@@ -433,6 +459,14 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

# Logo detection
if logo_detector_path is not None:
logo_detector = LogoDetector(logo_detector_path, device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The LogoDetector is instantiated inside the __call__ method. This means the logo detection model will be loaded from disk every time the pipeline is called, which is very inefficient and will significantly slow down repeated inferences. The detector should be loaded once and stored as a pipeline attribute, similar to how self.classifier is handled. You should add a load_logo_detector method to the pipeline and use a self.logo_detector attribute.

Suggested change
logo_detector = LogoDetector(logo_detector_path, device=self.device)
logo_detector = self.logo_detector

Comment on lines +3 to +5
import torchvision.transforms.functional as F
from torchvision.models.detection import fasterrcnn_vgg16_bn
from torchvision.ops import nms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The imports torchvision.transforms.functional as F and from torchvision.ops import nms are not used in this file. They should be removed to keep the code clean.

Suggested change
import torchvision.transforms.functional as F
from torchvision.models.detection import fasterrcnn_vgg16_bn
from torchvision.ops import nms
from torchvision.models.detection import fasterrcnn_vgg16_bn

self.model = torch.load(model_path, map_location=self.device)
else:
# Load a pretrained Faster R-CNN model from torchvision
self.model = fasterrcnn_resnet50_fpn(pretrained=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pretrained argument for torchvision models is deprecated and will be removed in a future version. You should use the weights argument instead for better future compatibility.

Suggested change
self.model = fasterrcnn_resnet50_fpn(pretrained=True)
self.model = fasterrcnn_resnet50_fpn(weights="FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT")

if model_path:
# Load a custom model if a path is provided
# For now, we assume the model is a Faster R-CNN model
self.model = torch.load(model_path, map_location=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.load(model_path) to load an entire model is not recommended because the serialized data is bound to the specific classes and directory structure used when the model was saved. The recommended approach is to save and load only the model's state_dict. This makes the code more robust to refactoring.

Suggested change
self.model = torch.load(model_path, map_location=self.device)
self.model = fasterrcnn_resnet50_fpn(pretrained=False) # Or your custom model class
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)


# Normalize the mask to be in the range [0, 1]
if mask.max() > 0:
mask = mask / mask.max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line mask = mask / mask.max() could result in a division by zero if mask.max() is 0. While unlikely if the threshold is met, it's safer to guard against this. Adding a small epsilon to the denominator will prevent this.

Suggested change
mask = mask / mask.max()
mask = mask / (mask.max() + 1e-8)

import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import GaussianBlur
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from torchvision.transforms import GaussianBlur is not used. F.gaussian_blur is used instead, which comes from torchvision.transforms.functional. This unused import should be removed.

Comment on lines +8 to +24
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt"
# This is a placeholder path. Please replace it with the actual path to your logo detector model.
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt"

pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()

# Load the classifier
pipe.load_classifier(LOGO_CLASSIFIER_PATH)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This example script uses placeholder paths for the logo classifier and detector models. If a user runs this script without replacing the placeholders with valid paths, it will crash with a FileNotFoundError. For a better user experience, the script should check if these files exist before attempting to load them and print a helpful message if they are not found.

import os
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt"
# This is a placeholder path. Please replace it with the actual path to your logo detector model.
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt"

if not os.path.exists(LOGO_CLASSIFIER_PATH) or not os.path.exists(LOGO_DETECTOR_PATH):
    print("Please replace the placeholder paths for LOGO_CLASSIFIER_PATH and LOGO_DETECTOR_PATH")
    print("in the script with actual paths to your models.")
    exit()

pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
    ],
)
pipe.enable_vram_management()

# Load the classifier
pipe.load_classifier(LOGO_CLASSIFIER_PATH)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant