Skip to content

Commit 06306ce

Browse files
committed
Merge branch 'main' into hongbinl/split_wgrad_new
2 parents 5131080 + 61312d6 commit 06306ce

File tree

5 files changed

+59
-77
lines changed

5 files changed

+59
-77
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
8282
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
8383
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
84-
from .cpu_offload import set_offloading_param
84+
from .cpu_offload import mark_activation_offload
8585

8686

8787
# Setup Attention Logging
@@ -4323,10 +4323,9 @@ def forward(
43234323
from .cpu_offload import CPUOffloadEnabled
43244324

43254325
if CPUOffloadEnabled:
4326-
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
4327-
for tensor in tensor_list:
4328-
if tensor is not None:
4329-
set_offloading_param(tensor, "activation_offloading", True)
4326+
mark_activation_offload(
4327+
query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv
4328+
)
43304329

43314330
with self.attention_dropout_ctx():
43324331
# | API | use cases
@@ -4729,13 +4728,8 @@ def forward(
47294728
tensor_list = [q, k, v, out_save]
47304729

47314730
qkv_layout = "sbhd_sbhd_sbhd"
4732-
for tensor in tensor_list:
4733-
if tensor is not None:
4734-
set_offloading_param(tensor, "activation_offloading", True)
4735-
4736-
for tensor in aux_ctx_tensors:
4737-
if tensor is not None:
4738-
set_offloading_param(tensor, "activation_offloading", True)
4731+
mark_activation_offload(*tensor_list)
4732+
mark_activation_offload(*aux_ctx_tensors)
47394733

47404734
ctx.is_input_fp8 = is_input_fp8
47414735
ctx.is_output_fp8 = is_output_fp8

transformer_engine/pytorch/cpu_offload.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,22 @@
1616
CPUOffloadEnabled = False
1717

1818

19-
def set_offloading_param(tensor, param_name, value):
19+
def mark_activation_offload(*tensors):
2020
"""Set the type of the offloading needed for a tensor."""
21-
assert param_name in ["weight_offloading", "activation_offloading"]
22-
if tensor is None:
23-
return
24-
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
25-
setattr(tensor, param_name, value)
26-
else:
27-
data_tensors = tensor.get_data_tensors()
28-
for tensor in data_tensors:
29-
if tensor is not None:
30-
setattr(tensor, param_name, value)
21+
for tensor in tensors:
22+
if tensor is None:
23+
continue
24+
if type(tensor) in [torch.Tensor, torch.nn.Parameter]:
25+
tensor.activation_offloading = True
26+
else:
27+
data_tensors = tensor.get_data_tensors()
28+
for tensor in data_tensors:
29+
if tensor is not None:
30+
tensor.activation_offloading = True
31+
# This is a hack to force clear the tensor after it is offloaded.
32+
# It is needed, because .*TensorBase classes are saved in the ctx,
33+
# and they contain the reference to their data tensors.
34+
tensor.needs_force_clear = True
3135

3236

3337
def is_cpu_offload_enabled() -> bool:
@@ -459,8 +463,15 @@ def synchronize_on_group_commit_forward(self, current_group):
459463
torch.cuda.current_stream().wait_stream(self.d2h_stream)
460464

461465
# Time to free the activation memory after usage
462-
for tensor_tag, _ in self.tensor_tag_to_buf.items():
466+
for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items():
463467
if tensor_tag[0] == self.offloaded_group_count:
468+
if hasattr(tensor_buf, "needs_force_clear"):
469+
# Need to clear activation tensor - sometimes references persist in the code.
470+
# This is the case for example with the Float8TensorBase class,
471+
# which is saved directly inside the ctx while its internal tensors are
472+
# saved inside save_for_backward.
473+
tensor_buf.data = torch.Tensor()
474+
# Release the pointer to the tensor
464475
self.tensor_tag_to_buf[tensor_tag] = None
465476

466477
# Time to offload the next group
@@ -538,7 +549,7 @@ def get_cpu_offload_context(
538549
num_layers: int = 1,
539550
model_layers: int = 1,
540551
offload_activations: bool = True,
541-
offload_weights: bool = True,
552+
offload_weights: bool = False,
542553
):
543554
"""
544555
This function returns the CPU Offload context and the synchronizer function that needs to be
@@ -570,28 +581,30 @@ def get_cpu_offload_context(
570581
571582
"""
572583

573-
def tensor_need_offloading_checker_activations(tensor):
574-
return hasattr(tensor, "activation_offloading")
575-
576-
# This includes the Gradient Accumulation Buffer
577-
def tensor_need_offloading_checker_weights(tensor):
578-
return hasattr(tensor, "weight_offloading")
579-
580-
def tensor_need_offloading_checker_all(tensor):
581-
return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading")
582-
583-
if offload_activations and offload_weights:
584-
tensor_need_offloading_checker = tensor_need_offloading_checker_all
585-
elif offload_activations:
586-
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
587-
elif offload_weights:
588-
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
589-
else:
584+
if not offload_weights and not offload_activations:
590585
raise ValueError(
591586
"CPU Offloading is enabled while it is not "
592587
"mentioned what to offload (weights/activations)"
593588
)
594589

590+
if offload_weights:
591+
import warnings
592+
593+
warnings.warn(
594+
"Offloading weights is deprecated. Using offload_weights=True does not have any"
595+
" effect.",
596+
DeprecationWarning,
597+
)
598+
599+
# Weights offloading is deprecated but we maintain backward compatibility by doing nothing.
600+
if not offload_activations:
601+
return nullcontext(), lambda x: x
602+
603+
def tensor_need_offloading_checker_activations(tensor):
604+
return hasattr(tensor, "activation_offloading")
605+
606+
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
607+
595608
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
596609
num_offload_group=num_layers,
597610
num_model_group=model_layers,

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
6565
from ..tensor.mxfp8_tensor import MXFP8Quantizer
6666
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
67-
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
67+
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
6868

6969
from ..cpp_extensions import (
7070
general_gemm,
@@ -357,15 +357,7 @@ def forward(
357357
weightmat.update_usage(columnwise_usage=True)
358358

359359
if cpu_offloading:
360-
if fp8 and weightmat is not None:
361-
set_offloading_param(weightmat, "weight_offloading", True)
362-
set_offloading_param(ln_weight, "weight_offloading", True)
363-
set_offloading_param(weight, "weight_offloading", True)
364-
365-
set_offloading_param(inputmat, "activation_offloading", True)
366-
set_offloading_param(mu, "activation_offloading", True)
367-
set_offloading_param(rsigma, "activation_offloading", True)
368-
set_offloading_param(ln_out, "activation_offloading", True)
360+
mark_activation_offload(inputmat, mu, rsigma, ln_out)
369361

370362
# Scatter intermediate/activation tensors saved for the backward pass
371363
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@
6363
Float8Tensor,
6464
)
6565
from ..tensor.mxfp8_tensor import MXFP8Quantizer
66-
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore
6766
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
68-
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
67+
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore
68+
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
6969
from ..tensor.quantized_tensor import (
7070
QuantizedTensor,
7171
Quantizer,
@@ -475,23 +475,9 @@ def forward(
475475
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
476476
else:
477477
if cpu_offloading:
478-
if fp8 and fc1_weight_final is not None:
479-
set_offloading_param(fc1_weight_final, "weight_offloading", True)
480-
if fp8 and fc2_weight_final is not None:
481-
set_offloading_param(fc2_weight_final, "weight_offloading", True)
482-
set_offloading_param(ln_weight, "weight_offloading", True)
483-
set_offloading_param(fc1_weight, "weight_offloading", True)
484-
set_offloading_param(fc2_weight, "weight_offloading", True)
485-
set_offloading_param(fc1_bias, "weight_offloading", True)
486-
487-
set_offloading_param(inputmat, "activation_offloading", True)
488-
set_offloading_param(mu, "activation_offloading", True)
489-
set_offloading_param(rsigma, "activation_offloading", True)
490-
set_offloading_param(mu, "activation_offloading", True)
491-
set_offloading_param(ln_out, "activation_offloading", True)
492-
set_offloading_param(fc1_out, "activation_offloading", True)
493-
set_offloading_param(fc1_out_without_bias, "activation_offloading", True)
494-
set_offloading_param(act_out, "activation_offloading", True)
478+
mark_activation_offload(
479+
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
480+
)
495481

496482
# Scatter intermediate/activation tensors saved for the backward pass
497483
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already

transformer_engine/pytorch/module/linear.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from ..tensor.mxfp8_tensor import MXFP8Quantizer
6464
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
6565
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
66-
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
66+
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
6767
from ...debug.pytorch.debug_state import TEDebugState
6868
from ...debug.pytorch.utils import any_feature_enabled
6969

@@ -309,11 +309,8 @@ def forward(
309309
if isinstance(weightmat, QuantizedTensor):
310310
weightmat.update_usage(columnwise_usage=True)
311311

312-
if cpu_offloading:
313-
set_offloading_param(weight, "weight_offloading", True)
314-
set_offloading_param(weightmat, "weight_offloading", True)
315-
if saved_inputmat is not None:
316-
set_offloading_param(saved_inputmat, "activation_offloading", True)
312+
if cpu_offloading and saved_inputmat is not None:
313+
mark_activation_offload(saved_inputmat)
317314

318315
# Scatter intermediate/activation tensors saved for the backward pass
319316
# NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights

0 commit comments

Comments
 (0)