diff --git a/.gitignore b/.gitignore index 681a000..4112465 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,17 @@ .venv/ __pycache__/ .DS_Store +wandb/ +.vscode/ /data /output /input +/model-finetuned experiments/depth/* experiments/normals/* !experiments/depth/eval_args/ !experiments/normals/eval_args/ + +metadata_images_split_scene_v1.csv diff --git a/GeoWizard/geowizard/models/geowizard_pipeline.py b/GeoWizard/geowizard/models/geowizard_pipeline.py index 6f1dea0..57ea7f9 100644 --- a/GeoWizard/geowizard/models/geowizard_pipeline.py +++ b/GeoWizard/geowizard/models/geowizard_pipeline.py @@ -1,6 +1,7 @@ # Adapted from Marigold :https://github.com/prs-eth/Marigold -# @GonzaloMartinGarcia, all new additions to the GeoWizard code have been marked with # add. +# @GonzaloMartinGarcia +# All new additions to the GeoWizard code have been marked with # add. from typing import Any, Dict, Union @@ -17,7 +18,6 @@ ) from ..models.unet_2d_condition import UNet2DConditionModel from diffusers.utils import BaseOutput -from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection import torchvision.transforms.functional as TF from torchvision.transforms import InterpolationMode @@ -27,7 +27,20 @@ from ..utils.image_util import resize_max_res, chw2hwc, colorize_depth_maps from ..utils.depth_ensemble import ensemble_depths from ..utils.normal_ensemble import ensemble_normals -from ..utils.noise import pyramid_noise_like + +# add +# Pyramid noise from GeoWizard training code. +def pyramid_noise_like(x, timesteps, discount=0.9): + b, c, w_ori, h_ori = x.shape + u = nn.Upsample(size=(w_ori, h_ori), mode='bilinear') + noise = torch.randn_like(x) + scale = 1.5 + for i in range(10): + r = np.random.random()*scale + scale # Rather than always going 2x, + w, h = max(1, int(w_ori/(r**i))), max(1, int(h_ori/(r**i))) + noise += u(torch.randn(b, c, w, h).to(x)) * (timesteps[...,None,None,None]/1000) * discount**i + if w==1 or h==1: break # Lowest resolution is 1x1 + return noise/noise.std() # Scaled back to roughly unit variance class DepthNormalPipelineOutput(BaseOutput): diff --git a/GeoWizard/geowizard/training/__init__.py b/GeoWizard/geowizard/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/GeoWizard/geowizard/training/train_depth_normal.py b/GeoWizard/geowizard/training/train_depth_normal.py new file mode 100644 index 0000000..25eaf37 --- /dev/null +++ b/GeoWizard/geowizard/training/train_depth_normal.py @@ -0,0 +1,878 @@ +# A reimplemented version in public environments by Xiao Fu and Mu Hu + +# @GonzaloMartinGarcia +# Training code for the end-to-end fine-tuned GeoWizard Model from +# 'Fine-Tuning Image-Conditional Diffusion Models is Easier than You Think'. +# This training code is a modified version of the original GeoWizard training code, +# https://github.com/fuxiao0719/GeoWizard/blob/main/geowizard/training/training/train_depth_normal.py. +# Modifications have been marked with the comment # add. + +import argparse +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +import os +os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" +import logging +import tqdm + +import sys + +from accelerate import Accelerator +import numpy as np +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from accelerate.utils import ProjectConfiguration, set_seed +import shutil + +from diffusers import DDPMScheduler, AutoencoderKL +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import is_wandb_available # check_min_version +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +import torchvision.transforms.functional as TF +from torchvision.transforms import InterpolationMode +import accelerate + +# add +if is_wandb_available(): + import wandb +sys.path.append(os.getcwd()) +from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline +from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel +from torch.optim.lr_scheduler import LambdaLR +from training.util.lr_scheduler import IterExponential +from training.util.loss import ScaleAndShiftInvariantLoss, AngularLoss +from training.dataloaders.load import MixedDataLoader, Hypersim, VirtualKITTI2 + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +# check_min_version("0.26.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + parser = argparse.ArgumentParser(description="GeoWizard") + + # add + # End-to-end fine-tuned Settings + parser.add_argument( + "--e2e_ft", + action="store_true", + default=False + ) + parser.add_argument( + "--noise_type", + choices=["zeros", "pyramid", "gaussian"], + ) + parser.add_argument( + "--lr_total_iter_length", + type=int, + default=20000, + ) + + # GeoWizard Settings + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="training/model-finetuned", # add + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training." + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=2, # add + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=None # add + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=20000, # add + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=16, # add + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-5, # add + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--class_embedding_lr_mult", + type=float, + default=10, + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="exponential", # add + help=( + 'The scheduler type to use. Also choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", + type=int, + default=100, # add + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + # using EMA for improving the generalization + parser.add_argument( + "--use_ema", + action="store_true", + help="Whether to use EMA model." + ) + # dataloaderes + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank" + ) + # how many steps csave a checkpoints + parser.add_argument( + "--checkpointing_steps", + type=int, + default=20000, # add + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # using xformers for efficient training + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers." + ) + # noise offset?::: #TODO HERE + parser.add_argument( + "--noise_offset", + type=float, + default=0, + help="The scale of noise offset." + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="e2e-ft-diffusion", # add + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + # get the local rank + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + +def pyramid_noise_like(x, timesteps, discount=0.9): + b, c, w_ori, h_ori = x.shape + u = nn.Upsample(size=(w_ori, h_ori), mode='bilinear') + noise = torch.randn_like(x) + scale = 1.5 + for i in range(10): + r = np.random.random()*scale + scale # Rather than always going 2x, + w, h = max(1, int(w_ori/(r**i))), max(1, int(h_ori/(r**i))) + noise += u(torch.randn(b, c, w, h).to(x)) * (timesteps[...,None,None,None]/1000) * discount**i + if w==1 or h==1: break # Lowest resolution is 1x1 + return noise/noise.std() # Scaled back to roughly unit variance + +def main(): + + ''' ------------------------Configs Preparation----------------------------''' + # give the args parsers + args = parse_args() + # save the tensorboard log files + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + # tell the gradient_accumulation_steps, mix precison, and tensorboard + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=True) # only the main process show the logs + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Doing I/O at the main proecss + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # add + # Save training arguments in a txt file + args_dict = vars(args) + args_str = '\n'.join(f"{key}: {value}" for key, value in args_dict.items()) + args_path = os.path.join(args.output_dir, "arguments.txt") + os.makedirs(args.output_dir, exist_ok=True) + with open(args_path, 'w') as file: + file.write(args_str) + + ''' ------------------------Non-NN Modules Definition----------------------------''' + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler') + sd_image_variations_diffusers_path = 'lambdalabs/sd-image-variations-diffusers' + image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_image_variations_diffusers_path, subfolder="image_encoder") + feature_extractor = CLIPImageProcessor.from_pretrained(sd_image_variations_diffusers_path, subfolder="feature_extractor") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') + # add + # no modification are made to the UNet since we fine-tune GeoWizard. + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + + # using EMA + if args.use_ema: + ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) + + # Freeze vae and set unet to trainable. + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.train() # only make the unet-trainable + + # using xformers for efficient attentions. + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + logger.info("use xformers to speed up", main_process_only=True) + + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # using checkpoint for saving the memories + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # how many cards did we use: accelerator.num_processes + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + params, params_class_embedding = [], [] + for name, param in unet.named_parameters(): + if 'class_embedding' in name: + params_class_embedding.append(param) + else: + params.append(param) + + # optimizer settings + optimizer = optimizer_cls( + [ + {"params": params, "lr": args.learning_rate}, + {"params": params_class_embedding, "lr": args.learning_rate * args.class_embedding_lr_mult} + ], + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # get the training dataset + with accelerator.main_process_first(): + + # add + # Load datasets + hypersim_root_dir = "data/hypersim/processed" + vkitti_root_dir = "data/virtual_kitti_2" + train_dataset_hypersim = Hypersim(root_dir=hypersim_root_dir, transform=True) + train_dataset_vkitti = VirtualKITTI2(root_dir=vkitti_root_dir, transform=True) + train_dataloader_vkitti = torch.utils.data.DataLoader(train_dataset_vkitti, shuffle=True, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers) + train_dataloader_hypersim = torch.utils.data.DataLoader(train_dataset_hypersim, shuffle=True, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers) + train_loader = MixedDataLoader(train_dataloader_hypersim, train_dataloader_vkitti, split1=9, split2=1) + + # because the optimizer not optimized every time, so we need to calculate how many steps it optimizes, + # it is usually optimized by + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + # add + # Scheduler + if args.lr_scheduler == "exponential": + lr_func = IterExponential(total_iter_length = args.lr_total_iter_length*accelerator.num_processes, final_ratio = 0.01, warmup_steps = args.lr_warmup_steps*accelerator.num_processes) + lr_scheduler = LambdaLR(optimizer= optimizer, lr_lambda=lr_func) + elif args.lr_scheduler == "constant": + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + else: + raise ValueError(f"Unknown lr_scheduler {args.lr_scheduler}") + + # Prepare everything with our `accelerator`. + unet, optimizer, train_loader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_loader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + + clip_image_mean = torch.as_tensor(feature_extractor.image_mean)[:,None,None].to(accelerator.device, dtype=torch.float32) + clip_image_std = torch.as_tensor(feature_extractor.image_std)[:,None,None].to(accelerator.device, dtype=torch.float32) + + # We need to initialize the trackers we use, and also store our configuration. + num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Here is the DDP training: actually is 4 + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + # add + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset_vkitti)+len(train_dataset_hypersim)}") # add + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + # Progress bar + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # add + # Init task specific losses + ssi_loss = ScaleAndShiftInvariantLoss() + angular_loss_norm = AngularLoss() + # Init loss dictionary for logging + loss_logger = { "ssi": 0.0, # depth level loss + "ssi_count": 0.0, + "normals_angular": 0.0, # normals level loss + "normals_angular_count": 0.0 + } + + # add + # Get noise scheduling parameters for later conversion from a parameterized prediction into clean latent. + alpha_prod = noise_scheduler.alphas_cumprod.to(accelerator.device, dtype=weight_dtype) + beta_prod = 1 - alpha_prod + + # Training Loop + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_loader): + with accelerator.accumulate(unet): + + # add + # RGB + image_data_resized = batch['rgb'].to(accelerator.device,dtype=weight_dtype) + # Depth + depth_resized_normalized = batch['depth'].to(accelerator.device,dtype=weight_dtype) + # Validation mask + val_mask = batch["val_mask"].to(accelerator.device) + invalid_mask = ~val_mask + latent_mask = ~torch.max_pool2d(invalid_mask.float(), 8, 8).bool() + latent_mask = latent_mask.repeat((2, 4, 1, 1)).detach() + # Surface normals + normal_resized = batch['normals'].to(accelerator.device,dtype=weight_dtype)*-1 # GeoWizard trains on inverted normals! + + # Compute CLIP image embeddings + imgs_in_proc = TF.resize((image_data_resized +1)/2, + (feature_extractor.crop_size['height'], feature_extractor.crop_size['width']), + interpolation=InterpolationMode.BICUBIC, + antialias=True + ) + # do the normalization in float32 to preserve precision + imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(weight_dtype) + imgs_embed = image_encoder(imgs_in_proc).image_embeds.unsqueeze(1).to(weight_dtype) + + # encode latents + with torch.no_grad(): + if args.e2e_ft: + # add + # When E2E FT, we only need to encode the RGB image + h_batch = vae.encoder(image_data_resized) + moments_batch = vae.quant_conv(h_batch) + mean_batch, _ = torch.chunk(moments_batch, 2, dim=1) + rgb_latents = mean_batch * vae.config.scaling_factor + depth_latents, normal_latents = torch.zeros_like(rgb_latents), torch.zeros_like(rgb_latents) # dummy latents + else: + h_batch = vae.encoder(torch.cat((image_data_resized, depth_resized_normalized, normal_resized), dim=0).to(weight_dtype)) + moments_batch = vae.quant_conv(h_batch) + mean_batch, _ = torch.chunk(moments_batch, 2, dim=1) + batch_latents = mean_batch * vae.config.scaling_factor + rgb_latents, depth_latents, normal_latents = torch.chunk(batch_latents, 3, dim=0) + geo_latents = torch.cat((depth_latents, normal_latents), dim=0) + + # here is the setting batch size, in our settings, it can be 1.0 + bsz = rgb_latents.shape[0] + + # add + # Sample timesteps + if args.e2e_ft: + # Set timesteps to the first denoising step + timesteps = torch.ones((bsz,), device=depth_latents.device).repeat(2) * (noise_scheduler.config.num_train_timesteps-1) + timesteps = timesteps.long() + else: + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=depth_latents.device).repeat(2) + timesteps = timesteps.long() + + # add + # Sample noise + if args.noise_type == "zeros": + noise = torch.zeros_like(geo_latents).to(accelerator.device) + elif args.noise_type == "pyramid": + noise = pyramid_noise_like(geo_latents, timesteps).to(accelerator.device) + elif args.noise_type == "gaussian": + noise = torch.randn_like(geo_latents).to(accelerator.device) + else: + raise ValueError(f"Unknown noise type {args.noise_type}") + + # add + # Add noise to the depth latents + if args.e2e_ft: + noisy_geo_latents = noise # no ground truth when single step fine-tuning + else: + # add noise to the depth lantents + noisy_geo_latents = noise_scheduler.add_noise(geo_latents, noise, timesteps) + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(geo_latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + batch_imgs_embed = imgs_embed.repeat((2, 1, 1)) # [B*2, 1, 768] + + # hybrid hierarchical switcher + geo_class = torch.tensor([[0, 1], [1, 0]], dtype=weight_dtype, device=accelerator.device) + geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1).repeat_interleave(bsz, 0) + + # add + # Domain class + if batch["domain"][0] == 'indoor': + domain_class = torch.tensor([[1., 0., 0]], device=accelerator.device, dtype=weight_dtype) + elif batch["domain"][0] == 'outdoor': + domain_class = torch.tensor([[0., 1., 0]], device=accelerator.device, dtype=weight_dtype) + else: + raise ValueError(f"Unknown domain {batch['domain'][0]}") + domain_class = domain_class.repeat(bsz, 1) + + domain_embedding = torch.cat([torch.sin(domain_class), torch.cos(domain_class)], dim=-1).repeat(2,1).to(accelerator.device) + class_embedding = torch.cat((geo_embedding, domain_embedding), dim=-1) + + # predict the noise residual and compute the loss. + unet_input = torch.cat((rgb_latents.repeat(2,1,1,1), noisy_geo_latents), dim=1) + + noise_pred = unet(unet_input, + timesteps, + encoder_hidden_states=batch_imgs_embed, + class_labels=class_embedding).sample #[B, 4, h, w] + + # add + # Compute loss + loss = torch.tensor(0.0, device=accelerator.device, requires_grad=True) + if latent_mask.any(): + if not args.e2e_ft: + # Diffusion loss + loss = F.mse_loss(noise_pred[latent_mask].float(), target[latent_mask].float(), reduction="mean") + else: + # End-to-end task specific fine-tuning loss + # Convert parameterized prediction into latent prediction. + # Code is based on the DDIM code from diffusers, + # https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. + alpha_prod_t = alpha_prod[timesteps].view(-1, 1, 1, 1) + beta_prod_t = beta_prod[timesteps].view(-1, 1, 1, 1) + if noise_scheduler.config.prediction_type == "v_prediction": + current_latent_estimate = (alpha_prod_t**0.5) * noisy_geo_latents - (beta_prod_t**0.5) * noise_pred + elif noise_scheduler.config.prediction_type == "epsilon": + current_latent_estimate = (noisy_geo_latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + # clip or threshold prediction (only here for completeness, not used by SD2 or our models with v_prediction) + if noise_scheduler.config.thresholding: + pred_original_sample = noise_scheduler._threshold_sample(pred_original_sample) + elif noise_scheduler.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -noise_scheduler.config.clip_sample_range, noise_scheduler.config.clip_sample_range + ) + # Decode the latent estimate + current_latent_estimate = current_latent_estimate / vae.config.scaling_factor + z = vae.post_quant_conv(current_latent_estimate) + current_estimate = vae.decoder(z) + current_depth_estimate, current_normal_estimate = torch.chunk(current_estimate, 2, dim=0) + # Process depth and get GT + current_depth_estimate = current_depth_estimate.mean(dim=1, keepdim=True) + current_depth_estimate = torch.clamp(current_depth_estimate,-1,1) + depth_ground_truth = batch["metric"].to(device=accelerator.device, dtype=weight_dtype) + # Process normals and get GT + norm = torch.norm(current_normal_estimate, p=2, dim=1, keepdim=True) + 1e-5 + current_normal_estimate = current_normal_estimate / norm + current_normal_estimate = torch.clamp(current_normal_estimate,-1,1) + normal_ground_truth = batch["normals"].to(device=accelerator.device, dtype=weight_dtype) * -1 # GeoWizard trains on inverted normals! + # Compute task-specific loss + estimation_loss = 0 + depth_scale = 0.5 # ssi loss is roughly 2x the angular loss + normal_scale = 1.0 + # Scale and shift invariant loss + estimation_loss_ssi = ssi_loss(current_depth_estimate, depth_ground_truth, val_mask) + if not torch.isnan(estimation_loss_ssi).any(): + estimation_loss = estimation_loss + (estimation_loss_ssi*depth_scale) + loss_logger["ssi"] += estimation_loss_ssi.detach().item() + loss_logger["ssi_count"] += 1 + # Angular loss + estimation_loss_ang_norm = angular_loss_norm(current_normal_estimate, normal_ground_truth, val_mask) + if not torch.isnan(estimation_loss_ang_norm).any(): + estimation_loss = estimation_loss + (estimation_loss_ang_norm*normal_scale) + loss_logger["normals_angular"] += estimation_loss_ang_norm.detach().item() + loss_logger["normals_angular_count"] += 1 + loss = loss + estimation_loss + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + # add + accelerator.log({"lr": lr_scheduler.get_last_lr()[0]}, step=global_step) + + train_loss = 0.0 + + # add + # logg depth and normals losses separately + for key in list(loss_logger.keys()): + if "_count" not in key: + count_key = key + "_count" + if loss_logger[count_key] != 0: + # compute avg + loss_logger[key] /= loss_logger[count_key] + # log loss + loss_name = key + "_loss" + accelerator.log({loss_name: loss_logger[key]}, step=global_step) + # set all losses to 0 + for key in list(loss_logger.keys()): + loss_logger[key] = 0.0 + + # saving the checkpoints + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # Log loss and learning rate for progress bar + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + # Stop training + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + # validation each epoch by calculate the epe and the visualization depth + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + # add + # Create GeoWizard pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + timestep_spacing="trailing" # set scheduler timestep spacing to trailing for later inference. + ) + pipeline = DepthNormalEstimationPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor + ) + logger.info(f"Saving pipeline to {args.output_dir}") + pipeline.save_pretrained(args.output_dir) + logger.info(f"Finished training.") + + accelerator.wait_for_everyone() + accelerator.end_training() + +if __name__=="__main__": + main() diff --git a/GeoWizard/geowizard/utils/depth2normal.py b/GeoWizard/geowizard/utils/depth2normal.py index bca7f1d..e01c76a 100644 --- a/GeoWizard/geowizard/utils/depth2normal.py +++ b/GeoWizard/geowizard/utils/depth2normal.py @@ -1,13 +1,8 @@ # A reimplemented version in public environments by Xiao Fu and Mu Hu -import pickle -import os -import h5py import numpy as np -import cv2 import torch import torch.nn as nn -import glob def init_image_coor(height, width): diff --git a/GeoWizard/geowizard/utils/image_util.py b/GeoWizard/geowizard/utils/image_util.py index a864ca1..2cb1b24 100644 --- a/GeoWizard/geowizard/utils/image_util.py +++ b/GeoWizard/geowizard/utils/image_util.py @@ -6,8 +6,6 @@ from PIL import Image - - def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: """ Resize image to limit maximum edge length while keeping aspect ratio. diff --git a/GeoWizard/geowizard/utils/noise.py b/GeoWizard/geowizard/utils/noise.py deleted file mode 100644 index 977a533..0000000 --- a/GeoWizard/geowizard/utils/noise.py +++ /dev/null @@ -1,19 +0,0 @@ -# add -import torch -import numpy as np -from torch import nn - - -# Pyramid Noise implementation from GeoWizard training code -# https://github.com/fuxiao0719/GeoWizard/blob/5b25910f5ceaecb4f5f3db000153052628611c9d/geowizard/training/training/train_depth_normal.py#L299 -def pyramid_noise_like(x, timesteps, discount=0.9): - b, c, w_ori, h_ori = x.shape - u = nn.Upsample(size=(w_ori, h_ori), mode='bilinear') - noise = torch.randn_like(x) - scale = 1.5 - for i in range(10): - r = np.random.random()*scale + scale # Rather than always going 2x, - w, h = max(1, int(w_ori/(r**i))), max(1, int(h_ori/(r**i))) - noise += u(torch.randn(b, c, w, h).to(x)) * (timesteps[...,None,None,None]/1000) * discount**i - if w==1 or h==1: break # Lowest resolution is 1x1 - return noise/noise.std() # Scaled back to roughly unit variance \ No newline at end of file diff --git a/GeoWizard/run_infer.py b/GeoWizard/run_infer.py index dfc0b19..76a8bc6 100644 --- a/GeoWizard/run_infer.py +++ b/GeoWizard/run_infer.py @@ -1,6 +1,7 @@ # Adapted from Marigold :https://github.com/prs-eth/Marigold -# @GonzaloMartinGarcia, all new additions to the GeoWizard code have been marked with # add. +# @GonzaloMartinGarcia, +# All new additions to the GeoWizard code have been marked with # add. import argparse import os @@ -14,11 +15,9 @@ from geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline from geowizard.utils.seed_all import seed_all from geowizard.utils.depth2normal import * -from diffusers import DiffusionPipeline, DDIMScheduler, AutoencoderKL +from diffusers import DDIMScheduler, AutoencoderKL from geowizard.models.unet_2d_condition import UNet2DConditionModel from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -import torchvision.transforms.functional as TF -from torchvision.transforms import InterpolationMode from geowizard.utils.seed_all import seed_all if __name__=="__main__": @@ -198,14 +197,12 @@ scheduler = DDIMScheduler.from_pretrained(checkpoint_path, subfolder='scheduler') image_encoder = CLIPVisionModelWithProjection.from_pretrained(checkpoint_path, subfolder="image_encoder") feature_extractor = CLIPImageProcessor.from_pretrained(checkpoint_path, subfolder="feature_extractor") - unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet") # load the UNet from checkpoint + unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet") pipe = DepthNormalEstimationPipeline(vae=vae, image_encoder=image_encoder, feature_extractor=feature_extractor, unet=unet, scheduler=scheduler) - - print(unet.config) logging.info("loading pipeline whole successfully.") diff --git a/Marigold/infer.py b/Marigold/infer.py index 840ad42..c23ddcb 100644 --- a/Marigold/infer.py +++ b/Marigold/infer.py @@ -19,8 +19,8 @@ # -------------------------------------------------------------------------- # @GonzaloMartinGarcia -# ! The following code is built upon Marigold's infer.py. I have adapted it to also run evaluation inference for my monocular -# rgb to metric depth diffusion models. All my changes to the original infer.py are marked with a '# add' comment. +# The following code is built upon Marigold's infer.py, and was adapted to include some new settings. +# All additions made are marked with # add. import argparse diff --git a/Marigold/script/dataset_preprocess/hypersim/hypersim_util.py b/Marigold/script/dataset_preprocess/hypersim/hypersim_util.py new file mode 100644 index 0000000..2bedaed --- /dev/null +++ b/Marigold/script/dataset_preprocess/hypersim/hypersim_util.py @@ -0,0 +1,69 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-19 + + +from pylab import count_nonzero, clip, np + + +# Adapted from https://github.com/apple/ml-hypersim/blob/main/code/python/tools/scene_generate_images_tonemap.py +def tone_map(rgb, entity_id_map): + assert (entity_id_map != 0).all() + + gamma = 1.0 / 2.2 # standard gamma correction exponent + inv_gamma = 1.0 / gamma + percentile = ( + 90 # we want this percentile brightness value in the unmodified image... + ) + brightness_nth_percentile_desired = 0.8 # ...to be this bright after scaling + + valid_mask = entity_id_map != -1 + + if count_nonzero(valid_mask) == 0: + scale = 1.0 # if there are no valid pixels, then set scale to 1.0 + else: + brightness = ( + 0.3 * rgb[:, :, 0] + 0.59 * rgb[:, :, 1] + 0.11 * rgb[:, :, 2] + ) # "CCIR601 YIQ" method for computing brightness + brightness_valid = brightness[valid_mask] + + eps = 0.0001 # if the kth percentile brightness value in the unmodified image is less than this, set the scale to 0.0 to avoid divide-by-zero + brightness_nth_percentile_current = np.percentile(brightness_valid, percentile) + + if brightness_nth_percentile_current < eps: + scale = 0.0 + else: + # Snavely uses the following expression in the code at https://github.com/snavely/pbrs_tonemapper/blob/master/tonemap_rgbe.py: + # scale = np.exp(np.log(brightness_nth_percentile_desired)*inv_gamma - np.log(brightness_nth_percentile_current)) + # + # Our expression below is equivalent, but is more intuitive, because it follows more directly from the expression: + # (scale*brightness_nth_percentile_current)^gamma = brightness_nth_percentile_desired + + scale = ( + np.power(brightness_nth_percentile_desired, inv_gamma) + / brightness_nth_percentile_current + ) + + rgb_color_tm = np.power(np.maximum(scale * rgb, 0), gamma) + rgb_color_tm = clip(rgb_color_tm, 0, 1) + return rgb_color_tm + + +# According to https://github.com/apple/ml-hypersim/issues/9 +def dist_2_depth(width, height, flt_focal, distance): + img_plane_x = ( + np.linspace((-0.5 * width) + 0.5, (0.5 * width) - 0.5, width) + .reshape(1, width) + .repeat(height, 0) + .astype(np.float32)[:, :, None] + ) + img_plane_y = ( + np.linspace((-0.5 * height) + 0.5, (0.5 * height) - 0.5, height) + .reshape(height, 1) + .repeat(width, 1) + .astype(np.float32)[:, :, None] + ) + img_plane_z = np.full([height, width, 1], flt_focal, np.float32) + img_plane = np.concatenate([img_plane_x, img_plane_y, img_plane_z], 2) + + depth = distance / np.linalg.norm(img_plane, 2, 2) * flt_focal + return depth diff --git a/Marigold/script/dataset_preprocess/hypersim/preprocess_hypersim.py b/Marigold/script/dataset_preprocess/hypersim/preprocess_hypersim.py new file mode 100644 index 0000000..32c8214 --- /dev/null +++ b/Marigold/script/dataset_preprocess/hypersim/preprocess_hypersim.py @@ -0,0 +1,153 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-19 + +# @GonzaloMartinGarcia +# Changes have been marked with # add. + +import argparse +import os + +import cv2 +import h5py +import numpy as np +import pandas as pd +from hypersim_util import dist_2_depth, tone_map +from tqdm import tqdm + +IMG_WIDTH = 1024 +IMG_HEIGHT = 768 +FOCAL_LENGTH = 886.81 + +if "__main__" == __name__: + parser = argparse.ArgumentParser() + # add + parser.add_argument( + "--split_csv", + type=str, + default="data/hypersim/metadata_images_split_scene_v1.csv", + ) + parser.add_argument("--dataset_dir", type=str, default="data/hypersim/raw_data") + parser.add_argument("--output_dir", type=str, default="data/hypersim/processed") + + args = parser.parse_args() + + split_csv = args.split_csv + dataset_dir = args.dataset_dir + output_dir = args.output_dir + + # %% + raw_meta_df = pd.read_csv(split_csv) + meta_df = raw_meta_df[raw_meta_df.included_in_public_release].copy() + + # %% + for split in ["train", "val", "test"]: + split_output_dir = os.path.join(output_dir, split) + os.makedirs(split_output_dir) + + split_meta_df = meta_df[meta_df.split_partition_name == split].copy() + split_meta_df["rgb_path"] = None + split_meta_df["rgb_mean"] = np.nan + split_meta_df["rgb_std"] = np.nan + split_meta_df["rgb_min"] = np.nan + split_meta_df["rgb_max"] = np.nan + split_meta_df["depth_path"] = None + split_meta_df["depth_mean"] = np.nan + split_meta_df["depth_std"] = np.nan + split_meta_df["depth_min"] = np.nan + split_meta_df["depth_max"] = np.nan + split_meta_df["invalid_ratio"] = np.nan + + for i, row in tqdm(split_meta_df.iterrows(), total=len(split_meta_df)): + # Load data + rgb_path = os.path.join( + row.scene_name, + "images", + f"scene_{row.camera_name}_final_hdf5", + f"frame.{row.frame_id:04d}.color.hdf5", + ) + dist_path = os.path.join( + row.scene_name, + "images", + f"scene_{row.camera_name}_geometry_hdf5", + f"frame.{row.frame_id:04d}.depth_meters.hdf5", + ) + render_entity_id_path = os.path.join( + row.scene_name, + "images", + f"scene_{row.camera_name}_geometry_hdf5", + f"frame.{row.frame_id:04d}.render_entity_id.hdf5", + ) + assert os.path.exists(os.path.join(dataset_dir, rgb_path)) + assert os.path.exists(os.path.join(dataset_dir, dist_path)) + + with h5py.File(os.path.join(dataset_dir, rgb_path), "r") as f: + rgb = np.array(f["dataset"]).astype(float) + with h5py.File(os.path.join(dataset_dir, dist_path), "r") as f: + dist_from_center = np.array(f["dataset"]).astype(float) + with h5py.File(os.path.join(dataset_dir, render_entity_id_path), "r") as f: + render_entity_id = np.array(f["dataset"]).astype(int) + + # Tone map + rgb_color_tm = tone_map(rgb, render_entity_id) + rgb_int = (rgb_color_tm * 255).astype(np.uint8) # [H, W, RGB] + + # Distance -> depth + plane_depth = dist_2_depth( + IMG_WIDTH, IMG_HEIGHT, FOCAL_LENGTH, dist_from_center + ) + valid_mask = render_entity_id != -1 + + # Record invalid ratio + invalid_ratio = (np.prod(valid_mask.shape) - valid_mask.sum()) / np.prod( + valid_mask.shape + ) + plane_depth[~valid_mask] = 0 + + # Save as png + scene_path = row.scene_name + if not os.path.exists(os.path.join(split_output_dir, row.scene_name)): + os.makedirs(os.path.join(split_output_dir, row.scene_name)) + + rgb_name = f"rgb_{row.camera_name}_fr{row.frame_id:04d}.png" + rgb_path = os.path.join(scene_path, rgb_name) + cv2.imwrite( + os.path.join(split_output_dir, rgb_path), + cv2.cvtColor(rgb_int, cv2.COLOR_RGB2BGR), + ) + + plane_depth *= 1000.0 + plane_depth = plane_depth.astype(np.uint16) + depth_name = f"depth_plane_{row.camera_name}_fr{row.frame_id:04d}.png" + depth_path = os.path.join(scene_path, depth_name) + cv2.imwrite(os.path.join(split_output_dir, depth_path), plane_depth) + + # Meta data + split_meta_df.at[i, "rgb_path"] = rgb_path + split_meta_df.at[i, "rgb_mean"] = np.mean(rgb_int) + split_meta_df.at[i, "rgb_std"] = np.std(rgb_int) + split_meta_df.at[i, "rgb_min"] = np.min(rgb_int) + split_meta_df.at[i, "rgb_max"] = np.max(rgb_int) + + split_meta_df.at[i, "depth_path"] = depth_path + restored_depth = plane_depth / 1000.0 + split_meta_df.at[i, "depth_mean"] = np.mean(restored_depth) + split_meta_df.at[i, "depth_std"] = np.std(restored_depth) + split_meta_df.at[i, "depth_min"] = np.min(restored_depth) + split_meta_df.at[i, "depth_max"] = np.max(restored_depth) + + split_meta_df.at[i, "invalid_ratio"] = invalid_ratio + + with open( + os.path.join(split_output_dir, f"filename_list_{split}.txt"), "w+" + ) as f: + lines = split_meta_df.apply( + lambda r: f"{r['rgb_path']} {r['depth_path']}", axis=1 + ).tolist() + f.writelines("\n".join(lines)) + + with open( + os.path.join(split_output_dir, f"filename_meta_{split}.csv"), "w+" + ) as f: + split_meta_df.to_csv(f, header=True) + + print("Preprocess finished") diff --git a/README.md b/README.md index c556942..9bbeb0b 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ By using the correct `trailing` timestep spacing, it is possible to sample singl Our single-step deterministic E2E FT models outperform the previously mentioned diffusion estimators. -## πŸ“‹ Performance +## πŸ“‹ Metrics | Depth Method | Inference Time | NYUv2 AbsRel↓ | KITTI AbsRel↓ | ETH3D AbsRel↓| ScanNet AbsRel↓ | DIODE AbsRel↓ | |-----------------------------|------------------|---------------|----------------|--------------|-----------------|----------------| @@ -119,7 +119,7 @@ Inference time is for a single 576x768-pixel image, evaluated on an NVIDIA RTX 4 ## πŸ“Š Evaluation -We utilize the official [Marigold](https://github.com/prs-eth/Marigold) evaluation pipeline to evaluate the affine-invariant depth estimation checkpoints, and we use the official [DSINE](https://github.com/baegwangbin/DSINE) evaluation pipeline to evaluate the surface normal estimation checkpoints. The code has been streamlined to exclude unnecessary parts, and changes have been marked. +We utilize the official [Marigold](https://github.com/prs-eth/Marigold) evaluation pipeline to evaluate the affine-invariant depth estimation checkpoints, and we use the official [DSINE](https://github.com/baegwangbin/DSINE) evaluation pipeline to evaluate the surface normals estimation checkpoints. The code has been streamlined to exclude unnecessary parts, and changes have been marked. ### Depth @@ -213,6 +213,132 @@ dsine └── metrics.txt ``` +## πŸ‹οΈ Training + +### Datasets + +The fine-tuned models are trained on the [Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) datasets. + + +#### Hypersim + +Download the [Hypersim](https://github.com/apple/ml-hypersim) dataset using the [dataset_download_images.py](https://github.com/apple/ml-hypersim/blob/20f398f4387aeca73175494d6a2568f37f372150/code/python/tools/dataset_download_images.py) script and unzip the files to `data/hypersim/raw_data` at the root of the project. Download the scene split file from the [Hypersim repository](https://github.com/apple/ml-hypersim/blob/main/evermotion_dataset/analysis/metadata_images_split_scene_v1.csv) and place it in `data/hypersim`. + +``` +data +└── hypersim + β”œβ”€β”€ metadata_images_split_scene_v1.csv + └── raw_data + β”œβ”€β”€ ai_001_001 + β”œβ”€β”€ ... + └── ai_055_010 +``` + +Run Marigold's preprocessing script, which will save the processed data to `data/hypersim/processed`. +```bash +python Marigold/script/dataset_preprocess/hypersim/preprocess_hypersim.py \ + --split_csv data/hypersim/metadata_images_split_scene_v1.csv +``` + +Download the surface normals in `png` format using Hypersim's [`download.py`](https://github.com/apple/ml-hypersim/tree/20f398f4387aeca73175494d6a2568f37f372150/contrib/99991) script. +```bash +./download.py --contains normal_cam.png --silent +``` +Place the downloaded surface normals in `data/hypersim/processed/normals`. + +The final processed file structure should look like this: +``` +data +└── hypersim + └── processed + β”œβ”€β”€ normals + β”‚ β”œβ”€β”€ ai_001_001 + β”‚ β”œβ”€β”€ ... + β”‚ └── ai_055_010 + └── train + β”œβ”€β”€ ai_001_001 + β”œβ”€β”€ ... + β”œβ”€β”€ ai_055_010 + └── filename_meta_train.csv +``` + +#### Virtual KITTI 2 + +Download the RGB (`vkitti_2.0.3_rgb.tar`) and depth (`vkitti_2.0.3_depth.tar`) files from the [official website](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/). Place them in `data/virtual_kitti_2` at the root of the project and finally extract them using the following shell commands. + +```bash +mkdir vkitti_2.0.3_rgb && tar -xf vkitti_2.0.3_rgb.tar -C vkitti_2.0.3_rgb +mkdir vkitti_2.0.3_depth && tar -xf vkitti_2.0.3_depth.tar -C vkitti_2.0.3_depth +``` + +Virtual KITTI 2 does not provide surface normals. Therefore, we estimate them from the depth maps using [discontinuity-aware gradient filters](https://github.com/fengyi233/depth-to-normal-translator). Run our provided script to generate the normals which will be saved to `data/virtual_kitti_2/vkitti_DAG_normals`. + +```bash +python depth-to-normal-translator/python/gen_vkitti_normals.py +``` + +The final processed file structure should look like this: + +``` +data +└── virtual_kitti_2 + β”œβ”€β”€ vkitti_2.0.3_depth + β”‚ β”œβ”€β”€ Scene01 + β”‚ β”œβ”€β”€ Scene02 + β”‚ β”œβ”€β”€ Scene06 + β”‚ β”œβ”€β”€ Scene18 + β”‚ └── Scene20 + β”œβ”€β”€ vkitti_2.0.3_rgb + β”‚ β”œβ”€β”€ Scene01 + β”‚ β”œβ”€β”€ Scene02 + β”‚ β”œβ”€β”€ Scene06 + β”‚ β”œβ”€β”€ Scene18 + β”‚ └── Scene20 + └── vkitti_DAG_normals + β”œβ”€β”€ Scene01 + β”œβ”€β”€ Scene02 + β”œβ”€β”€ Scene06 + β”œβ”€β”€ Scene18 + └── Scene20 +``` + +### E2E FT Model Training + +To train the end-to-end fine-tuned depth and normals models, run the scripts in the `training/scripts` directory: +```bash +./training/scripts/train_marigold_e2e_ft_depth.sh +``` +```bash +./training/scripts/train_stable_diffusion_e2e_ft_depth.sh +``` +```bash +./training/scripts/train_marigold_e2e_ft_normals.sh +``` +```bash +./training/scripts/train_stable_diffusion_e2e_ft_normals.sh +``` +```bash +./training/scripts/train_geowizard_e2e_ft.sh +``` + +The fine-tuned models will be saved to `model-finetuned` at the root of the project. + +```bash +model-finetuned + └── + β”œβ”€β”€ arguments.txt + β”œβ”€β”€ model_index.json + β”œβ”€β”€ text_encoder # or image_encoder for GeoWizard + β”œβ”€β”€ tokenizer + β”œβ”€β”€ feature_extractor + β”œβ”€β”€ scheduler + β”œβ”€β”€ vae + └── unet +``` + +> [!NOTE] +> For multi GPU training, set the desired number of devices and nodes in the `training/scripts/multi_gpu.yaml` file and replace `accelerate launch` with `accelerate launch --multi_gpu --config_file training/scripts/multi_gpu.yaml` in the training scripts. + ## πŸŽ“ Citation If you use our work in your research, please use the following BibTeX entry. @@ -224,4 +350,4 @@ If you use our work in your research, please use the following BibTeX entry. journal = {arXiv preprint arXiv:2409.11355}, year = {2024} } -``` +``` \ No newline at end of file diff --git a/depth-to-normal-translator/LICENSE b/depth-to-normal-translator/LICENSE new file mode 100644 index 0000000..b27fdc5 --- /dev/null +++ b/depth-to-normal-translator/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 fengyi233 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/depth-to-normal-translator/README.md b/depth-to-normal-translator/README.md new file mode 100644 index 0000000..caeee99 --- /dev/null +++ b/depth-to-normal-translator/README.md @@ -0,0 +1,2 @@ + +Code is copied from https://github.com/fengyi233/depth-to-normal-translator. Added gen_vkitti_normals.py file, which is based on the original demo.py file. Modifications are indicated within the code. diff --git a/depth-to-normal-translator/python/gen_vkitti_normals.py b/depth-to-normal-translator/python/gen_vkitti_normals.py new file mode 100644 index 0000000..12b4722 --- /dev/null +++ b/depth-to-normal-translator/python/gen_vkitti_normals.py @@ -0,0 +1,136 @@ +# @GonzaloMartinGarcia +# This file is based on the official demo.py from the original repository +# https://github.com/fengyi233/depth-to-normal-translator. +# The code was modified to generate Virtual KITTI 2 normals from the ground truth depth maps. + +from utils import * + +# add +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +import os +from PIL import Image +import numpy as np + +############################### +# Modified VKITTI 2 DATALOADER +############################### + +# add +# Modified Virtual KITTI 2.0 Dataset class to output the ground truth depth, intrinsics and normal path +class VirtualKITTI2(Dataset): + + def __init__(self, root_dir): + self.root_dir = root_dir + self.pairs = self._find_pairs() + + def _find_pairs(self): + scenes = ["Scene01", "Scene02", "Scene06", "Scene18", "Scene20"] + weather_conditions = ["15-deg-left","15-deg-right", "30-deg-left", "30-deg-right", "clone", "morning", "fog", "rain", "sunset", "overcast"] + cameras = ["Camera_0", "Camera_1"] + vkitti2_rgb_path = os.path.join(self.root_dir, "vkitti_2.0.3_rgb") + vkitti2_depth_path = os.path.join(self.root_dir, "vkitti_2.0.3_depth") + vkitti2_normal_path = os.path.join(self.root_dir, "vkitti_DAG_normals") # name of the new normals folder + pairs = [] + for scene in scenes: + for weather in weather_conditions: + for camera in cameras: + rgb_dir = os.path.join(vkitti2_rgb_path, scene, weather, "frames", "rgb" ,camera) + depth_dir = os.path.join(vkitti2_depth_path, scene, weather, "frames","depth" , camera) + normal_dir = os.path.join(vkitti2_normal_path, scene, weather, "frames", "normal", camera) + if os.path.exists(rgb_dir) and os.path.exists(depth_dir): + rgb_files = [f for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + rgb_files = [file[3:] for file in rgb_files] + for file in rgb_files: + rgb_file = "rgb" + file + depth_file = "depth" + file.replace('.jpg', '.png') + normal_file = "normal" + file.replace('.jpg', '.png') + rgb_path = os.path.join(rgb_dir, rgb_file) + depth_path = os.path.join(depth_dir, depth_file) + normal_path = os.path.join(normal_dir, normal_file) + pairs.append((rgb_path, depth_path, normal_path)) + return pairs + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + _, depth_path, normal_path = self.pairs[idx] + + # get depth + depth_image = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + depth_image = depth_image.astype(np.float32)/100.0 # cm to meters + depth_image = Image.fromarray(depth_image) # PIL + depth_tensor = transforms.ToTensor()(depth_image) + + # intrinsics (from the official vkitti_2.0.3_textgt.tar files) + fx_d = 725.0087 + fy_d = 725.0087 + cx_d = 620.5 + cy_d = 187 + K = torch.tensor([ [fx_d, 0, cx_d], + [ 0, fy_d, cy_d], + [ 0, 0, 1]]) + + return {"depth": depth_tensor, 'normal_path': normal_path, "intrinsics": K} + + +#################### +# Depth to Normals +#################### + +# version choices: ['d2nt_basic', 'd2nt_v2', 'd2nt_v3'] +VERSION = 'd2nt_v3' + +if __name__ == '__main__': + + # add + # init dataset + print(f"Generating Normals using Version {VERSION}") + root_dir = "data/virtual_kitti_2" + dataset = VirtualKITTI2(root_dir) + dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) + + print(f"Number of samples: {len(dataset)}") + + for i, data in enumerate(dataloader): + + # add + depth = data['depth'][:,0,:,:].squeeze().numpy()*100 # [H, W] + intrinsics = data['intrinsics'].squeeze().numpy() # [3, 3] + + # get camera parameters and depth + cam_fx, cam_fy, u0, v0 = intrinsics[0,0], intrinsics[1,1], intrinsics[0,2], intrinsics[1,2] + h, w = depth.shape + u_map = np.ones((h, 1)) * np.arange(1, w + 1) - u0 # u-u0 + v_map = np.arange(1, h + 1).reshape(h, 1) * np.ones((1, w)) - v0 # v-v0 + + # DAG Depth to Normals: + if VERSION == 'd2nt_basic': + Gu, Gv = get_filter(depth) + else: + Gu, Gv = get_DAG_filter(depth) + + # Depth to Normal Translation + est_nx = Gu * cam_fx + est_ny = Gv * cam_fy + est_nz = -(depth + v_map * Gv + u_map * Gu) + est_normal = cv2.merge((est_nx, est_ny, est_nz)) + + # vector normalization + est_normal = vector_normalization(est_normal) + + # MRF-based Normal Refinement + if VERSION == 'd2nt_v3': + est_normal = MRF_optim(depth, est_normal) + + # redirect normals against camera + est_normal = est_normal * -1 # [H,W,3] + + # add + # save normals + est_normal_16bit = ((est_normal + 1) * 32767.5).astype(np.uint16) + est_normal_16bit = cv2.cvtColor(est_normal_16bit, cv2.COLOR_RGB2BGR) + os.makedirs(os.path.dirname(data['normal_path'][0]), exist_ok=True) + cv2.imwrite(data['normal_path'][0], est_normal_16bit) \ No newline at end of file diff --git a/depth-to-normal-translator/python/utils/__init__.py b/depth-to-normal-translator/python/utils/__init__.py new file mode 100644 index 0000000..1490489 --- /dev/null +++ b/depth-to-normal-translator/python/utils/__init__.py @@ -0,0 +1,2 @@ +from .apis import * +from .myApis import * diff --git a/depth-to-normal-translator/python/utils/apis.py b/depth-to-normal-translator/python/utils/apis.py new file mode 100644 index 0000000..a0861d9 --- /dev/null +++ b/depth-to-normal-translator/python/utils/apis.py @@ -0,0 +1,73 @@ +import struct + +import cv2 +import numpy as np + + +# uMax = 640 # w +# vMax = 480 # h + + +def get_cam_params(calib_path): + with open(calib_path, 'r') as f: + data = f.read() + params = list(map(int, (data.split())))[:-1] + return params + + +def get_normal_gt(normal_path): + # retVal: [-1,1] + normal_gt = cv2.imread(normal_path, -1) + normal_gt = normal_gt[:, :, ::-1] + normal_gt = 1 - normal_gt / 65535 * 2 + return normal_gt + + +def get_depth(depth_path, height, width): + with open(depth_path, 'rb') as f: + data_raw = struct.unpack('f' * width * height, f.read(4 * width * height)) + z = np.array(data_raw).reshape(height, width) + + # create mask, 1 for foreground, 0 for background + mask = np.ones_like(z) + mask[z == 1] = 0 + + return z, mask + + +def vector_normalization(normal, eps=1e-8): + mag = np.linalg.norm(normal, axis=2) + normal /= (np.expand_dims(mag, axis=2) + eps) + return normal + + +def visualization_map_creation(normal, mask): + mask = np.expand_dims(mask, axis=2) + vis = normal * mask + mask - 1 + vis = (1 - vis) / 2 # transform the interval from [-1, 1] to [0, 1] + return vis + + +def angle_normalization(err_map): + err_map[err_map > np.pi / 2] = np.pi - err_map[err_map > np.pi / 2] + return err_map + + +def evaluation(n_gt, n_est, mask): + scale = np.pi / 180 + error_map = np.arccos(np.sum(n_gt * n_est, axis=2)) + error_map = angle_normalization(error_map) / scale + error_map *= mask + ea = error_map.sum() / mask.sum() + return error_map, ea + +# def softmax(x): +# x_exp = np.exp(x) +# x_sum = np.sum(x_exp) +# return x_exp / x_sum +# +# +# def softmin(x): +# x_exp = np.exp(-x) +# x_sum = np.sum(x_exp) +# return x_exp / x_sum diff --git a/depth-to-normal-translator/python/utils/myApis.py b/depth-to-normal-translator/python/utils/myApis.py new file mode 100644 index 0000000..a8cadbb --- /dev/null +++ b/depth-to-normal-translator/python/utils/myApis.py @@ -0,0 +1,179 @@ +import cv2 +import numpy as np + +kernel_Gx = np.array([[0, 0, 0], + [-1, 0, 1], + [0, 0, 0]]) + +kernel_Gy = np.array([[0, -1, 0], + [0, 0, 0], + [0, 1, 0]]) + +cp2tv_Gx = np.array([[0, 0, 0], + [0, -1, 1], + [0, 0, 0]]) + +cp2tv_Gy = np.array([[0, 0, 0], + [0, -1, 0], + [0, 1, 0]]) + +lap_ker_alpha = np.array([[0, -1, 0], + [-1, 4, -1], + [0, -1, 0]]) + +lap_ker_beta = np.array([[-1, -1, -1], + [-1, 8, -1], + [-1, -1, -1]]) + +lap_ker_gamma = np.array([[0.25, 0.5, 0.25], + [0.5, -3, 0.5], + [0.25, 0.5, 0.25]]) + +gradient_l = np.array([[-1, 1, 0]]) +gradient_r = np.array([[0, -1, 1]]) +gradient_u = np.array([[-1], + [1], + [0]]) +gradient_d = np.array([[0], + [-1], + [1]]) + +laplace_hor = np.array([[-1, 2, -1]]) + +laplace_ver = np.array([[-1], + [2], + [-1]]) + + +def soft_min(laplace_map, base, direction): + """ + + :param laplace_map: the horizontal laplace map or vertical laplace map, shape = [vMax, uMax] + :param base: the base of the exponent operation + :param direction: 0 for horizontal, 1 for vertical + :return: weighted map (lambda 1,2 or 3,4) + """ + h, w = laplace_map.shape + eps = 1e-8 # to avoid division by zero + + lap_power = np.power(base, -laplace_map) + if direction == 0: # horizontal + lap_pow_l = np.hstack([np.zeros((h, 1)), lap_power[:, :-1]]) + lap_pow_r = np.hstack([lap_power[:, 1:], np.zeros((h, 1))]) + return (lap_pow_l + eps * 0.5) / (eps + lap_pow_l + lap_pow_r), \ + (lap_pow_r + eps * 0.5) / (eps + lap_pow_l + lap_pow_r) + + elif direction == 1: # vertical + lap_pow_u = np.vstack([np.zeros((1, w)), lap_power[:-1, :]]) + lap_pow_d = np.vstack([lap_power[1:, :], np.zeros((1, w))]) + return (lap_pow_u + eps / 2) / (eps + lap_pow_u + lap_pow_d), \ + (lap_pow_d + eps / 2) / (eps + lap_pow_u + lap_pow_d) + + +def get_filter(Z, cp2tv=False): + """get partial u, partial v""" + if cp2tv: + Gu = cv2.filter2D(Z, -1, cp2tv_Gx) + Gv = cv2.filter2D(Z, -1, cp2tv_Gy) + else: + Gu = cv2.filter2D(Z, -1, kernel_Gx) / 2 + Gv = cv2.filter2D(Z, -1, kernel_Gy) / 2 + return Gu, Gv + + +def get_DAG_filter(Z, base=np.e, lap_conf='1D-DLF'): + # calculate gradients along four directions + grad_l = cv2.filter2D(Z, -1, gradient_l) + grad_r = cv2.filter2D(Z, -1, gradient_r) + grad_u = cv2.filter2D(Z, -1, gradient_u) + grad_d = cv2.filter2D(Z, -1, gradient_d) + + # calculate laplace along 2 directions + if lap_conf == '1D-DLF': + lap_hor = abs(grad_l - grad_r) + lap_ver = abs(grad_u - grad_d) + elif lap_conf == 'DLF-alpha': + lap_hor = abs(cv2.filter2D(Z, -1, lap_ker_alpha)) + lap_ver = abs(cv2.filter2D(Z, -1, lap_ker_alpha)) + elif lap_conf == 'DLF-beta': + lap_hor = abs(cv2.filter2D(Z, -1, lap_ker_beta)) + lap_ver = abs(cv2.filter2D(Z, -1, lap_ker_beta)) + elif lap_conf == 'DLF-gamma': + lap_hor = abs(cv2.filter2D(Z, -1, lap_ker_gamma)) + lap_ver = abs(cv2.filter2D(Z, -1, lap_ker_gamma)) + else: + raise ValueError + + lambda_map1, lambda_map2 = soft_min(lap_hor, base, 0) + lambda_map3, lambda_map4 = soft_min(lap_ver, base, 1) + + eps = 1e-8 + thresh = base + lambda_map1[lambda_map1 / (lambda_map2 + eps) > thresh] = 1 + lambda_map2[lambda_map1 / (lambda_map2 + eps) > thresh] = 0 + lambda_map1[lambda_map2 / (lambda_map1 + eps) > thresh] = 0 + lambda_map2[lambda_map2 / (lambda_map1 + eps) > thresh] = 1 + + lambda_map3[lambda_map3 / (lambda_map4 + eps) > thresh] = 1 + lambda_map4[lambda_map3 / (lambda_map4 + eps) > thresh] = 0 + lambda_map3[lambda_map4 / (lambda_map3 + eps) > thresh] = 0 + lambda_map4[lambda_map4 / (lambda_map3 + eps) > thresh] = 1 + + # lambda_maps = [lambda_map1, lambda_map2, lambda_map3, lambda_map4] + Gu = lambda_map1 * grad_l + lambda_map2 * grad_r + Gv = lambda_map3 * grad_u + lambda_map4 * grad_d + return Gu, Gv + + +def MRF_optim(depth, n_est, lap_conf='DLF-alpha'): + h, w = depth.shape + n_x, n_y, n_z = n_est[:, :, 0], n_est[:, :, 1], n_est[:, :, 2] + # =====================optimize the normal with MRF============================= + if lap_conf == '1D-DLF': + Z_laplace_hor = abs(cv2.filter2D(depth, -1, laplace_hor)) + Z_laplace_ver = abs(cv2.filter2D(depth, -1, laplace_ver)) + + # [x-1,y] [x+1,y] [x,y-1] [x,y+1], [x,y] + Z_laplace_stack = np.array((np.hstack((np.inf * np.ones((h, 1)), Z_laplace_hor[:, :-1])), + np.hstack((Z_laplace_hor[:, 1:], np.inf * np.ones((h, 1)))), + np.vstack((np.inf * np.ones((1, w)), Z_laplace_ver[:-1, :])), + np.vstack((Z_laplace_ver[1:, :], np.inf * np.ones((1, w)))), + (Z_laplace_hor + Z_laplace_ver) / 2)) + else: + if lap_conf == 'DLF-alpha': + Z_laplace = abs(cv2.filter2D(depth, -1, lap_ker_alpha)) + elif lap_conf == 'DLF-beta': + Z_laplace = abs(cv2.filter2D(depth, -1, lap_ker_beta)) + elif lap_conf == 'DLF-gamma': + Z_laplace = abs(cv2.filter2D(depth, -1, lap_ker_gamma)) + else: + raise ValueError + Z_laplace_stack = np.array((np.hstack((np.inf * np.ones((h, 1)), Z_laplace[:, :-1])), + np.hstack((Z_laplace[:, 1:], np.inf * np.ones((h, 1)))), + np.vstack((np.inf * np.ones((1, w)), Z_laplace[:-1, :])), + np.vstack((Z_laplace[1:, :], np.inf * np.ones((1, w)))), + Z_laplace)) + + # best_loc_map: 0 for left, 1 for right, 2 for up, 3 for down, 4 for self + best_loc_map = np.argmin(Z_laplace_stack, axis=0) + Nx_t_stack = np.array((np.hstack((np.zeros((h, 1)), n_x[:, :-1])), + np.hstack((n_x[:, 1:], np.zeros((h, 1)))), + np.vstack((np.zeros((1, w)), n_x[:-1, :])), + np.vstack((n_x[1:, :], np.zeros((1, w)))), + n_x)).reshape(5, -1) + Ny_t_stack = np.array((np.hstack((np.zeros((h, 1)), n_y[:, :-1])), + np.hstack((n_y[:, 1:], np.zeros((h, 1)))), + np.vstack((np.zeros((1, w)), n_y[:-1, :])), + np.vstack((n_y[1:, :], np.zeros((1, w)))), + n_y)).reshape(5, -1) + Nz_t_stack = np.array((np.hstack((np.zeros((h, 1)), n_z[:, :-1])), + np.hstack((n_z[:, 1:], np.zeros((h, 1)))), + np.vstack((np.zeros((1, w)), n_z[:-1, :])), + np.vstack((n_z[1:, :], np.zeros((1, w)))), + n_z)).reshape(5, -1) + + n_x = Nx_t_stack[best_loc_map.reshape(-1), np.arange(h * w)].reshape(h, w) + n_y = Ny_t_stack[best_loc_map.reshape(-1), np.arange(h * w)].reshape(h, w) + n_z = Nz_t_stack[best_loc_map.reshape(-1), np.arange(h * w)].reshape(h, w) + n_est = cv2.merge((n_x, n_y, n_z)) + return n_est diff --git a/requirements.txt b/requirements.txt index 0434926..d56cb00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,7 @@ scipy setuptools omegaconf tabulate -pandas \ No newline at end of file +pandas +wandb +datasets +peft \ No newline at end of file diff --git a/training/dataloaders/load.py b/training/dataloaders/load.py new file mode 100644 index 0000000..3954c23 --- /dev/null +++ b/training/dataloaders/load.py @@ -0,0 +1,376 @@ +# @GonzaloMartinGarcia +# This file houses our dataset mixer and training dataset classes. + +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import os +from PIL import Image +import numpy as np +import random +import pandas as pd +import cv2 + +################# +# Dataset Mixer +################# + +class MixedDataLoader: + def __init__(self, loader1, loader2, split1=9, split2=1): + self.loader1 = loader1 + self.loader2 = loader2 + self.split1 = split1 + self.split2 = split2 + self.frac1, self.frac2 = self.get_split_fractions() + self.randchoice1=None + + def __iter__(self): + self.loader_iter1 = iter(self.loader1) + self.loader_iter2 = iter(self.loader2) + self.randchoice1 = self.create_split() + self.indx = 0 + return self + + def get_split_fractions(self): + size1 = len(self.loader1) + size2 = len(self.loader2) + effective_fraction1 = min((size2/size1) * (self.split1/self.split2), 1) + effective_fraction2 = min((size1/size2) * (self.split2/self.split1), 1) + print("Effective fraction for loader1: ", effective_fraction1) + print("Effective fraction for loader2: ", effective_fraction2) + return effective_fraction1, effective_fraction2 + + def create_split(self): + randchoice1 = [True]*int(len(self.loader1)*self.frac1) + [False]*int(len(self.loader2)*self.frac2) + np.random.shuffle(randchoice1) + return randchoice1 + + def __next__(self): + if self.indx == len(self.randchoice1): + raise StopIteration + if self.randchoice1[self.indx]: + self.indx += 1 + return next(self.loader_iter1) + else: + self.indx += 1 + return next(self.loader_iter2) + + def __len__(self): + return int(len(self.loader1)*self.frac1) + int(len(self.loader2)*self.frac2) + + +################# +# Transforms +################# + +# Hyperism +class SynchronizedTransform_Hyper: + def __init__(self, H, W): + self.resize = transforms.Resize((H,W)) + self.resize_depth = transforms.Resize((H,W), interpolation=Image.NEAREST) + self.horizontal_flip = transforms.RandomHorizontalFlip(p=1.0) + self.to_tensor = transforms.ToTensor() + + def __call__(self, rgb_image, depth_image, normal_image=None): + # h-flip + if random.random() > 0.5: + rgb_image = self.horizontal_flip(rgb_image) + depth_image = self.horizontal_flip(depth_image) + if normal_image is not None: + normal_image = self.horizontal_flip(normal_image) + # correct normals for horizontal flip + np_normal_image = np.array(normal_image) + np_normal_image[:, :, 0] = 255 - np_normal_image[:, :, 0] + normal_image = Image.fromarray(np_normal_image) + # resize + rgb_image = self.resize(rgb_image) + depth_image = self.resize_depth(depth_image) + if normal_image is not None: + normal_image = self.resize(normal_image) + # to tensor + rgb_tensor = self.to_tensor(rgb_image) + depth_tensor = self.to_tensor(depth_image) + if normal_image is not None: + normal_tensor = self.to_tensor(normal_image) + # retrun + if normal_image is not None: + return rgb_tensor, depth_tensor, normal_tensor + return rgb_tensor, depth_tensor + +# Virtual KITTI 2 +class SynchronizedTransform_VKITTI: + def __init__(self): + self.to_tensor = transforms.ToTensor() + self.horizontal_flip = transforms.RandomHorizontalFlip(p=1.0) + + # KITTI benchmark crop from Marigold: + # https://github.com/prs-eth/Marigold/blob/62413d56099d36573b2de1eb8c429839734b7782/src/dataset/kitti_dataset.py#L75 + @staticmethod + def kitti_benchmark_crop(input_img): + KB_CROP_HEIGHT = 352 + KB_CROP_WIDTH = 1216 + height, width = input_img.shape[-2:] + top_margin = int(height - KB_CROP_HEIGHT) + left_margin = int((width - KB_CROP_WIDTH) / 2) + if 2 == len(input_img.shape): + out = input_img[ + top_margin : top_margin + KB_CROP_HEIGHT, + left_margin : left_margin + KB_CROP_WIDTH, + ] + elif 3 == len(input_img.shape): + out = input_img[ + :, + top_margin : top_margin + KB_CROP_HEIGHT, + left_margin : left_margin + KB_CROP_WIDTH, + ] + return out + + def __call__(self, rgb_image, depth_image, normal_image=None): + # h-flip + if random.random() > 0.5: + rgb_image = self.horizontal_flip(rgb_image) + depth_image = self.horizontal_flip(depth_image) + if normal_image is not None: + normal_image = self.horizontal_flip(normal_image) + # correct normals for horizontal flip + np_normal_image = np.array(normal_image) + np_normal_image[:, :, 0] = 255 - np_normal_image[:, :, 0] + normal_image = Image.fromarray(np_normal_image) + # to tensor + rgb_tensor = self.to_tensor(rgb_image) + depth_tensor = self.to_tensor(depth_image) + if normal_image is not None: + normal_tensor = self.to_tensor(normal_image) + # kitti benchmark crop + rgb_tensor = self.kitti_benchmark_crop(rgb_tensor) + depth_tensor = self.kitti_benchmark_crop(depth_tensor) + if normal_image is not None: + normal_tensor = self.kitti_benchmark_crop(normal_tensor) + # return + if normal_image is not None: + return rgb_tensor, depth_tensor, normal_tensor + return rgb_tensor, depth_tensor + + +##################### +# Training Datasets +##################### + +# Hypersim +class Hypersim(Dataset): + def __init__(self, root_dir, transform=True, near_plane=1e-5, far_plane=65.0): + self.root_dir = root_dir + self.split_path = os.path.join("data/hypersim/processed/train/filename_meta_train.csv") + self.near_plane = near_plane + self.far_plane = far_plane + self.align_cam_normal = True + self.pairs = self._find_pairs() + self.transform = SynchronizedTransform_Hyper(H=480, W=640) if transform else None + + def _find_pairs(self): + df = pd.read_csv(self.split_path) + pairs = [] + for _, row in df.iterrows(): + if row['included_in_public_release'] and (row['split_partition_name'] == "train"): + rgb_path = os.path.join(self.root_dir, "train", row['rgb_path']) + depth_path = os.path.join(self.root_dir, "train", row['depth_path']) + head, _ = os.path.split(os.path.join(self.root_dir, "train")) + normal_dir = os.path.join(os.path.join(head, 'normals'), row['scene_name'], 'images', f'scene_{row["camera_name"]}_geometry_preview',f'frame.{str(row["frame_id"]).zfill(4) }.normal_cam.png') + if os.path.exists(rgb_path) and os.path.exists(depth_path) and os.path.exists(normal_dir): + pair_info = {'rgb_path': rgb_path, 'depth_path': depth_path, 'normal_path': normal_dir} + pairs.append(pair_info) + return pairs + + def __len__(self): + return len(self.pairs) + + # Some Hypersim normals are not properly oriented towards the camera. + # The align_normals and creat_uv_mesh functions are from GeoWizard + # https://github.com/fuxiao0719/GeoWizard/blob/5ff496579c6be35d9d86fe4d0760a6b5e6ba25c5/geowizard/training/dataloader/file_io.py#L115 + def align_normals(self, normal, depth, K, H, W): + ''' + Orientation of surface normals in hypersim is not always consistent + see https://github.com/apple/ml-hypersim/issues/26 + ''' + # inv K + K = np.array([[K[0], 0, K[2]], + [ 0, K[1], K[3]], + [ 0, 0, 1]]) + inv_K = np.linalg.inv(K) + # reprojection depth to camera points + xy = self.creat_uv_mesh(H, W) + points = np.matmul(inv_K[:3, :3], xy).reshape(3, H, W) + points = depth * points + points = points.transpose((1,2,0)) + # align normal + orient_mask = np.sum(normal * points, axis=2) > 0 + normal[orient_mask] *= -1 + return normal + + def creat_uv_mesh(self, H, W): + y, x = np.meshgrid(np.arange(0, H, dtype=np.float64), np.arange(0, W, dtype=np.float64), indexing='ij') + meshgrid = np.stack((x,y)) + ones = np.ones((1,H*W), dtype=np.float64) + xy = meshgrid.reshape(2, -1) + return np.concatenate([xy, ones], axis=0) + + def __getitem__(self, idx): + pairs = self.pairs[idx] + + # get RGB + rgb_path = pairs['rgb_path'] + rgb_image = Image.open(rgb_path).convert('RGB') + # get depth + depth_path = pairs['depth_path'] + depth_image = Image.open(depth_path) + depth_image = np.array(depth_image) + depth_image = depth_image / 1000 # mm to meters + depth_image = Image.fromarray(depth_image) + # get normals + normal_path = pairs['normal_path'] + normal_image = Image.open(normal_path).convert('RGB') + if self.align_cam_normal: + # align normals towards camera + normal_array = (np.array(normal_image) / 255.0) * 2.0 - 1.0 + H, W = normal_array.shape[:2] + normal_array[:,:,1:] *= -1 + normal_array = self.align_normals(normal_array, np.array(depth_image), [886.81,886.81,W/2, H/2], H, W) * -1 + normal_image = Image.fromarray(((normal_array + 1.0) / 2.0 * 255).astype(np.uint8)) + + # transfrom + if self.transform is not None: + rgb_tensor, depth_tensor, normal_tensor = self.transform(rgb_image, depth_image, normal_image) + else: + rgb_tensor = transforms.ToTensor()(rgb_image) + depth_tensor = transforms.ToTensor()(depth_image) + normal_tensor = transforms.ToTensor()(normal_image) + + # get valid depth mask + valid_depth_mask = (depth_tensor > self.near_plane) & (depth_tensor < self.far_plane) + + # Process RGB + rgb_tensor = rgb_tensor*2.0 - 1.0 # [-1,1] + + # Process depth + if valid_depth_mask.any(): + flat_depth = depth_tensor[valid_depth_mask].flatten().float() + min_depth = torch.quantile(flat_depth, 0.02) + max_depth = torch.quantile(flat_depth, 0.98) + if min_depth == max_depth: + depth_tensor = torch.zeros_like(depth_tensor) + metric_tensor = torch.zeros_like(depth_tensor) + valid_depth_mask = torch.zeros_like(depth_tensor).bool() # empty mask + else: + depth_tensor = torch.clamp(depth_tensor, min_depth, max_depth) # remove outliers + depth_tensor[~valid_depth_mask] = max_depth # set invalid depth to relative far plane + metric_tensor = depth_tensor.clone() # keep metric depth for e2e loss ft + depth_tensor = torch.clamp((((depth_tensor - min_depth) / (max_depth - min_depth))*2.0)-1.0, -1, 1) # [-1,1] + else: + depth_tensor = torch.zeros_like(depth_tensor) + metric_tensor = torch.zeros_like(depth_tensor) + depth_tensor = torch.stack([depth_tensor, depth_tensor, depth_tensor]).squeeze() # stack depth map for VAE encoder + + # Process normals + normal_tensor = normal_tensor * 2.0 - 1.0 # [-1,1] + normal_tensor = torch.nn.functional.normalize(normal_tensor, p=2, dim=0) # normalize + # set invalid pixels to the zero vector (color grey) + normal_tensor[0,~valid_depth_mask.squeeze()] = 0 + normal_tensor[1,~valid_depth_mask.squeeze()] = 0 + normal_tensor[2,~valid_depth_mask.squeeze()] = 0 + + return {"rgb": rgb_tensor, "depth": depth_tensor, 'metric': metric_tensor, 'normals': normal_tensor, "val_mask": valid_depth_mask, "domain": "indoor"} + + +# Virtual KITTI 2.0 +class VirtualKITTI2(Dataset): + def __init__(self, root_dir, transform=None, near_plane=1e-5, far_plane=80.0): + self.root_dir = root_dir + self.near_plane = near_plane + self.far_plane = far_plane + self.pairs = self._find_pairs() + self.transform = SynchronizedTransform_VKITTI() if transform else None + + def _find_pairs(self): + scenes = ["Scene01", "Scene02", "Scene06", "Scene18", "Scene20"] + weather_conditions = ["morning", "fog", "rain", "sunset", "overcast"] + cameras = ["Camera_0", "Camera_1"] + vkitti2_rgb_path = os.path.join(self.root_dir, "vkitti_2.0.3_rgb") + vkitti2_depth_path = os.path.join(self.root_dir, "vkitti_2.0.3_depth") + vkitti2_normal_path = os.path.join(self.root_dir, "vkitti_DAG_normals") + pairs = [] + for scene in scenes: + for weather in weather_conditions: + for camera in cameras: + rgb_dir = os.path.join(vkitti2_rgb_path, scene, weather, "frames", "rgb" ,camera) + depth_dir = os.path.join(vkitti2_depth_path, scene, weather, "frames","depth" , camera) + normal_dir = os.path.join(vkitti2_normal_path, scene, weather, "frames", "normal", camera) + if os.path.exists(rgb_dir) and os.path.exists(depth_dir): + rgb_files = [f for f in os.listdir(rgb_dir) if f.endswith(".jpg")] + rgb_files = [file[3:] for file in rgb_files] + for file in rgb_files: + rgb_file = "rgb" + file + depth_file = "depth" + file.replace('.jpg', '.png') + normal_file = "normal" + file.replace('.jpg', '.png') + rgb_path = os.path.join(rgb_dir, rgb_file) + depth_path = os.path.join(depth_dir, depth_file) + normal_path = os.path.join(normal_dir, normal_file) + pairs.append((rgb_path, depth_path, normal_path)) + return pairs + + def __len__(self): + return len(self.pairs) + + def __getitem__(self, idx): + rgb_path, depth_path, normal_path = self.pairs[idx] + + # get RGB + rgb_image = Image.open(rgb_path).convert('RGB') + # get depth + depth_image = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + depth_image = depth_image.astype(np.float32)/100.0 # cm to meters + depth_image = Image.fromarray(depth_image) # PIL + # get normals + normal_image = Image.open(normal_path).convert('RGB') + + # transform + if self.transform is not None: + rgb_tensor, depth_tensor, normal_tensor = self.transform(rgb_image, depth_image, normal_image) + else: + rgb_tensor = transforms.ToTensor()(rgb_image) + depth_tensor = transforms.ToTensor()(depth_image) + normal_tensor = transforms.ToTensor()(normal_image) + + # get valid depth mask + valid_depth_mask = (depth_tensor > self.near_plane) & (depth_tensor < self.far_plane) + + # Process RGB + rgb_tensor = rgb_tensor*2.0 - 1.0 # [-1,1] + + # Process depth + if valid_depth_mask.any(): + flat_depth = depth_tensor[valid_depth_mask].flatten().float() + min_depth = torch.quantile(flat_depth, 0.02) + max_depth = torch.quantile(flat_depth, 0.98) + if min_depth == max_depth: + depth_tensor = torch.zeros_like(depth_tensor) + metric_tensor = torch.zeros_like(depth_tensor) + valid_depth_mask = torch.zeros_like(depth_tensor).bool() # empty mask + else: + depth_tensor = torch.clamp(depth_tensor, min_depth, max_depth) # remove outliers + depth_tensor[~valid_depth_mask] = max_depth # set invalid depth to relative far plane + metric_tensor = depth_tensor.clone() # keep metric depth for e2e loss ft + depth_tensor = torch.clamp((((depth_tensor - min_depth) / (max_depth - min_depth))*2.0)-1.0, -1, 1) # [-1,1] + else: + depth_tensor = torch.zeros_like(depth_tensor) + metric_tensor = torch.zeros_like(depth_tensor) + depth_tensor = torch.stack([depth_tensor, depth_tensor, depth_tensor]).squeeze() # stack depth map for VAE encoder + + # Process normals + normal_tensor = normal_tensor * 2.0 - 1.0 # [-1,1] + normal_tensor = torch.nn.functional.normalize(normal_tensor, p=2, dim=0) # normalize + # set invalid pixels to the zero vector (color grey) + normal_tensor[0,~valid_depth_mask.squeeze()] = 0 + normal_tensor[1,~valid_depth_mask.squeeze()] = 0 + normal_tensor[2,~valid_depth_mask.squeeze()] = 0 + + return {"rgb": rgb_tensor, "depth": depth_tensor, 'metric': metric_tensor, 'normals': normal_tensor, "val_mask": valid_depth_mask, "domain": "outdoor"} \ No newline at end of file diff --git a/training/scripts/multi_gpu.yaml b/training/scripts/multi_gpu.yaml new file mode 100644 index 0000000..a30d32e --- /dev/null +++ b/training/scripts/multi_gpu.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 4 # Number of GPUs +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/training/scripts/train_geowizard_e2e_ft.sh b/training/scripts/train_geowizard_e2e_ft.sh new file mode 100755 index 0000000..28a9cbc --- /dev/null +++ b/training/scripts/train_geowizard_e2e_ft.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +accelerate launch GeoWizard/geowizard/training/train_depth_normal.py \ + --pretrained_model_name_or_path "lemonaddie/geowizard" \ + --e2e_ft \ + --noise_type="zeros" \ + --max_train_steps 20000 \ + --checkpointing_steps 20000 \ + --train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing \ + --learning_rate 3e-5 \ + --lr_total_iter_length 20000 \ + --lr_warmup_steps 100 \ + --mixed_precision="no" \ + --output_dir "model-finetuned/geowizard_e2e_ft" \ + --enable_xformers_memory_efficient_attention \ + "$@" \ No newline at end of file diff --git a/training/scripts/train_marigold_e2e_ft_depth.sh b/training/scripts/train_marigold_e2e_ft_depth.sh new file mode 100755 index 0000000..f5b1eef --- /dev/null +++ b/training/scripts/train_marigold_e2e_ft_depth.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +accelerate launch training/train.py \ + --pretrained_model_name_or_path "prs-eth/marigold-v1-0" \ + --modality "depth" \ + --noise_type "zeros" \ + --max_train_steps 20000 \ + --checkpointing_steps 20000 \ + --train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing \ + --learning_rate 3e-05 \ + --lr_total_iter_length 20000 \ + --lr_exp_warmup_steps 100 \ + --mixed_precision "no" \ + --output_dir "model-finetuned/marigold_e2e_ft_depth" \ + --enable_xformers_memory_efficient_attention \ + "$@" \ No newline at end of file diff --git a/training/scripts/train_marigold_e2e_ft_normals.sh b/training/scripts/train_marigold_e2e_ft_normals.sh new file mode 100755 index 0000000..042c964 --- /dev/null +++ b/training/scripts/train_marigold_e2e_ft_normals.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +accelerate launch training/train.py \ + --pretrained_model_name_or_path "GonzaloMG/marigold-normals" \ + --modality "normals" \ + --noise_type "zeros" \ + --max_train_steps 20000 \ + --checkpointing_steps 20000 \ + --train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing \ + --learning_rate 3e-05 \ + --lr_total_iter_length 20000 \ + --lr_exp_warmup_steps 100 \ + --mixed_precision "no" \ + --output_dir "model-finetuned/marigold_e2e_ft_normals" \ + --enable_xformers_memory_efficient_attention \ + "$@" \ No newline at end of file diff --git a/training/scripts/train_stable_diffusion_e2e_ft_depth.sh b/training/scripts/train_stable_diffusion_e2e_ft_depth.sh new file mode 100755 index 0000000..5a7e49f --- /dev/null +++ b/training/scripts/train_stable_diffusion_e2e_ft_depth.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +accelerate launch training/train.py \ + --pretrained_model_name_or_path "stabilityai/stable-diffusion-2" \ + --modality "depth" \ + --noise_type "zeros" \ + --max_train_steps 20000 \ + --checkpointing_steps 20000 \ + --train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing \ + --learning_rate 3e-05 \ + --lr_total_iter_length 20000 \ + --lr_exp_warmup_steps 100 \ + --mixed_precision "no" \ + --output_dir "model-finetuned/stable_diffusion_e2e_ft_depth" \ + --enable_xformers_memory_efficient_attention \ + "$@" \ No newline at end of file diff --git a/training/scripts/train_stable_diffusion_e2e_ft_normals.sh b/training/scripts/train_stable_diffusion_e2e_ft_normals.sh new file mode 100755 index 0000000..2f95686 --- /dev/null +++ b/training/scripts/train_stable_diffusion_e2e_ft_normals.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +accelerate launch training/train.py \ + --pretrained_model_name_or_path "stabilityai/stable-diffusion-2" \ + --modality "normals" \ + --noise_type "zeros" \ + --max_train_steps 20000 \ + --checkpointing_steps 20000 \ + --train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing \ + --learning_rate 3e-05 \ + --lr_total_iter_length 20000 \ + --lr_exp_warmup_steps 100 \ + --mixed_precision "no" \ + --output_dir "model-finetuned/stable_diffusion_e2e_ft_normals" \ + --enable_xformers_memory_efficient_attention \ + "$@" \ No newline at end of file diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000..ceaaff5 --- /dev/null +++ b/training/train.py @@ -0,0 +1,644 @@ +# @GonzaloMartinGarcia +# Training code for 'Fine-Tuning Image-Conditional Diffusion Models is Easier than You Think'. +# This training code is a modified version of the original text-to-image SD training code from the HuggingFace Inc. Team, +# https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py. + +import argparse +import logging +import math +import os +import shutil + +import accelerate +import datasets +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from packaging import version +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from torch.optim.lr_scheduler import LambdaLR +from dataloaders.load import * +from util.noise import pyramid_noise_like +from util.loss import ScaleAndShiftInvariantLoss, AngularLoss +from util.unet_prep import replace_unet_conv_in +from util.lr_scheduler import IterExponential + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.27.0.dev0") +logger = get_logger(__name__, log_level="INFO") + +############# +# Arguments +############# + +def parse_args(): + parser = argparse.ArgumentParser(description="Training code for 'Fine-Tuning Image-Conditional Diffusion Models is Easier than You Think'.") + # Our settings: + parser.add_argument( + "--modality", + type=str, + choices=["depth", "normals"], + required=True, + ) + parser.add_argument( + "--noise_type", + type=str, + default=None, # If left as None, Stable Diffusion checkpoints can be trained without altering the input channels (i.e., only 4 input channels for the RGB input). + choices=["zeros", "gaussian", "pyramid"], + help="If left as None, Stable Diffusion checkpoints can be trained without altering the input channels (i.e., only 4 input channels for RGB input)." + ) + parser.add_argument( + "--lr_exp_warmup_steps", + type=int, + default=100, + ) + parser.add_argument( + "--lr_total_iter_length", + type=int, + default=20000, + ) + # Stable diffusion training settings + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--output_dir", + type=str, + default="training/model-finetuned", + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training." + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=2, + help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=15, + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + required=True, + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=16, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank" + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=20000, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers." + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="e2e-ft-diffusion", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + +######################## +# VAE Helper Functions +######################## + +# Apply VAE Encoder to image +def encode_image(vae, image): + h = vae.encoder(image) + moments = vae.quant_conv(h) + latent, _ = torch.chunk(moments, 2, dim=1) + return latent + +# Apply VAE Decoder to latent +def decode_image(vae, latent): + z = vae.post_quant_conv(latent) + image = vae.decoder(z) + return image + +########################## +# MAIN Training Function +########################## + +def main(): + args = parse_args() + + # Init accelerator and logger + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # Set seed + if args.seed is not None: + set_seed(args.seed) + + # Save training arguments in a .txt file + if accelerator.is_main_process: + args_dict = vars(args) + args_str = '\n'.join(f"{key}: {value}" for key, value in args_dict.items()) + args_path = os.path.join(args.output_dir, "arguments.txt") + os.makedirs(args.output_dir, exist_ok=True) + with open(args_path, 'w') as file: + file.write(args_str) + if args.noise_type is None: + logger.warning("Noise type is `None`. This setting is only meant for checkpoints without image conditioning, such as Stable Diffusion.") + + # Load model components + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision) + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant) + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=None) + if args.noise_type is not None: + # Double UNet input layers if necessary + if unet.config['in_channels'] != 8: + replace_unet_conv_in(unet, repeat=2) + logger.info("Unet conv_in layer is replaced for RGB-depth or RGB-normals input") + + # Freeze or set model components to training mode + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.train() + + # Use xformers for efficient attention + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Diffusers model loading and saving functions + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + def load_model_hook(models, input_dir): + for _ in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + model.load_state_dict(load_model.state_dict()) + del load_model + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Gradient checkpointing + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Optimizer + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Learning rate scheduler + lr_func = IterExponential(total_iter_length = args.lr_total_iter_length*accelerator.num_processes, final_ratio = 0.01, warmup_steps = args.lr_exp_warmup_steps*accelerator.num_processes) + lr_scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lr_func) + + # Training datasets + hypersim_root_dir = "data/hypersim/processed" + vkitti_root_dir = "data/virtual_kitti_2" + train_dataset_hypersim = Hypersim(root_dir=hypersim_root_dir, transform=True) + train_dataset_vkitti = VirtualKITTI2(root_dir=vkitti_root_dir, transform=True) + train_dataloader_vkitti = torch.utils.data.DataLoader(train_dataset_vkitti, shuffle=True, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers) + train_dataloader_hypersim = torch.utils.data.DataLoader(train_dataset_hypersim, shuffle=True, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers) + train_dataloader = MixedDataLoader(train_dataloader_hypersim, train_dataloader_vkitti, split1=9, split2=1) + + # Prepare everything with `accelerator` (Move to GPU) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # Mixed precision and weight dtype + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + unet.to(dtype=weight_dtype) + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + unet.to(dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # Calculate number of training steps and epochs + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset_vkitti)+len(train_dataset_hypersim)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Resume training from checkpoint + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + # Progress bar + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process,) + + # Init task specific losses + ssi_loss = ScaleAndShiftInvariantLoss() + angular_loss_norm = AngularLoss() + + # Pre-compute empty text CLIP encoding + empty_token = tokenizer([""], padding="max_length", truncation=True, return_tensors="pt").input_ids + empty_token = empty_token.to(accelerator.device) + empty_encoding = text_encoder(empty_token, return_dict=False)[0] + empty_encoding = empty_encoding.to(accelerator.device) + + # Get noise scheduling parameters for later conversion from a parameterized prediction into latent. + alpha_prod = noise_scheduler.alphas_cumprod.to(accelerator.device, dtype=weight_dtype) + beta_prod = 1 - alpha_prod + + # Training Loop + for epoch in range(first_epoch, args.num_train_epochs): + logger.info(f"At Epoch {epoch}:") + train_loss = 0.0 + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + + # RGB latent + rgb_latents = encode_image(vae, batch["rgb"].to(device=accelerator.device, dtype=weight_dtype)) + rgb_latents = rgb_latents * vae.config.scaling_factor + + # Depth or normals latent + if args.modality == "depth": + latents = encode_image(vae, batch["depth"].to(device=accelerator.device, dtype=weight_dtype)) + elif args.modality == "normals": + latents = encode_image(vae, batch["normals"].to(device=accelerator.device, dtype=weight_dtype)) + latents = latents * vae.config.scaling_factor + + # Validity mask + val_mask = batch["val_mask"].bool().to(device=accelerator.device) + + # Set timesteps to the first denoising step + timesteps = torch.ones((latents.shape[0],), device=latents.device) * (noise_scheduler.config.num_train_timesteps-1) # 999 + timesteps = timesteps.long() + + # Sample noisy latent + if (args.noise_type is None) or (args.noise_type == "zeros"): + noisy_latents = torch.zeros_like(latents).to(accelerator.device) + elif args.noise_type == "pyramid": + noisy_latents = pyramid_noise_like(latents).to(accelerator.device) + elif args.noise_type == "gaussian": + noisy_latents = torch.randn_like(latents).to(accelerator.device) + else: + raise ValueError(f"Unknown noise type {args.noise_type}") + + # Generate UNet prediction + encoder_hidden_states = empty_encoding.repeat(len(batch["rgb"]), 1, 1) + unet_input = ( + torch.cat((rgb_latents, noisy_latents), dim=1).to(accelerator.device) + if args.noise_type is not None + else rgb_latents + ) + model_pred = unet(unet_input, timesteps, encoder_hidden_states, return_dict=False)[0] + + # End-to-end fine-tuning + loss = torch.tensor(0.0, device=accelerator.device, requires_grad=True) + if val_mask.any(): + + # Convert parameterized prediction into latent prediction. + # Code is based on the DDIM code from diffusers, + # https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py. + alpha_prod_t = alpha_prod[timesteps].view(-1, 1, 1, 1) + beta_prod_t = beta_prod[timesteps].view(-1, 1, 1, 1) + if noise_scheduler.config.prediction_type == "v_prediction": + current_latent_estimate = (alpha_prod_t**0.5) * noisy_latents - (beta_prod_t**0.5) * model_pred + elif noise_scheduler.config.prediction_type == "epsilon": + current_latent_estimate = (noisy_latents - beta_prod_t ** (0.5) * model_pred) / alpha_prod_t ** (0.5) + elif noise_scheduler.config.prediction_type == "sample": + current_latent_estimate = model_pred + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + # clip and threshold prediction (only here for completeness, not used by SD2 or our models with v_prediction) + if noise_scheduler.config.thresholding: + pred_original_sample = noise_scheduler._threshold_sample(pred_original_sample) + elif noise_scheduler.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -noise_scheduler.config.clip_sample_range, noise_scheduler.config.clip_sample_range + ) + + # Decode latent prediction + current_latent_estimate = current_latent_estimate / vae.config.scaling_factor + current_estimate = decode_image(vae, current_latent_estimate) + + # Post-process predicted images and retrieve ground truth + if args.modality == "depth": + current_estimate = current_estimate.mean(dim=1, keepdim=True) + current_estimate = torch.clamp(current_estimate,-1,1) + ground_truth = batch["metric"].to(device=accelerator.device, dtype=weight_dtype) + elif args.modality == "normals": + norm = torch.norm(current_estimate, p=2, dim=1, keepdim=True) + 1e-5 + current_estimate = current_estimate / norm + current_estimate = torch.clamp(current_estimate,-1,1) + ground_truth = batch["normals"].to(device=accelerator.device, dtype=weight_dtype) + else: + raise ValueError(f"Unknown modality {args.modality}") + + # Compute task-specific loss + estimation_loss = 0 + if args.modality == "depth": + estimation_loss_ssi = ssi_loss(current_estimate, ground_truth, val_mask) + if not torch.isnan(estimation_loss_ssi).any(): + estimation_loss = estimation_loss + estimation_loss_ssi + elif args.modality == "normals": + estimation_loss_ang_norm = angular_loss_norm(current_estimate, ground_truth, val_mask) + if not torch.isnan(estimation_loss_ang_norm).any(): + estimation_loss = estimation_loss + estimation_loss_ang_norm + else: + raise ValueError(f"Unknown modality {args.modality}") + loss = loss + estimation_loss + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + accelerator.log({"lr": lr_scheduler.get_last_lr()[0]}, step=global_step) + train_loss = 0.0 + # Save model checkpoint + if global_step % args.checkpointing_steps == 0: + logger.info(f"Entered Saving Code at global step {global_step} checkpointing_steps {args.checkpointing_steps}") + if accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + # Log loss and learning rate for progress bar + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + # Break training + if global_step >= args.max_train_steps: + break + + # Create SD pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + timestep_spacing="trailing", # set scheduler timestep spacing to trailing for later inference. + revision=args.revision, + variant=args.variant + ) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + ) + logger.info(f"Saving pipeline to {args.output_dir}") + pipeline.save_pretrained(args.output_dir) + + logger.info(f"Finished training.") + + accelerator.end_training() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/training/util/__init__.py b/training/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/util/loss.py b/training/util/loss.py new file mode 100644 index 0000000..0ad977c --- /dev/null +++ b/training/util/loss.py @@ -0,0 +1,67 @@ +# @ GonzaloMartinGarcia +# Task specific loss functions are from Depth Anything https://github.com/LiheYoung/Depth-Anything. +# Modifications have been made to improve numerical stability for this project (marked by '# add'). + +import torch +import torch.nn as nn + +######### +# Losses +######### + +# Scale and Shift Invariant Loss +class ScaleAndShiftInvariantLoss(nn.Module): + def __init__(self): + super().__init__() + self.name = "SSILoss" + def forward(self, prediction, target, mask): + if mask.ndim == 4: + mask = mask.squeeze(1) + prediction, target = prediction.squeeze(1), target.squeeze(1) + # add + with torch.autocast(device_type='cuda', enabled=False): + prediction = prediction.float() + target = target.float() + + scale, shift = compute_scale_and_shift_masked(prediction, target, mask) + scaled_prediction = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) + loss = nn.functional.l1_loss(scaled_prediction[mask], target[mask]) + return loss + +def compute_scale_and_shift_masked(prediction, target, mask): + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = torch.sum(mask * prediction * prediction, (1, 2)) + a_01 = torch.sum(mask * prediction, (1, 2)) + a_11 = torch.sum(mask, (1, 2)) + # right hand side: b = [b_0, b_1] + b_0 = torch.sum(mask * prediction * target, (1, 2)) + b_1 = torch.sum(mask * target, (1, 2)) + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = torch.zeros_like(b_0) + x_1 = torch.zeros_like(b_1) + det = a_00 * a_11 - a_01 * a_01 + # A needs to be a positive definite matrix. + valid = det > 0 #1e-3 + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + return x_0, x_1 + + +# Angluar Loss +class AngularLoss(nn.Module): + def __init__(self): + super(AngularLoss, self).__init__() + self.name = "Angular" + + def forward(self, prediction, target, mask=None): + with torch.autocast(device_type='cuda', enabled=False): + prediction = prediction.float() + target = target.float() + mask = mask[:,0,:,:] + dot_product = torch.sum(prediction * target, dim=1) + dot_product = torch.clamp(dot_product, -1.0, 1.0) + angle = torch.acos(dot_product) + if mask is not None: + angle = angle[mask] + loss = angle.mean() + return loss \ No newline at end of file diff --git a/training/util/lr_scheduler.py b/training/util/lr_scheduler.py new file mode 100644 index 0000000..32e71d7 --- /dev/null +++ b/training/util/lr_scheduler.py @@ -0,0 +1,36 @@ +# @GonzaloMartinGarcia +# This file contains Marigold's exponential LR scheduler. +# https://github.com/prs-eth/Marigold/blob/main/src/util/lr_scheduler.py + +# Author: Bingxin Ke +# Last modified: 2024-02-22 + +import numpy as np + +class IterExponential: + + def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None: + """ + Customized iteration-wise exponential scheduler. + Re-calculate for every step, to reduce error accumulation + + Args: + total_iter_length (int): Expected total iteration number + final_ratio (float): Expected LR ratio at n_iter = total_iter_length + """ + self.total_length = total_iter_length + self.effective_length = total_iter_length - warmup_steps + self.final_ratio = final_ratio + self.warmup_steps = warmup_steps + + def __call__(self, n_iter) -> float: + if n_iter < self.warmup_steps: + alpha = 1.0 * n_iter / self.warmup_steps + elif n_iter >= self.total_length: + alpha = self.final_ratio + else: + actual_iter = n_iter - self.warmup_steps + alpha = np.exp( + actual_iter / self.effective_length * np.log(self.final_ratio) + ) + return alpha \ No newline at end of file diff --git a/training/util/noise.py b/training/util/noise.py new file mode 100644 index 0000000..075fc27 --- /dev/null +++ b/training/util/noise.py @@ -0,0 +1,18 @@ +# @GonzaloMartinGarcia + +import torch +import random + +# Multiresolution nosie from +# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31. +def pyramid_noise_like(x, discount=0.9): + b, c, w, h = x.shape + u = torch.nn.Upsample(size=(w, h), mode='bilinear') + noise = torch.randn_like(x) + for i in range(10): + r = random.random()*2+2 + w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) + noise += u(torch.randn(b, c, w, h).to(x)) * discount**i + if w==1 or h==1: + break + return noise / noise.std() \ No newline at end of file diff --git a/training/util/unet_prep.py b/training/util/unet_prep.py new file mode 100644 index 0000000..d2d9546 --- /dev/null +++ b/training/util/unet_prep.py @@ -0,0 +1,21 @@ +# @GonzaloMartinGarcia + +from torch.nn import Conv2d, Parameter + +# Function is based on an early commit of the Marigold GitHub repository. +def replace_unet_conv_in(unet, repeat=2): + _weight = unet.conv_in.weight.clone() + _bias = unet.conv_in.bias.clone() + _weight = _weight.repeat((1, repeat, 1, 1)) + # scale the activation magnitude + _weight /= repeat + _bias /= repeat + # new conv_in channel + _n_convin_out_channel = unet.conv_in.out_channels + _new_conv_in = Conv2d(4*repeat, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + unet.conv_in = _new_conv_in + # replace config + unet.config['in_channels'] = 4*repeat + return \ No newline at end of file