diff --git a/configs/finetune_text_encoder.yaml b/configs/finetune_text_encoder.yaml new file mode 100644 index 00000000..ee5b8060 --- /dev/null +++ b/configs/finetune_text_encoder.yaml @@ -0,0 +1,116 @@ +wandb: + entity: williamberman + +experiment: + project: "muse" + name: "finetune-text-encoder" + output_dir: "finetune-text-encoder" + max_train_examples: 8974320 # toal sucessfully downloaded images for laiona6plus + max_eval_examples: 8118 + save_every: 2000 + eval_every: 1000 + generate_every: 1000 + log_every: 50 + log_grad_norm_every: 500 + resume_from_checkpoint: False + +model: + vq_model: + type: "vqgan" + pretrained: "openMUSE/vqgan-f16-8192-laion" + + text_encoder: + type: "clip" + pretrained: "openMUSE/clip-vit-large-patch14-penultimate" + pad_token_id: 49407 + + architecture: "uvit" + + transformer: + vocab_size: 8256 # (8192 + 1 for = 8193 but 8256 is the next multiple of 8) + hidden_size: 1024 + intermediate_size: 4096 + num_hidden_layers: 22 + num_attention_heads: 16 + max_position_embeddings: 256 + in_channels: 512 + block_out_channels: + - 1024 + num_res_blocks: 3 + patch_size: 1 + encoder_hidden_size: 768 + add_cross_attention: True + project_encoder_hidden_states: False + codebook_size: 8192 + num_vq_tokens: 576 + initializer_range: 0.02 + norm_type: "rmsnorm" + layer_norm_eps: 1e-6 + use_normformer: False + use_encoder_layernorm: True + use_bias: False + hidden_dropout: 0.0 + attention_dropout: 0.0 + use_codebook_size_for_output: True + + gradient_checkpointing: True + enable_xformers_memory_efficient_attention: True + + +dataset: + type: "text2image" + params: + train_shards_path_or_url: "pipe:aws s3 cp s3://hf-datasets-laion-5b-us-west-2/glacier/laion-data/laion-aesthetics-v2-5-plus-data/{00000..60580}.tar -" + eval_shards_path_or_url: "pipe:aws s3 cp s3://muse-datasets/coco/2014/val/{00000..00010}.tar -" + validation_prompts_file: "validation_prompts/dalle_mini_prompts.txt" + batch_size: ${training.batch_size} + shuffle_buffer_size: 1000 + num_workers: 4 + resolution: 384 + pin_memory: True + persistent_workers: True + use_filtered_dataset: True + preprocessing: + max_seq_length: 77 + resolution: 384 + center_crop: False + random_flip: False + + +optimizer: + name: adamw + params: # default adamw params + learning_rate: 1e-4 + scale_lr: False # scale learning rate by total batch size + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.01 + epsilon: 1e-8 + + +lr_scheduler: + scheduler: "constant_with_warmup" + params: + learning_rate: ${optimizer.params.learning_rate} + warmup_steps: 2000 + + +training: + gradient_accumulation_steps: 1 + batch_size: 44 + mixed_precision: "fp16" + enable_tf32: True + use_ema: False + seed: 9345104 + max_train_steps: 1000000 + overfit_one_batch: False + cond_dropout_prob: 0.1 + min_masking_rate: 0.0 + label_smoothing: 0.1 + max_grad_norm: null + guidance_scale: 8 + generation_timesteps: 16 + # related to vae code sampling + use_soft_code_target: False + use_stochastic_code: False + soft_code_temp: 1.0 \ No newline at end of file diff --git a/slurm_scripts/finetune_text_encoder.slurm b/slurm_scripts/finetune_text_encoder.slurm new file mode 100644 index 00000000..ae83e51a --- /dev/null +++ b/slurm_scripts/finetune_text_encoder.slurm @@ -0,0 +1,86 @@ +#!/bin/bash +#SBATCH --job-name=oai-clip-lbl-smooth-laiona6-block-attn-512 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/william/logs/finetune_text_encoder.out + +set -x -e + +source /admin/home/william/.bashrc +source /fsx/william/miniconda3/etc/profile.d/conda.sh +conda activate base + +echo "START TIME: $(date)" + +MUSE_REPO=/fsx/william/open-muse +OUTPUT_DIR=/fsx/william/finetune_text_encoder +LOG_PATH=$OUTPUT_DIR/main_log.txt +CONFIG_PATH=/fsx/william/open-muse/configs/finetune_text_encoder.yaml + +mkdir -p $OUTPUT_DIR +touch $LOG_PATH +pushd $MUSE_REPO + +CMD=" \ + $MUSE_REPO/training/train_muse.py config=$CONFIG_PATH \ + experiment.name=$(basename $OUTPUT_DIR) \ + experiment.output_dir=$OUTPUT_DIR \ + " + +GPUS_PER_NODE=8 +NNODES=$SLURM_NNODES + +# set the visible GPUs +export CUDA_VISIBLE_DEVICES=${SLURM_STEP_GPUS:-$SLURM_JOB_GPUS} + +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH \ No newline at end of file diff --git a/training/train_muse.py b/training/train_muse.py index 1fe7f8c0..5cb7114e 100644 --- a/training/train_muse.py +++ b/training/train_muse.py @@ -27,7 +27,6 @@ import plotly.express as px import torch import torch.nn.functional as F -import wandb from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedType, set_seed @@ -46,6 +45,7 @@ import muse import muse.training_utils +import wandb from muse import ( MOVQ, EMAModel, @@ -391,6 +391,7 @@ def save_model_hook(models, weights, output_dir): "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, + {"params": text_encoder.parameters(), "lr": optimizer_config.learning_rate // 2}, ] optimizer = optimizer_cls( @@ -467,7 +468,7 @@ def save_model_hook(models, weights, output_dir): # Prepare everything with accelerator logger.info("Preparing model, optimizer and dataloaders") # The dataloader are already aware of distributed training, so we don't need to prepare them. - model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + model, text_encoder, optimizer, lr_scheduler = accelerator.prepare(model, text_encoder, optimizer, lr_scheduler) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -479,7 +480,7 @@ def save_model_hook(models, weights, output_dir): weight_dtype = torch.bfloat16 if not is_pre_encode: - text_encoder.to(device=accelerator.device, dtype=weight_dtype) + # text_encoder.to(device=accelerator.device, dtype=weight_dtype) vq_model.to(device=accelerator.device) if config.training.get("use_ema", False): ema.to(accelerator.device) @@ -587,6 +588,7 @@ def prepare_inputs_and_labels( # reuse the same training loop with other datasets/loaders. for epoch in range(first_epoch, num_train_epochs): model.train() + text_encoder.train() for batch in train_dataloader: # TODO(Patrick) - We could definitely pre-compute the image tokens for faster training on larger datasets pixel_values, input_ids = batch @@ -730,7 +732,7 @@ def prepare_inputs_and_labels( # Save model checkpoint if (global_step + 1) % config.experiment.save_every == 0: - save_checkpoint(model, config, accelerator, global_step + 1) + save_checkpoint(model, text_encoder, config, accelerator, global_step + 1) # Evaluate model on main process if (global_step + 1) % config.experiment.eval_every == 0 and accelerator.is_main_process: @@ -783,7 +785,7 @@ def prepare_inputs_and_labels( # Evaluate and save checkpoint at the end of training if accelerator.is_main_process: validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels) - save_checkpoint(model, config, accelerator, global_step) + save_checkpoint(model, text_encoder, config, accelerator, global_step) # Save the final trained checkpoint if accelerator.is_main_process: @@ -917,12 +919,13 @@ def generate_images( wandb.log({"generated_images": wandb_images}, step=global_step) -def save_checkpoint(model, config, accelerator, global_step): +def save_checkpoint(model, text_encoder, config, accelerator, global_step): save_path = Path(config.experiment.output_dir) / f"checkpoint-{global_step}" # retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet) # XXX: could also make this conditional on deepspeed state_dict = accelerator.get_state_dict(model) + text_encoder_state_dict = accelerator.get_state_dict(text_encoder) if accelerator.is_main_process: unwrapped_model = accelerator.unwrap_model(model) @@ -931,6 +934,14 @@ def save_checkpoint(model, config, accelerator, global_step): save_function=accelerator.save, state_dict=state_dict, ) + + unwrapped_text_encoder = accelerator.unwrap_model(text_encoder) + unwrapped_text_encoder.save_pretrained( + save_path / "unwrapped_textencoder", + save_function=accelerator.save, + state_dict=text_encoder_state_dict, + ) + json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) logger.info(f"Saved state to {save_path}")