-
Notifications
You must be signed in to change notification settings - Fork 929
I have created the demonstration script `examples/wanvideo/model_infe… #867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
I have created the demonstration script `examples/wanvideo/model_infe… #867
Conversation
…rence/Wan2.2-TI2V-5B-logo-guidance.py` to showcase the new classifier guidance feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @MaulanaYusufIkhsanRobbani, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a new classifier guidance feature to the WanVideoPipeline
, enabling the generation of videos that are influenced by a pre-trained classifier. This allows for fine-grained control over the generated content, such as ensuring the presence or absence of specific objects or characteristics. The changes include the addition of new modules for classifier and logo detection, integration of these modules into the video generation pipeline, and a demonstration script showcasing logo guidance in image-to-video generation.
Highlights
- New Classifier Module: A new
OpenLogoClassifier
module has been added, designed to load pre-trained object detection models (specifically Faster R-CNN with VGG16) and compute gradients of classification probabilities with respect to the video latents. This gradient is crucial for steering the video generation process. - New Logo Detector Module: A
LogoDetector
module has been introduced, which utilizes a Faster R-CNN model (ResNet50 FPN) to detect logos within an image and generate a soft mask. This mask can be used to apply classifier guidance selectively to specific regions of the generated video. - Pipeline Integration of Classifier Guidance: The core
WanVideoPipeline
has been updated to incorporate classifier guidance. This includes new parameters in the__call__
method for controlling the guidance strength, specifying the target class ID, and providing an optional logo mask. The guidance is applied directly to the noise prediction during the denoising steps. - Dynamic Guidance Strength Scheduling: A
get_guidance_strength
method was added to the pipeline, allowing for dynamic adjustment of the classifier guidance strength over the inference steps. This enables strategies like 'late step emphasis', where guidance is applied more strongly towards the end of the generation process. - Scheduler Enhancements for Guidance: The
FlowMatchScheduler
now includesget_sigma_t
andget_alpha_t
methods, which provide necessary values (sigma and alpha at a given timestep) required for the mathematical formulation of the classifier guidance applied within the pipeline. - New Demonstration Script: A new example script,
examples/wanvideo/model_inference/Wan2.2-TI2V-5B-logo-guidance.py
, has been added. This script demonstrates how to use the new classifier guidance feature for logo guidance in an image-to-video generation scenario, providing a practical usage example.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant new feature: classifier guidance for logo generation. It adds an OpenLogoClassifier
for gradient computation and a LogoDetector
for creating logo masks, which are then integrated into the WanVideoPipeline
. The changes are well-structured and the addition of a demonstration script is appreciated. My review focuses on correctness, performance, and best practices. I've identified a critical issue regarding model evaluation mode in the classifier, a high-severity performance issue with model loading in the pipeline, and several medium-severity issues related to deprecated APIs, code robustness, and unused imports.
self.model.train() | ||
|
||
image_for_model = image_processed | ||
features = self.model.backbone(image_for_model) | ||
proposals, proposal_losses = self.model.rpn(image_for_model, features, None) | ||
box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]]) | ||
box_features = self.model.roi_heads.box_head(box_features) | ||
class_logits, box_regression = self.model.roi_heads.box_predictor(box_features) | ||
|
||
log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1) | ||
log_prob_target = log_probs[:, class_id].mean() | ||
log_prob_sum = log_prob_sum + log_prob_target | ||
|
||
grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0] | ||
|
||
# Apply the soft logo mask if provided | ||
if mask is not None: | ||
grad = grad * mask | ||
|
||
# Clamp the gradient to avoid artifacts | ||
grad = torch.clamp(grad, -0.1, 0.1) | ||
|
||
self.model.eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling self.model.train()
inside the compute_gradient
method is incorrect. This will put the model, including its BatchNorm
layers, into training mode. During guidance, the model should behave as it does during inference, which means BatchNorm
layers should use their learned running statistics, not the statistics of the current mini-batch. Calling .train()
will cause them to use batch statistics and update their running statistics, which is not desired and will lead to incorrect gradients and potentially artifacts in the generated video. The model is already set to eval()
mode in the __init__
method, so these calls are not only incorrect but also unnecessary.
# Get predictions from the classifier
image_for_model = image_processed
features = self.model.backbone(image_for_model)
proposals, proposal_losses = self.model.rpn(image_for_model, features, None)
box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]])
box_features = self.model.roi_heads.box_head(box_features)
class_logits, box_regression = self.model.roi_heads.box_predictor(box_features)
log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1)
log_prob_target = log_probs[:, class_id].mean()
log_prob_sum = log_prob_sum + log_prob_target
grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0]
# Apply the soft logo mask if provided
if mask is not None:
grad = grad * mask
# Clamp the gradient to avoid artifacts
grad = torch.clamp(grad, -0.1, 0.1)
@@ -433,6 +459,14 @@ def __call__( | |||
for unit in self.units: | |||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) | |||
|
|||
# Logo detection | |||
if logo_detector_path is not None: | |||
logo_detector = LogoDetector(logo_detector_path, device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LogoDetector
is instantiated inside the __call__
method. This means the logo detection model will be loaded from disk every time the pipeline is called, which is very inefficient and will significantly slow down repeated inferences. The detector should be loaded once and stored as a pipeline attribute, similar to how self.classifier
is handled. You should add a load_logo_detector
method to the pipeline and use a self.logo_detector
attribute.
logo_detector = LogoDetector(logo_detector_path, device=self.device) | |
logo_detector = self.logo_detector |
import torchvision.transforms.functional as F | ||
from torchvision.models.detection import fasterrcnn_vgg16_bn | ||
from torchvision.ops import nms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imports torchvision.transforms.functional as F
and from torchvision.ops import nms
are not used in this file. They should be removed to keep the code clean.
import torchvision.transforms.functional as F | |
from torchvision.models.detection import fasterrcnn_vgg16_bn | |
from torchvision.ops import nms | |
from torchvision.models.detection import fasterrcnn_vgg16_bn | |
self.model = torch.load(model_path, map_location=self.device) | ||
else: | ||
# Load a pretrained Faster R-CNN model from torchvision | ||
self.model = fasterrcnn_resnet50_fpn(pretrained=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pretrained
argument for torchvision models is deprecated and will be removed in a future version. You should use the weights
argument instead for better future compatibility.
self.model = fasterrcnn_resnet50_fpn(pretrained=True) | |
self.model = fasterrcnn_resnet50_fpn(weights="FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT") |
if model_path: | ||
# Load a custom model if a path is provided | ||
# For now, we assume the model is a Faster R-CNN model | ||
self.model = torch.load(model_path, map_location=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using torch.load(model_path)
to load an entire model is not recommended because the serialized data is bound to the specific classes and directory structure used when the model was saved. The recommended approach is to save and load only the model's state_dict
. This makes the code more robust to refactoring.
self.model = torch.load(model_path, map_location=self.device) | |
self.model = fasterrcnn_resnet50_fpn(pretrained=False) # Or your custom model class | |
state_dict = torch.load(model_path, map_location=self.device) | |
self.model.load_state_dict(state_dict) |
|
||
# Normalize the mask to be in the range [0, 1] | ||
if mask.max() > 0: | ||
mask = mask / mask.max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch.nn as nn | ||
import torchvision.transforms.functional as F | ||
from torchvision.models.detection import fasterrcnn_resnet50_fpn | ||
from torchvision.transforms import GaussianBlur |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt" | ||
# This is a placeholder path. Please replace it with the actual path to your logo detector model. | ||
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt" | ||
|
||
pipe = WanVideoPipeline.from_pretrained( | ||
torch_dtype=torch.bfloat16, | ||
device="cuda", | ||
model_configs=[ | ||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), | ||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), | ||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), | ||
], | ||
) | ||
pipe.enable_vram_management() | ||
|
||
# Load the classifier | ||
pipe.load_classifier(LOGO_CLASSIFIER_PATH) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example script uses placeholder paths for the logo classifier and detector models. If a user runs this script without replacing the placeholders with valid paths, it will crash with a FileNotFoundError
. For a better user experience, the script should check if these files exist before attempting to load them and print a helpful message if they are not found.
import os
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt"
# This is a placeholder path. Please replace it with the actual path to your logo detector model.
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt"
if not os.path.exists(LOGO_CLASSIFIER_PATH) or not os.path.exists(LOGO_DETECTOR_PATH):
print("Please replace the placeholder paths for LOGO_CLASSIFIER_PATH and LOGO_DETECTOR_PATH")
print("in the script with actual paths to your models.")
exit()
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
# Load the classifier
pipe.load_classifier(LOGO_CLASSIFIER_PATH)
…rence/Wan2.2-TI2V-5B-logo-guidance.py` to showcase the new classifier guidance feature.