Skip to content

Commit 02d05ef

Browse files
authored
Update fastercache_sample_multi_device_opensoraplan.py
1 parent cbd8086 commit 02d05ef

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

scripts/opensora_plan/fastercache_sample_multi_device_opensoraplan.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def fastercache_model_forward(
132132
):
133133
self.counter+=1
134134
if self.counter >=40 and self.counter%6!=0:
135-
single_output = self.fastercache_model_single_forward(hidden_states[:1],timestep[:1],encoder_hidden_states[:1],added_cond_kwargs,class_labels,cross_attention_kwargs,attention_mask,encoder_attention_mask,use_image_num,enable_temporal_attentions,return_dict)[0]
135+
single_output = self.fastercache_model_single_forward(hidden_states[1:],timestep[1:],encoder_hidden_states[1:],added_cond_kwargs,class_labels,cross_attention_kwargs,attention_mask,encoder_attention_mask,use_image_num,enable_temporal_attentions,return_dict)[0]
136136
(bb, cc, tt, hh, ww) = single_output.shape
137137
cond = rearrange(single_output, "B C T H W -> (B T) C H W", B=bb, C=cc, T=tt, H=hh, W=ww)
138138
lf_c, hf_c = fft(cond.float())
@@ -149,13 +149,13 @@ def fastercache_model_forward(
149149
combined_fft = torch.fft.ifftshift(combine_uc)
150150
recovered_uncond = torch.fft.ifft2(combined_fft).real
151151
recovered_uncond = rearrange(recovered_uncond.to(single_output.dtype), "(B T) C H W -> B C T H W", B=bb, C=cc, T=tt, H=hh, W=ww)
152-
output = torch.cat([single_output,recovered_uncond],dim=0)
152+
output = torch.cat([recovered_uncond,single_output],dim=0)
153153
else:
154154
output = self.fastercache_model_single_forward(hidden_states,timestep,encoder_hidden_states,added_cond_kwargs,class_labels,cross_attention_kwargs,attention_mask,encoder_attention_mask,use_image_num,enable_temporal_attentions,return_dict)[0]
155155
if self.counter>38:
156156
(bb, cc, tt, hh, ww) = output.shape
157-
cond = rearrange(output[0:1], "B C T H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
158-
uncond = rearrange(output[1:2], "B C T H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
157+
cond = rearrange(output[1:2], "B C T H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
158+
uncond = rearrange(output[0:1], "B C T H W -> (B T) C H W", B=bb//2, C=cc, T=tt, H=hh, W=ww)
159159

160160
lf_c, hf_c = fft(cond.float())
161161
lf_uc, hf_uc = fft(uncond.float())

0 commit comments

Comments
 (0)