Skip to content

[Contrib][Openfold Triton] Use json instead of pickle #1900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions apex/contrib/openfold_triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# © 2023 NVIDIA CORPORATION & AFFILIATES

import pickle
import json
import warnings
from collections import OrderedDict
from copy import deepcopy
from io import BytesIO
from typing import BinaryIO, Union

import torch
from triton.runtime.autotuner import Autotuner, Heuristics
from triton.runtime.autotuner import Autotuner, Config, Heuristics
from triton.runtime.jit import JITFunction

from apex.contrib.openfold_triton._layer_norm_backward_kernels import (
Expand Down Expand Up @@ -58,23 +59,27 @@ def _get_tuneable_triton_func_name(f: Union[Autotuner, Heuristics, JITFunction])
)


def _save_triton_auto_tune_cache(f: BinaryIO, verbose: bool = False) -> None:
def _save_triton_auto_tune_cache(strict: bool = True, verbose: bool = False) -> BytesIO:
caches = OrderedDict()
for func_name, func in _tuneable_triton_kernels.items():
if len(func.cache) < 1:
raise ValueError(
f"Triton JIT kernel {func.__name__} didn't have tuning cache"
)
caches[func_name] = deepcopy(func.cache)
pickle.dump(caches, f)
msg = f"Triton JIT kernel {func_name} didn't have tuning cache"
if strict:
raise ValueError(msg)
else:
warnings.warn(msg)
else:
caches[func_name] = [(keys, vals.all_kwargs()) for keys, vals in zip(func.cache.keys(), func.cache.values())]
f = BytesIO(json.dumps(caches).encode('utf-8'))
if verbose:
print(f"Triton kernel auto-tuning caches written to {f}")
return f


def _load_triton_auto_tune_cache(
f: BinaryIO, strict: bool = True, verbose: bool = False
) -> None:
caches = pickle.load(f)
caches = json.load(f)
if strict:
loaded_func_name = set(caches.keys())
tuneable_func_name = set(_tuneable_triton_kernels.keys())
Expand All @@ -84,23 +89,23 @@ def _load_triton_auto_tune_cache(
f"Missing kernel caches: {tuneable_func_name - loaded_func_name}\n"
f"Unexpected kernel caches: {loaded_func_name - tuneable_func_name}"
)
for func_name, cache in caches.items():
for func_name, func_cache in caches.items():
if func_name not in _tuneable_triton_kernels:
raise ValueError(
f"{func_name} from {f} doesn't match any tuneable Triton kernels"
)
_tuneable_triton_kernels[func_name].cache = cache
for key, val in func_cache:
_tuneable_triton_kernels[func_name].cache[tuple(key)] = Config(val)
if verbose:
print(f"Triton kernel auto-tuning caches loaded from {f}")


def sync_triton_auto_tune_cache_across_gpus() -> None:
def sync_triton_auto_tune_cache_across_gpus(strict: bool = True, verbose: bool = False) -> None:
if not torch.distributed.is_initialized():
return
if torch.distributed.get_rank() == 0:
print("Broadcasting Triton auto-tuning cache from rank 0 to other ranks...")
cache = BytesIO()
_save_triton_auto_tune_cache(cache)
cache = _save_triton_auto_tune_cache(strict=strict, verbose=verbose)
cache.seek(0)
cache_list = [
cache,
Expand All @@ -113,6 +118,5 @@ def sync_triton_auto_tune_cache_across_gpus() -> None:
None,
]
torch.distributed.broadcast_object_list(cache_list)
cache = cache_list[0]
_load_triton_auto_tune_cache(cache)
_load_triton_auto_tune_cache(cache_list[0], strict=strict, verbose=verbose)
print("Succeed!")
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import torch
import torch.distributed as dist
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
run_tests,
)
from apex.contrib.openfold_triton import (
LayerNormSmallShapeOptImpl,
sync_triton_auto_tune_cache_across_gpus,
_tuneable_triton_kernels,
)

class SyncTritonAutoTuneCacheTest(MultiProcessTestCase):
device_type = "cuda"
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def setUp(self) -> None:
super().setUp()
self._spawn_processes()

def tearDown(self) -> None:
torch.cuda.synchronize()
torch.cuda.empty_cache()
super().tearDown()

@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)

@property
def init_method(self):
return f"{common_utils.FILE_SCHEMA}{self.file_name}"

@property
def destroy_pg_upon_exit(self) -> bool:
return True

def _create_process_group_nccl(self):
def maybe_export(env, val):
if not type(env) == str:
raise ValueError(f"Type of type of env is expected to be str, but got {type(env)}")
if not type(val) == str:
raise ValueError(f"Type of type of val is expected to be str, but got {type(val)}")
if os.getenv(env) is None:
os.environ[env] = val

maybe_export("MASTER_PORT", "29500")
maybe_export("MASTER_ADDR", "localhost")

# create nccl processgroup for two ranks
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
)
pg = dist.distributed_c10d._get_default_group()
return pg


@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_sync_triton_auto_tune_cache_across_gpus(self):
pg = self._create_process_group_nccl()
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)

if self.rank == 0:
eps = 1e-5
normalized_shape = (128, 64,)

weight = torch.ones(normalized_shape, device=device, requires_grad=True)
bias= torch.zeros(normalized_shape, device=device, requires_grad=True)

x = torch.randn((2, 2,) + normalized_shape, device=device)
y = LayerNormSmallShapeOptImpl.apply(
x, normalized_shape, weight, bias, eps
)
l = torch.sum(y)
l.backward()

sync_triton_auto_tune_cache_across_gpus(strict=False, verbose=True)

caches_synced = 0
for func_name, func in _tuneable_triton_kernels.items():
if len(func.cache) > 0:
caches_synced = caches_synced + 1
print(f"caches were synchronized for {func_name} at rank = {self.rank}:", func.cache)

self.assertTrue(caches_synced > 0)


if __name__ == '__main__':
run_tests()