-
Notifications
You must be signed in to change notification settings - Fork 84
/
3.distributed-training-mistral-mathstral.sbatch
107 lines (93 loc) · 2.92 KB
/
3.distributed-training-mistral-mathstral.sbatch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/bin/bash
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
#SBATCH --nodes=4 # number of nodes to use
#SBATCH --job-name=FSDP # name of your job
#SBATCH --exclusive # job has exclusive use of the resource, no sharing
set -ex;
###########################
###### User Variables #####
###########################
GPUS_PER_NODE=8 # 4 for G5.12x, 8 for P4/P5
###########################
## Environment Variables ##
###########################
## Plenty of EFA level variables
## Comment out for non-efa instances (G4d, P3)
## For G5.12x, Comment out RDMA and Fork safe
## For G4dn and other G5, comment out all
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d
export FI_EFA_FORK_SAFE=1
export FI_LOG_LEVEL=1
export FI_PROVIDER=efa
export NCCL_DEBUG=INFO
## Switching SYNC_MEMOPS to zero can boost throughput with FSDP
## Disables CU_POINTER_ATTRIBUTE_SYNC_MEMOPS
## Reduces memory synchronizations
## https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html
export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
###########################
####### Torch Dist #######
###########################
declare -a TORCHRUN_ARGS=(
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$SLURM_JOB_NUM_NODES \
--rdzv_id=$SLURM_JOB_ID \
--rdzv_backend=c10d \
--rdzv_endpoint=$(hostname) \
)
export TORCHRUN=./pt_fsdp/bin/torchrun
export TRAIN_SCRIPT=./train.py
############################
# Mistral Training Params ##
############################
declare -a TRAINING_ARGS=(
--train_batch_size=1 \
--val_batch_size=1 \
--max_steps=5000 \
--seed=42 \
--grad_clip=1.0 \
--weight_decay=0.2 \
--beta1=0.9 \
--beta2=0.95 \
--activation_checkpointing=1 \
--intermediate_size=14336 \
--num_key_value_heads=8 \
--logging_freq=1 \
--max_context_width=32768 \
--vocab_size=32768 \
--hidden_width=4096 \
--num_layers=32 \
--num_heads=32 \
--resid_pdrop=0.1 \
--embd_pdrop=0.1 \
--attn_pdrop=0.1 \
--summary_first_pdrop=0.1 \
--initializer_range=0.02 \
--model_type="mistral" \
--rotary_pct=0.25 \
--rotary_emb_base=10000 \
--lr=0.0001 \
--lr_decay_style="cosine" \
--min_lr=1e-5 \
--warmup=0.0032 \
--plateau=0.0 \
--dataset="c4" \
--tokenizer="mistralai/mathstral-7B-v0.1" \
--epochs=3 \
--checkpoint_dir="./checkpoints/mathstral-7B" \
--resume_from_checkpoint="./checkpoints/mathstral-7B" \
--checkpoint_freq=50 \
--validation_freq=500 \
--dataset_config_name="en" \
--limit_all_gathers=1 \
--sharding_strategy="full" \ # https://pytorch.org/docs/stable/fsdp.html
--offload_activations=1
)
## added auto-resume flag for resiliency checks
AUTO_RESUME=""
if [ -d "/opt/sagemaker_cluster" ]; then
echo "Detected Hyperpod cluster.. enabling --auto-resume=1"
AUTO_RESUME="--auto-resume=1"
fi
srun ${AUTO_RESUME} -l ${TORCHRUN} "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}"