Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train text encoder #102

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions configs/finetune_text_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
wandb:
entity: williamberman

experiment:
project: "muse"
name: "finetune-text-encoder"
output_dir: "finetune-text-encoder"
max_train_examples: 8974320 # toal sucessfully downloaded images for laiona6plus
max_eval_examples: 8118
save_every: 2000
eval_every: 1000
generate_every: 1000
log_every: 50
log_grad_norm_every: 500
resume_from_checkpoint: False

model:
vq_model:
type: "vqgan"
pretrained: "openMUSE/vqgan-f16-8192-laion"

text_encoder:
type: "clip"
pretrained: "openMUSE/clip-vit-large-patch14-penultimate"
pad_token_id: 49407

architecture: "uvit"

transformer:
vocab_size: 8256 # (8192 + 1 for <mask> = 8193 but 8256 is the next multiple of 8)
hidden_size: 1024
intermediate_size: 4096
num_hidden_layers: 22
num_attention_heads: 16
max_position_embeddings: 256
in_channels: 512
block_out_channels:
- 1024
num_res_blocks: 3
patch_size: 1
encoder_hidden_size: 768
add_cross_attention: True
project_encoder_hidden_states: False
codebook_size: 8192
num_vq_tokens: 576
initializer_range: 0.02
norm_type: "rmsnorm"
layer_norm_eps: 1e-6
use_normformer: False
use_encoder_layernorm: True
use_bias: False
hidden_dropout: 0.0
attention_dropout: 0.0
use_codebook_size_for_output: True

gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
type: "text2image"
params:
train_shards_path_or_url: "pipe:aws s3 cp s3://hf-datasets-laion-5b-us-west-2/glacier/laion-data/laion-aesthetics-v2-5-plus-data/{00000..60580}.tar -"
eval_shards_path_or_url: "pipe:aws s3 cp s3://muse-datasets/coco/2014/val/{00000..00010}.tar -"
validation_prompts_file: "validation_prompts/dalle_mini_prompts.txt"
batch_size: ${training.batch_size}
shuffle_buffer_size: 1000
num_workers: 4
resolution: 384
pin_memory: True
persistent_workers: True
use_filtered_dataset: True
preprocessing:
max_seq_length: 77
resolution: 384
center_crop: False
random_flip: False


optimizer:
name: adamw
params: # default adamw params
learning_rate: 1e-4
scale_lr: False # scale learning rate by total batch size
beta1: 0.9
beta2: 0.999
weight_decay: 0.01
epsilon: 1e-8


lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 2000


training:
gradient_accumulation_steps: 1
batch_size: 44
mixed_precision: "fp16"
enable_tf32: True
use_ema: False
seed: 9345104
max_train_steps: 1000000
overfit_one_batch: False
cond_dropout_prob: 0.1
min_masking_rate: 0.0
label_smoothing: 0.1
max_grad_norm: null
guidance_scale: 8
generation_timesteps: 16
# related to vae code sampling
use_soft_code_target: False
use_stochastic_code: False
soft_code_temp: 1.0
86 changes: 86 additions & 0 deletions slurm_scripts/finetune_text_encoder.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/bin/bash
#SBATCH --job-name=oai-clip-lbl-smooth-laiona6-block-attn-512
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96
#SBATCH --gres=gpu:8
#SBATCH --exclusive
#SBATCH --partition=production-cluster
#SBATCH --output=/fsx/william/logs/finetune_text_encoder.out

set -x -e

source /admin/home/william/.bashrc
source /fsx/william/miniconda3/etc/profile.d/conda.sh
conda activate base

echo "START TIME: $(date)"

MUSE_REPO=/fsx/william/open-muse
OUTPUT_DIR=/fsx/william/finetune_text_encoder
LOG_PATH=$OUTPUT_DIR/main_log.txt
CONFIG_PATH=/fsx/william/open-muse/configs/finetune_text_encoder.yaml

mkdir -p $OUTPUT_DIR
touch $LOG_PATH
pushd $MUSE_REPO

CMD=" \
$MUSE_REPO/training/train_muse.py config=$CONFIG_PATH \
experiment.name=$(basename $OUTPUT_DIR) \
experiment.output_dir=$OUTPUT_DIR \
"

GPUS_PER_NODE=8
NNODES=$SLURM_NNODES

