Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "ZeroBand is a production ready codebase for decentralized trainin
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.5.1",
"torch==2.6.0",
"numpy",
"setuptools",
"transformers>=4.44.2",
Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.distributed.checkpoint.stateful import Stateful
import warnings
import logging
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor
from zeroband.utils.state_dict_send_recv import (
_get_sendable_state_dict,
recv_state_dict,
Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from zeroband.utils.logger import get_logger
from zeroband.config import DilocoConfig
import torch.distributed as dist
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor
from functools import lru_cache


Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE
from torch.nn.attention import SDPBackend, sdpa_kernel

_flex_attention_compiled = torch.compile(flex_attention, dynamic=False)
_flex_attention_compiled = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")


# copied from https://github.com/pytorch/torchtune/blob/f2bd4bc25b24587aef40f486087412b9da8f1d94/torchtune/modules/attention_utils.py#L27
Expand Down
23 changes: 12 additions & 11 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore
from torch.autograd.profiler import record_function

from zeroband.checkpoint import CkptManager, TrainingProgress
Expand Down Expand Up @@ -70,10 +70,9 @@ def log_hash_training_state(
logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")

metrics.update({
f"outer_optimizer_hash_{id}": outer_optimizer_hash,
f"outer_model_hash_{id}": outer_model_hash
})
metrics.update(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)
if world_info.rank == 0:
assert metric_logger is not None
metric_logger.log(metrics)
Expand Down Expand Up @@ -142,13 +141,11 @@ def train(config: Config):
apply_ac_ckpt(model, num)

elastic_device_mesh = ElasticDeviceMesh(
enable=config.diloco is not None,
live_recovery_rank_src=config.ckpt.live_recovery_rank_src
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
)

offload_policy = CPUOffloadPolicy(pin_memory=True) if config.train.fsdp_cpu_offload else None
Expand Down Expand Up @@ -365,9 +362,13 @@ def train(config: Config):

with sw.record_block("Loss allreduce()"):
# Launch both allreduces at the same time to hide latency
loss_allreduce = dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True)
loss_allreduce = dist.all_reduce(
tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True
)
if config.optim.z_loss:
z_loss_allreduce = dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True)
z_loss_allreduce = dist.all_reduce(
tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True
)

assert isinstance(loss_allreduce, torch.distributed.Work)
loss_allreduce.wait()
Expand Down
4 changes: 2 additions & 2 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import torch
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor
from distributed_shampoo import DistributedShampoo


Expand Down Expand Up @@ -193,4 +193,4 @@ def __init__(self):
self.pad_token_id = 2

def __len__(self):
return self.vocab_size
return self.vocab_size
2 changes: 1 addition & 1 deletion src/zeroband/utils/state_dict_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import torch
from torch.distributed import ProcessGroup
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor


def _object_to_tensor(obj):
Expand Down
Loading
Loading