Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5165e84
try to support bagel diffusion
natureofnature Dec 12, 2025
6cd6951
try to fix bagel workflow
natureofnature Dec 12, 2025
6b410ad
try to fix output
natureofnature Dec 12, 2025
0f74329
Merge pull request #2 from natureofnature/wzliu_bagel_dev
princepride Dec 13, 2025
0adecc8
replace vllm_flash_attn
princepride Dec 13, 2025
90fb9fb
can generate perfect image
princepride Dec 14, 2025
d9623c1
Add Bagel model support and fix related issues
princepride Dec 15, 2025
66c0e31
add enable_cache_for_bagel
princepride Dec 15, 2025
cb028e8
Fix pre-commit errors and update Bagel model code
princepride Dec 18, 2025
156fb31
fix seed bug
princepride Dec 18, 2025
e97965d
fix pre-commit bug
princepride Dec 18, 2025
76f18dd
Support bagel ar in vllm-omni
natureofnature Dec 15, 2025
29f46c3
Merge pull request #3 from natureofnature/wzliu_bagel_dev
princepride Dec 20, 2025
deb7056
adjust bagel end2end.py
princepride Dec 20, 2025
3a619b5
remote useless code from diffusion bagel
princepride Dec 20, 2025
febd622
remote useless code from diffusion bagel
princepride Dec 20, 2025
1a32a04
remove duplicate code
princepride Dec 21, 2025
b893df0
remove duplicate code
princepride Dec 21, 2025
d86971e
remove duplicate code
princepride Dec 21, 2025
b8b16f4
Merge branch 'main' into wzliu_bagel_dev
natureofnature Dec 22, 2025
7f02505
fix version
natureofnature Dec 22, 2025
53468f9
try to enable sending kv cache
natureofnature Dec 22, 2025
6d6607b
Merge branch 'main' into wzliu_bagel_dev
natureofnature Dec 22, 2025
f34684a
add dit to stage
princepride Dec 22, 2025
cbe668a
fix some bug
princepride Dec 23, 2025
2a56fa1
add receiver on diffusion side
natureofnature Dec 22, 2025
929b5d4
Merge remote-tracking branch 'upstream/bagel-model' into wzliu_bagel_dev
natureofnature Dec 24, 2025
aae6b81
Add todo items, and dual test (ar+diffusion offline example test, not…
natureofnature Dec 24, 2025
a5728ba
Merge pull request #4 from natureofnature/wzliu_bagel_dev
princepride Dec 24, 2025
da2359c
Merge branch 'main' into bagel-model
princepride Dec 27, 2025
1d5d71f
Merge branch 'main' into bagel-model
hsliuustc0106 Dec 27, 2025
768f912
remove bagel stage related code
princepride Dec 27, 2025
8d82ded
remove bagel stage related code
princepride Dec 27, 2025
47a284d
remove bagel stage part code
princepride Dec 27, 2025
88ec0e4
adjust the code
princepride Dec 27, 2025
9bf1ef4
simplify qwen2_navit MoTLayer load logic
princepride Dec 27, 2025
845cb18
Refactored the Bagel file structure to align with the vLLM-omni diffu…
princepride Dec 29, 2025
60d7f2b
Merge branch 'main' into bagel-model
hsliuustc0106 Dec 31, 2025
b745719
Remove useless code
princepride Dec 29, 2025
4207714
add __init__ and use vllm_omni RotaryEmbedding
princepride Dec 31, 2025
f096aeb
Merge branch 'main' into bagel-model
hsliuustc0106 Dec 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ th {
|--------------|--------|-------------------|
| `Qwen3OmniMoeForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct` |
| `Qwen2_5OmniForConditionalGeneration` | Qwen2.5-Omni | `Qwen/Qwen2.5-Omni-7B`, `Qwen/Qwen2.5-Omni-3B` |
| `BagelForConditionalGeneration` | BAGEL (DiT-only) | `ByteDance-Seed/BAGEL-7B-MoT` |
| `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` |
| `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` |
| `QwenImageEditPlusPipeline` | Qwen-Image-Edit-2509 | `Qwen/Qwen-Image-Edit-2509` |
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,6 @@ extend-ignore-identifiers-re = [
".*MoBA",
".*temperal_downsample",
".*nothink.*",
".*NOTHINK.*"
".*NOTHINK.*",
".*nin.*",
]
360 changes: 360 additions & 0 deletions vllm_omni/diffusion/models/bagel/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
# Copyright (c) 2024 Black Forest Labs.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
#
# This modified file is released under the same license.

from dataclasses import dataclass

import torch
from einops import rearrange
from safetensors.torch import load_file as load_sft
from torch import Tensor, nn


@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
downsample: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float


def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)


class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels

self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)

b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)

return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)

def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))


class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels

self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)

h = self.norm2(h)
h = swish(h)
h = self.conv2(h)

if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)

return x + h


class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x


class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x


class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)

curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)

# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)

# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)

def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))

# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h


class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)

# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)

# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)

# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)

# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)

# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)

# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h


class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim

def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean


class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()

self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor

def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z

def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)

def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))


def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))


def load_ae(local_path: str) -> AutoEncoder:
ae_params = AutoEncoderParams(
resolution=256,
in_channels=3,
downsample=8,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)

# Loading the autoencoder
ae = AutoEncoder(ae_params)

if local_path is not None:
sd = load_sft(local_path)
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return ae, ae_params
Loading