Skip to content

Commit 5eefe3e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4105934 commit 5eefe3e

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

transformer_engine/pytorch/module/base.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,6 @@ def __setattr__(self, name: str, value: Any) -> None:
654654
# Default case
655655
super().__setattr__(name, value)
656656

657-
658657
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
659658
"""
660659
Delayed scaling only.
@@ -990,10 +989,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
990989

991990
if fp8_parameters or fp8_enabled:
992991
_original_recipe = meta.get("recipe", None)
993-
if (
994-
self.fp8_initialized
995-
and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe
996-
):
992+
if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe:
997993
# FP8 init has already been run and recipe is the same, don't do anything.
998994
return
999995
meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
@@ -1042,9 +1038,10 @@ def prepare_forward(
10421038
allow_non_contiguous: bool = False,
10431039
allow_different_data_and_param_types: bool = False,
10441040
) -> torch.Tensor:
1045-
"""Checks and prepare for FWD execution.
1046-
"""
1047-
self.fast_set_attr("allow_different_data_and_param_types", allow_different_data_and_param_types)
1041+
"""Checks and prepare for FWD execution."""
1042+
self.fast_set_attr(
1043+
"allow_different_data_and_param_types", allow_different_data_and_param_types
1044+
)
10481045
self.fast_set_attr("forwarded_at_least_once", True)
10491046

10501047
# Activation recomputation is used and this is the second forward phase.
@@ -1098,9 +1095,9 @@ def prepare_forward_ctx(
10981095
allow_non_contiguous: bool = False,
10991096
allow_different_data_and_param_types: bool = False,
11001097
) -> Generator[torch.Tensor, None, None]:
1101-
yield self.prepare_forward(inp, num_gemms,
1102-
allow_non_contiguous,
1103-
allow_different_data_and_param_types)
1098+
yield self.prepare_forward(
1099+
inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types
1100+
)
11041101
self.end_forward()
11051102

11061103
def set_nccl_overlap_warning_if_tp(self) -> None:

transformer_engine/pytorch/module/linear.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,8 +1397,7 @@ def forward(
13971397
).is_fp8_ubuf():
13981398
fp8_grad = True
13991399

1400-
inp = self.prepare_forward(inp,
1401-
allow_non_contiguous=isinstance(inp, QuantizedTensor))
1400+
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
14021401

14031402
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
14041403

0 commit comments

Comments
 (0)