|
1 | 1 | import torch |
2 | 2 | from torch import Tensor, nn |
3 | 3 |
|
4 | | -from comfy.ldm.flux.math import attention |
5 | 4 | from comfy.ldm.flux.layers import ( |
6 | 5 | MLPEmbedder, |
7 | 6 | RMSNorm, |
8 | | - QKNorm, |
9 | | - SelfAttention, |
10 | 7 | ModulationOut, |
11 | 8 | ) |
12 | 9 |
|
@@ -48,124 +45,6 @@ def forward(self, x: Tensor) -> Tensor: |
48 | 45 | return x |
49 | 46 |
|
50 | 47 |
|
51 | | -class DoubleStreamBlock(nn.Module): |
52 | | - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): |
53 | | - super().__init__() |
54 | | - |
55 | | - mlp_hidden_dim = int(hidden_size * mlp_ratio) |
56 | | - self.num_heads = num_heads |
57 | | - self.hidden_size = hidden_size |
58 | | - self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) |
59 | | - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) |
60 | | - |
61 | | - self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) |
62 | | - self.img_mlp = nn.Sequential( |
63 | | - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), |
64 | | - nn.GELU(approximate="tanh"), |
65 | | - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), |
66 | | - ) |
67 | | - |
68 | | - self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) |
69 | | - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) |
70 | | - |
71 | | - self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) |
72 | | - self.txt_mlp = nn.Sequential( |
73 | | - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), |
74 | | - nn.GELU(approximate="tanh"), |
75 | | - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), |
76 | | - ) |
77 | | - self.flipped_img_txt = flipped_img_txt |
78 | | - |
79 | | - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): |
80 | | - (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec |
81 | | - |
82 | | - # prepare image for attention |
83 | | - img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img)) |
84 | | - img_qkv = self.img_attn.qkv(img_modulated) |
85 | | - img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
86 | | - img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) |
87 | | - |
88 | | - # prepare txt for attention |
89 | | - txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt)) |
90 | | - txt_qkv = self.txt_attn.qkv(txt_modulated) |
91 | | - txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
92 | | - txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) |
93 | | - |
94 | | - # run actual attention |
95 | | - attn = attention(torch.cat((txt_q, img_q), dim=2), |
96 | | - torch.cat((txt_k, img_k), dim=2), |
97 | | - torch.cat((txt_v, img_v), dim=2), |
98 | | - pe=pe, mask=attn_mask, transformer_options=transformer_options) |
99 | | - |
100 | | - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] |
101 | | - |
102 | | - # calculate the img bloks |
103 | | - img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn)) |
104 | | - img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img)))) |
105 | | - |
106 | | - # calculate the txt bloks |
107 | | - txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn)) |
108 | | - txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt)))) |
109 | | - |
110 | | - if txt.dtype == torch.float16: |
111 | | - txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) |
112 | | - |
113 | | - return img, txt |
114 | | - |
115 | | - |
116 | | -class SingleStreamBlock(nn.Module): |
117 | | - """ |
118 | | - A DiT block with parallel linear layers as described in |
119 | | - https://arxiv.org/abs/2302.05442 and adapted modulation interface. |
120 | | - """ |
121 | | - |
122 | | - def __init__( |
123 | | - self, |
124 | | - hidden_size: int, |
125 | | - num_heads: int, |
126 | | - mlp_ratio: float = 4.0, |
127 | | - qk_scale: float = None, |
128 | | - dtype=None, |
129 | | - device=None, |
130 | | - operations=None |
131 | | - ): |
132 | | - super().__init__() |
133 | | - self.hidden_dim = hidden_size |
134 | | - self.num_heads = num_heads |
135 | | - head_dim = hidden_size // num_heads |
136 | | - self.scale = qk_scale or head_dim**-0.5 |
137 | | - |
138 | | - self.mlp_hidden_dim = int(hidden_size * mlp_ratio) |
139 | | - # qkv and mlp_in |
140 | | - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) |
141 | | - # proj and mlp_out |
142 | | - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) |
143 | | - |
144 | | - self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) |
145 | | - |
146 | | - self.hidden_size = hidden_size |
147 | | - self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) |
148 | | - |
149 | | - self.mlp_act = nn.GELU(approximate="tanh") |
150 | | - |
151 | | - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: |
152 | | - mod = vec |
153 | | - x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) |
154 | | - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) |
155 | | - |
156 | | - q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
157 | | - q, k = self.norm(q, k, v) |
158 | | - |
159 | | - # compute attention |
160 | | - attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) |
161 | | - # compute activation in mlp stream, cat again and run second linear layer |
162 | | - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) |
163 | | - x.addcmul_(mod.gate, output) |
164 | | - if x.dtype == torch.float16: |
165 | | - x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) |
166 | | - return x |
167 | | - |
168 | | - |
169 | 48 | class LastLayer(nn.Module): |
170 | 49 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): |
171 | 50 | super().__init__() |
|
0 commit comments