Skip to content

Commit 4105934

Browse files
committed
PoC of the changes
Signed-off-by: Przemek Tredak <[email protected]>
1 parent 3ff0b8d commit 4105934

File tree

7 files changed

+378
-348
lines changed

7 files changed

+378
-348
lines changed

tests/pytorch/attention/test_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2751,7 +2751,7 @@ def forward(
27512751
cu_seqlens,
27522752
max_s,
27532753
) -> torch.Tensor:
2754-
with self.prepare_forward(inp, num_gemms=3) as inp:
2754+
with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
27552755
out = _custom_mha_fp8.apply(
27562756
inp,
27572757
self.qkv_weight,

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ def forward(
10001000
cases. It is ignored for other backends and when context parallelism is enabled.
10011001
"""
10021002

1003-
with self.prepare_forward(
1003+
with self.prepare_forward_ctx(
10041004
query_layer,
10051005
num_gemms=3,
10061006
allow_non_contiguous=True,

transformer_engine/pytorch/module/base.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
is_non_tn_fp8_gemm_supported,
5050
torch_get_autocast_gpu_dtype,
5151
get_nvtx_range_context,
52+
_nvtx_enabled,
5253
)
5354
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
5455
from ...common.recipe import DelayedScaling, Recipe
@@ -640,16 +641,20 @@ def __init__(self) -> None:
640641
"fp8_parameters",
641642
}
642643

644+
def fast_set_attr(self, name: str, value: Any) -> None:
645+
self.__dict__[name] = value
646+
643647
def __setattr__(self, name: str, value: Any) -> None:
644648
if name in TransformerEngineBaseModule._fast_setattr_names:
645649
# torch.nn.Module has a custom __setattr__ that handles
646650
# modules, parameters, and buffers. This is unnecessary
647651
# overhead when setting plain attrs.
648-
self.__dict__[name] = value
652+
self.fast_set_attr(name, value)
649653
else:
650654
# Default case
651655
super().__setattr__(name, value)
652656

657+
653658
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
654659
"""
655660
Delayed scaling only.
@@ -926,7 +931,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
926931
"""Get activation data type for AMP."""
927932
# Native AMP (`torch.autocast`) gets highest priority
928933
if torch.is_autocast_enabled():
929-
self.activation_dtype = torch_get_autocast_gpu_dtype()
934+
self.fast_set_attr("activation_dtype", torch_get_autocast_gpu_dtype())
930935
return
931936

932937
# All checks after this have already been performed once, thus skip
@@ -941,7 +946,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
941946
"Data types for parameters must match when outside of autocasted region. "
942947
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
943948
)
944-
self.activation_dtype = dtype
949+
self.fast_set_attr("activation_dtype", dtype)
945950

946951
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
947952
"""
@@ -970,48 +975,54 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
970975
# assume FP8 execution.
971976
def init_fp8_metadata(self, num_gemms: int = 1) -> None:
972977
"""Initialize fp8 related metadata and tensors during fprop."""
973-
_original_recipe = self.fp8_meta.get("recipe", None)
978+
meta = self.fp8_meta
974979

975-
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
976-
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
977-
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
978-
fp8_enabled = self.fp8 or self.fp8_calibration
979-
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
980+
fp8 = FP8GlobalStateManager.is_fp8_enabled()
981+
fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
982+
fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
983+
self.fast_set_attr("fp8_parameters", fp8_parameters)
984+
self.fast_set_attr("fp8", fp8)
985+
self.fast_set_attr("fp8_calibration", fp8_calibration)
986+
fp8_enabled = fp8 or fp8_calibration
987+
meta["fp8_checkpoint"] = fp8_enabled
980988

981-
if self.fp8_parameters or fp8_enabled:
989+
_original_recipe = None
990+
991+
if fp8_parameters or fp8_enabled:
992+
_original_recipe = meta.get("recipe", None)
982993
if (
983994
self.fp8_initialized
984-
and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
995+
and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe
985996
):
986997
# FP8 init has already been run and recipe is the same, don't do anything.
987998
return
988-
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
999+
meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
9891000
else:
9901001
# If fp8 isn't enabled, turn off and return.
991-
self.fp8_initialized = False
1002+
self.fast_set_attr("fp8_initialized", False)
9921003
return
9931004

994-
if self.fp8_parameters and not self.fp8_initialized:
995-
self.fp8_meta["num_gemms"] = num_gemms
996-
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
1005+
if fp8_parameters and not self.fp8_initialized:
1006+
meta["num_gemms"] = num_gemms
1007+
self.init_fp8_meta_tensors(meta["recipe"])
9971008

9981009
if fp8_enabled:
9991010
# Set FP8 and other FP8 metadata
1000-
self.fp8_meta["num_gemms"] = num_gemms
1001-
self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
1011+
meta["num_gemms"] = num_gemms
1012+
meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
10021013

10031014
# Set FP8_MAX per tensor according to recipe
1004-
if hasattr(self.fp8_meta["recipe"], "fp8_format"):
1005-
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
1006-
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
1015+
if hasattr(meta["recipe"], "fp8_format"):
1016+
meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd
1017+
meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd
10071018

10081019
# Allocate scales and amaxes
1009-
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
1020+
self.init_fp8_meta_tensors(meta["recipe"])
10101021
self.fp8_initialized = True
10111022

1012-
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
1023+
meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
10131024

1014-
_current_recipe = self.fp8_meta["recipe"]
1025+
_current_recipe = meta["recipe"]
10151026
if _original_recipe is not None and not (
10161027
issubclass(_current_recipe.__class__, _original_recipe.__class__)
10171028
or issubclass(_original_recipe.__class__, _current_recipe.__class__)
@@ -1024,22 +1035,17 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
10241035
# Clear cached workspaces as they were created with the old recipe/quantizer type
10251036
self._fp8_workspaces.clear()
10261037

1027-
@contextmanager
10281038
def prepare_forward(
10291039
self,
10301040
inp: torch.Tensor,
10311041
num_gemms: int = 1,
10321042
allow_non_contiguous: bool = False,
10331043
allow_different_data_and_param_types: bool = False,
1034-
) -> Generator[torch.Tensor, None, None]:
1035-
"""Checks and prep for FWD.
1036-
The context manager is needed because there isn't a way for a module to know
1037-
if it's the last FP8 module in the forward autocast. It is useful
1038-
to setup the forward aggregated amax reduction for every module
1039-
just in case. The autocast exit will pick up the most recent one.
1044+
) -> torch.Tensor:
1045+
"""Checks and prepare for FWD execution.
10401046
"""
1041-
self.allow_different_data_and_param_types = allow_different_data_and_param_types
1042-
self.forwarded_at_least_once = True
1047+
self.fast_set_attr("allow_different_data_and_param_types", allow_different_data_and_param_types)
1048+
self.fast_set_attr("forwarded_at_least_once", True)
10431049