# set the visible GPUs
export CUDA_VISIBLE_DEVICES=${SLURM_STEP_GPUS:-$SLURM_JOB_GPUS}

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000

export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"

echo $CMD

# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json

# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1

# AWS specific
export NCCL_PROTO=simple
export RDMAV_FORK_SAFE=1
export FI_EFA_FORK_SAFE=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_PROVIDER=efa
export FI_LOG_LEVEL=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=ens


# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"

# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH
23 changes: 17 additions & 6 deletions training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import plotly.express as px
import torch
import torch.nn.functional as F
import wandb
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedType, set_seed
Expand All @@ -46,6 +45,7 @@

import muse
import muse.training_utils
import wandb
from muse import (
MOVQ,
EMAModel,
Expand Down Expand Up @@ -391,6 +391,7 @@ def save_model_hook(models, weights, output_dir):
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
{"params": text_encoder.parameters(), "lr": optimizer_config.learning_rate // 2},
]

optimizer = optimizer_cls(
Expand Down Expand Up @@ -467,7 +468,7 @@ def save_model_hook(models, weights, output_dir):
# Prepare everything with accelerator
logger.info("Preparing model, optimizer and dataloaders")
# The dataloader are already aware of distributed training, so we don't need to prepare them.
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
model, text_encoder, optimizer, lr_scheduler = accelerator.prepare(model, text_encoder, optimizer, lr_scheduler)

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
Expand All @@ -479,7 +480,7 @@ def save_model_hook(models, weights, output_dir):
weight_dtype = torch.bfloat16

if not is_pre_encode:
text_encoder.to(device=accelerator.device, dtype=weight_dtype)
# text_encoder.to(device=accelerator.device, dtype=weight_dtype)
vq_model.to(device=accelerator.device)
if config.training.get("use_ema", False):
ema.to(accelerator.device)
Expand Down Expand Up @@ -587,6 +588,7 @@ def prepare_inputs_and_labels(
# reuse the same training loop with other datasets/loaders.
for epoch in range(first_epoch, num_train_epochs):
model.train()
text_encoder.train()
for batch in train_dataloader:
# TODO(Patrick) - We could definitely pre-compute the image tokens for faster training on larger datasets
pixel_values, input_ids = batch
Expand Down Expand Up @@ -730,7 +732,7 @@ def prepare_inputs_and_labels(

# Save model checkpoint
if (global_step + 1) % config.experiment.save_every == 0:
save_checkpoint(model, config, accelerator, global_step + 1)
save_checkpoint(model, text_encoder, config, accelerator, global_step + 1)

# Evaluate model on main process
if (global_step + 1) % config.experiment.eval_every == 0 and accelerator.is_main_process:
Expand Down Expand Up @@ -783,7 +785,7 @@ def prepare_inputs_and_labels(
# Evaluate and save checkpoint at the end of training
if accelerator.is_main_process:
validate_model(model, eval_dataloader, accelerator, global_step, prepare_inputs_and_labels)
save_checkpoint(model, config, accelerator, global_step)
save_checkpoint(model, text_encoder, config, accelerator, global_step)

# Save the final trained checkpoint
if accelerator.is_main_process:
Expand Down Expand Up @@ -917,12 +919,13 @@ def generate_images(
wandb.log({"generated_images": wandb_images}, step=global_step)


def save_checkpoint(model, config, accelerator, global_step):
def save_checkpoint(model, text_encoder, config, accelerator, global_step):
save_path = Path(config.experiment.output_dir) / f"checkpoint-{global_step}"

# retrieve the model on all processes for deepspeed stage 3 to work then save on one process (we are not using stage 3 yet)
# XXX: could also make this conditional on deepspeed
state_dict = accelerator.get_state_dict(model)
text_encoder_state_dict = accelerator.get_state_dict(text_encoder)

if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
Expand All @@ -931,6 +934,14 @@ def save_checkpoint(model, config, accelerator, global_step):
save_function=accelerator.save,
state_dict=state_dict,
)

unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
unwrapped_text_encoder.save_pretrained(
save_path / "unwrapped_textencoder",
save_function=accelerator.save,
state_dict=text_encoder_state_dict,
)

json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+"))
logger.info(f"Saved state to {save_path}")

Expand Down