Skip to content

Conversation

@chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Mar 12, 2025

Description

Implemented support for masked_scatter in the lowering path, referring to this implementation in PyTorch Inductor.

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 requested review from apbose and peri044 March 12, 2025 13:00
@chohk88 chohk88 self-assigned this Mar 12, 2025
@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 12, 2025
@github-actions github-actions bot requested a review from narendasan March 12, 2025 13:44
Comment on lines 589 to 615
input_b, mask_b = aten.broadcast_tensors([input, mask])

# 2) Flatten the broadcasted tensors and the source tensor
input_flat = input_b.flatten()
mask_flat = mask_b.flatten()
source_flat = source.flatten()

# 3) Compute gather indices: (cumsum of mask as int64) - 1
source_idx = mask_flat.to(torch.int64).cumsum(0) - 1

# 4) Gather elements from source_flat using these indices
gathered = source_flat.gather(0, source_idx)

# 5) Replace positions where mask is True with gathered values, otherwise keep original
replaced = torch.where(mask_flat, gathered, input_flat)

# 6) Reshape the result back to the broadcasted shape
return replaced.view(input_b.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chohk88 I have a question. I tried running this code in a separate python session for the test input size of (2, 3, 4) and I see the following error. Do you know why this happens ? Am I missing something here ?

import torch
shape=(2, 3, 4)

ax=torch.randn(*shape, dtype=torch.float32, device="cuda")
mask=torch.rand(*shape, device="cuda") > 0.5
num_trues = mask.sum().item()
source = torch.arange(num_trues, dtype=torch.float32, device="cuda")
ax_b, mask_b = torch.ops.aten.broadcast_tensors([ax, mask])
ax_flat = ax_b.flatten()
mask_flat = mask_b.flatten()
source_flat = source.flatten()
source_idx = mask_flat.to(torch.int64).cumsum(0) - 1
gathered = source_flat.gather(0, source_idx)
>>> /pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [1,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.

@chohk88 chohk88 force-pushed the lowering_masked_scatter branch from 946c3c1 to 46b1c52 Compare April 8, 2025 04:04
@chohk88
Copy link
Collaborator Author

chohk88 commented Apr 15, 2025

Closing this PR due to an incorrect author. Reopened as PR

@chohk88 chohk88 closed this Apr 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants