4949 is_non_tn_fp8_gemm_supported ,
5050 torch_get_autocast_gpu_dtype ,
5151 get_nvtx_range_context ,
52+ _nvtx_enabled ,
5253)
5354from ..tensor .storage .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
5455from ...common .recipe import DelayedScaling , Recipe
@@ -640,16 +641,20 @@ def __init__(self) -> None:
640641 "fp8_parameters" ,
641642 }
642643
644+ def fast_set_attr (self , name : str , value : Any ) -> None :
645+ self .__dict__ [name ] = value
646+
643647 def __setattr__ (self , name : str , value : Any ) -> None :
644648 if name in TransformerEngineBaseModule ._fast_setattr_names :
645649 # torch.nn.Module has a custom __setattr__ that handles
646650 # modules, parameters, and buffers. This is unnecessary
647651 # overhead when setting plain attrs.
648- self .__dict__ [ name ] = value
652+ self .fast_set_attr ( name , value )
649653 else :
650654 # Default case
651655 super ().__setattr__ (name , value )
652656
657+
653658 def adjust_amax_history_length (self , length : int , fwd : Optional [bool ] = None ) -> None :
654659 """
655660 Delayed scaling only.
@@ -926,7 +931,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
926931 """Get activation data type for AMP."""
927932 # Native AMP (`torch.autocast`) gets highest priority
928933 if torch .is_autocast_enabled ():
929- self .activation_dtype = torch_get_autocast_gpu_dtype ()
934+ self .fast_set_attr ( " activation_dtype" , torch_get_autocast_gpu_dtype () )
930935 return
931936
932937 # All checks after this have already been performed once, thus skip
@@ -941,7 +946,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None:
941946 "Data types for parameters must match when outside of autocasted region. "
942947 f" Found input dtype: { dtype } and { name !r} dtype: { param .dtype } "
943948 )
944- self .activation_dtype = dtype
949+ self .fast_set_attr ( " activation_dtype" , dtype )
945950
946951 def set_tensor_parallel_group (self , tp_group : Union [dist_group_type , None ]) -> None :
947952 """
@@ -970,48 +975,54 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
970975 # assume FP8 execution.
971976 def init_fp8_metadata (self , num_gemms : int = 1 ) -> None :
972977 """Initialize fp8 related metadata and tensors during fprop."""
973- _original_recipe = self .fp8_meta . get ( "recipe" , None )
978+ meta = self .fp8_meta
974979
975- self .fp8_parameters = FP8GlobalStateManager .with_fp8_parameters ()
976- self .fp8 = FP8GlobalStateManager .is_fp8_enabled ()
977- self .fp8_calibration = FP8GlobalStateManager .is_fp8_calibration ()
978- fp8_enabled = self .fp8 or self .fp8_calibration
979- self .fp8_meta ["fp8_checkpoint" ] = self .fp8 or self .fp8_calibration
980+ fp8 = FP8GlobalStateManager .is_fp8_enabled ()
981+ fp8_parameters = FP8GlobalStateManager .with_fp8_parameters ()
982+ fp8_calibration = FP8GlobalStateManager .is_fp8_calibration ()
983+ self .fast_set_attr ("fp8_parameters" , fp8_parameters )
984+ self .fast_set_attr ("fp8" , fp8 )
985+ self .fast_set_attr ("fp8_calibration" , fp8_calibration )
986+ fp8_enabled = fp8 or fp8_calibration
987+ meta ["fp8_checkpoint" ] = fp8_enabled
980988
981- if self .fp8_parameters or fp8_enabled :
989+ _original_recipe = None
990+
991+ if fp8_parameters or fp8_enabled :
992+ _original_recipe = meta .get ("recipe" , None )
982993 if (
983994 self .fp8_initialized
984- and FP8GlobalStateManager .get_fp8_recipe () == self . fp8_meta [ "recipe" ]
995+ and FP8GlobalStateManager .get_fp8_recipe () == _original_recipe
985996 ):
986997 # FP8 init has already been run and recipe is the same, don't do anything.
987998 return
988- self . fp8_meta ["recipe" ] = FP8GlobalStateManager .get_fp8_recipe ()
999+ meta ["recipe" ] = FP8GlobalStateManager .get_fp8_recipe ()
9891000 else :
9901001 # If fp8 isn't enabled, turn off and return.
991- self .fp8_initialized = False
1002+ self .fast_set_attr ( " fp8_initialized" , False )
9921003 return
9931004
994- if self . fp8_parameters and not self .fp8_initialized :
995- self . fp8_meta ["num_gemms" ] = num_gemms
996- self .init_fp8_meta_tensors (self . fp8_meta ["recipe" ])
1005+ if fp8_parameters and not self .fp8_initialized :
1006+ meta ["num_gemms" ] = num_gemms
1007+ self .init_fp8_meta_tensors (meta ["recipe" ])
9971008
9981009 if fp8_enabled :
9991010 # Set FP8 and other FP8 metadata
1000- self . fp8_meta ["num_gemms" ] = num_gemms
1001- self . fp8_meta ["fp8_group" ] = FP8GlobalStateManager .get_fp8_group ()
1011+ meta ["num_gemms" ] = num_gemms
1012+ meta ["fp8_group" ] = FP8GlobalStateManager .get_fp8_group ()
10021013
10031014 # Set FP8_MAX per tensor according to recipe
1004- if hasattr (self . fp8_meta ["recipe" ], "fp8_format" ):
1005- self . fp8_meta ["fp8_max_fwd" ] = self . fp8_meta ["recipe" ].fp8_format .value .max_fwd
1006- self . fp8_meta ["fp8_max_bwd" ] = self . fp8_meta ["recipe" ].fp8_format .value .max_bwd
1015+ if hasattr (meta ["recipe" ], "fp8_format" ):
1016+ meta ["fp8_max_fwd" ] = meta ["recipe" ].fp8_format .value .max_fwd
1017+ meta ["fp8_max_bwd" ] = meta ["recipe" ].fp8_format .value .max_bwd
10071018
10081019 # Allocate scales and amaxes
1009- self .init_fp8_meta_tensors (self . fp8_meta ["recipe" ])
1020+ self .init_fp8_meta_tensors (meta ["recipe" ])
10101021 self .fp8_initialized = True
10111022
1012- self . fp8_meta ["recipe" ] = FP8GlobalStateManager .get_fp8_recipe ()
1023+ meta ["recipe" ] = FP8GlobalStateManager .get_fp8_recipe ()
10131024
1014- _current_recipe = self . fp8_meta ["recipe" ]
1025+ _current_recipe = meta ["recipe" ]
10151026 if _original_recipe is not None and not (
10161027 issubclass (_current_recipe .__class__ , _original_recipe .__class__ )
10171028 or issubclass (_original_recipe .__class__ , _current_recipe .__class__ )
@@ -1024,22 +1035,17 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
10241035 # Clear cached workspaces as they were created with the old recipe/quantizer type
10251036 self ._fp8_workspaces .clear ()
10261037
1027- @contextmanager
10281038 def prepare_forward (
10291039 self ,
10301040 inp : torch .Tensor ,
10311041 num_gemms : int = 1 ,
10321042 allow_non_contiguous : bool = False ,
10331043 allow_different_data_and_param_types : bool = False ,
1034- ) -> Generator [torch .Tensor , None , None ]:
1035- """Checks and prep for FWD.
1036- The context manager is needed because there isn't a way for a module to know
1037- if it's the last FP8 module in the forward autocast. It is useful
1038- to setup the forward aggregated amax reduction for every module
1039- just in case. The autocast exit will pick up the most recent one.
1044+ ) -> torch .Tensor :
1045+ """Checks and prepare for FWD execution.
10401046 """
1041- self .allow_different_data_and_param_types = allow_different_data_and_param_types
1042- self .forwarded_at_least_once = True
1047+ self .fast_set_attr ( " allow_different_data_and_param_types" , allow_different_data_and_param_types )
1048+ self .fast_set_attr ( " forwarded_at_least_once" , True )
10431049
10441050 # Activation recomputation is used and this is the second forward phase.
10451051 if self .fp8 and in_fp8_activation_recompute_phase ():
@@ -1070,13 +1076,32 @@ def prepare_forward(
10701076 if self .training and is_fp8_activation_recompute_enabled ():
10711077 FP8GlobalStateManager .copy_forward_fp8_meta_tensors_for_recompute (self .fp8_meta )
10721078
1073- with get_nvtx_range_context (self .__class__ .__name__ + " forward" ):
1074- if not allow_non_contiguous and not inp .is_contiguous ():
1075- inp = inp .contiguous ()
1076- yield inp
1079+ # with get_nvtx_range_context(self.__class__.__name__ + " forward"):
1080+ if _nvtx_enabled ():
1081+ torch .cuda .nvtx .range_push (self .__class__ .__name__ + " forward" )
1082+ if not allow_non_contiguous and not inp .is_contiguous ():
1083+ inp = inp .contiguous ()
1084+ return inp
10771085
1086+ def end_forward (self ):
1087+ delayed_scaling_recipe = self .fp8 and self .fp8_meta ["recipe" ].delayed ()
10781088 if delayed_scaling_recipe and self .fp8 and in_fp8_activation_recompute_phase ():
10791089 FP8GlobalStateManager .restore_fp8_meta_tensors (self .fp8_meta )
1090+ if _nvtx_enabled ():
1091+ torch .cuda .nvtx .range_pop ()
1092+
1093+ @contextmanager
1094+ def prepare_forward_ctx (
1095+ self ,
1096+ inp : torch .Tensor ,
1097+ num_gemms : int = 1 ,
1098+ allow_non_contiguous : bool = False ,
1099+ allow_different_data_and_param_types : bool = False ,
1100+ ) -> Generator [torch .Tensor , None , None ]:
1101+ yield self .prepare_forward (inp , num_gemms ,
1102+ allow_non_contiguous ,
1103+ allow_different_data_and_param_types )
1104+ self .end_forward ()
10801105
10811106 def set_nccl_overlap_warning_if_tp (self ) -> None :
10821107 """When using TP, the NCCL communication needs to be scheduled
0 commit comments