@@ -40,7 +40,7 @@ def forward(self, ids):
40
40
n_axes = ids .shape [- 1 ]
41
41
emb = torch .cat ([self .rope (ids [..., i ], self .axes_dim [i ], self .theta ) for i in range (n_axes )], dim = - 3 )
42
42
return emb .unsqueeze (1 )
43
-
43
+
44
44
45
45
46
46
class FluxJointAttention (torch .nn .Module ):
@@ -70,7 +70,7 @@ def apply_rope(self, xq, xk, freqs_cis):
70
70
xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
71
71
return xq_out .reshape (* xq .shape ).type_as (xq ), xk_out .reshape (* xk .shape ).type_as (xk )
72
72
73
- def forward (self , hidden_states_a , hidden_states_b , image_rotary_emb , ipadapter_kwargs_list = None ):
73
+ def forward (self , hidden_states_a , hidden_states_b , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
74
74
batch_size = hidden_states_a .shape [0 ]
75
75
76
76
# Part A
@@ -91,7 +91,7 @@ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_
91
91
92
92
q , k = self .apply_rope (q , k , image_rotary_emb )
93
93
94
- hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v )
94
+ hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask )
95
95
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
96
96
hidden_states = hidden_states .to (q .dtype )
97
97
hidden_states_b , hidden_states_a = hidden_states [:, :hidden_states_b .shape [1 ]], hidden_states [:, hidden_states_b .shape [1 ]:]
@@ -103,7 +103,7 @@ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_
103
103
else :
104
104
hidden_states_b = self .b_to_out (hidden_states_b )
105
105
return hidden_states_a , hidden_states_b
106
-
106
+
107
107
108
108
109
109
class FluxJointTransformerBlock (torch .nn .Module ):
@@ -129,12 +129,12 @@ def __init__(self, dim, num_attention_heads):
129
129
)
130
130
131
131
132
- def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , ipadapter_kwargs_list = None ):
132
+ def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
133
133
norm_hidden_states_a , gate_msa_a , shift_mlp_a , scale_mlp_a , gate_mlp_a = self .norm1_a (hidden_states_a , emb = temb )
134
134
norm_hidden_states_b , gate_msa_b , shift_mlp_b , scale_mlp_b , gate_mlp_b = self .norm1_b (hidden_states_b , emb = temb )
135
135
136
136
# Attention
137
- attn_output_a , attn_output_b = self .attn (norm_hidden_states_a , norm_hidden_states_b , image_rotary_emb , ipadapter_kwargs_list )
137
+ attn_output_a , attn_output_b = self .attn (norm_hidden_states_a , norm_hidden_states_b , image_rotary_emb , attn_mask , ipadapter_kwargs_list )
138
138
139
139
# Part A
140
140
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -147,7 +147,7 @@ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipad
147
147
hidden_states_b = hidden_states_b + gate_mlp_b * self .ff_b (norm_hidden_states_b )
148
148
149
149
return hidden_states_a , hidden_states_b
150
-
150
+
151
151
152
152
153
153
class FluxSingleAttention (torch .nn .Module ):
@@ -184,7 +184,7 @@ def forward(self, hidden_states, image_rotary_emb):
184
184
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
185
185
hidden_states = hidden_states .to (q .dtype )
186
186
return hidden_states
187
-
187
+
188
188
189
189
190
190
class AdaLayerNormSingle (torch .nn .Module ):
@@ -200,7 +200,7 @@ def forward(self, x, emb):
200
200
shift_msa , scale_msa , gate_msa = emb .chunk (3 , dim = 1 )
201
201
x = self .norm (x ) * (1 + scale_msa [:, None ]) + shift_msa [:, None ]
202
202
return x , gate_msa
203
-
203
+
204
204
205
205
206
206
class FluxSingleTransformerBlock (torch .nn .Module ):
@@ -225,8 +225,8 @@ def apply_rope(self, xq, xk, freqs_cis):
225
225
xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
226
226
return xq_out .reshape (* xq .shape ).type_as (xq ), xk_out .reshape (* xk .shape ).type_as (xk )
227
227
228
-
229
- def process_attention (self , hidden_states , image_rotary_emb , ipadapter_kwargs_list = None ):
228
+
229
+ def process_attention (self , hidden_states , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
230
230
batch_size = hidden_states .shape [0 ]
231
231
232
232
qkv = hidden_states .view (batch_size , - 1 , 3 * self .num_heads , self .head_dim ).transpose (1 , 2 )
@@ -235,29 +235,29 @@ def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_li
235
235
236
236
q , k = self .apply_rope (q , k , image_rotary_emb )
237
237
238
- hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v )
238
+ hidden_states = torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = attn_mask )
239
239
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , self .num_heads * self .head_dim )
240
240
hidden_states = hidden_states .to (q .dtype )
241
241
if ipadapter_kwargs_list is not None :
242
242
hidden_states = interact_with_ipadapter (hidden_states , q , ** ipadapter_kwargs_list )
243
243
return hidden_states
244
244
245
245
246
- def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , ipadapter_kwargs_list = None ):
246
+ def forward (self , hidden_states_a , hidden_states_b , temb , image_rotary_emb , attn_mask = None , ipadapter_kwargs_list = None ):
247
247
residual = hidden_states_a
248
248
norm_hidden_states , gate = self .norm (hidden_states_a , emb = temb )
249
249
hidden_states_a = self .to_qkv_mlp (norm_hidden_states )
250
250
attn_output , mlp_hidden_states = hidden_states_a [:, :, :self .dim * 3 ], hidden_states_a [:, :, self .dim * 3 :]
251
251
252
- attn_output = self .process_attention (attn_output , image_rotary_emb , ipadapter_kwargs_list )
252
+ attn_output = self .process_attention (attn_output , image_rotary_emb , attn_mask , ipadapter_kwargs_list )
253
253
mlp_hidden_states = torch .nn .functional .gelu (mlp_hidden_states , approximate = "tanh" )
254
254
255
255
hidden_states_a = torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
256
256
hidden_states_a = gate .unsqueeze (1 ) * self .proj_out (hidden_states_a )
257
257
hidden_states_a = residual + hidden_states_a
258
-
258
+
259
259
return hidden_states_a , hidden_states_b
260
-
260
+
261
261
262
262
263
263
class AdaLayerNormContinuous (torch .nn .Module ):
@@ -300,7 +300,7 @@ def patchify(self, hidden_states):
300
300
def unpatchify (self , hidden_states , height , width ):
301
301
hidden_states = rearrange (hidden_states , "B (H W) (C P Q) -> B C (H P) (W Q)" , P = 2 , Q = 2 , H = height // 2 , W = width // 2 )
302
302
return hidden_states
303
-
303
+
304
304
305
305
def prepare_image_ids (self , latents ):
306
306
batch_size , _ , height , width = latents .shape
@@ -317,7 +317,7 @@ def prepare_image_ids(self, latents):
317
317
latent_image_ids = latent_image_ids .to (device = latents .device , dtype = latents .dtype )
318
318
319
319
return latent_image_ids
320
-
320
+
321
321
322
322
def tiled_forward (
323
323
self ,
@@ -337,12 +337,45 @@ def tiled_forward(
337
337
)
338
338
return hidden_states
339
339
340
+ def construct_mask (self , entity_masks , prompt_seq_len , image_seq_len ):
341
+ N = len (entity_masks )
342
+ batch_size = entity_masks [0 ].shape [0 ]
343
+ total_seq_len = N * prompt_seq_len + image_seq_len
344
+ patched_masks = [self .patchify (entity_masks [i ]) for i in range (N )]
345
+ attention_mask = torch .ones ((batch_size , total_seq_len , total_seq_len ), dtype = torch .bool ).to (device = entity_masks [0 ].device )
346
+
347
+ image_start = N * prompt_seq_len
348
+ image_end = N * prompt_seq_len + image_seq_len
349
+ # prompt-image mask
350
+ for i in range (N ):
351
+ prompt_start = i * prompt_seq_len
352
+ prompt_end = (i + 1 ) * prompt_seq_len
353
+ image_mask = torch .sum (patched_masks [i ], dim = - 1 ) > 0
354
+ image_mask = image_mask .unsqueeze (1 ).repeat (1 , prompt_seq_len , 1 )
355
+ # prompt update with image
356
+ attention_mask [:, prompt_start :prompt_end , image_start :image_end ] = image_mask
357
+ # image update with prompt
358
+ attention_mask [:, image_start :image_end , prompt_start :prompt_end ] = image_mask .transpose (1 , 2 )
359
+ # prompt-prompt mask
360
+ for i in range (N ):
361
+ for j in range (N ):
362
+ if i != j :
363
+ prompt_start_i = i * prompt_seq_len
364
+ prompt_end_i = (i + 1 ) * prompt_seq_len
365
+ prompt_start_j = j * prompt_seq_len
366
+ prompt_end_j = (j + 1 ) * prompt_seq_len
367
+ attention_mask [:, prompt_start_i :prompt_end_i , prompt_start_j :prompt_end_j ] = False
368
+
369
+ attention_mask = attention_mask .float ()
370
+ attention_mask [attention_mask == 0 ] = float ('-inf' )
371
+ attention_mask [attention_mask == 1 ] = 0
372
+ return attention_mask
340
373
341
374
def forward (
342
375
self ,
343
376
hidden_states ,
344
377
timestep , prompt_emb , pooled_prompt_emb , guidance , text_ids , image_ids = None ,
345
- tiled = False , tile_size = 128 , tile_stride = 64 ,
378
+ tiled = False , tile_size = 128 , tile_stride = 64 , entity_prompts = None , entity_masks = None ,
346
379
use_gradient_checkpointing = False ,
347
380
** kwargs
348
381
):
@@ -353,54 +386,78 @@ def forward(
353
386
tile_size = tile_size , tile_stride = tile_stride ,
354
387
** kwargs
355
388
)
356
-
389
+
357
390
if image_ids is None :
358
391
image_ids = self .prepare_image_ids (hidden_states )
359
-
392
+
360
393
conditioning = self .time_embedder (timestep , hidden_states .dtype ) + self .pooled_text_embedder (pooled_prompt_emb )
361
394
if self .guidance_embedder is not None :
362
395
guidance = guidance * 1000
363
396
conditioning = conditioning + self .guidance_embedder (guidance , hidden_states .dtype )
364
- prompt_emb = self .context_embedder (prompt_emb )
365
- image_rotary_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
366
397
398
+ repeat_dim = hidden_states .shape [1 ]
367
399
height , width = hidden_states .shape [- 2 :]
368
400
hidden_states = self .patchify (hidden_states )
369
401
hidden_states = self .x_embedder (hidden_states )
370
-
402
+
403
+ max_masks = 0
404
+ attention_mask = None
405
+ prompt_embs = [prompt_emb ]
406
+ if entity_masks is not None :
407
+ # entity_masks
408
+ batch_size , max_masks = entity_masks .shape [0 ], entity_masks .shape [1 ]
409
+ entity_masks = entity_masks .repeat (1 , 1 , repeat_dim , 1 , 1 )
410
+ entity_masks = [entity_masks [:, i , None ].squeeze (1 ) for i in range (max_masks )]
411
+ # global mask
412
+ global_mask = torch .ones_like (entity_masks [0 ]).to (device = hidden_states .device , dtype = hidden_states .dtype )
413
+ entity_masks = entity_masks + [global_mask ] # append global to last
414
+ # attention mask
415
+ attention_mask = self .construct_mask (entity_masks , prompt_emb .shape [1 ], hidden_states .shape [1 ])
416
+ attention_mask = attention_mask .to (device = hidden_states .device , dtype = hidden_states .dtype )
417
+ attention_mask = attention_mask .unsqueeze (1 )
418
+ # embds: n_masks * b * seq * d
419
+ local_embs = [entity_prompts [:, i , None ].squeeze (1 ) for i in range (max_masks )]
420
+ prompt_embs = local_embs + prompt_embs # append global to last
421
+ prompt_embs = [self .context_embedder (prompt_emb ) for prompt_emb in prompt_embs ]
422
+ prompt_emb = torch .cat (prompt_embs , dim = 1 )
423
+
424
+ # positional embedding
425
+ text_ids = torch .cat ([text_ids ] * (max_masks + 1 ), dim = 1 )
426
+ image_rotary_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
427
+
371
428
def create_custom_forward (module ):
372
429
def custom_forward (* inputs ):
373
430
return module (* inputs )
374
431
return custom_forward
375
-
432
+
376
433
for block in self .blocks :
377
434
if self .training and use_gradient_checkpointing :
378
435
hidden_states , prompt_emb = torch .utils .checkpoint .checkpoint (
379
436
create_custom_forward (block ),
380
- hidden_states , prompt_emb , conditioning , image_rotary_emb ,
437
+ hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask ,
381
438
use_reentrant = False ,
382
439
)
383
440
else :
384
- hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb )
441
+ hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask )
385
442
386
443
hidden_states = torch .cat ([prompt_emb , hidden_states ], dim = 1 )
387
444
for block in self .single_blocks :
388
445
if self .training and use_gradient_checkpointing :
389
446
hidden_states , prompt_emb = torch .utils .checkpoint .checkpoint (
390
447
create_custom_forward (block ),
391
- hidden_states , prompt_emb , conditioning , image_rotary_emb ,
448
+ hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask ,
392
449
use_reentrant = False ,
393
450
)
394
451
else :
395
- hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb )
452
+ hidden_states , prompt_emb = block (hidden_states , prompt_emb , conditioning , image_rotary_emb , attention_mask )
396
453
hidden_states = hidden_states [:, prompt_emb .shape [1 ]:]
397
454
398
455
hidden_states = self .final_norm_out (hidden_states , conditioning )
399
456
hidden_states = self .final_proj_out (hidden_states )
400
457
hidden_states = self .unpatchify (hidden_states , height , width )
401
458
402
459
return hidden_states
403
-
460
+
404
461
405
462
def quantize (self ):
406
463
def cast_to (weight , dtype = None , device = None , copy = False ):
@@ -440,24 +497,24 @@ class quantized_layer:
440
497
class Linear (torch .nn .Linear ):
441
498
def __init__ (self , * args , ** kwargs ):
442
499
super ().__init__ (* args , ** kwargs )
443
-
500
+
444
501
def forward (self ,input ,** kwargs ):
445
502
weight ,bias = cast_bias_weight (self ,input )
446
503
return torch .nn .functional .linear (input ,weight ,bias )
447
-
504
+
448
505
class RMSNorm (torch .nn .Module ):
449
506
def __init__ (self , module ):
450
507
super ().__init__ ()
451
508
self .module = module
452
-
509
+
453
510
def forward (self ,hidden_states ,** kwargs ):
454
511
weight = cast_weight (self .module ,hidden_states )
455
512
input_dtype = hidden_states .dtype
456
513
variance = hidden_states .to (torch .float32 ).square ().mean (- 1 , keepdim = True )
457
514
hidden_states = hidden_states * torch .rsqrt (variance + self .module .eps )
458
515
hidden_states = hidden_states .to (input_dtype ) * weight
459
516
return hidden_states
460
-
517
+
461
518
def replace_layer (model ):
462
519
for name , module in model .named_children ():
463
520
if isinstance (module , torch .nn .Linear ):
@@ -483,7 +540,6 @@ def replace_layer(model):
483
540
@staticmethod
484
541
def state_dict_converter ():
485
542
return FluxDiTStateDictConverter ()
486
-
487
543
488
544
489
545
class FluxDiTStateDictConverter :
@@ -587,7 +643,7 @@ def from_diffusers(self, state_dict):
587
643
state_dict_ .pop (name .replace (f".{ component } _to_q." , f".{ component } _to_k." ))
588
644
state_dict_ .pop (name .replace (f".{ component } _to_q." , f".{ component } _to_v." ))
589
645
return state_dict_
590
-
646
+
591
647
def from_civitai (self , state_dict ):
592
648
rename_dict = {
593
649
"time_in.in_layer.bias" : "time_embedder.timestep_embedder.0.bias" ,
0 commit comments