-
Notifications
You must be signed in to change notification settings - Fork 93
Expand file tree
/
Copy pathall_reduce.py
More file actions
69 lines (49 loc) · 1.76 KB
/
all_reduce.py
File metadata and controls
69 lines (49 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pydantic_config import BaseConfig, parse_argv
import torch
from torch.distributed import destroy_process_group, init_process_group, ReduceOp
import torch.utils.benchmark as benchmark
from zeroband.collectives import Compression, all_reduce
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logger import get_logger
from enum import Enum
class TorchDtype(str, Enum):
FLOAT32 = "float32"
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
UINT8 = "uint8"
TORCH_DTYPE_MAP = {
None: None,
TorchDtype.FLOAT32: torch.float32,
TorchDtype.FLOAT16: torch.float16,
TorchDtype.BFLOAT16: torch.bfloat16,
TorchDtype.UINT8: torch.uint8,
}
class Config(BaseConfig):
size_model: int = int(1e7)
n_iters: int = 4
compression: Compression = Compression.NO
def main(config: Config):
world_info = get_world_info()
mat = torch.rand(1, config.size_model)
logger.info(
f"\n ======== Benchmark all reduce between {world_info.world_size} gpus over {world_info.nnodes} nodes =========\n"
)
t0 = benchmark.Timer(
stmt="compressed_all_reduce(compression, mat, op=op)",
globals={
"compressed_all_reduce": all_reduce,
"mat": mat,
"compression": config.compression,
"op": ReduceOp.SUM,
},
)
measured_time = t0.timeit(config.n_iters).mean
bandwidth = config.size_model * 4 / 1e6 / measured_time
logger.info(f"Average time per iteration: {measured_time:.2f} seconds, Average bandwidth: {bandwidth:.4f} MB/s")
if __name__ == "__main__":
config = Config(**parse_argv())
torch.set_float32_matmul_precision("high")
init_process_group(backend="gloo")
logger = get_logger()
main(config)
destroy_process_group()