Skip to content

feat: support masked_scatter by lowering path #3438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ def create_constant(
with unset_fake_temporarily():

torch_value = to_torch(value, dtype)
if torch_value is None:
raise ValueError(
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
)
if torch_value.dtype == torch.float64:
raise ValueError(
"TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model."
Expand Down Expand Up @@ -1065,3 +1069,42 @@ def load_tensorrt_llm() -> bool:
)
return False
return False


def promote_trt_tensors_to_same_dtype(
ctx: ConversionContext, lhs: TRTTensor, rhs: TRTTensor, name_prefix: str
) -> tuple[TRTTensor, TRTTensor]:
"""
Promotes two TensorRT tensors to a common data type to ensure type compatibility
during operations (e.g., select, where, etc.), following simplified PyTorch promotion rules.

Args:
ctx: Conversion context containing the TRT network definition.
lhs: The left-hand-side TensorRT tensor.
rhs: The right-hand-side TensorRT tensor.
name_prefix: A prefix string used to name any cast operations.

Returns:
A tuple of (lhs_cast, rhs_cast) TensorRT tensors, both cast to the promoted dtype.
"""

# Define supported float types (TensorRT supports float16 and float32)
float_types = {trt.float16, trt.float32}

# Case 1: If either tensor is a float, promote to the wider float type
if lhs.dtype in float_types or rhs.dtype in float_types:
# Prefer float32 if either tensor is float32
if lhs.dtype == trt.float32 or rhs.dtype == trt.float32:
promoted_dtype = trt.float32
else:
promoted_dtype = trt.float16
else:
# Case 2: If both tensors are int types (e.g., int32, int64), promote to int32
# (Note: TensorRT does not support int64 for many ops like select/where)
promoted_dtype = trt.int32

# Cast both tensors to the promoted dtype
lhs_cast = cast_trt_tensor(ctx, lhs, promoted_dtype, f"{name_prefix}lhs_cast")
rhs_cast = cast_trt_tensor(ctx, rhs, promoted_dtype, f"{name_prefix}rhs_cast")

return lhs_cast, rhs_cast
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
cast_trt_tensor,
get_trt_tensor,
prepend_ones,
promote_trt_tensors_to_same_dtype,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise import ne
Expand Down Expand Up @@ -57,6 +58,9 @@ def where(
if diff > 0:
other = prepend_ones(ctx, other, f"{name}_other_broadcast", diff)

# Ensure that input and other have the same TRT dtype
input, other = promote_trt_tensors_to_same_dtype(ctx, input, other, name)

return select(ctx, target, source_ir, name, input, other, condition)


Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def gather(
) -> TRTTensor:
input_shape = input.shape
dim = get_positive_dim(dim, len(input_shape))
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
gather_layer = ctx.net.add_gather(input, index, axis=dim)
gather_layer.mode = trt.GatherMode.ELEMENT
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
Expand Down
51 changes: 50 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ def slice_scatter_decomposition(
) -> torch.Tensor:
dim_size = input_tensor.shape[dim]
device_input_tensor = input_tensor.device

start = 0 if start is None else start # Ensure start is int
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None:
if end is None: # Ensure end is int
end = dim_size
end = get_positive_dim(end, input_tensor.shape[dim])
if step is None:
Expand Down Expand Up @@ -575,6 +577,53 @@ def cudnn_grid_sampler_decomposition(
return torch.grid_sampler_2d(x, grid, 0, 0, True)


@register_torch_trt_decomposition(
aten.masked_scatter, registry=TORCH_TRT_DECOMPOSITIONS
)
def masked_scatter_decomposition(
input: torch.Tensor,
mask: torch.Tensor,
source: torch.Tensor,
) -> torch.Tensor:
"""
Decomposition of `aten.masked_scatter` for TensorRT.

Emulates the behavior of `input[mask] = source` using only TensorRT-compatible ops.

Steps:
1) Broadcast `input` and `mask` to a common shape.
2) Flatten all tensors for uniform indexing.
3) Compute gather indices for `source` by applying cumsum to the boolean mask.
- Use `masked_fill` to avoid invalid indices in positions where `mask` is False.
4) Gather values from `source` at valid positions.
5) Use `torch.where` to insert gathered values into `input` where `mask` is True.
6) Reshape the result back to the original broadcasted shape.
"""

# 1) Broadcast input and mask to the same shape
input_b, mask_b = aten.broadcast_tensors([input, mask])

# 2) Flatten tensors for element-wise operations
input_flat = input_b.flatten()
mask_flat = mask_b.flatten()
source_flat = source.flatten()

# 3) Compute gather indices from cumsum of the mask
# Subtract 1 so that the first True position maps to index 0 in source
source_idx = mask_flat.cumsum(0) - 1
# Set gather index to 0 where mask is False (these will be ignored later)
safe_idx = source_idx.masked_fill(~mask_flat, 0)

# 4) Gather values from source using computed indices
gathered = source_flat.gather(0, safe_idx)

# 5) Replace masked positions in input with gathered values
replaced = torch.where(mask_flat, gathered, input_flat)

# 6) Reshape the result to match the original broadcasted shape
return replaced.view(input_b.shape)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
80 changes: 80 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,6 +2167,86 @@ def forward(self, x, grid):
msg="Cudnn_grid_sampler TRT outputs don't match with the original model.",
)

@parameterized.expand(
[
("float32_2d", torch.float32, (4, 4)),
("float16_3d", torch.float16, (2, 3, 4)),
]
)
def test_masked_scatter(self, _, dtype, shape):
"""
Test that masked_scatter.default is correctly decomposed into
(cumsum, gather, where, etc.) and that final TRT results match PyTorch.
"""

class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, mask, source):
return torch.ops.aten.masked_scatter.default(x, mask, source)

x = torch.randn(*shape, dtype=dtype, device="cuda")

mask = torch.rand(*shape, device="cuda") > 0.5
num_trues = mask.sum().item()
if num_trues == 0:
mask[0] = True
num_trues = 1
source = torch.arange(num_trues, dtype=dtype, device="cuda")

inputs = [x, mask, source]

fx_graph = torch.fx.symbolic_trace(TestModule())

expected_ops = {
torch.ops.aten.where.self,
torch.ops.aten.gather.default,
torch.ops.aten.cumsum.default,
}
unexpected_ops = {torch.ops.aten.masked_scatter.default}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

trt_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
with torch.no_grad():
trt_results = trt_model(*inputs).detach().cpu()
torch_results = fx_graph(*inputs).detach().cpu()

max_diff = float(torch.max(torch.abs(trt_results - torch_results)))
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Masked_scatter TRT outputs don't match with the original model. (diff={max_diff})",
)


if __name__ == "__main__":
run_tests()
Loading