Skip to content

[PyTorch] Add ops for dropout and constant scale #1995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 100 additions & 23 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@
import transformer_engine_torch as tex

# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe
from utils import dtype_tols, make_recipe, reset_rng_states

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand Down Expand Up @@ -327,10 +325,7 @@ class TestFuser:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update(
Expand Down Expand Up @@ -544,10 +539,7 @@ class TestBasicOps:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
Expand Down Expand Up @@ -1693,16 +1685,107 @@ def test_swiglu(
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
def test_constant_scale(
self,
*,
scale: float,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
):

# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref = scale * x_ref
y_ref.backward(dy_ref)

# Implementation with fusible operation
op = te_ops.ConstantScale(scale)
y_test = op(x_test)
y_test.backward(dy_test)

# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
@pytest.mark.parametrize("is_training", (True, False))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_dropout(
self,
*,
prob: float,
is_training: bool,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
):

# Random data
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
x_test = x_ref.clone().requires_grad_()
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
dy_test = dy_ref.clone()

# Apply dropout
op = te_ops.Dropout(prob)
if is_training:
op.train()
else:
op.eval()
y = op(x_test)
y.backward(dy_test)

# Check values
if is_training:
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y, x_ref * mask)
torch.testing.assert_close(x_test.grad, dy_ref * mask)
else:
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)

# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
# mean p and standard deviation sqrt(p*(1-p)). By the central
# limit theorem, the mean of n iid Bernoulli variables
# converges to a normal random variable with mean p and
# standard deviation sqrt(p*(1-p)/n). If the observed mean is
# below the 0.5th or above the 99.5th percentiles, then the
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if is_training:
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"


class TestFusedOps:
"""Tests for fused operations"""

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()

@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
Expand Down Expand Up @@ -2125,10 +2208,7 @@ class TestCheckpointing:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()

@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
Expand Down Expand Up @@ -2240,10 +2320,7 @@ class TestSequentialModules:

@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()

@pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True))
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .constant_scale import ConstantScale
from .dropout import Dropout
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
Expand Down
40 changes: 40 additions & 0 deletions transformer_engine/pytorch/ops/basic/constant_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for constant scaling."""

from __future__ import annotations
from typing import Optional

import torch

from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer


class ConstantScale(BasicOperation):
"""Multiply by a constant"""

def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale

def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_ * self.scale

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output * self.scale, ()
67 changes: 67 additions & 0 deletions transformer_engine/pytorch/ops/basic/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for dropout."""

from __future__ import annotations
from typing import Optional

import torch

from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer


class Dropout(BasicOperation):
"""Randomly zero out tensor entries during training

During training, tensor entries are randomly set to zero with
probability :math:`p` and remaining entries are scaled by
:math:`1/(1-p)`.

"""

def __init__(self, p: float) -> None:
super().__init__()
self.dropout_probability = p

def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:

# Compute dropout if training
out = input_
is_training = self.training
mask = None
if is_training:
keep_prob = 1 - self.dropout_probability
Copy link
Collaborator

@negvet negvet Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handing this case similar to torch.nn.Dropout:

Suggested change
keep_prob = 1 - self.dropout_probability
if self.dropout_probability == 1:
mask = torch.zeros_like(input_)
out = mask
else:
keep_prob = 1 - self.dropout_probability

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing impl should handle this case correctly. We will also replace this mask-based impl soon, so no need to optimize aggressively.

mask = torch.empty_like(input_)
mask.bernoulli_(keep_prob)
mask *= 1 / keep_prob
out = out * mask

# Save context for backward
if ctx.requires_grad:
ctx.save_for_backward(mask)
ctx.is_training = is_training

return out

def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
(mask,) = ctx.saved_tensors
grad_input = grad_output
if ctx.is_training:
grad_input = grad_input * mask
return grad_input, ()