@@ -132,7 +132,7 @@ def fastercache_model_forward(
132
132
):
133
133
self .counter += 1
134
134
if self .counter >= 40 and self .counter % 6 != 0 :
135
- single_output = self .fastercache_model_single_forward (hidden_states [: 1 ],timestep [: 1 ],encoder_hidden_states [: 1 ],added_cond_kwargs ,class_labels ,cross_attention_kwargs ,attention_mask ,encoder_attention_mask ,use_image_num ,enable_temporal_attentions ,return_dict )[0 ]
135
+ single_output = self .fastercache_model_single_forward (hidden_states [1 : ],timestep [1 : ],encoder_hidden_states [1 : ],added_cond_kwargs ,class_labels ,cross_attention_kwargs ,attention_mask ,encoder_attention_mask ,use_image_num ,enable_temporal_attentions ,return_dict )[0 ]
136
136
(bb , cc , tt , hh , ww ) = single_output .shape
137
137
cond = rearrange (single_output , "B C T H W -> (B T) C H W" , B = bb , C = cc , T = tt , H = hh , W = ww )
138
138
lf_c , hf_c = fft (cond .float ())
@@ -149,13 +149,13 @@ def fastercache_model_forward(
149
149
combined_fft = torch .fft .ifftshift (combine_uc )
150
150
recovered_uncond = torch .fft .ifft2 (combined_fft ).real
151
151
recovered_uncond = rearrange (recovered_uncond .to (single_output .dtype ), "(B T) C H W -> B C T H W" , B = bb , C = cc , T = tt , H = hh , W = ww )
152
- output = torch .cat ([single_output , recovered_uncond ],dim = 0 )
152
+ output = torch .cat ([recovered_uncond , single_output ],dim = 0 )
153
153
else :
154
154
output = self .fastercache_model_single_forward (hidden_states ,timestep ,encoder_hidden_states ,added_cond_kwargs ,class_labels ,cross_attention_kwargs ,attention_mask ,encoder_attention_mask ,use_image_num ,enable_temporal_attentions ,return_dict )[0 ]
155
155
if self .counter > 38 :
156
156
(bb , cc , tt , hh , ww ) = output .shape
157
- cond = rearrange (output [0 : 1 ], "B C T H W -> (B T) C H W" , B = bb // 2 , C = cc , T = tt , H = hh , W = ww )
158
- uncond = rearrange (output [1 : 2 ], "B C T H W -> (B T) C H W" , B = bb // 2 , C = cc , T = tt , H = hh , W = ww )
157
+ cond = rearrange (output [1 : 2 ], "B C T H W -> (B T) C H W" , B = bb // 2 , C = cc , T = tt , H = hh , W = ww )
158
+ uncond = rearrange (output [0 : 1 ], "B C T H W -> (B T) C H W" , B = bb // 2 , C = cc , T = tt , H = hh , W = ww )
159
159
160
160
lf_c , hf_c = fft (cond .float ())
161
161
lf_uc , hf_uc = fft (uncond .float ())
0 commit comments