Skip to content

Commit f609235

Browse files
Use same code for chroma and flux blocks so that optimizations are shared. (comfyanonymous#10746)
1 parent 1ef328c commit f609235

File tree

4 files changed

+31
-135
lines changed

4 files changed

+31
-135
lines changed

comfy/ldm/chroma/layers.py

Lines changed: 0 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import torch
22
from torch import Tensor, nn
33

4-
from comfy.ldm.flux.math import attention
54
from comfy.ldm.flux.layers import (
65
MLPEmbedder,
76
RMSNorm,
8-
QKNorm,
9-
SelfAttention,
107
ModulationOut,
118
)
129

@@ -48,124 +45,6 @@ def forward(self, x: Tensor) -> Tensor:
4845
return x
4946

5047

51-
class DoubleStreamBlock(nn.Module):
52-
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
53-
super().__init__()
54-
55-
mlp_hidden_dim = int(hidden_size * mlp_ratio)
56-
self.num_heads = num_heads
57-
self.hidden_size = hidden_size
58-
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
59-
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
60-
61-
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
62-
self.img_mlp = nn.Sequential(
63-
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
64-
nn.GELU(approximate="tanh"),
65-
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
66-
)
67-
68-
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
69-
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
70-
71-
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
72-
self.txt_mlp = nn.Sequential(
73-
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
74-
nn.GELU(approximate="tanh"),
75-
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
76-
)
77-
self.flipped_img_txt = flipped_img_txt
78-
79-
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
80-
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
81-
82-
# prepare image for attention
83-
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
84-
img_qkv = self.img_attn.qkv(img_modulated)
85-
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
86-
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
87-
88-
# prepare txt for attention
89-
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
90-
txt_qkv = self.txt_attn.qkv(txt_modulated)
91-
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
92-
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
93-
94-
# run actual attention
95-
attn = attention(torch.cat((txt_q, img_q), dim=2),
96-
torch.cat((txt_k, img_k), dim=2),
97-
torch.cat((txt_v, img_v), dim=2),
98-
pe=pe, mask=attn_mask, transformer_options=transformer_options)
99-
100-
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
101-
102-
# calculate the img bloks
103-
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
104-
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
105-
106-
# calculate the txt bloks
107-
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
108-
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
109-
110-
if txt.dtype == torch.float16:
111-
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
112-
113-
return img, txt
114-
115-
116-
class SingleStreamBlock(nn.Module):
117-
"""
118-
A DiT block with parallel linear layers as described in
119-
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
120-
"""
121-
122-
def __init__(
123-
self,
124-
hidden_size: int,
125-
num_heads: int,
126-
mlp_ratio: float = 4.0,
127-
qk_scale: float = None,
128-
dtype=None,
129-
device=None,
130-
operations=None
131-
):
132-
super().__init__()
133-
self.hidden_dim = hidden_size
134-
self.num_heads = num_heads
135-
head_dim = hidden_size // num_heads
136-
self.scale = qk_scale or head_dim**-0.5
137-
138-
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
139-
# qkv and mlp_in
140-
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
141-
# proj and mlp_out
142-
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
143-
144-
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
145-
146-
self.hidden_size = hidden_size
147-
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
148-
149-
self.mlp_act = nn.GELU(approximate="tanh")
150-
151-
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
152-
mod = vec
153-
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
154-
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
155-
156-
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
157-
q, k = self.norm(q, k, v)
158-
159-
# compute attention
160-
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
161-
# compute activation in mlp stream, cat again and run second linear layer
162-
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
163-
x.addcmul_(mod.gate, output)
164-
if x.dtype == torch.float16:
165-
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
166-
return x
167-
168-
16948
class LastLayer(nn.Module):
17049
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
17150
super().__init__()

comfy/ldm/chroma/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from comfy.ldm.flux.layers import (
1212
EmbedND,
1313
timestep_embedding,
14+
DoubleStreamBlock,
15+
SingleStreamBlock,
1416
)
1517

1618
from .layers import (
17-
DoubleStreamBlock,
1819
LastLayer,
19-
SingleStreamBlock,
2020
Approximator,
2121
ChromaModulationOut,
2222
)
@@ -90,6 +90,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
9090
self.num_heads,
9191
mlp_ratio=params.mlp_ratio,
9292
qkv_bias=params.qkv_bias,
93+
modulation=False,
9394
dtype=dtype, device=device, operations=operations
9495
)
9596
for _ in range(params.depth)
@@ -98,7 +99,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
9899

