diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index fd9ccaa5c..0a3bfe4d6 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -29,6 +29,7 @@ th { | `OvisImagePipeline` | Ovis-Image | `OvisAI/Ovis-Image` | |`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` | |`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` | +|`StableDiffusion3Pipeline` | Stable-Diffusion-3 | `stabilityai/stable-diffusion-3.5-medium` | ## List of Supported Models for NPU diff --git a/vllm_omni/diffusion/models/sd3/__init__.py b/vllm_omni/diffusion/models/sd3/__init__.py new file mode 100644 index 000000000..c7efafd94 --- /dev/null +++ b/vllm_omni/diffusion/models/sd3/__init__.py @@ -0,0 +1,15 @@ +"""Stable diffusion3 model components.""" + +from vllm_omni.diffusion.models.sd3.pipeline_sd3 import ( + StableDiffusion3Pipeline, + get_sd3_image_post_process_func, +) +from vllm_omni.diffusion.models.sd3.sd3_transformer import ( + SD3Transformer2DModel, +) + +__all__ = [ + "StableDiffusion3Pipeline", + "SD3Transformer2DModel", + "get_sd3_image_post_process_func", +] diff --git a/vllm_omni/diffusion/models/sd3/pipeline_sd3.py b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py new file mode 100644 index 000000000..2ea6b0c5f --- /dev/null +++ b/vllm_omni/diffusion/models/sd3/pipeline_sd3.py @@ -0,0 +1,666 @@ +import inspect +import json +import logging +import os +from collections.abc import Iterable + +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5Tokenizer +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.sd3.sd3_transformer import ( + SD3Transformer2DModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = logging.getLogger(__name__) + + +def get_sd3_image_post_process_func( + od_config: OmniDiffusionConfig, +): + if od_config.output_type == "latent": + return lambda x: x + model_name = od_config.model + if os.path.exists(model_name): + model_path = model_name + else: + model_path = download_weights_from_hf_specific(model_name, None, ["*"]) + vae_config_path = os.path.join(model_path, "vae/config.json") + with open(vae_config_path) as f: + vae_config = json.load(f) + vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8 + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + def post_process_func( + images: torch.Tensor, + ): + return image_processor.postprocess(images) + + return post_process_func + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +) -> tuple[torch.Tensor, int]: + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Pipeline( + nn.Module, +): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="transformer", + revision=None, + prefix="transformer.", + fall_back_to_pt=True, + ) + ] + + self.device = get_local_device() + model = od_config.model + # Check if model is a local path + local_files_only = os.path.exists(model) + + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model, subfolder="scheduler", local_files_only=local_files_only + ) + self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.tokenizer_2 = CLIPTokenizer.from_pretrained( + model, subfolder="tokenizer_2", local_files_only=local_files_only + ) + self.tokenizer_3 = T5Tokenizer.from_pretrained( + model, subfolder="tokenizer_3", local_files_only=local_files_only + ) + self.text_encoder = CLIPTextModelWithProjection.from_pretrained( + model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + model, subfolder="text_encoder_2", local_files_only=local_files_only + ) + self.text_encoder_3 = T5EncoderModel.from_pretrained( + model, + subfolder="text_encoder_3", + local_files_only=local_files_only, + ) + self.transformer = SD3Transformer2DModel(od_config=od_config) + + self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( + self.device + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + self.patch_size = 2 + self.output_type = self.od_config.output_type + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by " + f"{self.vae_scale_factor * self.patch_size} but are " + f"{height} and {width}. You can use height " + f"{height - height % (self.vae_scale_factor * self.patch_size)} " + f"and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str] = "", + num_images_per_prompt: int = 1, + dtype: torch.dtype | None = None, + clip_model_index: int = 0, + ): + dtype = dtype or self.text_encoder.dtype + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(self.device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = "", + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + dtype: torch.dtype | None = None, + ): + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size, + max_sequence_length, + self.transformer.joint_attention_dim, + ), + device=self.device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(self.device) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(self.device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + num_images_per_prompt: int = 1, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + + pooled_prompt_embeds = None + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + return prompt_embeds, pooled_prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + def prepare_timesteps(self, num_inference_steps, sigmas, image_seq_len): + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None): + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + sigmas=sigmas, + **scheduler_kwargs, + ) + return timesteps, num_inference_steps + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def diffuse( + self, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + latents, + timesteps, + do_cfg, + ): + self.scheduler.set_begin_index(0) + for _, t in enumerate(timesteps): + if self.interrupt: + continue + self._current_timestep = t + + # Broadcast timestep to match batch size + timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + + transformer_kwargs = { + "hidden_states": latents, + "timestep": timestep, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "return_dict": False, + } + + noise_pred = self.transformer(**transformer_kwargs)[0] + + if do_cfg: + neg_transformer_kwargs = { + "hidden_states": latents, + "timestep": timestep, + "encoder_hidden_states": negative_prompt_embeds, + "pooled_projections": negative_pooled_prompt_embeds, + "return_dict": False, + } + + neg_noise_pred = self.transformer(**neg_transformer_kwargs)[0] + noise_pred = neg_noise_pred + self.guidance_scale * (noise_pred - neg_noise_pred) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + return latents + + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] = "", + prompt_2: str | list[str] = "", + prompt_3: str | list[str] = "", + negative_prompt: str | list[str] = "", + negative_prompt_2: str | list[str] = "", + negative_prompt_3: str | list[str] = "", + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + ) -> DiffusionOutput: + # # TODO: only support single prompt now + # if req.prompt is not None: + # prompt = req.prompt[0] if isinstance(req.prompt, list) else req.prompt + prompt = req.prompt if req.prompt is not None else prompt + negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt + height = req.height or self.default_sample_size * self.vae_scale_factor + width = req.width or self.default_sample_size * self.vae_scale_factor + width = req.width or self.default_sample_size * self.vae_scale_factor + sigmas = req.sigmas or sigmas + num_inference_steps = req.num_inference_steps or num_inference_steps + generator = req.generator or generator + req_num_outputs = getattr(req, "num_outputs_per_prompt", None) + if req_num_outputs and req_num_outputs > 0: + num_images_per_prompt = req_num_outputs + # 1. check inputs + # 2. encode prompts + # 3. prepare latents and timesteps + # 4. diffusion process + # 5. decode latents + # 6. post-process outputs + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = req.guidance_scale + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_embeds, pooled_prompt_embeds = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + do_cfg = self.guidance_scale > 1 + if do_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_3=negative_prompt_3, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + num_channels_latents = self.transformer.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + self.device, + generator, + latents, + ) + + timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) + self._num_timesteps = len(timesteps) + + latents = self.diffuse( + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + latents, + timesteps, + do_cfg, + ) + + self._current_timestep = None + if self.output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput(output=image) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/models/sd3/sd3_transformer.py b/vllm_omni/diffusion/models/sd3/sd3_transformer.py new file mode 100644 index 000000000..6a1780af8 --- /dev/null +++ b/vllm_omni/diffusion/models/sd3/sd3_transformer.py @@ -0,0 +1,471 @@ +from collections.abc import Iterable + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward + +# TODO replace this with vLLM implementation +from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.data import OmniDiffusionConfig + +logger = init_logger(__name__) + + +class SD3PatchEmbed(nn.Module): + """ + 2D Image to Patch Embedding with support for SD3. + + Args: + patch_size (`int`, defaults to `16`): The size of the patches. + in_channels (`int`, defaults to `3`): The number of input channels. + embed_dim (`int`, defaults to `768`): The output dimension of the embedding. + """ + + def __init__( + self, + patch_size=16, + in_channels=3, + embed_dim=768, + ): + super().__init__() + + self.patch_size = patch_size + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True + ) + + def forward(self, latent): + x = self.proj(latent) # [B, embed_dim, patch_size, patch_size] + x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim] + return x + + +class SD3CrossAttention(nn.Module): + def __init__( + self, + dim: int, # query_dim + num_heads: int, + head_dim: int, + added_kv_proj_dim: int = 0, + out_bias: bool = True, + qk_norm=True, # rmsnorm + eps=1e-6, + pre_only=False, + context_pre_only: bool = False, + parallel_attention=False, + out_dim: int = 0, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.head_dim, + total_num_heads=num_heads, + disable_tp=True, + ) + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads + self.inner_kv_dim = self.inner_dim + if added_kv_proj_dim is not None: + self.add_kv_proj = QKVParallelLinear( + added_kv_proj_dim, + head_size=self.inner_kv_dim // self.num_heads, + total_num_heads=self.num_heads, + disable_tp=True, + ) + + if not context_pre_only: + self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) + else: + self.to_add_out = None + + if not pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)) + else: + self.to_out = None + + self.norm_added_q = RMSNorm(head_dim, eps=eps) + self.norm_added_k = RMSNorm(head_dim, eps=eps) + + self.attn = Attention( + num_heads=num_heads, + head_size=self.head_dim, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + ): + # Compute QKV for image stream (sample projections) + qkv, _ = self.to_qkv(hidden_states) + img_query, img_key, img_value = qkv.chunk(3, dim=-1) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (self.num_heads, -1)) + img_key = img_key.unflatten(-1, (self.num_heads, -1)) + img_value = img_value.unflatten(-1, (self.num_heads, -1)) + + # Apply QK normalization + img_query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + + if encoder_hidden_states is not None: + # Compute QKV for text stream (context projections) + qkv, _ = self.add_kv_proj(encoder_hidden_states) + txt_query, txt_key, txt_value = qkv.chunk(3, dim=-1) + + txt_query = txt_query.unflatten(-1, (self.num_heads, -1)) + txt_key = txt_key.unflatten(-1, (self.num_heads, -1)) + txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) + + txt_query = self.norm_added_q(txt_query) + txt_key = self.norm_added_k(txt_key) + + # Concatenate for joint attention + # Order: [text, image] + query = torch.cat([txt_query, img_query], dim=1) + key = torch.cat([txt_key, img_key], dim=1) + value = torch.cat([txt_value, img_value], dim=1) + else: + query = img_query + key = img_key + value = img_value + + hidden_states = self.attn( + query, + key, + value, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split attention outputs back + context_seqlen = encoder_hidden_states.shape[1] + hidden_states, encoder_hidden_states = ( + hidden_states[:, context_seqlen:, :], # Image part + hidden_states[:, :context_seqlen, :], # Text part + ) + if self.to_add_out is not None: + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + # Apply output projections + if self.to_out is not None: + hidden_states, _ = self.to_out[0](hidden_states) + + if encoder_hidden_states is None: + return hidden_states + else: + return hidden_states, encoder_hidden_states + + +class SD3TransformerBlock(nn.Module): + r""" + A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. + + Reference: https://huggingface.co/papers/2403.03206 + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the + processing of `context` conditions. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: str | None = None, + use_dual_attention: bool = False, + ): + super().__init__() + + self.use_dual_attention = use_dual_attention + self.context_pre_only = context_pre_only + context_norm_type = "ada_norm_continuous" if context_pre_only else "ada_norm_zero" + + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) + + if context_norm_type == "ada_norm_continuous": + self.norm1_context = AdaLayerNormContinuous( + dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm" + ) + elif context_norm_type == "ada_norm_zero": + self.norm1_context = AdaLayerNormZero(dim) + else: + raise ValueError( + f"Unknown context_norm_type: {context_norm_type}, currently " + f"only support `ada_norm_continuous`, `ada_norm_zero`" + ) + + self.attn = SD3CrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + added_kv_proj_dim=dim, + context_pre_only=context_pre_only, + out_dim=dim, + qk_norm=True if qk_norm == "rms_norm" else False, + eps=1e-6, + ) + + if use_dual_attention: + self.attn2 = SD3CrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + out_dim=dim, + qk_norm=True if qk_norm == "rms_norm" else False, + eps=1e-6, + ) + else: + self.attn2 = None + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + if not context_pre_only: + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + else: + self.norm2_context = None + self.ff_context = None + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + if self.context_pre_only: + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) + else: + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # Attention. + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + ) + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + if self.context_pre_only: + encoder_hidden_states = None + else: + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return encoder_hidden_states, hidden_states + + +class SD3Transformer2DModel(nn.Module): + """ + The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + ): + super().__init__() + model_config = od_config.tf_model_config + self.num_layers = model_config.num_layers + self.parallel_config = od_config.parallel_config + self.sample_size = model_config.sample_size + self.in_channels = model_config.in_channels + self.out_channels = model_config.out_channels + self.num_attention_heads = model_config.num_attention_heads + self.attention_head_dim = model_config.attention_head_dim + self.inner_dim = model_config.num_attention_heads * model_config.attention_head_dim + self.caption_projection_dim = model_config.caption_projection_dim + self.pooled_projection_dim = model_config.pooled_projection_dim + self.joint_attention_dim = model_config.joint_attention_dim + self.patch_size = model_config.patch_size + self.dual_attention_layers = model_config.dual_attention_layers + self.qk_norm = model_config.qk_norm + self.pos_embed_max_size = model_config.pos_embed_max_size + + self.pos_embed = PatchEmbed( + height=self.sample_size, + width=self.sample_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + pos_embed_max_size=self.pos_embed_max_size, + ) + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=self.inner_dim, pooled_projection_dim=self.pooled_projection_dim + ) + self.context_embedder = nn.Linear(self.joint_attention_dim, self.caption_projection_dim) + + self.transformer_blocks = nn.ModuleList( + [ + SD3TransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.num_attention_heads, + attention_head_dim=self.attention_head_dim, + context_pre_only=i == self.num_layers - 1, + qk_norm=self.qk_norm, + use_dual_attention=True if i in self.dual_attention_layers else False, + ) + for i in range(self.num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, self.patch_size * self.patch_size * self.out_channels, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + pooled_projections: torch.Tensor, + timestep: torch.LongTensor, + return_dict: bool = True, + ) -> torch.Tensor | Transformer2DModelOutput: + """ + The [`SD3Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + height, width = hidden_states.shape[-2:] + + hidden_states = self.pos_embed(hidden_states) + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # unpatchify + patch_size = self.patch_size + height = height // patch_size + width = width // patch_size + + hidden_states = hidden_states.reshape( + shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) + ) + + return Transformer2DModelOutput(sample=output) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # self-attn + (".to_qkv", ".to_q", "q"), + (".to_qkv", ".to_k", "k"), + (".to_qkv", ".to_v", "v"), + # cross-attn + (".add_kv_proj", ".add_q_proj", "q"), + (".add_kv_proj", ".add_k_proj", "k"), + (".add_kv_proj", ".add_v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + + for name, buffer in self.named_buffers(): + if name.endswith(".pos_embed"): + params_dict[name] = buffer + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 4d1cda1fd..87f674e9f 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -59,6 +59,11 @@ "pipeline_longcat_image_edit", "LongcatImageEditPipeline", ), + "StableDiffusion3Pipeline": ( + "sd3", + "pipeline_sd3", + "StableDiffusion3Pipeline", + ), } @@ -102,6 +107,7 @@ def initialize_model( "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", "LongCatImagePipeline": "get_longcat_image_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", + "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = {