Skip to content

Commit e3d89ce

Browse files
committed
temp commit for entity control
1 parent 1b6e96a commit e3d89ce

File tree

5 files changed

+349
-44
lines changed

5 files changed

+349
-44
lines changed

diffsynth/models/flux_dit.py

Lines changed: 92 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def forward(self, ids):
4040
n_axes = ids.shape[-1]
4141
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
4242
return emb.unsqueeze(1)
43-
43+
4444

4545

4646
class FluxJointAttention(torch.nn.Module):
@@ -70,7 +70,7 @@ def apply_rope(self, xq, xk, freqs_cis):
7070
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
7171
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
7272

73-
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
73+
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
7474
batch_size = hidden_states_a.shape[0]
7575

7676
# Part A
@@ -91,7 +91,7 @@ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_
9191

9292
q, k = self.apply_rope(q, k, image_rotary_emb)
9393

94-
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
94+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
9595
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
9696
hidden_states = hidden_states.to(q.dtype)
9797
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
@@ -103,7 +103,7 @@ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_
103103
else:
104104
hidden_states_b = self.b_to_out(hidden_states_b)
105105
return hidden_states_a, hidden_states_b
106-
106+
107107

108108

109109
class FluxJointTransformerBlock(torch.nn.Module):
@@ -129,12 +129,12 @@ def __init__(self, dim, num_attention_heads):
129129
)
130130

131131

132-
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
132+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
133133
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
134134
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
135135

136136
# Attention
137-
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
137+
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
138138

139139
# Part A
140140
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -147,7 +147,7 @@ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipad
147147
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
148148

149149
return hidden_states_a, hidden_states_b
150-
150+
151151

152152

153153
class FluxSingleAttention(torch.nn.Module):
@@ -184,7 +184,7 @@ def forward(self, hidden_states, image_rotary_emb):
184184
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
185185
hidden_states = hidden_states.to(q.dtype)
186186
return hidden_states
187-
187+
188188

189189

190190
class AdaLayerNormSingle(torch.nn.Module):
@@ -200,7 +200,7 @@ def forward(self, x, emb):
200200
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
201201
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
202202
return x, gate_msa
203-
203+
204204

205205

206206
class FluxSingleTransformerBlock(torch.nn.Module):
@@ -225,8 +225,8 @@ def apply_rope(self, xq, xk, freqs_cis):
225225
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
226226
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
227227

228-
229-
def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
228+
229+
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
230230
batch_size = hidden_states.shape[0]
231231

232232
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
@@ -235,29 +235,29 @@ def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_li
235235

236236
q, k = self.apply_rope(q, k, image_rotary_emb)
237237

238-
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
238+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
239239
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
240240
hidden_states = hidden_states.to(q.dtype)
241241
if ipadapter_kwargs_list is not None:
242242
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
243243
return hidden_states
244244

245245

246-
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
246+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
247247
residual = hidden_states_a
248248
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
249249
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
250250
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
251251

252-
attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
252+
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
253253
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
254254

255255
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
256256
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
257257
hidden_states_a = residual + hidden_states_a
258-
258+
259259
return hidden_states_a, hidden_states_b
260-
260+
261261

262262

263263
class AdaLayerNormContinuous(torch.nn.Module):
@@ -300,7 +300,7 @@ def patchify(self, hidden_states):
300300
def unpatchify(self, hidden_states, height, width):
301301
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
302302
return hidden_states
303-
303+
304304

305305
def prepare_image_ids(self, latents):
306306
batch_size, _, height, width = latents.shape
@@ -317,7 +317,7 @@ def prepare_image_ids(self, latents):
317317
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
318318

319319
return latent_image_ids
320-
320+
321321

322322
def tiled_forward(
323323
self,
@@ -337,12 +337,45 @@ def tiled_forward(
337337
)
338338
return hidden_states
339339

340+
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
341+
N = len(entity_masks)
342+
batch_size = entity_masks[0].shape[0]
343+
total_seq_len = N * prompt_seq_len + image_seq_len
344+
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
345+
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
346+
347+
image_start = N * prompt_seq_len
348+
image_end = N * prompt_seq_len + image_seq_len
349+
# prompt-image mask
350+
for i in range(N):
351+
prompt_start = i * prompt_seq_len
352+
prompt_end = (i + 1) * prompt_seq_len
353+
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
354+
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
355+
# prompt update with image
356+
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
357+
# image update with prompt
358+
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
359+
# prompt-prompt mask
360+
for i in range(N):
361+
for j in range(N):
362+
if i != j:
363+
prompt_start_i = i * prompt_seq_len
364+
prompt_end_i = (i + 1) * prompt_seq_len
365+
prompt_start_j = j * prompt_seq_len
366+
prompt_end_j = (j + 1) * prompt_seq_len
367+
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
368+
369+
attention_mask = attention_mask.float()
370+
attention_mask[attention_mask == 0] = float('-inf')
371+
attention_mask[attention_mask == 1] = 0
372+
return attention_mask
340373

341374
def forward(
342375
self,
343376
hidden_states,
344377
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
345-
tiled=False, tile_size=128, tile_stride=64,
378+
tiled=False, tile_size=128, tile_stride=64, entity_prompts=None, entity_masks=None,
346379
use_gradient_checkpointing=False,
347380
**kwargs
348381
):
@@ -353,54 +386,78 @@ def forward(
353386
tile_size=tile_size, tile_stride=tile_stride,
354387
**kwargs
355388
)
356-
389+
357390
if image_ids is None:
358391
image_ids = self.prepare_image_ids(hidden_states)
359-
392+
360393
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
361394
if self.guidance_embedder is not None:
362395
guidance = guidance * 1000
363396
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
364-
prompt_emb = self.context_embedder(prompt_emb)
365-
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
366397

398+
repeat_dim = hidden_states.shape[1]
367399
height, width = hidden_states.shape[-2:]
368400
hidden_states = self.patchify(hidden_states)
369401
hidden_states = self.x_embedder(hidden_states)
370-
402+
403+
max_masks = 0
404+
attention_mask = None
405+
prompt_embs = [prompt_emb]
406+
if entity_masks is not None:
407+
# entity_masks
408+
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
409+
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
410+
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
411+
# global mask
412+
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
413+
entity_masks = entity_masks + [global_mask] # append global to last
414+
# attention mask
415+
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
416+
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
417+
attention_mask = attention_mask.unsqueeze(1)
418+
# embds: n_masks * b * seq * d
419+
local_embs = [entity_prompts[:, i, None].squeeze(1) for i in range(max_masks)]
420+
prompt_embs = local_embs + prompt_embs # append global to last
421+
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
422+
prompt_emb = torch.cat(prompt_embs, dim=1)
423+
424+
# positional embedding
425+
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
426+
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
427+
371428
def create_custom_forward(module):
372429
def custom_forward(*inputs):
373430
return module(*inputs)
374431
return custom_forward
375-
432+
376433
for block in self.blocks:
377434
if self.training and use_gradient_checkpointing:
378435
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
379436
create_custom_forward(block),
380-
hidden_states, prompt_emb, conditioning, image_rotary_emb,
437+
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
381438
use_reentrant=False,
382439
)
383440
else:
384-
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
441+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
385442

386443
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
387444
for block in self.single_blocks:
388445
if self.training and use_gradient_checkpointing:
389446
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
390447
create_custom_forward(block),
391-
hidden_states, prompt_emb, conditioning, image_rotary_emb,
448+
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
392449
use_reentrant=False,
393450
)
394451
else:
395-
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
452+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
396453
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
397454

398455
hidden_states = self.final_norm_out(hidden_states, conditioning)
399456
hidden_states = self.final_proj_out(hidden_states)
400457
hidden_states = self.unpatchify(hidden_states, height, width)
401458

402459
return hidden_states
403-
460+
404461

405462
def quantize(self):
406463
def cast_to(weight, dtype=None, device=None, copy=False):
@@ -440,24 +497,24 @@ class quantized_layer:
440497
class Linear(torch.nn.Linear):
441498
def __init__(self, *args, **kwargs):
442499
super().__init__(*args, **kwargs)
443-
500+
444501
def forward(self,input,**kwargs):
445502
weight,bias= cast_bias_weight(self,input)
446503
return torch.nn.functional.linear(input,weight,bias)
447-
504+
448505
class RMSNorm(torch.nn.Module):
449506
def __init__(self, module):
450507
super().__init__()
451508
self.module = module
452-
509+
453510
def forward(self,hidden_states,**kwargs):
454511
weight= cast_weight(self.module,hidden_states)
455512
input_dtype = hidden_states.dtype
456513
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
457514
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
458515
hidden_states = hidden_states.to(input_dtype) * weight
459516
return hidden_states
460-
517+
461518
def replace_layer(model):
462519
for name, module in model.named_children():
463520
if isinstance(module, torch.nn.Linear):
@@ -483,7 +540,6 @@ def replace_layer(model):
483540
@staticmethod
484541
def state_dict_converter():
485542
return FluxDiTStateDictConverter()
486-
487543

488544

489545
class FluxDiTStateDictConverter:
@@ -587,7 +643,7 @@ def from_diffusers(self, state_dict):
587643
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
588644
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
589645
return state_dict_
590-
646+
591647
def from_civitai(self, state_dict):
592648
rename_dict = {
593649
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",

0 commit comments

Comments
 (0)