From 5c32d40168c0ab9a883b2d38897505bf8664ac46 Mon Sep 17 00:00:00 2001 From: GonzaloMartinGarcia <47854849+GonzaloMartinGarcia@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:16:33 +0100 Subject: [PATCH] Update train.py --- training/train.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/training/train.py b/training/train.py index ceaaff5..edf8227 100644 --- a/training/train.py +++ b/training/train.py @@ -472,28 +472,21 @@ def unwrap_model(model): # RGB latent rgb_latents = encode_image(vae, batch["rgb"].to(device=accelerator.device, dtype=weight_dtype)) rgb_latents = rgb_latents * vae.config.scaling_factor - - # Depth or normals latent - if args.modality == "depth": - latents = encode_image(vae, batch["depth"].to(device=accelerator.device, dtype=weight_dtype)) - elif args.modality == "normals": - latents = encode_image(vae, batch["normals"].to(device=accelerator.device, dtype=weight_dtype)) - latents = latents * vae.config.scaling_factor - + # Validity mask val_mask = batch["val_mask"].bool().to(device=accelerator.device) # Set timesteps to the first denoising step - timesteps = torch.ones((latents.shape[0],), device=latents.device) * (noise_scheduler.config.num_train_timesteps-1) # 999 + timesteps = torch.ones((rgb_latents.shape[0],), device=rgb_latents.device) * (noise_scheduler.config.num_train_timesteps-1) # 999 timesteps = timesteps.long() # Sample noisy latent if (args.noise_type is None) or (args.noise_type == "zeros"): - noisy_latents = torch.zeros_like(latents).to(accelerator.device) + noisy_latents = torch.zeros_like(rgb_latents).to(accelerator.device) elif args.noise_type == "pyramid": - noisy_latents = pyramid_noise_like(latents).to(accelerator.device) + noisy_latents = pyramid_noise_like(rgb_latents).to(accelerator.device) elif args.noise_type == "gaussian": - noisy_latents = torch.randn_like(latents).to(accelerator.device) + noisy_latents = torch.randn_like(rgb_latents).to(accelerator.device) else: raise ValueError(f"Unknown noise type {args.noise_type}")