@@ -654,7 +654,6 @@ def __setattr__(self, name: str, value: Any) -> None:
654654 # Default case
655655 super ().__setattr__ (name , value )
656656
657-
658657 def adjust_amax_history_length (self , length : int , fwd : Optional [bool ] = None ) -> None :
659658 """
660659 Delayed scaling only.
@@ -990,10 +989,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
990989
991990 if fp8_parameters or fp8_enabled :
992991 _original_recipe = meta .get ("recipe" , None )
993- if (
994- self .fp8_initialized
995- and FP8GlobalStateManager .get_fp8_recipe () == _original_recipe
996- ):
992+ if self .fp8_initialized and FP8GlobalStateManager .get_fp8_recipe () == _original_recipe :
997993 # FP8 init has already been run and recipe is the same, don't do anything.
998994 return
999995 meta ["recipe" ] = FP8GlobalStateManager .get_fp8_recipe ()
@@ -1042,9 +1038,10 @@ def prepare_forward(
10421038 allow_non_contiguous : bool = False ,
10431039 allow_different_data_and_param_types : bool = False ,
10441040 ) -> torch .Tensor :
1045- """Checks and prepare for FWD execution.
1046- """
1047- self .fast_set_attr ("allow_different_data_and_param_types" , allow_different_data_and_param_types )
1041+ """Checks and prepare for FWD execution."""
1042+ self .fast_set_attr (
1043+ "allow_different_data_and_param_types" , allow_different_data_and_param_types
1044+ )
10481045 self .fast_set_attr ("forwarded_at_least_once" , True )
10491046
10501047 # Activation recomputation is used and this is the second forward phase.
@@ -1098,9 +1095,9 @@ def prepare_forward_ctx(
10981095 allow_non_contiguous : bool = False ,
10991096 allow_different_data_and_param_types : bool = False ,
11001097 ) -> Generator [torch .Tensor , None , None ]:
1101- yield self .prepare_forward (inp , num_gemms ,
1102- allow_non_contiguous ,
1103- allow_different_data_and_param_types )
1098+ yield self .prepare_forward (
1099+ inp , num_gemms , allow_non_contiguous , allow_different_data_and_param_types
1100+ )
11041101 self .end_forward ()
11051102
11061103 def set_nccl_overlap_warning_if_tp (self ) -> None :
0 commit comments