Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
GonzaloMartinGarcia committed Oct 28, 2024
1 parent c86bb06 commit 5c32d40
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,28 +472,21 @@ def unwrap_model(model):
# RGB latent
rgb_latents = encode_image(vae, batch["rgb"].to(device=accelerator.device, dtype=weight_dtype))
rgb_latents = rgb_latents * vae.config.scaling_factor

# Depth or normals latent
if args.modality == "depth":
latents = encode_image(vae, batch["depth"].to(device=accelerator.device, dtype=weight_dtype))
elif args.modality == "normals":
latents = encode_image(vae, batch["normals"].to(device=accelerator.device, dtype=weight_dtype))
latents = latents * vae.config.scaling_factor


# Validity mask
val_mask = batch["val_mask"].bool().to(device=accelerator.device)

# Set timesteps to the first denoising step
timesteps = torch.ones((latents.shape[0],), device=latents.device) * (noise_scheduler.config.num_train_timesteps-1) # 999
timesteps = torch.ones((rgb_latents.shape[0],), device=rgb_latents.device) * (noise_scheduler.config.num_train_timesteps-1) # 999
timesteps = timesteps.long()

# Sample noisy latent
if (args.noise_type is None) or (args.noise_type == "zeros"):
noisy_latents = torch.zeros_like(latents).to(accelerator.device)
noisy_latents = torch.zeros_like(rgb_latents).to(accelerator.device)
elif args.noise_type == "pyramid":
noisy_latents = pyramid_noise_like(latents).to(accelerator.device)
noisy_latents = pyramid_noise_like(rgb_latents).to(accelerator.device)
elif args.noise_type == "gaussian":
noisy_latents = torch.randn_like(latents).to(accelerator.device)
noisy_latents = torch.randn_like(rgb_latents).to(accelerator.device)
else:
raise ValueError(f"Unknown noise type {args.noise_type}")

Expand Down

0 comments on commit 5c32d40

Please sign in to comment.