From 7a38ce43e1c8d0c025114f0b420f67be71e1a4cc Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Sat, 2 Aug 2025 19:30:36 -0700 Subject: [PATCH 1/5] Added multi-layout support for attention Signed-off-by: Selvaraj Anandaraj --- .../dot_product_attention/backends.py | 19 ++++++++++++++++++- transformer_engine/pytorch/cpu_offload.py | 9 +++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index afa1bae633..27a217c6d1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1111,7 +1111,24 @@ def forward( ctx.attn_scale = attn_scale ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout + + if CPUOffloadedLayer and CPUOffloadEnabled: + reload_layout = "" + split_list = qkv_layout.split("_") + for split in split_list: + temp_layout = "" + rep_count = 1 + for s in split: + if s.isalpha(): + temp_layout = temp_layout + s + else: + rep_count = int(s) + for i in range(rep_count): + reload_layout = reload_layout + temp_layout + "_" + ctx.qkv_layout = reload_layout[:-1] + else: + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.window_size = window_size diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 3fdf8b14fd..48e6257c26 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -16,6 +16,7 @@ __all__ = ["get_cpu_offload_context"] CPUOffloadEnabled = False +CPUOffloadedLayer = False def mark_activation_offload(*tensors): @@ -408,6 +409,11 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: tensor.clear() else: self.tensor_tag_to_buf[tensor_tag] = t + + # Needed to differentiate non offloaded layer's attention + # QKV layout of attention of non-offloaded layer needs + # to be modified while reloading + CPUOffloadedLayer = True else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 @@ -528,6 +534,9 @@ def synchronize_on_group_commit_forward(self, current_group): # Increment the offload group count to keep track self.offloaded_group_count += 1 + if current_group == (self.num_offload_group - 1): + CPUOffloadedLayer = False + if not self.double_buffer_created: # Creating second copy of double buffer for tensors that are offloaded if current_group == (self.num_layers - 1): From f93f6466c902a82adf7cf5cecf10334f02d7c978 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Sat, 2 Aug 2025 19:34:32 -0700 Subject: [PATCH 2/5] Comment/cleanup Signed-off-by: Selvaraj Anandaraj --- .../pytorch/attention/dot_product_attention/backends.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 27a217c6d1..4eb581f6e1 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1069,6 +1069,7 @@ def forward( from transformer_engine.pytorch.cpu_offload import ( CPUOffloadEnabled, + CPUOffloadedLayer, mark_activation_offload, ) @@ -1078,7 +1079,6 @@ def forward( else: tensor_list = [q, k, v, out_save] - qkv_layout = "sbhd_sbhd_sbhd" mark_activation_offload(*tensor_list) mark_activation_offload(*aux_ctx_tensors) @@ -1112,6 +1112,9 @@ def forward( ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill + # If interleaved tensor is offloaded, reloaded tensor will be + # non-interleaved, so we need to modify the QKV layout + # for backward if CPUOffloadedLayer and CPUOffloadEnabled: reload_layout = "" split_list = qkv_layout.split("_") From 78434e4fe4b14f62b26bee87c7fba40a90aa978a Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Sat, 2 Aug 2025 19:41:01 -0700 Subject: [PATCH 3/5] Bug fix on import time Signed-off-by: Selvaraj Anandaraj --- .../pytorch/attention/dot_product_attention/backends.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4eb581f6e1..fdaf7f334e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1069,7 +1069,6 @@ def forward( from transformer_engine.pytorch.cpu_offload import ( CPUOffloadEnabled, - CPUOffloadedLayer, mark_activation_offload, ) @@ -1112,6 +1111,10 @@ def forward( ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill + from transformer_engine.pytorch.cpu_offload import ( + CPUOffloadedLayer, + ) + # If interleaved tensor is offloaded, reloaded tensor will be # non-interleaved, so we need to modify the QKV layout # for backward From a92e2ba59ff7785c80da44c6a6e81f34a0b2b4b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Aug 2025 05:39:35 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/backends.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index fdaf7f334e..970ccaca0f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1116,7 +1116,7 @@ def forward( ) # If interleaved tensor is offloaded, reloaded tensor will be - # non-interleaved, so we need to modify the QKV layout + # non-interleaved, so we need to modify the QKV layout # for backward if CPUOffloadedLayer and CPUOffloadEnabled: reload_layout = "" @@ -1131,7 +1131,7 @@ def forward( rep_count = int(s) for i in range(rep_count): reload_layout = reload_layout + temp_layout + "_" - ctx.qkv_layout = reload_layout[:-1] + ctx.qkv_layout = reload_layout[:-1] else: ctx.qkv_layout = qkv_layout From 582360d4f8f72f9499c0effc9d5c93fbd399ebf8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 27 Aug 2025 12:06:35 +0000 Subject: [PATCH 5/5] fix Signed-off-by: Pawel Gadzinski --- .../pytorch/attention/dot_product_attention/backends.py | 2 +- transformer_engine/pytorch/cpu_offload.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 970ccaca0f..c941b1d84a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1129,7 +1129,7 @@ def forward( temp_layout = temp_layout + s else: rep_count = int(s) - for i in range(rep_count): + for _ in range(rep_count): reload_layout = reload_layout + temp_layout + "_" ctx.qkv_layout = reload_layout[:-1] else: diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 48e6257c26..316b78d770 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -354,6 +354,7 @@ def __init__( self.h2d_stream = torch.cuda.Stream() def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + global CPUOffloadedLayer torch_stray_tensor = isinstance( tensor, @@ -423,6 +424,8 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: def tensor_pop(self, tensor_tag, **kwargs): """Tensor pop.""" + global CPUOffloadedLayer + assert tensor_tag in self.tensor_tag_to_state tensor = self.tensor_tag_to_state.pop(tensor_tag) @@ -486,6 +489,7 @@ def bulk_offload_group(self, group_to_offload): def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" + global CPUOffloadedLayer # For the first group, kickstart the offload after we have # the first compute completion