Skip to content

Commit

Permalink
Updated Dockerfile
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Nov 21, 2024
1 parent bfc1631 commit f55d7bc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ RUN apt-get update && \
RUN pip3 install Ipython matplotlib
RUN pip3 install numpy pandas scipy

RUN pip3 install -U numpy==1.26.4
RUN pip3 install -U -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip3 install -q transformers datasets scalax tokenizers icecream wandb einops torch tqdm jaxtyping optax optuna equinox rich
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly lm-eval
RUN pip3 install -U tensorboard-plugin-profile optuna-integration plotly lm-eval pdbpp
RUN pip3 install git+https://github.com/deepmind/jmp
RUN pip3 install git+https://github.com/Findus23/jax-array-info.git

Expand Down
4 changes: 2 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BRANCH="dev"

# arguments for train_model.py
TRAIN_ARGS="--save_dir ./ReAct/outputs/ --dataset tinystories --group debug \
TRAIN_ARGS="--save_dir ./ReAct/outputs/ --dataset owt --group debug \
--num_blocks 8 --width 1536 --n_heads 8 --epochs 1 --num_classes 50304 \
--log_interval 750 --save_interval 10000 --seqlen 512 \
--max_iters 3 --batch_size 64 --accum_steps 8 \
Expand Down Expand Up @@ -62,5 +62,5 @@ fi

echo "Executing train_model.py"
source main_env/bin/activate
XLA_FLAGS="--xla_gpu_triton_gemm_any=true --xla_gpu_enable_triton_softmax_fusion=true" python3 ReAct_Jax/train_model.py $TRAIN_ARGS
XLA_FLAGS="--xla_gpu_triton_gemm_any=true --xla_gpu_enable_triton_softmax_fusion=true --xla_gpu_enable_triton_hopper=true" python3 ReAct_Jax/train_model.py $TRAIN_ARGS
echo "Finished training!"

0 comments on commit f55d7bc

Please sign in to comment.