Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ mkdir -p datasets
uv run python scripts/subset_data.py --dataset_name PrimeIntellect/fineweb-edu --data_world_size 1 --data_rank 0 --max_shards 32
mv fineweb-edu/ datasets/fineweb-edu/
```

## Compatibility with AMD GPUs
When using AMD GPUs, during dependency download use the following:
```bash
uv sync --extra rocm --extra all
```
Also when executing the run commands, use `uv run --extra rocm` as otherwise the uv environment falls back to default configuration.

### Quick Check

Expand Down
26 changes: 26 additions & 0 deletions configs/150M/MI250.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name_model = "150M"
project = "debug_150m_prime"
type_model = "llama2"

[train]
micro_bs = 64 # change this base on the gpu
reshard_after_forward = false
torch_profiler = false

[optim]
batch_size = 256
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 4e-4


[diloco]
inner_steps = 500
compression = "uint8"


[ckpt]
path = "/ckpt/outputs_1b_diloco_node_1_compress"
interval = 44000
24 changes: 24 additions & 0 deletions configs/150M/MI300.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name_model = "150M"
project = "150m_prime_300"
type_model = "llama2"

[train]
micro_bs = 64 # change this base on the gpu
reshard_after_forward = false
torch_profiler = false

[optim]
batch_size = 512
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 4e-4

[diloco]
inner_steps = 500
compression = "uint8"

[ckpt]
path = "/ckpt"
interval = 44000
19 changes: 19 additions & 0 deletions configs/1B/MI250.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name_model = "1B"
project = "uneven_configuration"
type_model = "llama2"

[train]
micro_bs = 32
reshard_after_forward = true

[optim]
batch_size = 512
warmup_steps = 1000
total_steps = 88_000

[optim.optim]
lr = 7e-4

[diloco]
inner_steps = 500
compression = "uint8"
34 changes: 31 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ description = "ZeroBand is a production ready codebase for decentralized trainin
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"torch==2.5.1",
"torch==2.6.0; extra!='rocm'",
"numpy",
"setuptools",
"transformers>=4.44.2",
"datasets>=3.0.0",
"pydantic_config @ git+https://github.com/samsja/pydantic_config.git@b7becc3",
"pydantic_config[toml] @ git+https://github.com/samsja/pydantic_config.git@b7becc3",
"torchdata>=0.8.0",
"fsspec[gcs]>=2024.3.1",
"ninja",
Expand All @@ -25,7 +25,9 @@ dependencies = [
[project.optional-dependencies]

all = ["wandb","lm-eval"]

cuda=[ "torch==2.6.0"]
rocm = [ "torch==2.6.0;extra=='rocm'",
"pytorch-triton-rocm==3.2.0; sys_platform == 'linux' and extra=='rocm'"]

[build-system]
requires = ["hatchling"]
Expand All @@ -39,3 +41,29 @@ line-length = 120

[tool.uv]
dev-dependencies = ["ruff>=0.5.0", "pre-commit>=3.0.0","pytest>=7.0.0", "faker"]
conflicts = [
[
{ extra = "cuda" },
{ extra = "rocm" },
],
]

[tool.uv.sources]
torch = [
{ index = "pytorch-rocm" ,extra="rocm" },
{index = "pytorch-cuda",extra="cuda"}
]

pytorch-triton-rocm = [
{ index = "pytorch-rocm" ,extra="rocm"},
]

[[tool.uv.index]]
name = "pytorch-rocm"
url = "https://download.pytorch.org/whl/rocm6.2.4"
explicit = true

[[tool.uv.index]]
name = "pytorch-cuda"
url = "https://download.pytorch.org/whl/cu126"
explicit = true
29 changes: 26 additions & 3 deletions scripts/install/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,33 @@ log_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}

detect_gpu_vendor() {
if command -v rocm-smi &> /dev/null; then
echo "rocm"
elif command -v nvidia-smi &> /dev/null; then
echo "nvidia"
else
echo "none"
fi
}

main() {
# Check if sudo is installed
if ! command -v sudo &> /dev/null; then
apt update
apt install sudo -y
fi

log_info "Detecting GPU vendor..."
GPU_VENDOR=$(detect_gpu_vendor)
log_info "Detected GPU vendor: $GPU_VENDOR"

if [ "$GPU_VENDOR" = "rocm" ]; then
EXTRA_ARGS="--extra rocm"
else
EXTRA_ARGS=""
fi

log_info "Updating apt..."
sudo apt update

Expand Down Expand Up @@ -47,17 +67,20 @@ main() {
source .venv/bin/activate

log_info "Installing dependencies..."
uv sync --extra all
uv sync $EXTRA_ARGS --extra all

log_info "Updating git submodules..."
git submodule update --init --recursive

log_info "Downloading data..."
uv run $EXTRA_ARGS python scripts/subset_data.py \
--dataset_name PrimeIntellect/fineweb-edu \
--data_world_size 1 --data_rank 0 --max_shards 128

mkdir -p datasets
uv run python scripts/subset_data.py --dataset_name PrimeIntellect/fineweb-edu --data_world_size 1 --data_rank 0 --max_shards 128
mv fineweb-edu/ datasets/fineweb-edu/

log_info "Installation completed! You can double check that everything is install correctly by running 'GLOO_SOCKET_IFNAME=lo GLOBAL_ADDR=localhost GLOBAL_RANK=0 GLOBAL_UNIQUE_ID=0 GLOBAL_WORLD_SIZE=1 GLOBAL_PORT=8989 uv run torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/diloco.toml'"
log_info "Installation completed! You can double check that everything is install correctly by running 'GLOO_SOCKET_IFNAME=lo GLOBAL_ADDR=localhost GLOBAL_RANK=0 GLOBAL_UNIQUE_ID=0 GLOBAL_WORLD_SIZE=1 GLOBAL_PORT=8989 uv run $EXTRA_ARGS torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/diloco.toml'"
}

main
31 changes: 15 additions & 16 deletions src/zeroband/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,26 @@
import torch
import torch.nn as nn

import torch.version
import triton
import triton.language as tl

from torch.distributed._tensor import Partial, Replicate, Shard
from torch.distributed._tensor.experimental import local_map


warp_config=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16)
]

rocm = torch.version.hip is not None
if not rocm:
warp_config.append(triton.Config({}, num_warps=32))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation:
The kernel's warp setting of 32 warps per block is incompatible with AMD GPUs due to AMDs larger warp size compared to NVIDIA GPUs. AMD GPUs have 64 threads per warp.

To resolve the issue, the number of warps configured for AMD GPUs should not exceed 16. Therefore, the line in the code setting num_warps to 32 should be disabled for AMD GPUs.

This will be solved in a new triton release, but having this piece of code gives us backward compatibility as well.

def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Builds the specified normalization layer based on the norm_type.
Expand Down Expand Up @@ -114,14 +127,7 @@ def reset_parameters(self):


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
configs=warp_config,
key=["N"],
)
@triton.jit
Expand Down Expand Up @@ -162,14 +168,7 @@ def _rms_norm_fwd_kernel(


@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
configs=warp_config,
key=["N"],
)
@triton.jit
Expand Down
13 changes: 11 additions & 2 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,18 @@ def get_sharding_strategy(sharding_strategy: str) -> ShardingStrategy:
### code above inspired and copied from https://github.com/pytorch/torchtitan/blob/4b3f2e41a084bf79a8540068ed525539d1244edd/torchtitan/utils.py#L119


# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
# hardcoded BF16 type peak flops for NVIDIA and AMD GPUs
def get_peak_flops(device_name: str) -> int:
if "A100" in device_name:
if "MI250" in device_name:
#data from https://rocm.docs.amd.com/en/latest/conceptual/gpu-arch/mi250.html
return 362e12
elif "MI300" in device_name:
#data from https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf
return 1307e12
elif "MI355" in device_name:
#data from https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/product-briefs/amd-instinct-mi355x-gpu-brochure.pdf
return 2.5e15
elif "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
Expand Down
Loading