Skip to content

Commit a13c683

Browse files
Changes the DPO + finetune scripts to provide progress updates in the Beaker description. (#1127)
* Now, we get num_attention_heads from the hf config. * Update code * Added test that we match manual values * Updated calculations * Updated code with check_calculation * Updated code * Now, tests pass. * Updated code to normalize properly * Added some fixes * Updated code * Updated code * Another fix * Cleaned up tests. * Cleaned up PR * Update MFU/MBU code. * Now, mbu tests pass. * Moved to json file * Added test data * undid changes and simplified test function. * An attempt at a fix * Update code with patches * now, tests pass * Added MFU to DPO * updated script * uses uv for dpo * Added a chat template to the DPO script. * Added trackign * Updated code to handle tracking when none * Added description updates * undid changes * Check out dpo script * updated script * Update code to remove whitespace * fix finetune timing * Fixed bugs pointed out by cursor.
1 parent a76cc4b commit a13c683

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

open_instruct/dpo_tune_cache.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
is_beaker_job,
8080
launch_ai2_evals_on_weka,
8181
maybe_get_beaker_config,
82+
maybe_update_beaker_description,
8283
maybe_use_ai2_hf_entity,
8384
maybe_use_ai2_wandb_entity,
8485
)
@@ -498,6 +499,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
498499
},
499500
)
500501
wandb_tracker = accelerator.get_tracker("wandb")
502+
maybe_update_beaker_description(wandb_url=wandb_tracker.run.get_url() if args.with_tracking else None)
501503

502504
if accelerator.is_main_process:
503505
pprint([args, tc])
@@ -813,6 +815,7 @@ def load_model():
813815
print("=============after cache logprobs; clear cache")
814816
print_gpu_stats(init_gpu_memory)
815817
# Only show the progress bar once on each machine.
818+
start_time = time.perf_counter()
816819
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
817820
# update the progress_bar if load from checkpoint
818821
progress_bar.update(completed_steps)
@@ -936,6 +939,12 @@ def load_model():
936939
logger.info(logger_str)
937940
if args.with_tracking:
938941
accelerator.log(metrics_to_log, step=completed_steps)
942+
maybe_update_beaker_description(
943+
current_step=completed_steps,
944+
total_steps=args.max_train_steps,
945+
start_time=start_time,
946+
wandb_url=wandb_tracker.run.get_url() if args.with_tracking else None,
947+
)
939948
# Reset the local metrics
940949
local_metrics.zero_()
941950

@@ -989,7 +998,7 @@ def load_model():
989998
path=args.output_dir,
990999
leaderboard_name=args.hf_repo_revision,
9911000
oe_eval_max_length=args.oe_eval_max_length,
992-
wandb_url=wandb_tracker.run.get_url(),
1001+
wandb_url=wandb_tracker.run.get_url() if args.with_tracking else None,
9931002
oe_eval_tasks=args.oe_eval_tasks,
9941003
gs_bucket_path=args.gs_bucket_path,
9951004
)

open_instruct/finetune.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
is_beaker_job,
6363
launch_ai2_evals_on_weka,
6464
maybe_get_beaker_config,
65+
maybe_update_beaker_description,
6566
maybe_use_ai2_hf_entity,
6667
maybe_use_ai2_wandb_entity,
6768
)
@@ -438,6 +439,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
438439
},
439440
)
440441
wandb_tracker = accelerator.get_tracker("wandb")
442+
maybe_update_beaker_description(wandb_url=wandb_tracker.run.get_url())
441443
else:
442444
wandb_tracker = None # for later eval launching
443445

@@ -727,7 +729,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
727729
local_total_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
728730
local_pred_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
729731
total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
730-
start_time = time.time()
732+
start_time = time.perf_counter()
731733
skipped_batches = False
732734
for epoch in range(starting_epoch, args.num_train_epochs):
733735
model.train()
@@ -824,10 +826,12 @@ def main(args: FlatArguments, tc: TokenizerConfig):
824826
"avg_tokens_per_batch": avg_tokens_per_batch,
825827
"avg_tokens_per_batch_including_padding": avg_tokens_per_batch_including_padding,
826828
"avg_pred_tokens_per_batch": avg_pred_tokens_per_batch,
827-
"per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time),
829+
"per_device_tps": total_tokens
830+
/ accelerator.num_processes
831+
/ (time.perf_counter() - start_time),
828832
"per_device_tps_including_padding": total_tokens_including_padding
829833
/ accelerator.num_processes
830-
/ (time.time() - start_time),
834+
/ (time.perf_counter() - start_time),
831835
"reserved_mem_GiB": torch.cuda.max_memory_reserved(device=torch.cuda.current_device()) / 2**30,
832836
"allocated_mem_GiB": torch.cuda.max_memory_allocated(device=torch.cuda.current_device())
833837
/ 2**30,
@@ -855,7 +859,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
855859
avg_loss = sum_loss / total_fwd_passes
856860
metrics_to_log["train_loss"] = avg_loss
857861
if args.verbose:
858-
sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)
862+
sec_per_step = (time.perf_counter() - start_time) / (completed_steps - resume_step)
859863
steps_remaining = args.max_train_steps - completed_steps
860864
secs_remaining = steps_remaining * sec_per_step
861865
accelerator.print(
@@ -869,17 +873,23 @@ def main(args: FlatArguments, tc: TokenizerConfig):
869873
/ args.logging_steps
870874
)
871875
logger.info(
872-
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, Aux Loss: {avg_aux_loss}, TPS: {total_tokens / (time.time() - start_time)}"
876+
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, Aux Loss: {avg_aux_loss}, TPS: {total_tokens / (time.perf_counter() - start_time)}"
873877
)
874878
metrics_to_log["aux_loss"] = avg_aux_loss
875879
else:
876880
logger.info(
877-
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}"
881+
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.perf_counter() - start_time)}"
878882
)
879883
if args.verbose:
880884
accelerator.print(f"{metrics_to_log=}")
881885
if args.with_tracking:
882886
accelerator.log(metrics_to_log, step=completed_steps)
887+
maybe_update_beaker_description(
888+
current_step=completed_steps,
889+
total_steps=args.max_train_steps,
890+
start_time=start_time,
891+
wandb_url=wandb_tracker.run.get_url() if wandb_tracker is not None else None,
892+
)
883893
total_loss = 0
884894
total_aux_loss = 0
885895

scripts/train/debug/dpo.sh

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1-
python mason.py \
1+
#!/bin/bash
2+
BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}"
3+
4+
uv run python mason.py \
25
--cluster ai2/neptune \
6+
--cluster ai2/saturn \
7+
--cluster ai2/jupiter \
8+
--cluster ai2/prior \
9+
--description "Single GPU DPO run, for debugging purposes." \
310
--workspace ai2/tulu-thinker \
411
--priority high \
5-
--image nathanl/open_instruct_auto --pure_docker_mode \
12+
--image "$BEAKER_IMAGE" \
13+
--pure_docker_mode \
614
--preemptible \
715
--num_nodes 1 \
816
--budget ai2/oe-adapt \
@@ -26,5 +34,6 @@ python mason.py \
2634
--logging_steps 1 \
2735
--dataset_mixer_list allenai/tulu-3-wildchat-reused-on-policy-8b 100 \
2836
--add_bos \
37+
--chat_template_name olmo \
2938
--seed 123
30-
# --with_tracking
39+
# --with_tracking

0 commit comments

Comments
 (0)