diff --git a/config/train_marigold_depth_LCM.yaml b/config/train_marigold_depth_LCM.yaml new file mode 100644 index 0000000..3a43f2f --- /dev/null +++ b/config/train_marigold_depth_LCM.yaml @@ -0,0 +1,17 @@ +base_config: + - config/train_marigold_depth.yaml + + +trainer: + name: MarigoldLCMTrainer + training_noise_scheduler: + pretrained_path: stable-diffusion-2 + init_seed: 2024 # use null to train w/o seeding + save_period: 50 + backup_period: 2000 + validation_period: 500 + visualization_period: 1000 + ema_decay: 0.95 # LCM specific settings + num_ddim_timesteps: 5 # LCM specific settings + save_target: false # LCM specific settings + diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index fbee2de..516b0c3 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -31,14 +31,17 @@ from .marigold_depth_trainer import MarigoldDepthTrainer from .marigold_iid_trainer import MarigoldIIDTrainer from .marigold_normals_trainer import MarigoldNormalsTrainer +from .marigold_depth_LCM_trainer import MarigoldDepthLCMTrainer trainer_cls_name_dict = { "MarigoldDepthTrainer": MarigoldDepthTrainer, "MarigoldIIDTrainer": MarigoldIIDTrainer, "MarigoldNormalsTrainer": MarigoldNormalsTrainer, + "MarigoldDepthLCMTrainer": MarigoldDepthLCMTrainer + } def get_trainer_cls(trainer_name): - return trainer_cls_name_dict[trainer_name] + return trainer_cls_name_dict[trainer_name] \ No newline at end of file diff --git a/src/trainer/marigold_depth_LCM_trainer.py b/src/trainer/marigold_depth_LCM_trainer.py new file mode 100644 index 0000000..ff286f0 --- /dev/null +++ b/src/trainer/marigold_depth_LCM_trainer.py @@ -0,0 +1,855 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information about Marigold: +# https://marigoldmonodepth.github.io +# https://marigoldcomputervision.github.io +# Efficient inference pipelines are now part of diffusers: +# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage +# https://huggingface.co/docs/diffusers/api/pipelines/marigold +# Examples of trained models and live demos: +# https://huggingface.co/prs-eth +# Related projects: +# https://rollingdepth.github.io/ +# https://marigolddepthcompletion.github.io/ +# Citation (BibTeX): +# https://github.com/prs-eth/Marigold#-citation +# If you find Marigold useful, we kindly ask you to cite our papers. +# -------------------------------------------------------------------------- + +import logging +import os +import shutil +from datetime import datetime +from typing import List, Union + +import numpy as np +import torch +from diffusers import DDPMScheduler, LCMScheduler, UNet2DConditionModel +from omegaconf import OmegaConf +from PIL import Image +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam, AdamW +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm + +from marigold.marigold_pipeline import MarigoldDepthOutput, MarigoldDepthPipeline +from src.util import metric +from src.util.alignment import ( + align_depth_least_square, + depth2disparity, + disparity2depth, + median_scale_align, +) +from src.util.data_loader import skip_first_batches +from src.util.logging_util import eval_dic_to_text, tb_logger +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker +from src.util.multi_res_noise import multi_res_noise_like +from src.util.alignment import align_depth_least_square +from src.util.seeding import generate_seed_sequence +from src.trainer.marigold_depth_trainer import MarigoldDepthTrainer + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +class DDIMSolver: + def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): + # DDIM sampling parameters + step_ratio = timesteps // ddim_timesteps + self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + + def to(self, device): + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + return self + + def ddim_step(self, pred_x0, pred_noise, timestep_index): + alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + +# From LCMScheduler.get_scalings_for_boundary_condition_discrete +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +# Compare LCMScheduler.step, Step 4 +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output + elif prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_x_0 + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + +def huber_loss(a, b, huber_c = 0.001): + return torch.mean(torch.sqrt((a.float() - b.float()) ** 2 + huber_c**2) - huber_c) + +class MarigoldDepthLCMTrainer(MarigoldDepthTrainer): + def __init__( + self, + cfg: OmegaConf, + model: MarigoldDepthPipeline, + train_dataloader: DataLoader, + device, + base_ckpt_dir, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + ): + + self.cfg: OmegaConf = cfg + self.model: MarigoldDepthPipeline = model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + + #load models + + self.model.scheduler = LCMScheduler.from_pretrained(cfg.model.pretrained_path, subfolder="scheduler") + + + self.teacher_unet = UNet2DConditionModel.from_pretrained(cfg.model.pretrained_path, subfolder="unet") + + self.target_unet = UNet2DConditionModel(**self.teacher_unet.config) + self.target_unet.load_state_dict(self.model.unet.state_dict()) + + + + # Encode empty text prompt + self.model.encode_empty_text() + self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + self.teacher_unet.enable_xformers_memory_efficient_attention() + self.target_unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.vae.requires_grad_(False) + self.model.text_encoder.requires_grad_(False) + self.model.unet.requires_grad_(True) + self.teacher_unet.requires_grad_(False) + self.target_unet.requires_grad_(False) + + + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + self.optimizer = AdamW(self.model.unet.parameters(), lr=lr) + + # LR scheduler + lr_func = IterExponential( + total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + ) + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_config( + self.model.scheduler.config, + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + ) + + logging.info( + "DDPM training noise scheduler config is updated: " + f"rescale_betas_zero_snr = {self.training_noise_scheduler.config.rescale_betas_zero_snr}, " + f"timestep_spacing = {self.training_noise_scheduler.config.timestep_spacing}" + ) + + self.prediction_type = self.training_noise_scheduler.config.prediction_type + assert ( + self.prediction_type == self.model.scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.training_noise_scheduler.config.num_train_timesteps + ) + + # Eval metrics + self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] + self.train_metrics = MetricTracker(*["loss"]) + self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) + # main metric for best checkpoint saving + self.main_val_metric = cfg.validation.main_val_metric + self.main_val_metric_goal = cfg.validation.main_val_metric_goal + assert ( + self.main_val_metric in cfg.eval.eval_metrics + ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." + self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_depth_type = self.cfg.gt_depth_type + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + + # Multi-resolution noise + self.apply_multi_res_noise = False + + # Internal variables + self.epoch = 1 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + # LCM parameters + self.ema_decay = cfg.trainer.ema_decay + self.num_ddim_timesteps = cfg.trainer.num_ddim_timesteps + self.max_grad_norm = 1.0 + self.huber_c = 0.001 + self.timestep_scaling_factor = 10 + self.solver = DDIMSolver( + self.training_noise_scheduler.alphas_cumprod.numpy(), + timesteps=self.training_noise_scheduler.config.num_train_timesteps, + ddim_timesteps=self.num_ddim_timesteps, + ).to(self.device) + self.alpha_schedule = torch.sqrt(self.training_noise_scheduler.alphas_cumprod).to(self.device) + self.sigma_schedule = torch.sqrt(1 - self.training_noise_scheduler.alphas_cumprod).to(self.device) + self.save_target = cfg.trainer.save_target + + def train(self, t_end=None): + logging.info("Start training") + + device = self.device + self.model.to(device) + self.teacher_unet.to(device) + self.target_unet.to(device) + + if self.in_evaluation: + logging.info( + "Last evaluation was not finished, will do evaluation before continue training." + ) + self.validate() + + self.train_metrics.reset() + accumulated_step = 0 + + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + self.target_unet.train() + self.target_unet.requires_grad_(False) + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data + rgb = batch["rgb_norm"].to(device) + depth_gt_for_latent = batch[self.gt_depth_type].to(device) + + if self.gt_mask_type is not None: + valid_mask_for_latent = batch[self.gt_mask_type].to(device) + invalid_mask = ~valid_mask_for_latent + valid_mask_down = ~torch.max_pool2d( + invalid_mask.float(), 8, 8 + ).bool() + valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) + else: + raise NotImplementedError + + batch_size = rgb.shape[0] + + with torch.no_grad(): + # Encode image + rgb_latent = self.encode_rgb(rgb) # [B, 4, h, w] + # Encode GT depth + gt_depth_latent = self.encode_depth( + depth_gt_for_latent + ) # [B, 4, h, w] + + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] + topk = self.training_noise_scheduler.config.num_train_timesteps // self.num_ddim_timesteps + index = torch.randint(0, self.num_ddim_timesteps, (batch_size,), device=rgb_latent.device).long() + start_timesteps = self.solver.ddim_timesteps[index].to(device) + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=self.timestep_scaling_factor + ) + c_skip_start, c_out_start = [append_dims(x, rgb_latent.ndim) for x in [c_skip_start, c_out_start]] + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=self.timestep_scaling_factor + ) + c_skip, c_out = [append_dims(x, rgb_latent.ndim) for x in [c_skip, c_out]] + + # Sample noise + + noise = torch.randn( + gt_depth_latent.shape, + device=device, + generator=rand_num_generator, + ) # [B, 4, h, w] + + # Add noise to the latents (diffusion forward process) + noisy_latents = self.training_noise_scheduler.add_noise( + gt_depth_latent, noise, start_timesteps + ) # [B, 4, h, w] + + # Text embedding + text_embed = self.empty_text_embed.to(device).repeat( + (batch_size, 1, 1) + ) # [B, 77, 1024] + + # Concat rgb and depth latents + cat_latents = torch.cat( + [rgb_latent, noisy_latents], dim=1 + ) # [B, 8, h, w] + cat_latents = cat_latents.float() + + # Predict the noise residual + model_pred = self.model.unet( + cat_latents, + start_timesteps, + text_embed + ).sample # [B, 4, h, w] + + + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + pred_x_0 = get_predicted_original_sample( + model_pred, + start_timesteps, + noisy_latents, + self.prediction_type, + self.alpha_schedule, + self.sigma_schedule, + ) # [B, 4, h, w] + + prediction = c_skip_start * noisy_latents + c_out_start * pred_x_0 # [B, 4, h, w] + + with torch.no_grad(): + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c + cond_teacher_output = self.teacher_unet( + cat_latents, + start_timesteps, + encoder_hidden_states = text_embed + ).sample # [B, 4, h, w] + teacher_pred_x0 = get_predicted_original_sample( + cond_teacher_output, + start_timesteps, + noisy_latents, + self.prediction_type, + self.alpha_schedule, + self.sigma_schedule, + ) # [B, 4, h, w] + teacher_pred_noise = get_predicted_noise( + cond_teacher_output, + start_timesteps, + noisy_latents, + self.prediction_type, + self.alpha_schedule, + self.sigma_schedule, + ) # [B, 4, h, w] + + x_prev = self.solver.ddim_step(teacher_pred_x0, teacher_pred_noise, index).to(device) # [B, 4, h, w] + + with torch.no_grad(): + + model_input_prev = torch.cat([rgb_latent, x_prev], dim=1) # [B, 8, h, w] + model_input_prev = model_input_prev.float() + + target_noise_pred = self.target_unet( + model_input_prev, + timesteps, + text_embed + ).sample # [B, 4, h, w] + target_pred_x0 = get_predicted_original_sample( + target_noise_pred, + timesteps, + x_prev, + self.prediction_type, + self.alpha_schedule, + self.sigma_schedule, + ) # [B, 4, h, w] + target = c_skip * x_prev + c_out * target_pred_x0 # [B, 4, h, w] + + if self.gt_mask_type is not None: + loss = huber_loss(prediction[valid_mask_down], + target[valid_mask_down], + huber_c = self.huber_c) + else: + loss = huber_loss(prediction, + target, + huber_c = self.huber_c) + + + self.train_metrics.update("loss", loss.item()) + + loss = loss / self.gradient_accumulation_steps + loss.backward() + accumulated_step += 1 + + self.n_batch_in_epoch += 1 + # Practical batch end + + # Perform optimization step + if accumulated_step >= self.gradient_accumulation_steps: + torch.nn.utils.clip_grad_norm_(self.model.unet.parameters(), self.max_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + update_ema(self.target_unet.parameters(), self.model.unet.parameters(), self.ema_decay) + + accumulated_step = 0 + + self.effective_iter += 1 + + # Log to tensorboard + accumulated_loss = self.train_metrics.result()["loss"] + tb_logger.log_dic( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" + ) + self.train_metrics.reset() + + # Per-step callback + self._train_step_callback() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), + save_train_state=False, + ) + logging.info("Training ended.") + return + # Time's up + elif t_end is not None and datetime.now() >= t_end: + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + logging.info("Time is up, training paused.") + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + + def validate(self): + for i, val_loader in enumerate(self.val_loaders): + val_dataset_name = val_loader.dataset.disp_name + val_metric_dic = self.validate_single_dataset( + data_loader=val_loader, metric_tracker=self.val_metrics + ) + logging.info( + f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}" + ) + tb_logger.log_dic( + {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, + global_step=self.effective_iter, + ) + # save to file + eval_text = eval_dic_to_text( + val_metrics=val_metric_dic, + dataset_name=val_dataset_name, + sample_list_path=val_loader.dataset.filename_ls_path, + ) + _save_to = os.path.join( + self.out_dir_eval, + f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", + ) + with open(_save_to, "w+") as f: + f.write(eval_text) + + # Update main eval metric + if 0 == i: + main_eval_metric = val_metric_dic[self.main_val_metric] + if ( + "minimize" == self.main_val_metric_goal + and main_eval_metric < self.best_metric + or "maximize" == self.main_val_metric_goal + and main_eval_metric > self.best_metric + ): + self.best_metric = main_eval_metric + logging.info( + f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" + ) + # Save a checkpoint + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + @torch.no_grad() + def validate_single_dataset( + self, + data_loader: DataLoader, + metric_tracker: MetricTracker, + save_to_dir: str = None, + ): + self.model.to(self.device) + metric_tracker.reset() + + # Generate seed sequence for consistent evaluation + val_init_seed = self.cfg.validation.init_seed + val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) + + for i, batch in enumerate( + tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), + start=1, + ): + assert 1 == data_loader.batch_size + # Read input image + rgb_int = batch["rgb_int"].squeeze() # [3, H, W] + # GT depth + depth_raw_ts = batch["depth_raw_linear"].squeeze() + depth_raw = depth_raw_ts.numpy() + depth_raw_ts = depth_raw_ts.to(self.device) + valid_mask_ts = batch["valid_mask_raw"].squeeze() + valid_mask = valid_mask_ts.numpy() + valid_mask_ts = valid_mask_ts.to(self.device) + + # Random number generator + seed = val_seed_ls.pop() + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Predict depth + pipe_out: MarigoldDepthOutput = self.model( + rgb_int, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=1, # use batch size 1 to increase reproducibility + color_map=None, + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + + depth_pred: np.ndarray = pipe_out.depth_np + + if "least_square" == self.cfg.eval.alignment: + depth_pred, scale, shift = align_depth_least_square( + gt_arr=depth_raw, + pred_arr=depth_pred, + valid_mask_arr=valid_mask, + return_scale_shift=True, + max_resolution=self.cfg.eval.align_max_res, + ) + elif "median_scale" == self.cfg.eval.alignment: + depth_pred, scale = median_scale_align( + gt_arr=depth_raw, + pred_arr=depth_pred, + valid_mask_arr=valid_mask, + return_scale=True, + ) + elif "least_square_disparity" == self.cfg.eval.alignment: + # convert GT depth -> GT disparity + gt_disparity, gt_non_neg_mask = depth2disparity( + depth=depth_raw, return_mask=True + ) + pred_non_neg_mask = depth_pred > 0 + valid_nonnegative_mask = ( + valid_mask & gt_non_neg_mask & pred_non_neg_mask + ) + # LS alignment in disparity space + disparity_pred, scale, shift = align_depth_least_square( + gt_arr=gt_disparity, + pred_arr=depth_pred, + valid_mask_arr=valid_nonnegative_mask, + return_scale_shift=True, + max_resolution=self.cfg.eval.align_max_res, + ) + # convert to depth + disparity_pred = np.clip( + disparity_pred, a_min=1e-3, a_max=None + ) # avoid 0 disparity + depth_pred = disparity2depth(disparity_pred) + elif "median_scale_disparity" == self.cfg.eval.alignment: + # convert GT depth -> GT disparity + gt_disparity, gt_non_neg_mask = depth2disparity( + depth=depth_raw, return_mask=True + ) + pred_non_neg_mask = depth_pred > 0 + valid_nonnegative_mask = ( + valid_mask & gt_non_neg_mask & pred_non_neg_mask + ) + disparity_pred, scale = median_scale_align( + gt_arr=gt_disparity, + pred_arr=depth_pred, + valid_mask_arr=valid_nonnegative_mask, + return_scale=True, + ) + # convert to depth + disparity_pred = np.clip( + disparity_pred, a_min=1e-3, a_max=None + ) # avoid 0 disparity + depth_pred = disparity2depth(disparity_pred) + + else: + raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") + + # Clip to dataset min max + depth_pred = np.clip( + depth_pred, + a_min=data_loader.dataset.min_depth, + a_max=data_loader.dataset.max_depth, + ) + + # clip to d > 0 for evaluation + depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) + + # Evaluate + sample_metric = [] + depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) + + for met_func in self.metric_funcs: + _metric_name = met_func.__name__ + _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() + sample_metric.append(_metric.__str__()) + metric_tracker.update(_metric_name, _metric) + + # Save as 16-bit uint png + if save_to_dir is not None: + img_name = batch["rgb_relative_path"][0].replace("/", "_") + png_save_path = os.path.join(save_to_dir, f"{img_name}.png") + depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16) + Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") + + return metric_tracker.result() + + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + self.model.unet.save_pretrained(unet_path, safe_serialization=True) + logging.info(f"UNet is saved to: {unet_path}") + + if self.save_target: + target_unet_path = os.path.join(ckpt_dir, "target_unet") + self.target_unet.save_pretrained(target_unet_path, safe_serialization=True) + logging.info(f"Target UNet is saved to: {target_unet_path}") + + if save_train_state: + state = { + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + if self.save_target: + + _target_model_path = os.path.join(ckpt_path, "target_unet", "diffusion_pytorch_model.bin") + self.target_unet.load_state_dict( + torch.load(_target_model_path, map_location=self.device) + ) + self.target_unet.to(self.device) + logging.info(f"Target UNet parameters loaded to: {_target_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}"