diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0a3bfe4d6..ecf3afbb2 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -19,6 +19,7 @@ th { |--------------|--------|-------------------| | `Qwen3OmniMoeForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | | `Qwen2_5OmniForConditionalGeneration` | Qwen2.5-Omni | `Qwen/Qwen2.5-Omni-7B`, `Qwen/Qwen2.5-Omni-3B` | +| `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` | | `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | | `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` | | `QwenImageEditPlusPipeline` | Qwen-Image-Edit-2509 | `Qwen/Qwen-Image-Edit-2509` | diff --git a/pyproject.toml b/pyproject.toml index 79541c557..79d15672d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,5 +166,6 @@ extend-ignore-identifiers-re = [ ".*MoBA", ".*temperal_downsample", ".*nothink.*", - ".*NOTHINK.*" + ".*NOTHINK.*", + ".*nin.*", ] diff --git a/vllm_omni/diffusion/models/bagel/__init__.py b/vllm_omni/diffusion/models/bagel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_omni/diffusion/models/bagel/autoencoder.py b/vllm_omni/diffusion/models/bagel/autoencoder.py new file mode 100644 index 000000000..0980f25cd --- /dev/null +++ b/vllm_omni/diffusion/models/bagel/autoencoder.py @@ -0,0 +1,324 @@ +# Copyright (c) 2024 Black Forest Labs. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. +# +# Original file was released under Apache-2.0, with the full license text +# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE. +# +# This modified file is released under the same license. + +from dataclasses import dataclass + +import torch +from einops import rearrange +from torch import Tensor, nn + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + downsample: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py new file mode 100644 index 000000000..7670248c0 --- /dev/null +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -0,0 +1,893 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates. +# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team. +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. +# +# Original file was released under Apache-2.0, with the full license text +# available at https://github.com/huggingface/transformers/blob/main/LICENSE. + +import math +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +from torch import nn +from torch.nn.attention.flex_attention import flex_attention +from transformers.configuration_utils import PretrainedConfig +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2MLP, + Qwen2PreTrainedModel, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, +) +from transformers.utils import ModelOutput +from vllm.vllm_flash_attn import flash_attn_varlen_func + +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +torch._dynamo.config.cache_size_limit = 512 +torch._dynamo.config.accumulated_cache_size_limit = 4096 +flex_attention = torch.compile(flex_attention) + + +class Qwen2MoTConfig(_Qwen2Config): + """Configuration for Qwen2MoT (Mixture of Tokens) model. + + This is fundamentally different from Qwen2, hence the distinct name. + """ + + model_type = "qwen2_mot" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + is_causal=True, + _attn_implementation="eager", + qk_norm=True, + layer_module="Qwen2MoTDecoderLayer", + freeze_und=False, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + rms_norm_eps=rms_norm_eps, + use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + use_sliding_window=use_sliding_window, + sliding_window=sliding_window, + max_window_layers=max_window_layers, + attention_dropout=attention_dropout, + is_causal=is_causal, + _attn_implementation=_attn_implementation, + **kwargs, + ) + self.qk_norm = qk_norm + self.layer_module = layer_module + + +class NaiveCache: + def __init__(self, num_layers): + self.key_cache = {k: None for k in range(num_layers)} + self.value_cache = {k: None for k in range(num_layers)} + + @property + def num_layers(self): + return len(self.key_cache) + + @property + def seq_lens(self): + if self.key_cache[0] is not None: + return self.key_cache[0].shape[0] + else: + return 0 + + +@dataclass +class BaseNavitOutputWithPast(ModelOutput): + packed_query_sequence: torch.FloatTensor = None + past_key_values: NaiveCache | None = None + + +class PackedAttentionMoT(Qwen2Attention): + def __init__(self, config, layer_idx: int | None = None): + super().__init__(config, layer_idx) + self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + + head_dim = self.head_dim + self.q_proj_moe_gen = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) + self.v_proj_moe_gen = nn.Linear(config.hidden_size, config.num_key_value_heads * head_dim, bias=True) + self.o_proj_moe_gen = nn.Linear(config.num_attention_heads * head_dim, config.hidden_size, bias=False) + + self.rotary_op = RotaryEmbedding(is_neox_style=True) + + def forward( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ): + if mode == "und": + packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim) + packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim) + packed_query_states = self.q_norm(packed_query_states) + packed_key_states = self.k_norm(packed_key_states) + elif mode == "gen": + packed_query_sequence = packed_query_sequence.to(torch.bfloat16) + packed_query_states = packed_query_sequence.new_zeros( + (packed_query_sequence.shape[0], self.num_heads * self.head_dim) + ) + packed_key_states = packed_query_sequence.new_zeros( + (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) + packed_value_states = packed_query_sequence.new_zeros( + (packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim) + ) + + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + + packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence) + packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence) + + packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence) + packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence) + + packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence) + packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence) + + packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) + packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim) + packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim) + + packed_query_states = packed_query_states.to(torch.float32) + packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) + packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen( + packed_query_states[packed_vae_token_indexes] + ) + + packed_key_states = packed_key_states.to(torch.float32) + packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) + packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen( + packed_key_states[packed_vae_token_indexes] + ) + + cos, sin = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] + packed_query_states = self.rotary_op(packed_query_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + packed_key_states = self.rotary_op(packed_key_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + + packed_query_states = packed_query_states.to(torch.bfloat16) + packed_key_states = packed_key_states.to(torch.bfloat16) + packed_value_states = packed_value_states.to(torch.bfloat16) + + if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: + past_key_states = past_key_values.key_cache[self.layer_idx] + past_value_states = past_key_values.value_cache[self.layer_idx] + + seqlens = sum(query_lens) + sum(key_values_lens) + merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim]) + merged_key_states[packed_query_indexes] = packed_key_states + merged_key_states[packed_key_value_indexes] = past_key_states + merged_value_states[packed_query_indexes] = packed_value_states + merged_value_states[packed_key_value_indexes] = past_value_states + key_values_lens = key_values_lens + query_lens + else: + merged_key_states = packed_key_states + merged_value_states = packed_value_states + key_values_lens = query_lens + + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) + + packed_attn_output = flash_attn_varlen_func( + q=packed_query_states, + k=merged_key_states, + v=merged_value_states, + cu_seqlens_q=cu_seqlens_q.to(torch.int32), + cu_seqlens_k=cu_seqlens_k.to(torch.int32), + max_seqlen_q=max(query_lens).item(), + max_seqlen_k=max(key_values_lens).item(), + causal=is_causal, + ) + packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size) + if mode == "und": + packed_attn_output = self.o_proj(packed_attn_output) + elif mode == "gen": + packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes]) + packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen( + packed_attn_output[packed_vae_token_indexes] + ) + + if update_past_key_values: + past_key_values.key_cache[self.layer_idx] = merged_key_states + past_key_values.value_cache[self.layer_idx] = merged_value_states + + return packed_attn_output, past_key_values + + +class Qwen2MoTDecoderLayer(nn.Module): + def __init__( + self, + config, + layer_idx: int | None = None, + attn_module: Qwen2Attention | None = PackedAttentionMoT, + ): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = attn_module(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.mlp_moe_gen = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_embeddings: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + residual = packed_query_sequence + if mode == "und": + packed_query_sequence = self.input_layernorm(packed_query_sequence) + elif mode == "gen": + packed_query_sequence_ = torch.zeros_like(packed_query_sequence) + packed_query_sequence_[packed_text_indexes] = self.input_layernorm( + packed_query_sequence[packed_text_indexes] + ) + packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen( + packed_query_sequence[packed_vae_token_indexes] + ) + packed_query_sequence = packed_query_sequence_ + + # Self Attention + packed_query_sequence, past_key_values = self.self_attn( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + mode=mode, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + packed_query_sequence = residual + packed_query_sequence + + # Fully Connected + residual = packed_query_sequence + if mode == "und": + packed_query_sequence = self.post_attention_layernorm(packed_query_sequence) + packed_query_sequence = self.mlp(packed_query_sequence) + elif mode == "gen": + packed_text_query_sequence = packed_query_sequence[packed_text_indexes] + packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] + packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16) + packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to( + torch.bfloat16 + ) + + packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) + packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) + packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) + packed_query_sequence = packed_query_sequence_ + + packed_query_sequence = residual + packed_query_sequence + + return packed_query_sequence, past_key_values + + +class Qwen2MoTModel(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.use_moe = "Mo" in config.layer_module + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [ + Qwen2MoTDecoderLayer(config, layer_idx, attn_module=PackedAttentionMoT) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.use_moe: + self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_ids: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + # create position embeddings to be shared across the decoder layers + cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0)) + cos = cos.squeeze(0) + sin = sin.squeeze(0) + packed_query_position_embeddings = (cos, sin) + + extra_inputs = {} + if self.use_moe: + extra_inputs.update(mode=mode) + if mode == "gen": + assert packed_vae_token_indexes is not None + assert packed_text_indexes is not None + extra_inputs.update( + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + + for layer_idx, decoder_layer in enumerate(self.layers): + packed_query_sequence, past_key_values = decoder_layer( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_embeddings=packed_query_position_embeddings, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + **extra_inputs, + ) + + if self.use_moe: + if mode == "und": + packed_query_sequence = self.norm(packed_query_sequence) + elif mode == "gen": + packed_query_sequence_ = torch.zeros_like(packed_query_sequence) + packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) + packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen( + packed_query_sequence[packed_vae_token_indexes] + ) + packed_query_sequence = packed_query_sequence_ + else: + packed_query_sequence = self.norm(packed_query_sequence) + + return BaseNavitOutputWithPast( + packed_query_sequence=packed_query_sequence, + past_key_values=past_key_values, + ) + + +class Qwen2MoTForCausalLM(Qwen2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2MoTModel(config) + self.vocab_size = config.vocab_size + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + packed_query_sequence: torch.Tensor, + query_lens: torch.Tensor, + packed_query_position_ids: torch.Tensor, + packed_query_indexes: torch.Tensor, + past_key_values: NaiveCache | None = None, + key_values_lens: torch.Tensor | None = None, + packed_key_value_indexes: torch.Tensor | None = None, + update_past_key_values=True, + is_causal=True, + mode="und", + packed_vae_token_indexes=None, + packed_text_indexes=None, + ) -> BaseNavitOutputWithPast: + outputs = self.model( + packed_query_sequence=packed_query_sequence, + query_lens=query_lens, + packed_query_position_ids=packed_query_position_ids, + packed_query_indexes=packed_query_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=update_past_key_values, + is_causal=is_causal, + mode=mode, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_text_indexes=packed_text_indexes, + ) + + return outputs + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class PositionEmbedding(nn.Module): + def __init__(self, max_num_patch_per_side, hidden_size): + super().__init__() + self.max_num_patch_per_side = max_num_patch_per_side + self.hidden_size = hidden_size + self.pos_embed = nn.Parameter(torch.zeros(max_num_patch_per_side**2, hidden_size), requires_grad=False) + self._init_weights() + + def _init_weights(self): + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + + def forward(self, position_ids): + return self.pos_embed[position_ids] + + +def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): + num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size + coords_h = torch.arange(0, num_patches_h) + coords_w = torch.arange(0, num_patches_w) + pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() + return pos_ids + + +class BagelConfig(PretrainedConfig): + def __init__( + self, + llm_config=None, + vae_config=None, + latent_patch_size=2, + max_latent_size=32, + timestep_shift=1.0, + **kwargs, + ): + super().__init__(**kwargs) + self.llm_config = llm_config + self.vae_config = vae_config + self.latent_patch_size = latent_patch_size + self.max_latent_size = max_latent_size + self.timestep_shift = timestep_shift + + +class Bagel(torch.nn.Module): + config_class = BagelConfig + base_model_prefix = "bagel" + + def __init__(self, language_model, config: BagelConfig): + super().__init__() + self.language_model = language_model + self.hidden_size = config.llm_config.hidden_size + self.use_moe = "Mo" in config.llm_config.layer_module + self.num_heads = config.llm_config.num_attention_heads + + self.latent_patch_size = config.latent_patch_size + self.timestep_shift = config.timestep_shift + self.latent_downsample = config.vae_config.downsample * config.latent_patch_size + self.max_latent_size = config.max_latent_size + self.latent_channel = config.vae_config.z_channels + self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel + self.time_embedder = TimestepEmbedder(self.hidden_size) + self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) + self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) + self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) + + self.get_flattened_position_ids = get_flattened_position_ids_extrapolate + + self.config = config + self._init_weights() + + def _init_weights(self): + nn.init.constant_(self.llm2vae.weight, 0) + nn.init.constant_(self.llm2vae.bias, 0) + + def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids): + packed_text_ids = list() + packed_text_position_ids = list() + text_token_lens = list() + packed_text_indexes = list() + packed_key_value_indexes = list() + + curr = 0 + newlens, new_rope = list(), list() + for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + text_ids = tokenizer.encode(prompt) + text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]] + text_token_lens.append(len(text_ids)) + packed_text_ids.extend(text_ids) + packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids))) + packed_text_indexes.extend(range(curr, curr + len(text_ids))) + newlens.append(curr_kvlen + len(text_ids)) + new_rope.append(curr_position_id + len(text_ids)) + curr += len(text_ids) + + generation_input = { + "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int), + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + } + + return generation_input, newlens, new_rope + + @torch.no_grad + def forward_cache_update_text( + self, + past_key_values: NaiveCache, + packed_text_ids: torch.IntTensor, + packed_text_position_ids: torch.LongTensor, + text_token_lens: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_key_value_indexes: torch.LongTensor, + key_values_lens: torch.IntTensor, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + + extra_inputs = {} + if self.use_moe: + extra_inputs = {"mode": "und"} + + output = self.language_model.forward( + packed_query_sequence=packed_text_embedding, + query_lens=text_token_lens, + packed_query_position_ids=packed_text_position_ids, + packed_query_indexes=packed_text_indexes, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + key_values_lens=key_values_lens, + update_past_key_values=True, + is_causal=True, + **extra_inputs, + ) + past_key_values = output.past_key_values + + return past_key_values + + def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None): + packed_text_ids, packed_text_indexes = list(), list() + packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() + packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() + packed_key_value_indexes = list() + + query_curr = curr = 0 + for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) + curr += curr_kvlen + + packed_text_ids.append(new_token_ids["start_of_image"]) + packed_text_indexes.append(query_curr) + + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + vae_position_ids = self.get_flattened_position_ids( + H, W, self.latent_downsample, max_num_patches_per_side=self.max_latent_size + ) + packed_vae_position_ids.append(vae_position_ids) + + h, w = H // self.latent_downsample, W // self.latent_downsample + num_image_tokens = h * w + + packed_init_noises.append(torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size**2)) + packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) + packed_seqlens.append(num_image_tokens + 2) + + packed_indexes.extend(range(curr, curr + num_image_tokens)) + curr += num_image_tokens + query_curr += num_image_tokens + + packed_text_ids.append(new_token_ids["end_of_image"]) + packed_text_indexes.append(query_curr) + + packed_indexes.append(curr) + curr += 1 + query_curr += 1 + + packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) + + # Construct Output + generation_input = { + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), + "packed_init_noises": torch.cat(packed_init_noises, dim=0), + "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), + "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), + } + + return generation_input + + def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids): + return self.prepare_input(curr_kvlens, curr_rope, image_sizes, new_token_ids) + + @torch.no_grad + def generate_image( + self, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_init_noises: torch.Tensor, + packed_vae_position_ids: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_seqlens: torch.IntTensor, + packed_position_ids: torch.LongTensor, + packed_indexes: torch.LongTensor, + past_key_values: NaiveCache, + key_values_lens: torch.IntTensor, + packed_key_value_indexes: torch.LongTensor, + num_timesteps: int = 24, + timestep_shift: float = 1.0, + ): + model_pred_cache_dic, model_pred_current = None, None + model_pred_text_cache_dic, model_pred_text_current = None, None + model_pred_img_cache_dic, model_pred_img_current = None, None + + x_t = packed_init_noises + + timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device) + timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps) + dts = timesteps[:-1] - timesteps[1:] + timesteps = timesteps[:-1] + + for i, t in enumerate(timesteps): + timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device) + v_t = self._forward_flow( + x_t=x_t, + timestep=timestep, + packed_vae_token_indexes=packed_vae_token_indexes, + packed_vae_position_ids=packed_vae_position_ids, + packed_text_ids=packed_text_ids, + packed_text_indexes=packed_text_indexes, + packed_position_ids=packed_position_ids, + packed_indexes=packed_indexes, + packed_seqlens=packed_seqlens, + key_values_lens=key_values_lens, + past_key_values=past_key_values, + packed_key_value_indexes=packed_key_value_indexes, + # cache + model_pred_cache_dic=model_pred_cache_dic, + model_pred_current=model_pred_current, + model_pred_text_cache_dic=model_pred_text_cache_dic, + model_pred_text_current=model_pred_text_current, + model_pred_img_cache_dic=model_pred_img_cache_dic, + model_pred_img_current=model_pred_img_current, + ) + + x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise + + unpacked_latent = x_t.split((packed_seqlens - 2).tolist()) + return unpacked_latent + + @torch.no_grad + def _forward_flow( + self, + x_t: torch.Tensor, + timestep: torch.LongTensor, + packed_vae_token_indexes: torch.LongTensor, + packed_vae_position_ids: torch.LongTensor, + packed_text_ids: torch.LongTensor, + packed_text_indexes: torch.LongTensor, + packed_indexes: torch.LongTensor, + packed_position_ids: torch.LongTensor, + packed_seqlens: torch.IntTensor, + key_values_lens: torch.IntTensor, + past_key_values: NaiveCache, + packed_key_value_indexes: torch.LongTensor, + # cache + model_pred_cache_dic: dict[str, Any] | None = None, + model_pred_current: int | None = None, + model_pred_text_cache_dic: dict[str, Any] | None = None, + model_pred_text_current: int | None = None, + model_pred_img_cache_dic: dict[str, Any] | None = None, + model_pred_img_current: int | None = None, + ): + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) + packed_sequence[packed_text_indexes] = packed_text_embedding + + assert timestep.unique().shape[0] == 1 + packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) + packed_timestep_embeds = self.time_embedder(timestep) + x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed + if x_t.dtype != packed_sequence.dtype: + x_t = x_t.to(packed_sequence.dtype) + packed_sequence[packed_vae_token_indexes] = x_t + + extra_inputs = {} + if self.use_moe: + extra_inputs = { + "mode": "gen", + "packed_vae_token_indexes": packed_vae_token_indexes, + "packed_text_indexes": packed_text_indexes, + } + + output = self.language_model.forward( + packed_query_sequence=packed_sequence, + query_lens=packed_seqlens, + packed_query_position_ids=packed_position_ids, + packed_query_indexes=packed_indexes, + past_key_values=past_key_values, + key_values_lens=key_values_lens, + packed_key_value_indexes=packed_key_value_indexes, + update_past_key_values=False, + is_causal=False, + **extra_inputs, + ) + v_t = self.llm2vae(output.packed_query_sequence) + v_t = v_t[packed_vae_token_indexes] + + return v_t diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py new file mode 100644 index 000000000..82e68f70b --- /dev/null +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -0,0 +1,456 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +BagelPipeline implementation for vLLM-Omni. +""" + +from __future__ import annotations + +import json +import os +from collections.abc import Iterable +from dataclasses import dataclass +from math import isqrt + +import torch +from PIL import Image +from torch import nn +from transformers import AutoTokenizer +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific + +from .autoencoder import AutoEncoder, AutoEncoderParams +from .bagel_transformer import Bagel, BagelConfig, NaiveCache, Qwen2MoTConfig, Qwen2MoTForCausalLM + +logger = init_logger(__name__) + + +@dataclass +class BagelGenParams: + num_timesteps: int = 50 + timestep_shift: float = 1.0 + + +def add_special_tokens(tokenizer): + all_special_tokens = [] + for k, v in tokenizer.special_tokens_map.items(): + if isinstance(v, str): + all_special_tokens.append(v) + elif isinstance(v, list): + all_special_tokens += v + + new_tokens = [] + + if "<|im_start|>" not in all_special_tokens: + new_tokens.append("<|im_start|>") + + if "<|im_end|>" not in all_special_tokens: + new_tokens.append("<|im_end|>") + + if "<|vision_start|>" not in all_special_tokens: + new_tokens.append("<|vision_start|>") + + if "<|vision_end|>" not in all_special_tokens: + new_tokens.append("<|vision_end|>") + + num_new_tokens = tokenizer.add_tokens(new_tokens) + bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") + eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + start_of_image = tokenizer.convert_tokens_to_ids("<|vision_start|>") + end_of_image = tokenizer.convert_tokens_to_ids("<|vision_end|>") + + new_token_ids = dict( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + start_of_image=start_of_image, + end_of_image=end_of_image, + ) + + return tokenizer, new_token_ids, num_new_tokens + + +def get_bagel_post_process_func(od_config: OmniDiffusionConfig): + # BagelPipeline returns PIL.Image.Image directly. + def post_process_func(x): + return x + + return post_process_func + + +@dataclass +class _VaeCfg: + z_channels: int = 16 + downsample: int = 8 + + +def default_ae_params() -> AutoEncoderParams: + return AutoEncoderParams( + resolution=256, + in_channels=3, + downsample=8, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ) + + +class BagelPipeline(nn.Module): + """Bagel generation pipeline (MoT) packaged for vllm-omni diffusion engine. + + This pipeline is self-contained and uses the ported Bagel core files. + """ + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__() + self.od_config = od_config + self.device = get_local_device() + + model = od_config.model + local_files_only = os.path.exists(model) + if local_files_only: + model_path = model + else: + # Download everything required (ema.safetensors, ae.safetensors, tokenizer files, configs). + model_path = download_weights_from_hf_specific(model, od_config.revision, ["*"]) + + # Load Bagel top-level config for VAE settings. + cfg_path = os.path.join(model_path, "config.json") + with open(cfg_path, encoding="utf-8") as f: + bagel_cfg = json.load(f) + + vae_cfg_dict = bagel_cfg.get("vae_config") or {} + vae_cfg = _VaeCfg( + z_channels=int(vae_cfg_dict.get("z_channels", 16)), + downsample=int(vae_cfg_dict.get("downsample", 8)), + ) + + # LLM config: Bagel MoT requires explicitly setting layer_module + llm_cfg_path = os.path.join(model_path, "llm_config.json") + llm_config = Qwen2MoTConfig.from_json_file(llm_cfg_path) + llm_config.qk_norm = True + llm_config.tie_word_embeddings = False + # Allow overriding from vllm-omni config if user wants MoE/vanilla. + llm_config.layer_module = od_config.override_transformer_cls_name or "Qwen2MoTDecoderLayer" + + # Tokenizer and special tokens. + # Bagel uses a Qwen2 tokenizer variant; prefer trust_remote_code to get the + # correct tokenizer implementation from the checkpoint repo when available. + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + local_files_only=True, + trust_remote_code=True, + ) + self.tokenizer, self.new_token_ids, _ = add_special_tokens(self.tokenizer) + + try: + tok_len = len(self.tokenizer) + except Exception: # pragma: no cover - very old tokenizers + tok_len = getattr(self.tokenizer, "vocab_size", llm_config.vocab_size) + required_max_id = max(int(v) for v in self.new_token_ids.values()) + llm_config.vocab_size = max( + int(getattr(llm_config, "vocab_size", tok_len)), + int(tok_len), + int(required_max_id + 1), + ) + + self.language_model = Qwen2MoTForCausalLM(llm_config) + ae_params: AutoEncoderParams = default_ae_params() + self.vae = AutoEncoder(ae_params) + + self.bagel = Bagel( + language_model=self.language_model, + config=BagelConfig( + llm_config=llm_config, + vae_config=vae_cfg, + latent_patch_size=int(bagel_cfg.get("latent_patch_size", 2)), + max_latent_size=int(bagel_cfg.get("max_latent_size", 32)), + timestep_shift=float(bagel_cfg.get("timestep_shift", 1.0)), + ), + ) + + # Let vLLM loader download and stream all *.safetensors under model root. + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder=None, + revision=od_config.revision, + prefix="", + fall_back_to_pt=False, + ) + ] + + self.to(self.device) + + @staticmethod + def _decode_image_from_latent( + bagel: Bagel, vae: AutoEncoder, latent: torch.Tensor, image_shape: tuple[int, int] + ) -> Image.Image: + H, W = image_shape + h, w = H // bagel.latent_downsample, W // bagel.latent_downsample + p = bagel.latent_patch_size + c = bagel.latent_channel + latent = latent.reshape(1, h, w, p, p, c) + latent = torch.einsum("nhwpqc->nchpwq", latent) + latent = latent.reshape(1, c, h * p, w * p) + + # Cast to VAE dtype (e.g. bfloat16) as latents might remain float32 from generation loop + vae_dtype = next(vae.parameters()).dtype + latent = latent.to(vae_dtype) + + image = vae.decode(latent) + image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255 + return Image.fromarray(image.to(torch.uint8).cpu().numpy()) + + @torch.inference_mode() + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + prompt = req.prompt or "" + if isinstance(prompt, list): + # vllm-omni request supports list; Bagel pipeline currently supports first prompt. + prompt = prompt[0] if prompt else "" + max_hw = int(self.bagel.max_latent_size * self.bagel.latent_downsample) + if req.height is None and req.width is None: + height = width = max_hw + else: + height = int(req.height) if req.height is not None else max_hw + width = int(req.width) if req.width is not None else max_hw + if height > max_hw or width > max_hw: + raise ValueError( + f"Requested resolution {height}x{width} exceeds Bagel checkpoint limit " + f"{max_hw}x{max_hw} (max_latent_size={self.bagel.max_latent_size}, " + f"latent_downsample={self.bagel.latent_downsample})." + ) + image_shape = (height, width) + + # Map request params to Bagel gen params (defaults follow Bagel inferencer) + gen_params = BagelGenParams( + num_timesteps=int(req.num_inference_steps or 50), + timestep_shift=3.0, + ) + + gen_context = { + "kv_lens": [0], + "ropes": [0], + "past_key_values": NaiveCache(self.bagel.config.llm_config.num_hidden_layers), + } + + # Add text prompt (prefill) on gen context. + # [Omni] Check for injected KV Cache from remote transfer + injected_kv = getattr(req, "past_key_values", None) + injected_metadata = getattr(req, "kv_metadata", None) + + if injected_kv is not None and injected_metadata is not None: + logger.info("Using injected KV Cache from remote transfer") + + # [Fix] Reconstruct NaiveCache if injected_kv is a dict of tensors + current_cache = gen_context["past_key_values"] + if isinstance(current_cache, NaiveCache) and isinstance(injected_kv, dict): + # injected_kv keys are like "0_k", "0_v", "1_k", ... + for key_name, tensor in injected_kv.items(): + try: + # Parse layer index and type + parts = key_name.split("_") + if len(parts) < 2: + continue + + layer_idx = int(parts[0]) + cache_type = parts[1] # 'k' or 'v' + + # Ensure tensor is on correct device + if tensor.device != self.device: + tensor = tensor.to(self.device) + + if layer_idx in current_cache.key_cache: + if cache_type == "k": + current_cache.key_cache[layer_idx] = tensor + elif cache_type == "v": + current_cache.value_cache[layer_idx] = tensor + elif cache_type == "kv": + # Fallback if sender sent mixed/packed (less ideal) + current_cache.key_cache[layer_idx] = tensor + current_cache.value_cache[layer_idx] = tensor + except Exception as e: + logger.warning(f"Failed to load injected KV part {key_name}: {e}") + + if "kv_lens" in injected_metadata: + val = injected_metadata["kv_lens"] + if isinstance(val, (int, float)): + gen_context["kv_lens"] = [int(val)] + else: + gen_context["kv_lens"] = list(val) + + if "ropes" in injected_metadata: + val = injected_metadata["ropes"] + if isinstance(val, (int, float)): + gen_context["ropes"] = [int(val)] + else: + gen_context["ropes"] = list(val) + + else: + # Standard local prefill path + generation_input, newlens, new_rope = self.bagel.prepare_prompts( + curr_kvlens=gen_context["kv_lens"], + curr_rope=gen_context["ropes"], + prompts=[prompt], + tokenizer=self.tokenizer, + new_token_ids=self.new_token_ids, + ) + # Fail fast with a clear error instead of CUDA gather OOB. + max_tid = int(generation_input["packed_text_ids"].max().item()) + emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + if max_tid >= emb_n: + raise ValueError( + "Tokenizer/model vocab mismatch: max token id " + f"{max_tid} >= embed_tokens size {emb_n}. " + "This usually means you're not using the tokenizer shipped with the Bagel checkpoint, " + "or llm_config.vocab_size is smaller than the tokenizer vocab." + ) + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(self.device) + with torch.autocast(device_type="cuda", enabled=self.device.type == "cuda", dtype=torch.bfloat16): + gen_context["past_key_values"] = self.bagel.forward_cache_update_text( + gen_context["past_key_values"], **generation_input + ) + gen_context["kv_lens"] = newlens + gen_context["ropes"] = new_rope + + if req.seed is not None: + torch.manual_seed(req.seed) + if self.device.type == "cuda": + torch.cuda.manual_seed(req.seed) + + # Prepare latent query and run flow + generation_input = self.bagel.prepare_vae_latent( + curr_kvlens=gen_context["kv_lens"], + curr_rope=gen_context["ropes"], + image_sizes=[image_shape], + new_token_ids=self.new_token_ids, + ) + # Fail fast for special tokens used by the image path as well. + max_tid_img = int(generation_input["packed_text_ids"].max().item()) + emb_n = int(self.language_model.model.embed_tokens.weight.shape[0]) + if max_tid_img >= emb_n: + raise ValueError( + "Tokenizer/model vocab mismatch (image path): max token id " + f"{max_tid_img} >= embed_tokens size {emb_n}. " + "This indicates the tokenizer token IDs do not match the checkpoint embeddings." + ) + # Position ids must be non-negative; negative ids can trigger CUDA gather OOB inside RoPE. + min_pid = int(generation_input["packed_position_ids"].min().item()) + if min_pid < 0: + raise ValueError(f"Invalid packed_position_ids: min={min_pid} (must be >= 0)") + # Latent position embedding bounds check: ids must be < max_latent_size^2. + max_lat_pid = int(generation_input["packed_vae_position_ids"].max().item()) + max_lat_pid_allowed = int(self.bagel.max_latent_size * self.bagel.max_latent_size) - 1 + if max_lat_pid > max_lat_pid_allowed: + raise ValueError( + "Invalid packed_vae_position_ids (latent position embedding OOB): " + f"max={max_lat_pid} > allowed_max={max_lat_pid_allowed}. " + f"Requested image_shape={image_shape}, max_latent_size={self.bagel.max_latent_size}." + ) + for k, v in generation_input.items(): + if torch.is_tensor(v): + generation_input[k] = v.to(self.device) + + with torch.autocast(device_type="cuda", enabled=self.device.type == "cuda", dtype=torch.bfloat16): + latents = self.bagel.generate_image( + past_key_values=gen_context["past_key_values"], + num_timesteps=gen_params.num_timesteps, + timestep_shift=gen_params.timestep_shift, + **generation_input, + ) + + # Decode first sample + img = self._decode_image_from_latent(self.bagel, self.vae, latents[0], image_shape) + return DiffusionOutput(output=img) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + state = self.state_dict() + allowed = set(state.keys()) + shapes = {k: tuple(v.shape) for k, v in state.items()} + + def _normalize_name(name: str) -> str: + # Common wrappers/prefixes in checkpoints. + for pfx in ("module.", "model."): + if name.startswith(pfx): + name = name[len(pfx) :] + # Common component renames across repos. + if name.startswith("vae_model."): + name = "vae." + name[len("vae_model.") :] + # Bagel `ae.safetensors` commonly stores AE weights without a top-level prefix. + # Map them into this pipeline's `vae.*` namespace. + if name.startswith("encoder.") or name.startswith("decoder."): + name = "vae." + name + return name + + def _iter_candidate_names(name: str) -> Iterable[str]: + """Yield candidate parameter names in this pipeline for a checkpoint key. + + The upstream Bagel repo typically stores Bagel-core layers (time_embedder, + latent_pos_embed, vae2llm, llm2vae, etc.) at the top-level of the model, + while this vllm-omni integration nests them under `self.bagel`. + """ + n = _normalize_name(name) + yield n + + # Map Bagel core layers from top-level -> `bagel.*` namespace. + for pfx in ("time_embedder.", "latent_pos_embed.", "vae2llm.", "llm2vae."): + if n.startswith(pfx): + yield "bagel." + n + break + + def _filtered_weights(): + total = 0 + kept = 0 + shape_mismatch = 0 + for name, tensor in weights: + total += 1 + picked = None + for cand in _iter_candidate_names(name): + if cand in allowed: + # Only accept if tensor shape matches target param/buffer shape. + if tuple(tensor.shape) == shapes.get(cand): + picked = cand + break + else: + if cand.endswith("bagel.latent_pos_embed.pos_embed") and tensor.ndim == 2: + npos, hdim = tensor.shape + side = isqrt(int(npos)) + if side * side == int(npos) and hdim == int(self.bagel.hidden_size): + param = self.bagel.latent_pos_embed.pos_embed + # Resize in-place to keep the same Parameter object. + param.data = param.data.new_empty((npos, hdim)) + # Update model bookkeeping so position-id generation matches. + self.bagel.max_latent_size = int(side) + if hasattr(self.bagel, "config"): + setattr(self.bagel.config, "max_latent_size", int(side)) + if hasattr(self.bagel.latent_pos_embed, "max_num_patch_per_side"): + self.bagel.latent_pos_embed.max_num_patch_per_side = int(side) + shapes[cand] = (npos, hdim) + picked = cand + break + shape_mismatch += 1 + # Keep this quiet; shape mismatches are expected for ignored modules. + if picked is not None: + kept += 1 + yield picked, tensor + # else: ignore extra weights (e.g. connector/vision/und) + logger.info_once( + "BagelPipeline weight filter kept %d/%d tensors (shape mismatches seen: %d)", + kept, + total, + shape_mismatch, + ) + + loader = AutoWeightsLoader(self) + return loader.load_weights(_filtered_weights()) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 87f674e9f..10030de4a 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -54,6 +54,11 @@ "pipeline_longcat_image", "LongCatImagePipeline", ), + "BagelPipeline": ( + "bagel", + "pipeline_bagel", + "BagelPipeline", + ), "LongCatImageEditPipeline": ( "longcat_image", "pipeline_longcat_image_edit", @@ -106,6 +111,7 @@ def initialize_model( "WanPipeline": "get_wan22_post_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", "LongCatImagePipeline": "get_longcat_image_post_process_func", + "BagelPipeline": "get_bagel_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", } diff --git a/vllm_omni/diffusion/utils/hf_utils.py b/vllm_omni/diffusion/utils/hf_utils.py index 6fc22a779..cfc1807a1 100644 --- a/vllm_omni/diffusion/utils/hf_utils.py +++ b/vllm_omni/diffusion/utils/hf_utils.py @@ -14,6 +14,19 @@ def load_diffusers_config(model_name) -> dict: return config +def _looks_like_bagel(model_name: str) -> bool: + """Best-effort detection for Bagel (non-diffusers) diffusion models.""" + try: + cfg = get_hf_file_to_dict("config.json", model_name) + except Exception: + return False + model_type = cfg.get("model_type") + if model_type == "bagel": + return True + architectures = cfg.get("architectures") or [] + return "BagelForConditionalGeneration" in architectures + + @lru_cache def is_diffusion_model(model_name: str) -> bool: """Check if a model is a diffusion model. @@ -61,4 +74,6 @@ def is_diffusion_model(model_name: str) -> bool: except Exception as e: logger.debug("Failed to load diffusers config via DiffusionPipeline: %s", e) - return False + # Bagel is not a diffusers pipeline (no model_index.json), but is still a + # diffusion-style model in vllm-omni. Detect it via config.json. + return _looks_like_bagel(model_name) diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 9f149f37b..1d71e23cb 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -48,17 +48,34 @@ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): self.od_config = od_config - config_dict = get_hf_file_to_dict( - "model_index.json", - od_config.model, - ) - od_config.model_class_name = config_dict.get("_class_name", None) - od_config.update_multimodal_support() - tf_config_dict = get_hf_file_to_dict( - "transformer/config.json", - od_config.model, - ) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + # Diffusers-style models expose `model_index.json` with `_class_name`. + # Bagel models (and other non-diffusers) typically expose `config.json`. + try: + config_dict = get_hf_file_to_dict( + "model_index.json", + od_config.model, + ) + od_config.model_class_name = config_dict.get("_class_name", None) + od_config.update_multimodal_support() + + tf_config_dict = get_hf_file_to_dict( + "transformer/config.json", + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + except (AttributeError, OSError, ValueError): + cfg = get_hf_file_to_dict("config.json", od_config.model) + if cfg is None: + raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") + + model_type = cfg.get("model_type") + architectures = cfg.get("architectures") or [] + if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: + od_config.model_class_name = "BagelPipeline" + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() + else: + raise self.engine: DiffusionEngine = DiffusionEngine.make_engine(od_config) @@ -76,11 +93,20 @@ def generate( raise ValueError("Prompt must be a string or a list of strings") requests: list[OmniDiffusionRequest] = [] - for p in prompts: + + # Check if request_id is provided in kwargs + request_id = kwargs.get("request_id") + + for i, p in enumerate(prompts): + req_kwargs = kwargs.copy() + if request_id is None: + # Generate default ID consistent with OmniLLM: "{i}_{uuid}" + req_kwargs["request_id"] = f"{i}" + requests.append( prepare_requests( p, - **kwargs, + **req_kwargs, ) ) logger.info(f"Prepared {len(requests)} requests for generation.")