Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
768e4ec
silence stopwatch log
samsja Mar 21, 2025
d7228d6
remove future to be useless code becase pccl rocks
samsja Mar 21, 2025
a0137b6
remove log hash
samsja Mar 26, 2025
3f9e980
remove profiler stuff
samsja Mar 26, 2025
8866d8c
rmove stuff
samsja Mar 26, 2025
da199b6
rmove stuff
samsja Mar 26, 2025
47b5317
remove stuff
samsja Mar 26, 2025
c4d7ab3
remove gloo
samsja Mar 26, 2025
c81d2b4
refactor metric logger
samsja Mar 26, 2025
689cc19
refactor perf counter
samsja Mar 26, 2025
3b81752
fix configs
samsja Mar 26, 2025
d89fd44
fix reduce scatter
samsja Mar 26, 2025
906ef2b
remove ce and z loss
samsja Mar 26, 2025
ff3a31d
remove shanpoo
samsja Mar 26, 2025
2c4b4cc
update depenendcies
samsja Mar 26, 2025
fb20315
delete retry_all_reduce config variable
mikex86 Mar 26, 2025
098b046
bump to torch 2.6
samsja Mar 26, 2025
f8ae249
use new fdsp2 public api
samsja Mar 26, 2025
384a172
silence tests flex
samsja Mar 26, 2025
7b2fbd1
Remove hf lrscheds (#234)
mikex86 Mar 27, 2025
695771a
minor cleanup
mikex86 Mar 27, 2025
7ccc584
delete csrc
mikex86 Mar 27, 2025
51d582c
more cleanup
mikex86 Mar 27, 2025
3f7b516
fix imports
mikex86 Mar 27, 2025
1aa5ba8
move FakeTokenizer
mikex86 Mar 27, 2025
f4d96b8
fix imports
mikex86 Mar 27, 2025
5008d10
make memory_profiler optional and not rely on branchy definition
mikex86 Mar 27, 2025
90a187b
use better profiler
mikex86 Mar 27, 2025
59d55d8
remove unused varaible
mikex86 Mar 27, 2025
48e985f
new mfu tracker
mikex86 Mar 28, 2025
3f261f0
fix unused
mikex86 Mar 28, 2025
2ead8a4
fix unused
mikex86 Mar 28, 2025
0759c49
preserve default value semantics for fwd
mikex86 Mar 28, 2025
39865bf
preserve default value semantics for fwd
mikex86 Mar 28, 2025
e533f4a
accept hw config in apply_sharding
mikex86 Mar 28, 2025
3134287
accept hw config in apply_sharding
mikex86 Mar 28, 2025
2c1d841
add FlopCounter docstring
mikex86 Mar 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ uv run pytest
To eval you need first to convert the checkpoint to a huggingface compatible model.

```bash
uv run python scripts/export_dcp.py @configs/10B/H100.toml --ckpt.path CONVERTED_MODEL_PATH --ckpt.resume CHECKPOINT_PATH --torch_dtype bfloat16 --ckpt.interval 1
uv run python scripts/export_dcp.py @configs/10B/H100_simple.toml --ckpt.path CONVERTED_MODEL_PATH --ckpt.resume CHECKPOINT_PATH --torch_dtype bfloat16 --ckpt.interval 1
```


Expand Down Expand Up @@ -178,7 +178,7 @@ You may also pass the `torch_dtype` argument to either `float32` or `bfloat16` t

Example export command:
```bash
python scripts/export_dcp.py @configs/10B/H100.toml --ckpt.path /path/to/save/converted_model --ckpt.resume /path/to/ckpt/step_84000 --torch_dtype bfloat16
python scripts/export_dcp.py @configs/10B/H100_simple.toml --ckpt.path /path/to/save/converted_model --ckpt.resume /path/to/ckpt/step_84000 --torch_dtype bfloat16
```

You can then upload the model to huggingface using huggingface-cli:
Expand Down
40 changes: 0 additions & 40 deletions configs/10B/H100_cooldown.toml

This file was deleted.

34 changes: 0 additions & 34 deletions configs/10B/H100_devel.toml

This file was deleted.

29 changes: 15 additions & 14 deletions configs/10B/H100.toml → configs/10B/H100_intellect1.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
name_model = "10B"
project = "10B_zero_band"
model_name = "10B"
model_type = "llama3"

wandb_resume = false

[train]
micro_bs = 1
ac_ckpt = true
[hardware]
micro_batch_size = 1
act_ckpt = true

[optim]
sched_type = "wsd-sqrt"
[train]
batch_size = 128 #1M tokens bs
warmup_steps = 1000
total_steps = 1_000_000_000_000


z_loss = true

[optim.optim]
[train.lr_scheduler]
decay_type = "sqrt"
lr = 7.5e-5
end_lr = 0.0
num_warmup_steps = 1000
num_stable_steps = 70_000
num_decay_steps = 30_000

[train.optimizer]
betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1
Expand All @@ -36,6 +39,4 @@ compression = "uint8"

[ckpt]
interval = 100
topk = 40
path = "/data/10B"
remote_data_path = "/data/10B_data_ckpt"
27 changes: 15 additions & 12 deletions configs/10B/H100_simple.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
name_model = "10B"
project = "debug_10B_zero_band"
model_name = "10B"
model_type = "llama3"

[train]
micro_bs = 1
ac_ckpt = true

[optim]
sched_type = "wsd-sqrt"
batch_size = 128 #1M tokens bs
warmup_steps = 1000
total_steps = 1_000_000_000_000
[hardware]
micro_batch_size = 1
act_ckpt = true

z_loss = true
[train]
batch_size = 128 #1M tokens bs

[optim.optim]
[train.lr_scheduler]
decay_type = "sqrt"
lr = 7.5e-5
end_lr = 0.0
num_warmup_steps = 1000
num_decay_steps = 1_000_000_000_000

[train.optimizer]
type = 'adamw'
betas1 = 0.9
betas2 = 0.95
weight_decay = 0.1
Expand Down
20 changes: 9 additions & 11 deletions configs/13B/H100.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
name_model = "13B"
project = "debug_13B_zero_band"

[train]
micro_bs = 1
ac_ckpt = true
model_name = "13B"
model_type = "llama2"

[optim]
batch_size = 1024 #2M tokens bs
warmup_steps = 1000
total_steps = 88_000
[hardware]
micro_batch_size = 64
reshard_after_forward = false

[optim.optim]
lr = 3e-4
[train]
batch_size = 512

[data]
seq_length = 2048
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"
17 changes: 0 additions & 17 deletions configs/150M/3090.toml

This file was deleted.

21 changes: 21 additions & 0 deletions configs/150M/A100_debug.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
project = "debug_150m_zero_band"

model_name = "150M"
model_type = "llama2"

wandb = false

[hardware]
micro_batch_size = 64
torch_compile = true

[train]
batch_size = 512

[train.lr_scheduler]
num_warmup_steps = 10
num_decay_steps = 1000

[data]
fake = true

16 changes: 0 additions & 16 deletions configs/150M/A40.toml

This file was deleted.

19 changes: 9 additions & 10 deletions configs/150M/H100.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
name_model = "150M"
project = "debug_150m_zero_band"
type_model = "llama2"

[train]
micro_bs = 64 # change this base on the gpu
model_name = "150M"
model_type = "llama2"

[hardware]
micro_batch_size = 64
reshard_after_forward = false

[optim]
[train]
batch_size = 512
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 4e-4

[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"
20 changes: 12 additions & 8 deletions configs/150M/H100-fast.toml → configs/150M/H100_best.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
name_model = "150M"
project = "debug_150m_zero_band"
type_model = "llama2"
model_name = "150M"
model_type = "llama2"

[train]
micro_bs = 64 # change this base on the gpu
[hardware]
micro_batch_size = 64
reshard_after_forward = false

[optim]
[train]
batch_size = 512
warmup_steps = 278
total_steps = 8192

[optim.optim]
[train.lr_scheduler]
decay_type = 'cosine'
num_warmup_steps = 278
num_decay_steps = 7914 # 278 + 7914 = 8192
lr = 0.003551730141097694

[train.optimizer]
type = 'adamw'
betas1 = 0.9454835470717078
betas2 = 0.9190488086654895
weight_decay = 0.24530252977858977
Expand Down
16 changes: 0 additions & 16 deletions configs/150M_short/3090.toml

This file was deleted.

17 changes: 0 additions & 17 deletions configs/150M_short/A40.toml

This file was deleted.

16 changes: 0 additions & 16 deletions configs/150M_short/H100.toml

This file was deleted.

22 changes: 11 additions & 11 deletions configs/1B/H100.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
name_model = "1B"
project = "debug_1B_zero_band"
type_model = "llama2"

[train]
micro_bs = 32
reshard_after_forward = true
model_name = "1B"
model_type = "llama2"

[hardware]
micro_batch_size = 64
reshard_after_forward = false

[optim]
batch_size = 1024
warmup_steps = 1000
total_steps = 8192
[train]
batch_size = 512

[optim.optim]
lr = 7e-4
[data]
seq_length = 1024
dataset_name_or_paths = "datasets/fineweb-edu"
Loading
Loading