Skip to content

Commit a3547aa

Browse files
committed
update fsdp import
1 parent 9d9ea1f commit a3547aa

File tree

5 files changed

+17
-16
lines changed

5 files changed

+17
-16
lines changed

src/zeroband/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torch.distributed.checkpoint.stateful import Stateful
2929
import warnings
3030
import logging
31-
from torch.distributed._tensor.api import DTensor
31+
from torch.distributed.tensor import DTensor
3232
from zeroband.utils.state_dict_send_recv import (
3333
_get_sendable_state_dict,
3434
recv_state_dict,

src/zeroband/diloco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from zeroband.utils.logging import get_logger
99
from zeroband.config import DilocoConfig
1010
import torch.distributed as dist
11-
from torch.distributed._tensor.api import DTensor
11+
from torch.distributed.tensor import DTensor
1212
from functools import lru_cache
1313

1414

src/zeroband/train.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
import torch.distributed as dist
8-
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore
8+
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore
99
from torch.autograd.profiler import record_function
1010

1111
from zeroband.checkpoint import CkptManager, TrainingProgress
@@ -69,10 +69,9 @@ def log_hash_training_state(
6969
logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
7070
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")
7171

72-
metrics.update({
73-
f"outer_optimizer_hash_{id}": outer_optimizer_hash,
74-
f"outer_model_hash_{id}": outer_model_hash
75-
})
72+
metrics.update(
73+
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
74+
)
7675
if world_info.rank == 0:
7776
assert metric_logger is not None
7877
metric_logger.log(metrics)
@@ -139,13 +138,11 @@ def train(config: Config):
139138
apply_ac_ckpt(model, num)
140139

141140
elastic_device_mesh = ElasticDeviceMesh(
142-
enable=config.diloco is not None,
143-
live_recovery_rank_src=config.ckpt.live_recovery_rank_src
141+
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
144142
)
145143

146144
mp_policy = MixedPrecisionPolicy(
147-
param_dtype=torch.bfloat16,
148-
reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
145+
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
149146
)
150147

151148
offload_policy = CPUOffloadPolicy(pin_memory=True) if config.train.fsdp_cpu_offload else None
@@ -366,9 +363,13 @@ def train(config: Config):
366363
with record_function("Inner allreduce"):
367364
logger.debug("loss allreduce()")
368365
# Launch both allreduces at the same time to hide latency
369-
loss_allreduce = dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True)
366+
loss_allreduce = dist.all_reduce(
367+
tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True
368+
)
370369
if config.optim.z_loss:
371-
z_loss_allreduce = dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True)
370+
z_loss_allreduce = dist.all_reduce(
371+
tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True
372+
)
372373

373374
assert isinstance(loss_allreduce, torch.distributed.Work)
374375
loss_allreduce.wait()

src/zeroband/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import torch
55
from torch.distributed.fsdp import ShardingStrategy
6-
from torch.distributed._tensor.api import DTensor
6+
from torch.distributed.tensor import DTensor
77
from distributed_shampoo import DistributedShampoo
88

99

@@ -193,4 +193,4 @@ def __init__(self):
193193
self.pad_token_id = 2
194194

195195
def __len__(self):
196-
return self.vocab_size
196+
return self.vocab_size

src/zeroband/utils/state_dict_send_recv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pickle
33
import torch
44
from torch.distributed import ProcessGroup
5-
from torch.distributed._tensor.api import DTensor
5+
from torch.distributed.tensor import DTensor
66

77

88
def _object_to_tensor(obj):

0 commit comments

Comments
 (0)