You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I decode latents by vae. And i enable gradient checkpoint in vae.
When loss.backward(), I meet the error:
Exception has occurred: RuntimeError
Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 256, 1, 136, 240]) and output[0] has a shape of torch.Size([1, 256, 2, 136, 240]).
File "/high_perf_store3/world-model/ailab_vision/wangzepeng5/code/v1_5_pilot_weather_transfer/paper_codes/test.py", line 224, in
loss_total.backward()
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 256, 1, 136, 240]) and output[0] has a shape of torch.Size([1, 256, 2, 136, 240]).
latents = 1 / vae.config.scaling_factor * latents
frames = vae.decode(latents).sample
frames = (frames / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return frames
if name == "main":
device = "cuda"
weight_dtype = torch.float32
pretrained_model_name_or_path = "pretrain_weights/CogVideoX-Fun-V1.5-5b-InP"
clip_model = "pretrain_weights/clip/pretrain_weights/ViT-B-32.pt"
vgg_model_path = "pretrain_weights/vgg/vgg19-dcbb9e9d.pth"
batch = 1
nf = 5
h = 272
w = 480
System Info / 系統信息
I decode latents by vae. And i enable gradient checkpoint in vae.
When loss.backward(), I meet the error:
Exception has occurred: RuntimeError
Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 256, 1, 136, 240]) and output[0] has a shape of torch.Size([1, 256, 2, 136, 240]).
File "/high_perf_store3/world-model/ailab_vision/wangzepeng5/code/v1_5_pilot_weather_transfer/paper_codes/test.py", line 224, in
loss_total.backward()
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 256, 1, 136, 240]) and output[0] has a shape of torch.Size([1, 256, 2, 136, 240]).
My codes bellow:
`def load_vae(pretrained_model_name_or_path):
vae = AutoencoderKLCogVideoX.from_pretrained(
pretrained_model_name_or_path,
subfolder="vae",
)
def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor:
# with torch.no_grad():
latents = latents.to(vae.dtype).to(vae.device)
if name == "main":
device = "cuda"
weight_dtype = torch.float32
pretrained_model_name_or_path = "pretrain_weights/CogVideoX-Fun-V1.5-5b-InP"
clip_model = "pretrain_weights/clip/pretrain_weights/ViT-B-32.pt"
vgg_model_path = "pretrain_weights/vgg/vgg19-dcbb9e9d.pth"
batch = 1
nf = 5
h = 272
w = 480
I think the error is when i use gradient checkpoint, "conv cache" in vae has been deleted?
Anyone can help me?
Information / 问题信息
Reproduction / 复现过程
As above.
Expected behavior / 期待表现
Fix the bug when apply grad in vae.
The text was updated successfully, but these errors were encountered: