diff --git a/.gitignore b/.gitignore index 70ead23a..b8fd8616 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Byte-compiled / optimized / DLL files output.jpg +imagenet-vqgan-training +wandb __pycache__/ *.py[cod] *$py.class diff --git a/configs/imagenet_vqgan_training.yaml b/configs/imagenet_vqgan_training.yaml new file mode 100644 index 00000000..0663536c --- /dev/null +++ b/configs/imagenet_vqgan_training.yaml @@ -0,0 +1,95 @@ +wandb: + entity: null + +experiment: + project: "muse" + name: "imagenet-vqgan-training" + output_dir: "imagenet-vqgan-training" + max_train_examples: 1281167 # total number of imagenet examples + max_eval_examples: 12800 + save_every: 1000 + eval_every: 1000 + generate_every: 1000 + log_every: 30 + log_grad_norm_every: 500 + resume_from_checkpoint: False + resume_lr_scheduler: True + +model: + vq_model: + type: "taming_vqgan" + pretrained: "openMUSE/vqgan-f16-8192-laion" + gradient_checkpointing: True + enable_xformers_memory_efficient_attention: True + + +dataset: + type: "classification" + params: + train_shards_path_or_url: "pipe:aws s3 cp s3://s-laion/muse-imagenet/imagenet-train-{000000..000320}.tar -" + eval_shards_path_or_url: "pipe:aws s3 cp s3://s-laion/muse-imagenet/imagenet-val-{000000..000012}.tar -" + imagenet_class_mapping_path: "/fsx/Isamu/data/imagenet-class-mapping.json" + dataset.params.validation_prompts_file: null + batch_size: ${training.batch_size} + shuffle_buffer_size: 1000 + num_workers: 4 + resolution: 256 + pin_memory: True + persistent_workers: True + preprocessing: + max_seq_length: 16 + resolution: 256 + center_crop: True + random_flip: False +discriminator: + dim: 64 + channels: 3 + groups: 32 + init_kernel_size: 5 + kernel_size: 3 + act: "silu" + discr_layers: 4 +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + discr_learning_rate: 1e-4 + + +lr_scheduler: + scheduler: "constant_with_warmup" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 1000 + + +training: + gradient_accumulation_steps: 2 + batch_size: 16 + mixed_precision: "bf16" + enable_tf32: True + use_ema: False + seed: 9345104 + max_train_steps: 200000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: null + guidance_scale: 2.0 + generation_timesteps: 8 + # related to vae code sampling + use_soft_code_target: False + use_stochastic_code: False + soft_code_temp: 1.0 + timm_discriminator_backend: "vgg19" + timm_disc_layers: "features|pre_logits|head" + timm_discr_offset: 0 + vae_loss: "l2" + num_validation_log: 4 + discriminator_warmup: 10000 diff --git a/configs/imagenet_vqgan_training_jewels.yaml b/configs/imagenet_vqgan_training_jewels.yaml new file mode 100644 index 00000000..59868769 --- /dev/null +++ b/configs/imagenet_vqgan_training_jewels.yaml @@ -0,0 +1,95 @@ +wandb: + entity: null + +experiment: + project: "muse" + name: "imagenet-vqgan-training" + output_dir: "imagenet-vqgan-training" + max_train_examples: 1281167 # total number of imagenet examples + max_eval_examples: 12800 + save_every: 1000 + eval_every: 1000 + generate_every: 1000 + log_every: 30 + log_grad_norm_every: 500 + resume_from_checkpoint: False + resume_lr_scheduler: True + +model: + vq_model: + type: "taming_vqgan" + pretrained: "openMUSE/vqgan-f16-8192-laion" + gradient_checkpointing: True + enable_xformers_memory_efficient_attention: True + + +dataset: + type: "classification" + params: + train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar" + eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar" + imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json" + dataset.params.validation_prompts_file: null + batch_size: ${training.batch_size} + shuffle_buffer_size: 1000 + num_workers: 4 + resolution: 256 + pin_memory: True + persistent_workers: True + preprocessing: + max_seq_length: 16 + resolution: 256 + center_crop: True + random_flip: False +discriminator: + dim: 64 + channels: 3 + groups: 32 + init_kernel_size: 5 + kernel_size: 3 + act: "silu" + discr_layers: 4 +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + discr_learning_rate: 1e-4 + + +lr_scheduler: + scheduler: "constant_with_warmup" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 1000 + + +training: + gradient_accumulation_steps: 2 + batch_size: 16 + mixed_precision: "bf16" + enable_tf32: True + use_ema: False + seed: 9345104 + max_train_steps: 200000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: null + guidance_scale: 2.0 + generation_timesteps: 8 + # related to vae code sampling + use_soft_code_target: False + use_stochastic_code: False + soft_code_temp: 1.0 + timm_discriminator_backend: "vgg19" + timm_disc_layers: "features|pre_logits|head" + timm_discr_offset: 0 + vae_loss: "l2" + num_validation_log: 4 + discriminator_warmup: 10000 diff --git a/configs/imagenet_vqgan_training_jewels_f16_vqgan.yaml b/configs/imagenet_vqgan_training_jewels_f16_vqgan.yaml new file mode 100644 index 00000000..b481acd1 --- /dev/null +++ b/configs/imagenet_vqgan_training_jewels_f16_vqgan.yaml @@ -0,0 +1,95 @@ +wandb: + entity: null + +experiment: + project: "muse" + name: "imagenet-vqgan-training" + output_dir: "imagenet-vqgan-training" + max_train_examples: 1281167 # total number of imagenet examples + max_eval_examples: 12800 + save_every: 1000 + eval_every: 1000 + generate_every: 1000 + log_every: 30 + log_grad_norm_every: 500 + resume_from_checkpoint: False + resume_lr_scheduler: True + +model: + vq_model: + type: "taming_vqgan" + pretrained: "vqgan-f16-8192-laion-movq" + gradient_checkpointing: True + enable_xformers_memory_efficient_attention: True + + +dataset: + type: "classification" + params: + train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar" + eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar" + imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json" + dataset.params.validation_prompts_file: null + batch_size: ${training.batch_size} + shuffle_buffer_size: 1000 + num_workers: 4 + resolution: 256 + pin_memory: True + persistent_workers: True + preprocessing: + max_seq_length: 16 + resolution: 256 + center_crop: True + random_flip: False +discriminator: + dim: 64 + channels: 3 + groups: 32 + init_kernel_size: 5 + kernel_size: 3 + act: "silu" + discr_layers: 4 +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + discr_learning_rate: 1e-4 + + +lr_scheduler: + scheduler: "constant_with_warmup" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 1000 + + +training: + gradient_accumulation_steps: 2 + batch_size: 16 + mixed_precision: "bf16" + enable_tf32: True + use_ema: False + seed: 9345104 + max_train_steps: 200000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.0 + max_grad_norm: null + guidance_scale: 2.0 + generation_timesteps: 8 + # related to vae code sampling + use_soft_code_target: False + use_stochastic_code: False + soft_code_temp: 1.0 + timm_discriminator_backend: "vgg19" + timm_disc_layers: "features|pre_logits|head" + timm_discr_offset: 0 + vae_loss: "l2" + num_validation_log: 4 + discriminator_warmup: 10000 diff --git a/slurm_scripts/imagenet_vqgan.slurm b/slurm_scripts/imagenet_vqgan.slurm new file mode 100644 index 00000000..f636af9d --- /dev/null +++ b/slurm_scripts/imagenet_vqgan.slurm @@ -0,0 +1,83 @@ +#!/bin/bash +#SBATCH --job-name=vqgan_testing +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --partition=g40 +#SBATCH --output=/fsx/Isamu/logs/maskgit-imagenet/%x-%j.out + +set -x -e + +echo "START TIME: $(date)" + +MUSE_REPO=/fsx/Isamu/open-muse +OUTPUT_DIR=/fsx/Isamu +LOG_PATH=$OUTPUT_DIR/main_log.txt + +mkdir -p $OUTPUT_DIR +touch $LOG_PATH +pushd $MUSE_REPO + +CMD=" \ + training/train_vqgan.py config=configs/imagenet_vqgan_training.yaml \ + wandb.entity=isamu \ + experiment.name=$(basename $OUTPUT_DIR) \ + experiment.output_dir=$OUTPUT_DIR \ + training.seed=9345104 \ + training.batch_size=160 \ + " + +GPUS_PER_NODE=8 +NNODES=$SLURM_NNODES + +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" diff --git a/training/deprecate_utils.py b/training/deprecate_utils.py new file mode 100644 index 00000000..6bdda664 --- /dev/null +++ b/training/deprecate_utils.py @@ -0,0 +1,49 @@ +import inspect +import warnings +from typing import Any, Dict, Optional, Union + +from packaging import version + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): + from .. import __version__ + + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError( + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" + ) + + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, FutureWarning, stacklevel=2) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values diff --git a/training/discriminator.py b/training/discriminator.py new file mode 100644 index 00000000..b972a555 --- /dev/null +++ b/training/discriminator.py @@ -0,0 +1,75 @@ +""" +Ported from lucidrian's muse maksgit repository with some qol changes +""" +from torch import nn + +def leaky_relu(p=0.1): + return nn.LeakyReLU(0.1) + +def get_activation(name): + if name == "leaky_relu": + return leaky_relu + elif name == "silu": + return nn.SiLU + else: + raise NotImplementedError(f"Activation {name} is not implemented") + +class Discriminator(nn.Module): + def __init__( + self, + config + ): + super().__init__() + dim = config.discriminator.dim + discr_layers = config.discriminator.discr_layers + layer_mults = list(map(lambda t: 2**t, range(discr_layers))) + layer_dims = [dim * mult for mult in layer_mults] + dims = (dim, *layer_dims) + channels=config.discriminator.channels + groups=config.discriminator.groups + init_kernel_size=config.discriminator.init_kernel_size + kernel_size=config.discriminator.kernel_size + act=config.discriminator.act + activation = get_activation(act) + dim_pairs = zip(dims[:-1], dims[1:]) + + self.layers = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d( + channels, + dims[0], + init_kernel_size, + padding=init_kernel_size // 2, + ), + activation(), + ) + ] + ) + + for dim_in, dim_out in dim_pairs: + self.layers.append( + nn.Sequential( + nn.Conv2d( + dim_in, + dim_out, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + ), + nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2)), + nn.GroupNorm(groups, dim_out), + activation(), + ) + ) + + dim = dims[-1] + self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training + nn.Conv2d(dim, dim, 1), activation(), nn.Conv2d(dim, 1, 4) + ) + + def forward(self, x): + for net in self.layers: + x = net(x) + + return self.to_logits(x) \ No newline at end of file diff --git a/training/ema.py b/training/ema.py new file mode 100644 index 00000000..00ff6af1 --- /dev/null +++ b/training/ema.py @@ -0,0 +1,323 @@ +""" +Taken from diffusers training_utils.py https://github.com/huggingface/diffusers/blob/main/src/diffusers/training_utils.py +""" +import copy +import os +import random +from typing import Any, Dict, Iterable, Optional, Union + +import numpy as np +import torch +from training.deprecate_utils import deprecate + +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # set seed first + set_seed(seed) + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + + +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 +class EMAModel: + """ + Exponential Moving Average of models weights + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + model_cls: Optional[Any] = None, + model_config: Dict[str, Any] = None, + **kwargs, + ): + """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. + + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + """ + + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility + use_ema_warmup = True + + if kwargs.get("max_value", None) is not None: + deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." + deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) + decay = kwargs["max_value"] + + if kwargs.get("min_value", None) is not None: + deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." + deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) + min_decay = kwargs["min_value"] + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + if kwargs.get("device", None) is not None: + deprecation_message = "The `device` argument is deprecated. Please use `to` instead." + deprecate("device", "1.0.0", deprecation_message, standard_warn=False) + self.to(device=kwargs["device"]) + + self.temp_stored_params = None + + self.decay = decay + self.min_decay = min_decay + self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup + self.inv_gamma = inv_gamma + self.power = power + self.optimization_step = 0 + self.cur_decay_value = None # set in `step()` + + self.model_cls = model_cls + self.model_config = model_config + + @classmethod + def from_pretrained(cls, path, model_cls) -> "EMAModel": + _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + model = model_cls.from_pretrained(path) + + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + + ema_model.load_state_dict(ema_kwargs) + return ema_model + + def save_pretrained(self, path): + if self.model_cls is None: + raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") + + if self.model_config is None: + raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") + + model = self.model_cls.from_config(self.model_config) + state_dict = self.state_dict() + state_dict.pop("shadow_params", None) + + model.register_to_config(**state_dict) + self.copy_to(model.parameters()) + model.save_pretrained(path) + + def get_decay(self, optimization_step: int) -> float: + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + + if step <= 0: + return 0.0 + + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value + + @torch.no_grad() + def step(self, parameters: Iterable[torch.nn.Parameter]): + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + + parameters = list(parameters) + + self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + self.cur_decay_value = decay + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.min_decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + } + + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Save the current parameters for restoring later. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] + + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + r""" + Args: + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: + affecting the original optimization process. Store the parameters before the `copy_to()` method. After + validation (or model saving), use this to restore the former parameters. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + if self.temp_stored_params is None: + raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) + + # Better memory-wise. + self.temp_stored_params = None + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict.get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + shadow_params = state_dict.get("shadow_params", None) + if shadow_params is not None: + self.shadow_params = shadow_params + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") \ No newline at end of file diff --git a/training/train_vqgan.py b/training/train_vqgan.py index c4e82216..d8bfe074 100644 --- a/training/train_vqgan.py +++ b/training/train_vqgan.py @@ -1 +1,696 @@ -"""Training script for VQGAN.""" +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import math +import os +import time +from pathlib import Path +from typing import Any, List, Tuple +from ema import EMAModel +import numpy as np +import torch +import torch.nn.functional as F +import wandb +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, set_seed +from data import ClassificationDataset +from omegaconf import DictConfig, ListConfig, OmegaConf +from optimizer import Lion +from PIL import Image +from torch.optim import AdamW # why is shampoo not available in PT :( + +import muse +from muse import MOVQ, MaskGitTransformer, MaskGitVQGAN, VQGANModel +from muse.lr_schedulers import get_scheduler +from muse.sampling import cosine_schedule +from training.discriminator import Discriminator +try: + import apex + + is_apex_available = True +except ImportError: + is_apex_available = False +import timm +from einops import repeat, rearrange +from tqdm import tqdm + + +logger = get_logger(__name__, log_level="INFO") + + +def get_config(): + cli_conf = OmegaConf.from_cli() + + yaml_conf = OmegaConf.load(cli_conf.config) + conf = OmegaConf.merge(yaml_conf, cli_conf) + + return conf + + +def flatten_omega_conf(cfg: Any, resolve: bool = False) -> List[Tuple[str, Any]]: + ret = [] + + def handle_dict(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]: + return [(f"{key}.{k1}", v1) for k1, v1 in flatten_omega_conf(value, resolve=resolve)] + + def handle_list(key: Any, value: Any, resolve: bool) -> List[Tuple[str, Any]]: + return [(f"{key}.{idx}", v1) for idx, v1 in flatten_omega_conf(value, resolve=resolve)] + + if isinstance(cfg, DictConfig): + for k, v in cfg.items_ex(resolve=resolve): + if isinstance(v, DictConfig): + ret.extend(handle_dict(k, v, resolve=resolve)) + elif isinstance(v, ListConfig): + ret.extend(handle_list(k, v, resolve=resolve)) + else: + ret.append((str(k), v)) + elif isinstance(cfg, ListConfig): + for idx, v in enumerate(cfg._iter_ex(resolve=resolve)): + if isinstance(v, DictConfig): + ret.extend(handle_dict(idx, v, resolve=resolve)) + elif isinstance(v, ListConfig): + ret.extend(handle_list(idx, v, resolve=resolve)) + else: + ret.append((str(idx), v)) + else: + assert False + + return ret + + +def get_vq_model_class(model_type): + if model_type == "movq": + return MOVQ + elif model_type == "maskgit_vqgan": + return MaskGitVQGAN + elif model_type == "taming_vqgan": + return VQGANModel + else: + raise ValueError(f"model_type {model_type} not supported for VQGAN") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def _map_layer_to_idx(backbone, layers, offset=0): + """Maps set of layer names to indices of model. Ported from anomalib + + Returns: + Feature map extracted from the CNN + """ + idx = [] + features = timm.create_model( + backbone, + pretrained=False, + features_only=False, + exportable=True, + ) + for i in layers: + try: + idx.append(list(dict(features.named_children()).keys()).index(i)-offset) + except ValueError: + raise ValueError( + f"Layer {i} not found in model {backbone}. Select layer from {list(dict(features.named_children()).keys())}. The network architecture is {features}" + ) + return idx + +# From https://arxiv.org/abs/2111.01007v1 Projected Gan where instead of giving the discriminator/generator the input image, we give hierarchical features +# from a timm model +class MultiLayerTimmModel(torch.nn.Module): + def __init__(self, model, input_shape=(3, 224, 224)): + super().__init__() + self.model = model + self.input_shape = input_shape + self.image_sizes = [] + self.max_feats = self.get_layer_widths(input_shape) + self.max_feature_sizes = (self.max_feats, self.max_feats) + def get_layer_widths(self, shape=(3, 224, 224)): + output = [] + batch_size = 1 + input = torch.autograd.Variable(torch.rand(batch_size, *shape)) + output_feats = self.model(input) + for output_feat in output_feats: + output.append(output_feat.shape[-1]) + max_feats = max(output) + return max_feats + def forward(self, images): + features = self.model(images) + output_features = [] + for feature in features: + if feature.shape[-1] == self.max_feats: + output_features.append(feature) + else: + output_features.append(F.interpolate(feature, size=self.max_feature_sizes)) + return torch.cat(output_features, dim=1) + +def get_perceptual_loss(pixel_values, fmap, timm_discriminator): + img_timm_discriminator_input = pixel_values + fmap_timm_discriminator_input = fmap + + if pixel_values.shape[1] == 1: + # handle grayscale for timm_discriminator + img_timm_discriminator_input, fmap_timm_discriminator_input = map( + lambda t: repeat(t, "b 1 ... -> b c ...", c=3), + (img_timm_discriminator_input, fmap_timm_discriminator_input), + ) + + img_timm_discriminator_feats = timm_discriminator( + img_timm_discriminator_input + ) + recon_timm_discriminator_feats = timm_discriminator( + fmap_timm_discriminator_input + ) + perceptual_loss = F.mse_loss( + img_timm_discriminator_feats[0], recon_timm_discriminator_feats[0] + ) + for i in range(1, len(img_timm_discriminator_feats)): + perceptual_loss += F.mse_loss( + img_timm_discriminator_feats[i], recon_timm_discriminator_feats[i] + ) + perceptual_loss /= len(img_timm_discriminator_feats) + return perceptual_loss + +def grad_layer_wrt_loss(loss, layer): + return torch.autograd.grad( + outputs=loss, + inputs=layer, + grad_outputs=torch.ones_like(loss), + retain_graph=True, + )[0].detach() + +def gradient_penalty(images, output, weight=10): + gradients = torch.autograd.grad( + outputs=output, + inputs=images, + grad_outputs=torch.ones(output.size(), device=images.device), + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients = rearrange(gradients, "b ... -> b (...)") + return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() + +def main(): + ######################### + # SETUP Accelerator # + ######################### + config = get_config() + + # Enable TF32 on Ampere GPUs + if config.training.enable_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + config.experiment.logging_dir = str(Path(config.experiment.output_dir) / "logs") + accelerator = Accelerator( + gradient_accumulation_steps=config.training.gradient_accumulation_steps, + mixed_precision=config.training.mixed_precision, + log_with="wandb", + logging_dir=config.experiment.logging_dir, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes. + ) + + if accelerator.distributed_type == DistributedType.DEEPSPEED: + accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = ( + config.training.batch_size + ) + + ##################################### + # SETUP LOGGING, SEED and CONFIG # + ##################################### + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + muse.logging.set_verbosity_info() + else: + muse.logging.set_verbosity_error() + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + resume_wandb_run = config.experiment.resume_from_checkpoint + run_id = config.wandb.get("run_id", None) + if run_id is None: + resume_wandb_run = False + run_id = wandb.util.generate_id() + config.wandb.run_id = run_id + + wandb_init_kwargs = dict( + name=config.experiment.name, + id=run_id, + resume=resume_wandb_run, + entity=config.wandb.get("entity", None), + config_exclude_keys=[], + ) + wandb_config = {k: v for k, v in flatten_omega_conf(config, resolve=True)} + wandb_config.pop("experiment.resume_from_checkpoint") + accelerator.init_trackers( + config.experiment.project, + config=wandb_config, + init_kwargs={"wandb": wandb_init_kwargs}, + ) + + if accelerator.is_main_process: + os.makedirs(config.experiment.output_dir, exist_ok=True) + config_path = Path(config.experiment.output_dir) / "config.yaml" + logging.info(f"Saving config to {config_path}") + OmegaConf.save(config, config_path) + + # If passed along, set the training seed now. + if config.training.seed is not None: + set_seed(config.training.seed) + + ######################### + # MODELS and OPTIMIZER # + ######################### + logger.info("Loading models and optimizer") + + vq_class = get_vq_model_class(config.model.vq_model.type) + model = vq_class.from_pretrained(config.model.vq_model.pretrained) + if config.training.use_ema: + ema_model = EMAModel(model.parameters(), model_cls=vq_class, model_config=model.config) + + discriminator = Discriminator(config) + # TODO: Add timm_discriminator_backend to config.training. Set default to vgg16 + idx = _map_layer_to_idx(config.training.timm_discriminator_backend,\ + config.training.timm_disc_layers.split("|"), config.training.timm_discr_offset) + + timm_discriminator = timm.create_model( + config.training.timm_discriminator_backend, + pretrained=True, + features_only=True, + exportable=True, + out_indices=idx, + ) + timm_discriminator = timm_discriminator.to(accelerator.device) + timm_discriminator.requires_grad = False + timm_discriminator.eval() + # Enable flash attention if asked + if config.model.enable_xformers_memory_efficient_attention: + model.enable_xformers_memory_efficient_attention() + + optimizer_config = config.optimizer.params + learning_rate = optimizer_config.learning_rate + if optimizer_config.scale_lr: + learning_rate = ( + learning_rate + * config.training.batch_size + * accelerator.num_processes + * config.training.gradient_accumulation_steps + ) + + optimizer_type = config.optimizer.name + if optimizer_type == "adamw": + optimizer_cls = AdamW + elif optimizer_type == "fused_adamw": + if is_apex_available: + optimizer_cls = apex.optimizers.FusedAdam + else: + raise ImportError("Please install apex to use fused_adam") + elif optimizer_type == "lion": + optimizer_cls = Lion + else: + raise ValueError(f"Optimizer {optimizer_type} not supported") + + optimizer = optimizer_cls( + list(model.parameters()), + lr=optimizer_config.learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + discr_optimizer = optimizer_cls( + list(discriminator.parameters()), + lr=optimizer_config.discr_learning_rate, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + eps=optimizer_config.epsilon, + ) + + ################################## + # DATLOADER and LR-SCHEDULER # + ################################# + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size_without_accum = config.training.batch_size * accelerator.num_processes + total_batch_size = ( + config.training.batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps + ) + + # DataLoaders creation: + # We use webdataset for data loading. The dataloaders are created with sampling with replacement. + # We don't do dataset resuming here, instead we resample the shards and buffer each time. The sampling is stochastic. + # This means that the dataloading is not deterministic, but it's fast and efficient. + preproc_config = config.dataset.preprocessing + dataset_config = config.dataset.params + dataset = ClassificationDataset( + train_shards_path_or_url=dataset_config.train_shards_path_or_url, + eval_shards_path_or_url=dataset_config.eval_shards_path_or_url, + num_train_examples=config.experiment.max_train_examples, + per_gpu_batch_size=config.training.batch_size, + global_batch_size=total_batch_size_without_accum, + num_workers=dataset_config.num_workers, + resolution=preproc_config.resolution, + center_crop=preproc_config.center_crop, + random_flip=preproc_config.random_flip, + shuffle_buffer_size=dataset_config.shuffle_buffer_size, + pin_memory=dataset_config.pin_memory, + persistent_workers=dataset_config.persistent_workers, + ) + train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader + + lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + ) + discr_lr_scheduler = get_scheduler( + config.lr_scheduler.scheduler, + optimizer=discr_optimizer, + num_training_steps=config.training.max_train_steps, + num_warmup_steps=config.lr_scheduler.params.warmup_steps, + ) + + + # Prepare everything with accelerator + logger.info("Preparing model, optimizer and dataloaders") + # The dataloader are already aware of distributed training, so we don't need to prepare them. + model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare(model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler) + + if config.training.overfit_one_batch: + train_dataloader = [next(iter(train_dataloader))] + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / config.training.gradient_accumulation_steps) + # Afterwards we recalculate our number of training epochs. + # Note: We are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num training steps = {config.training.max_train_steps}") + logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") + logger.info(f" Instantaneous batch size per device = { config.training.batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + resume_from_checkpoint = config.experiment.resume_from_checkpoint + if resume_from_checkpoint: + if resume_from_checkpoint != "latest": + path = resume_from_checkpoint + else: + # Get the most recent checkpoint + dirs = os.listdir(config.experiment.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + path = os.path.join(config.experiment.output_dir, path) + + if path is None: + accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") + resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + + resume_lr_scheduler = config.experiment.get("resume_lr_scheduler", True) + if not resume_lr_scheduler: + logger.info("Not resuming the lr scheduler.") + accelerator._schedulers = [] # very hacky, but we don't want to resume the lr scheduler + accelerator.load_state(path) + accelerator.wait_for_everyone() + if not resume_lr_scheduler: + accelerator._schedulers = [lr_scheduler] + global_step = int(os.path.basename(path).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + avg_gen_loss, avg_discr_loss = None, None + for epoch in range(first_epoch, num_train_epochs): + model.train() + for i, batch in tqdm(enumerate(train_dataloader)): + pixel_values, _ = batch + pixel_values = pixel_values.to(accelerator.device, non_blocking=True) + data_time_m.update(time.time() - end) + generator_step = ((i // config.training.gradient_accumulation_steps) % 2) == 0 and i > config.training.discriminator_warmup + # TODO: + # Add entropy to maximize codebook usage + # Train Step + # The behavior of accelerator.accumulate is to + # 1. Check if gradients are synced(reached gradient-accumulation_steps) + # 2. If so sync gradients by stopping the not syncing process + if generator_step: + if optimizer_type == "fused_adamw": + optimizer.zero_grad() + else: + optimizer.zero_grad(set_to_none=True) + else: + if optimizer_type == "fused_adamw": + discr_optimizer.zero_grad() + else: + discr_optimizer.zero_grad(set_to_none=True) + # encode images to the latent space and get the commit loss from vq tokenization + # Return commit loss + fmap, _, _, commit_loss = model(pixel_values, return_loss=True) + + if generator_step: + with accelerator.accumulate(model): + # reconstruction loss. Pixel level differences between input vs output + if config.training.vae_loss == "l2": + loss = F.mse_loss(pixel_values, fmap) + else: + loss = F.l1_loss(pixel_values, fmap) + # perceptual loss. The high level feature mean squared error loss + perceptual_loss = get_perceptual_loss(pixel_values, fmap, timm_discriminator) + # generator loss + gen_loss = -discriminator(fmap).mean() + last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight + norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2) + norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) + + adaptive_weight = norm_grad_wrt_perceptual_loss/norm_grad_wrt_gen_loss.clamp(min=1e-8) + adaptive_weight = adaptive_weight.clamp(max=1e4) + loss += commit_loss + loss += perceptual_loss + loss += adaptive_weight*gen_loss + # Gather thexd losses across all processes for logging (if we use distributed training). + avg_gen_loss = accelerator.gather(loss.repeat(config.training.batch_size)).float().mean() + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + # log gradient norm before zeroing it + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(model, accelerator, global_step + 1) + else: + # Return discriminator loss + with accelerator.accumulate(discriminator): + fmap.detach_() + pixel_values.requires_grad_() + real = discriminator(pixel_values) + fake = discriminator(fmap) + loss = (F.relu(1 + fake) + F.relu(1 - real)).mean() + gp = gradient_penalty(pixel_values, real) + loss += gp + avg_discr_loss = accelerator.gather(loss.repeat(config.training.batch_size)).mean() + accelerator.backward(loss) + + if config.training.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(discriminator.parameters(), config.training.max_grad_norm) + + discr_optimizer.step() + discr_lr_scheduler.step() + if ( + accelerator.sync_gradients + and (global_step + 1) % config.experiment.log_grad_norm_every == 0 + and accelerator.is_main_process + ): + log_grad_norm(discriminator, accelerator, global_step + 1) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients and not generator_step: + if config.training.use_ema: + ema_model.step(model.parameters()) + # wait for both generator and discriminator to settle + batch_time_m.update(time.time() - end) + end = time.time() + # Log metrics + if (global_step + 1) % config.experiment.log_every == 0: + samples_per_second_per_gpu = ( + config.training.gradient_accumulation_steps * config.training.batch_size / batch_time_m.val + ) + logs = { + "step_discr_loss": avg_discr_loss.item(), + "lr": lr_scheduler.get_last_lr()[0], + "samples/sec/gpu": samples_per_second_per_gpu, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + } + if avg_gen_loss is not None: + logs["step_gen_loss"] = avg_gen_loss.item() + accelerator.log(logs, step=global_step + 1) + logger.info( + f"Data (t): {data_time_m.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " + f"Batch (t): {batch_time_m.val:0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f} " + f"Step: {global_step + 1} " + f"Discriminator Loss: {avg_discr_loss.item():0.4f} " + ) + if avg_gen_loss is not None: + logger.info(f"Generator Loss: {avg_gen_loss.item():0.4f} ") + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # Save model checkpoint + if (global_step + 1) % config.experiment.save_every == 0: + save_checkpoint(model, discriminator, config, accelerator, global_step + 1) + + # Generate images + if (global_step + 1) % config.experiment.generate_every == 0 and accelerator.is_main_process: + generate_images(model, pixel_values[:config.training.num_validation_log], accelerator, global_step + 1) + + global_step += 1 + # TODO: Add generation + + # Stop training if max steps is reached + if global_step >= config.training.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(model, discriminator, config, accelerator, global_step) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + if config.training.use_ema: + ema_model.copy_to(model.parameters()) + model.save_pretrained(config.experiment.output_dir) + + accelerator.end_training() + + + +@torch.no_grad() +def generate_images(model, original_images, accelerator, global_step): + logger.info("Generating images...") + original_images = torch.clone(original_images) + # Generate images + model.eval() + dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + dtype = torch.bfloat16 + + with torch.autocast("cuda", dtype=dtype, enabled=accelerator.mixed_precision != "no"): + _, enc_token_ids = accelerator.unwrap_model(model).encode(original_images) + # In the beginning of training, the model is not fully trained and the generated token ids can be out of range + # so we clamp them to the correct range. + enc_token_ids = torch.clamp(enc_token_ids, max=accelerator.unwrap_model(model).config.num_embeddings - 1) + images = accelerator.unwrap_model(model).decode_code(enc_token_ids) + model.train() + + # Convert to PIL images + images = 2.0 * images - 1.0 + original_images = 2.0 * original_images - 1.0 + images = torch.clamp(images, -1.0, 1.0) + original_images = torch.clamp(original_images, -1.0, 1.0) + images = (images + 1.0) / 2.0 + original_images = (original_images + 1.0) / 2.0 + images *= 255.0 + original_images *= 255.0 + images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) + images = np.concatenate([original_images, images], axis=2) + pil_images = [Image.fromarray(image) for image in images] + + # Log images + wandb_images = [wandb.Image(image, caption="Original, Generated") for image in pil_images] + wandb.log({"vae_images": wandb_images}, step=global_step) + + +def save_checkpoint(model, discriminator, config, accelerator, global_step): + save_path = Path(config.experiment.output_dir) / f"checkpoint-{global_step}" + + # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) + # XXX: could also make this conditional on deepspeed + state_dict = accelerator.get_state_dict(model) + discr_state_dict = accelerator.get_state_dict(discriminator) + + if accelerator.is_main_process: + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + save_path / "unwrapped_model", + save_function=accelerator.save, + state_dict=state_dict, + ) + torch.save(discr_state_dict, save_path / "unwrapped_discriminator") + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) + logger.info(f"Saved state to {save_path}") + + accelerator.save_state(save_path) + + +def log_grad_norm(model, accelerator, global_step): + for name, param in model.named_parameters(): + if param.grad is not None: + grads = param.grad.detach().data + grad_norm = (grads.norm(p=2) / grads.numel()).item() + accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) + + +if __name__ == "__main__": + main()