diff --git a/apex/contrib/openfold_triton/__init__.py b/apex/contrib/openfold_triton/__init__.py index 87a01236a..0edd803a2 100644 --- a/apex/contrib/openfold_triton/__init__.py +++ b/apex/contrib/openfold_triton/__init__.py @@ -1,13 +1,14 @@ # © 2023 NVIDIA CORPORATION & AFFILIATES -import pickle +import json +import warnings from collections import OrderedDict from copy import deepcopy from io import BytesIO from typing import BinaryIO, Union import torch -from triton.runtime.autotuner import Autotuner, Heuristics +from triton.runtime.autotuner import Autotuner, Config, Heuristics from triton.runtime.jit import JITFunction from apex.contrib.openfold_triton._layer_norm_backward_kernels import ( @@ -58,23 +59,27 @@ def _get_tuneable_triton_func_name(f: Union[Autotuner, Heuristics, JITFunction]) ) -def _save_triton_auto_tune_cache(f: BinaryIO, verbose: bool = False) -> None: +def _save_triton_auto_tune_cache(strict: bool = True, verbose: bool = False) -> BytesIO: caches = OrderedDict() for func_name, func in _tuneable_triton_kernels.items(): if len(func.cache) < 1: - raise ValueError( - f"Triton JIT kernel {func.__name__} didn't have tuning cache" - ) - caches[func_name] = deepcopy(func.cache) - pickle.dump(caches, f) + msg = f"Triton JIT kernel {func_name} didn't have tuning cache" + if strict: + raise ValueError(msg) + else: + warnings.warn(msg) + else: + caches[func_name] = [(keys, vals.all_kwargs()) for keys, vals in zip(func.cache.keys(), func.cache.values())] + f = BytesIO(json.dumps(caches).encode('utf-8')) if verbose: print(f"Triton kernel auto-tuning caches written to {f}") + return f def _load_triton_auto_tune_cache( f: BinaryIO, strict: bool = True, verbose: bool = False ) -> None: - caches = pickle.load(f) + caches = json.load(f) if strict: loaded_func_name = set(caches.keys()) tuneable_func_name = set(_tuneable_triton_kernels.keys()) @@ -84,23 +89,23 @@ def _load_triton_auto_tune_cache( f"Missing kernel caches: {tuneable_func_name - loaded_func_name}\n" f"Unexpected kernel caches: {loaded_func_name - tuneable_func_name}" ) - for func_name, cache in caches.items(): + for func_name, func_cache in caches.items(): if func_name not in _tuneable_triton_kernels: raise ValueError( f"{func_name} from {f} doesn't match any tuneable Triton kernels" ) - _tuneable_triton_kernels[func_name].cache = cache + for key, val in func_cache: + _tuneable_triton_kernels[func_name].cache[tuple(key)] = Config(val) if verbose: print(f"Triton kernel auto-tuning caches loaded from {f}") -def sync_triton_auto_tune_cache_across_gpus() -> None: +def sync_triton_auto_tune_cache_across_gpus(strict: bool = True, verbose: bool = False) -> None: if not torch.distributed.is_initialized(): return if torch.distributed.get_rank() == 0: print("Broadcasting Triton auto-tuning cache from rank 0 to other ranks...") - cache = BytesIO() - _save_triton_auto_tune_cache(cache) + cache = _save_triton_auto_tune_cache(strict=strict, verbose=verbose) cache.seek(0) cache_list = [ cache, @@ -113,6 +118,5 @@ def sync_triton_auto_tune_cache_across_gpus() -> None: None, ] torch.distributed.broadcast_object_list(cache_list) - cache = cache_list[0] - _load_triton_auto_tune_cache(cache) + _load_triton_auto_tune_cache(cache_list[0], strict=strict, verbose=verbose) print("Succeed!") diff --git a/apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py b/apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py new file mode 100644 index 000000000..94a2ca23e --- /dev/null +++ b/apex/contrib/test/openfold_triton/test_sync_triton_auto_tune_cache_across_gpus.py @@ -0,0 +1,97 @@ +import os +import torch +import torch.distributed as dist +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_nccl, + skip_if_lt_x_gpu, + run_tests, +) +from apex.contrib.openfold_triton import ( + LayerNormSmallShapeOptImpl, + sync_triton_auto_tune_cache_across_gpus, + _tuneable_triton_kernels, +) + +class SyncTritonAutoTuneCacheTest(MultiProcessTestCase): + device_type = "cuda" + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def tearDown(self) -> None: + torch.cuda.synchronize() + torch.cuda.empty_cache() + super().tearDown() + + @property + def world_size(self) -> int: + return min(torch.cuda.device_count(), 2) + + @property + def init_method(self): + return f"{common_utils.FILE_SCHEMA}{self.file_name}" + + @property + def destroy_pg_upon_exit(self) -> bool: + return True + + def _create_process_group_nccl(self): + def maybe_export(env, val): + if not type(env) == str: + raise ValueError(f"Type of type of env is expected to be str, but got {type(env)}") + if not type(val) == str: + raise ValueError(f"Type of type of val is expected to be str, but got {type(val)}") + if os.getenv(env) is None: + os.environ[env] = val + + maybe_export("MASTER_PORT", "29500") + maybe_export("MASTER_ADDR", "localhost") + + # create nccl processgroup for two ranks + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + ) + pg = dist.distributed_c10d._get_default_group() + return pg + + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_sync_triton_auto_tune_cache_across_gpus(self): + pg = self._create_process_group_nccl() + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") + torch.cuda.set_device(device) + + if self.rank == 0: + eps = 1e-5 + normalized_shape = (128, 64,) + + weight = torch.ones(normalized_shape, device=device, requires_grad=True) + bias= torch.zeros(normalized_shape, device=device, requires_grad=True) + + x = torch.randn((2, 2,) + normalized_shape, device=device) + y = LayerNormSmallShapeOptImpl.apply( + x, normalized_shape, weight, bias, eps + ) + l = torch.sum(y) + l.backward() + + sync_triton_auto_tune_cache_across_gpus(strict=False, verbose=True) + + caches_synced = 0 + for func_name, func in _tuneable_triton_kernels.items(): + if len(func.cache) > 0: + caches_synced = caches_synced + 1 + print(f"caches were synchronized for {func_name} at rank = {self.rank}:", func.cache) + + self.assertTrue(caches_synced > 0) + + +if __name__ == '__main__': + run_tests()