Commit 422be88
committed
[APS] Implement __torch_function__ for KeyedTensor (pytorch#161683)
Summary:
X-link: meta-pytorch/torchrec#3329
Pull Request resolved: pytorch#161683
1. There are a bunch of `torch.ops.aten` operations that can't handle `KeyedTensor`: The error was occurring because these ops expects a regular `Tensor` but was receiving a `KeyedTensor` object.
2. Implement `__torch_function__` for `KeyedTensor`, so when these incompatible operations are called with a `KeyedTensor`, the `__torch_function__` method automatically delegates the op to the underlying values tensor from the `KeyedTensor` and returns a new `KeyedTensor` with updated values.
Test Plan:
```
buck2 run mode/opt fbcode//aps_models/ads/gmp:launcher_with_publish mode=mtml_mobile_cvr_model/managed/Y2025Q2/local_mode_mtml_mobile_cvr_model_733415799_v0_fork +training.ir_serializer=manifold
```
MAST job: https://fburl.com/mlhub/pp937uxf
Rollback Plan:
Reviewed By: malaybag
Differential Revision: D810472781 parent 55c289d commit 422be88
1 file changed
+9
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
817 | 817 | | |
818 | 818 | | |
819 | 819 | | |
820 | | - | |
821 | | - | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
| 825 | + | |
| 826 | + | |
| 827 | + | |
| 828 | + | |
822 | 829 | | |
823 | | - | |
824 | 830 | | |
825 | 831 | | |
826 | 832 | | |
| |||
0 commit comments