|
5 | 5 |
|
6 | 6 | import torch |
7 | 7 | import torch.distributed as dist |
8 | | -from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore |
| 8 | +from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore |
9 | 9 | from torch.autograd.profiler import record_function |
10 | 10 |
|
11 | 11 | from zeroband.checkpoint import CkptManager, TrainingProgress |
@@ -69,10 +69,9 @@ def log_hash_training_state( |
69 | 69 | logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") |
70 | 70 | logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") |
71 | 71 |
|
72 | | - metrics.update({ |
73 | | - f"outer_optimizer_hash_{id}": outer_optimizer_hash, |
74 | | - f"outer_model_hash_{id}": outer_model_hash |
75 | | - }) |
| 72 | + metrics.update( |
| 73 | + {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} |
| 74 | + ) |
76 | 75 | if world_info.rank == 0: |
77 | 76 | assert metric_logger is not None |
78 | 77 | metric_logger.log(metrics) |
@@ -139,13 +138,11 @@ def train(config: Config): |
139 | 138 | apply_ac_ckpt(model, num) |
140 | 139 |
|
141 | 140 | elastic_device_mesh = ElasticDeviceMesh( |
142 | | - enable=config.diloco is not None, |
143 | | - live_recovery_rank_src=config.ckpt.live_recovery_rank_src |
| 141 | + enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src |
144 | 142 | ) |
145 | 143 |
|
146 | 144 | mp_policy = MixedPrecisionPolicy( |
147 | | - param_dtype=torch.bfloat16, |
148 | | - reduce_dtype=torch.float32 if config.train.reduce_fp32 else None |
| 145 | + param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None |
149 | 146 | ) |
150 | 147 |
|
151 | 148 | offload_policy = CPUOffloadPolicy(pin_memory=True) if config.train.fsdp_cpu_offload else None |
@@ -366,9 +363,13 @@ def train(config: Config): |
366 | 363 | with record_function("Inner allreduce"): |
367 | 364 | logger.debug("loss allreduce()") |
368 | 365 | # Launch both allreduces at the same time to hide latency |
369 | | - loss_allreduce = dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True) |
| 366 | + loss_allreduce = dist.all_reduce( |
| 367 | + tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True |
| 368 | + ) |
370 | 369 | if config.optim.z_loss: |
371 | | - z_loss_allreduce = dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True) |
| 370 | + z_loss_allreduce = dist.all_reduce( |
| 371 | + tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True |
| 372 | + ) |
372 | 373 |
|
373 | 374 | assert isinstance(loss_allreduce, torch.distributed.Work) |
374 | 375 | loss_allreduce.wait() |
|
0 commit comments