10441050
# Activation recomputation is used and this is the second forward phase.
10451051
if self.fp8 and in_fp8_activation_recompute_phase():
@@ -1070,13 +1076,32 @@ def prepare_forward(
10701076
if self.training and is_fp8_activation_recompute_enabled():
10711077
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
10721078

1073-
with get_nvtx_range_context(self.__class__.__name__ + " forward"):
1074-
if not allow_non_contiguous and not inp.is_contiguous():
1075-
inp = inp.contiguous()
1076-
yield inp
1079+
# with get_nvtx_range_context(self.__class__.__name__ + " forward"):
1080+
if _nvtx_enabled():
1081+
torch.cuda.nvtx.range_push(self.__class__.__name__ + " forward")
1082+
if not allow_non_contiguous and not inp.is_contiguous():
1083+
inp = inp.contiguous()
1084+
return inp
10771085

1086+
def end_forward(self):
1087+
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
10781088
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
10791089
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
1090+
if _nvtx_enabled():
1091+
torch.cuda.nvtx.range_pop()
1092+
1093+
@contextmanager
1094+
def prepare_forward_ctx(
1095+
self,
1096+
inp: torch.Tensor,
1097+
num_gemms: int = 1,
1098+
allow_non_contiguous: bool = False,
1099+
allow_different_data_and_param_types: bool = False,
1100+
) -> Generator[torch.Tensor, None, None]:
1101+
yield self.prepare_forward(inp, num_gemms,
1102+
allow_non_contiguous,
1103+
allow_different_data_and_param_types)
1104+
self.end_forward()
10801105

10811106
def set_nccl_overlap_warning_if_tp(self) -> None:
10821107
"""When using TP, the NCCL communication needs to be scheduled

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -787,60 +787,62 @@ def forward(
787787

788788
is_grad_enabled = torch.is_grad_enabled()
789789

790-
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
791-
weight_tensors = self._get_weight_tensors()
792-
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
793-
794-
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
795-
796-
if debug:
797-
if self.no_debug_features_active(list(chain(*quantizers))):
798-
debug = False
799-
quantizers = self._get_quantizers()
800-
801-
if isinstance(weight_tensors, QuantizedTensorStorage):
802-
raise RuntimeError("FP8 weights are not supported in debug mode.")
803-
804-
(
805-
input_quantizers,
806-
weight_quantizers,
807-
output_quantizers,
808-
grad_input_quantizers,
809-
grad_weight_quantizers,
810-
grad_output_quantizers,
811-
) = quantizers
812-
813-
if is_grad_enabled:
814-
linear_fn = _GroupedLinear.apply
815-
autograd_ctx = []
816-
else:
817-
linear_fn = _GroupedLinear.forward
818-
autograd_ctx = [None]
819-
820-
non_tensor_args = (
821-
m_splits,
822-
self.apply_bias,
823-
is_first_microbatch,
824-
self.fp8,
825-
self.fp8_calibration,
826-
self.wgrad_store,
827-
input_quantizers,
828-
weight_quantizers,
829-
output_quantizers,
830-
grad_input_quantizers,
831-
grad_weight_quantizers,
832-
grad_output_quantizers,
833-
self.fuse_wgrad_accumulation,
834-
is_cpu_offload_enabled(),
835-
self.sequence_parallel,
836-
self.activation_dtype,
837-
is_grad_enabled,
838-
self,
839-
None, # skip_fp8_weight_update
840-
self.save_original_input,
841-
debug,
842-
)
843-
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
790+
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
791+
weight_tensors = self._get_weight_tensors()
792+
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
793+
794+
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
795+
796+
if debug:
797+
if self.no_debug_features_active(list(chain(*quantizers))):
798+
debug = False
799+
quantizers = self._get_quantizers()
800+
801+
if isinstance(weight_tensors, QuantizedTensorStorage):
802+
raise RuntimeError("FP8 weights are not supported in debug mode.")
803+
804+
(
805+
input_quantizers,
806+
weight_quantizers,
807+
output_quantizers,
808+
grad_input_quantizers,
809+
grad_weight_quantizers,
810+
grad_output_quantizers,
811+
) = quantizers
812+
813+
if is_grad_enabled:
814+
linear_fn = _GroupedLinear.apply
815+
autograd_ctx = []
816+
else:
817+
linear_fn = _GroupedLinear.forward
818+
autograd_ctx = [None]
819+
820+
non_tensor_args = (
821+
m_splits,
822+
self.apply_bias,
823+
is_first_microbatch,
824+
self.fp8,
825+
self.fp8_calibration,
826+
self.wgrad_store,
827+
input_quantizers,
828+
weight_quantizers,
829+
output_quantizers,
830+
grad_input_quantizers,
831+
grad_weight_quantizers,
832+
grad_output_quantizers,
833+
self.fuse_wgrad_accumulation,
834+
is_cpu_offload_enabled(),
835+
self.sequence_parallel,
836+
self.activation_dtype,
837+
is_grad_enabled,
838+
self,
839+
None, # skip_fp8_weight_update
840+
self.save_original_input,
841+
debug,
842+
)
843+
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
844+
845+
self.end_forward()
844846

845847
if self.return_bias:
846848
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]

0 commit comments

Comments
 (0)