99100
self.single_blocks = nn.ModuleList(
100101
[
101-
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
102+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
102103
for _ in range(params.depth_single_blocks)
103104
]
104105
)

comfy/ldm/chroma_radiance/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010
from einops import repeat
1111
import comfy.ldm.common_dit
1212

13-
from comfy.ldm.flux.layers import EmbedND
13+
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
1414

1515
from comfy.ldm.chroma.model import Chroma, ChromaParams
1616
from comfy.ldm.chroma.layers import (
17-
DoubleStreamBlock,
18-
SingleStreamBlock,
1917
Approximator,
2018
)
2119
from .layers import (
@@ -89,14 +87,14 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
8987
dtype=dtype, device=device, operations=operations
9088
)
9189

92-
9390
self.double_blocks = nn.ModuleList(
9491
[
9592
DoubleStreamBlock(
9693
self.hidden_size,
9794
self.num_heads,
9895
mlp_ratio=params.mlp_ratio,
9996
qkv_bias=params.qkv_bias,
97+
modulation=False,
10098
dtype=dtype, device=device, operations=operations
10199
)
102100
for _ in range(params.depth)
@@ -109,6 +107,7 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
109107
self.hidden_size,
110108
self.num_heads,
111109
mlp_ratio=params.mlp_ratio,
110+
modulation=False,
112111
dtype=dtype, device=device, operations=operations,
113112
)
114113
for _ in range(params.depth_single_blocks)

comfy/ldm/flux/layers.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
130130

131131

132132
class DoubleStreamBlock(nn.Module):
133-
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
133+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
134134
super().__init__()
135135

136136
mlp_hidden_dim = int(hidden_size * mlp_ratio)
137137
self.num_heads = num_heads
138138
self.hidden_size = hidden_size
139-
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
139+
self.modulation = modulation
140+
141+
if self.modulation:
142+
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
143+
140144
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
141145
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
142146

@@ -147,7 +151,9 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
147151
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
148152
)
149153

150-
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
154+
if self.modulation:
155+
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
156+
151157
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
152158
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
153159

@@ -160,8 +166,11 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
160166
self.flipped_img_txt = flipped_img_txt
161167

162168
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
163-
img_mod1, img_mod2 = self.img_mod(vec)
164-
txt_mod1, txt_mod2 = self.txt_mod(vec)
169+
if self.modulation:
170+
img_mod1, img_mod2 = self.img_mod(vec)
171+
txt_mod1, txt_mod2 = self.txt_mod(vec)
172+
else:
173+
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
165174

166175
# prepare image for attention
167176
img_modulated = self.img_norm1(img)
@@ -236,6 +245,7 @@ def __init__(
236245
num_heads: int,
237246
mlp_ratio: float = 4.0,
238247
qk_scale: float = None,
248+
modulation=True,
239249
dtype=None,
240250
device=None,
241251
operations=None
@@ -258,10 +268,17 @@ def __init__(
258268
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
259269

260270
self.mlp_act = nn.GELU(approximate="tanh")
261-
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
271+
if modulation:
272+
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
273+
else:
274+
self.modulation = None
262275

263276
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
264-
mod, _ = self.modulation(vec)
277+
if self.modulation:
278+
mod, _ = self.modulation(vec)
279+
else:
280+
mod = vec
281+
265282
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
266283

267284
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

0 commit comments

Comments
 (0)