Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mismatch in shape When vae.enable_gradient_checkpointing() #713

Open
1 of 2 tasks
CodingWZP opened this issue Feb 20, 2025 · 0 comments
Open
1 of 2 tasks

Mismatch in shape When vae.enable_gradient_checkpointing() #713

CodingWZP opened this issue Feb 20, 2025 · 0 comments
Assignees

Comments

@CodingWZP
Copy link

CodingWZP commented Feb 20, 2025

System Info / 系統信息

I decode latents by vae. And i enable gradient checkpoint in vae.
When loss.backward(), I meet the error:
Exception has occurred: RuntimeError
Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 256, 1, 136, 240]) and output[0] has a shape of torch.Size([1, 256, 2, 136, 240]).
File "/high_perf_store3/world-model/ailab_vision/wangzepeng5/code/v1_5_pilot_weather_transfer/paper_codes/test.py", line 224, in
loss_total.backward()
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 256, 1, 136, 240]) and output[0] has a shape of torch.Size([1, 256, 2, 136, 240]).

My codes bellow:
`def load_vae(pretrained_model_name_or_path):
vae = AutoencoderKLCogVideoX.from_pretrained(
pretrained_model_name_or_path,
subfolder="vae",
)

return vae

def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor:
# with torch.no_grad():
latents = latents.to(vae.dtype).to(vae.device)

latents = 1 / vae.config.scaling_factor * latents

frames = vae.decode(latents).sample
frames = (frames / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return frames

if name == "main":
device = "cuda"
weight_dtype = torch.float32
pretrained_model_name_or_path = "pretrain_weights/CogVideoX-Fun-V1.5-5b-InP"
clip_model = "pretrain_weights/clip/pretrain_weights/ViT-B-32.pt"
vgg_model_path = "pretrain_weights/vgg/vgg19-dcbb9e9d.pth"
batch = 1
nf = 5
h = 272
w = 480

vae = load_vae(pretrained_model_name_or_path).to(device).to(weight_dtype)
vae.enable_gradient_checkpointing()


loss_computer = LossComputer(
    device=device,
    clip_model=clip_model,
    vgg_model_path=vgg_model_path,
    content_mse_noise=False,
    content_mse=False,
    content_contrastive=False,
    style_clip_direction_global=False,
    style_clip_direction_patch=False,
    style_clip_align_global=True,
    time_continuous=False,
)

latents = torch.randn([batch, 16, ((nf - 1) // 4)+1, h//8, w//8]).to(device).to(weight_dtype).requires_grad_(True)
gen_pixel_values = decode_latents(latents, vae)
# gen_pixel_values = simple_model(latents)
tgt_text = ["driving in the daytime"] * batch

loss_total = loss_computer(
    gen_pixel_values=gen_pixel_values,
    og_pixel_values=None,
    tgt_text=tgt_text,
    src_text=None,
    tgt_noise=None,
    pred_noise=None,
)

loss_total.backward()`

I think the error is when i use gradient checkpoint, "conv cache" in vae has been deleted?
Anyone can help me?

Information / 问题信息

  • The official example scripts / 官方的示例脚本
  • My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

As above.

Expected behavior / 期待表现

Fix the bug when apply grad in vae.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants