From 6345451b18a8ccdb4eeb9ecb285e19b4589664eb Mon Sep 17 00:00:00 2001 From: "Deng, Daisy" Date: Thu, 27 Nov 2025 08:27:40 +0000 Subject: [PATCH 1/4] add test_nestedtensor_xpu.py --- test/xpu/skip_list_common.py | 1 + test/xpu/test_nestedtensor_xpu.py | 9373 +++++++++++++++++++++++++++-- 2 files changed, 8827 insertions(+), 547 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 9f810433bd..620f3516f9 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -804,4 +804,5 @@ "functorch/test_ops_xpu.py": None, "test_sparse_xpu.py": None, "test_sparse_csr_xpu.py": None, + "test_nestedtensor_xpu.py": None, } diff --git a/test/xpu/test_nestedtensor_xpu.py b/test/xpu/test_nestedtensor_xpu.py index 756b05f18a..45b061c212 100644 --- a/test/xpu/test_nestedtensor_xpu.py +++ b/test/xpu/test_nestedtensor_xpu.py @@ -1,44 +1,598 @@ -# Owner(s): ["module: intel"] - +# Owner(s): ["module: nestedtensor"] +# ruff: noqa: F841 import ast +import io +import itertools +import math +import os +import random import sys +import tempfile import unittest +from functools import partial +from typing import Optional +import numpy as np import torch +import torch._dynamo +import torch._dynamo.testing +import torch.nn import torch.nn.functional as F -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION +from torch.nested._internal.nested_tensor import ( + buffer_from_jagged, + jagged_from_list, + nested_view_from_values_offsets, + NestedTensor, + ViewNestedFromBuffer, +) +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FUSED_ATTENTION, + SM70OrLater, + SM80OrLater, + tf32_on_and_off, +) from torch.testing._internal.common_device_type import ( dtypes, + dtypesIfCUDA, instantiate_device_type_tests, + onlyCPU, + onlyCUDA, + onlyOn, + ops, + PYTORCH_CUDA_MEMCHECK, + skipCPUIf, + skipCUDAIf, + skipCUDAIfRocm, skipMeta, ) +from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_utils import ( + decorateIf, + freeze_rng_state, + gradcheck, instantiate_parametrized_tests, + IS_FBCODE, IS_WINDOWS, + markDynamoStrictTest, + NestedTensorTestCase, parametrize, run_tests, + serialTest, + skipIfSlowGradcheckEnv, skipIfTorchDynamo, + subtest, + TEST_WITH_ROCM, + xfailIfTorchDynamo, +) +from torch.testing._internal.opinfo.core import ( + BinaryUfuncInfo, + ReductionOpInfo, + sample_skips_and_xfails, + SkipRule, + XFailRule, ) +from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db +from torch.utils._pytree import tree_flatten, tree_map_only +from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts + +# Tests are ported from pytorch/nestedtensor. +# This makes porting as_nested_tensor easier in the future. + + +def _iter_constructors(): + # yield as_nested_tensor + yield torch.nested.nested_tensor + + +# Returns True if the function recompiles between inputs1 and inputs2 with the +# specified dynamic setting. +def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): + compile_count = [0] + + def counter(gm, example_inputs): + compile_count[0] += 1 + return gm + + compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) + compiled_f(*inputs1) + compiled_f(*inputs2) + return compile_count[0] > 1 + + +# Helper function to generate a pair of random nested tensors +# one is contiguous, the other is not, but they appear to have same entries +# an output nested tensor consists of +# * `len(ragged_sizes)` matrices +# * matrices[i].shape == (20, ragged_sizes[i]) + + +def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16): + xs = [] + for size in ragged_sizes: + xs.append(torch.randn((size, 20), device=device, dtype=dtype)) + # contiguous nested tensor + ys = [] + for x in xs: + ys.append(x.transpose(-1, -2)) + nt_contiguous = torch.nested.nested_tensor(ys) + # noncontiguous nested tensor + n = len(ragged_sizes) + nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2) + return nt_contiguous, nt_noncontiguous + + +# Helper functions to pad a noncontiguous nested tensor +# can be replaced once to_padded_tensor supports noncontiguous memory + + +def noncontiguous_to_padded_tensor(input, shape=None): + tensors = input.unbind() + ntensors = len(tensors) + assert ntensors > 0 + if shape is None: + shape = [] + for size in tensors[0].shape: + shape.append(size) + for i in range(1, ntensors): + new_shape = tensors[i].shape + for j in range(len(shape)): + shape[j] = max(shape[j], new_shape[j]) + shape = [ntensors] + shape + result = tensors[0].new_zeros(shape) + for itensor in range(ntensors): + tensor = tensors[itensor] + view = result[itensor] + for idim in range(tensor.dim()): + view = view.narrow(idim, 0, tensor.size(idim)) + view.copy_(tensor) + return result + + +# Helper function to generate a random nested tensor + + +def random_nt( + device, + dtype, + num_tensors, + max_dims, + min_dims=None, + layout=torch.strided, + require_non_empty=True, +): + if min_dims is None: + min_dims = tuple([0] * len(max_dims)) + + assert len(max_dims) == len(min_dims) + for min_dim, max_dim in zip(min_dims, max_dims): + assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" + assert min_dim >= 0, "random_nt: min_dim must be non-negative" + if require_non_empty: + assert not ( + min_dim == 0 and max_dim == 1 + ), "random_nt: zero cannot be the only possible value if require_non_empty is True" + + if require_non_empty: + # Select a random idx that will be required to be non-empty + non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item() + + ts1 = [] + for i, _ in enumerate(range(num_tensors)): + tensor_dims = [] + for min_dim, max_dim in zip(min_dims, max_dims): + new_min_dim = min_dim + if require_non_empty and i == non_zero_idx and min_dim == 0: + new_min_dim = 1 + tensor_dims.append( + torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item() + ) + t1 = torch.randn(tensor_dims, device=device, dtype=dtype) + ts1.append(t1) + + return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout) + + +# Alternate approach to generating a random NT. +# dims should be something like [5, None, 10], with None indicating that a +# random ragged structure should be used +def random_nt_from_dims( + dims, device=None, dtype=None, layout=torch.strided, requires_grad=False +): + sizes = [ + [ + d if d is not None else torch.randint(2, 10, size=(1,)).item() + for d in dims[1:] + ] + for d in range(dims[0]) + ] + return torch.nested.nested_tensor( + [torch.randn(*size) for size in sizes], + device=device, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + ) + + +# Creates an NT matching another NT's number of components and +# shape / ragged structure for all dims specified to be -1. +def random_nt_from_similar(other, dims=None): + if dims is None: + return torch.randn_like(other) + assert len(dims) == other.dim() + assert dims[0] == -1 or dims[0] == other.size(0) + + ret_sizes = [] + for t in other.unbind(): + other_size = t.shape + ret_size = [] + for i, d in enumerate(dims[1:]): + if d == -1: + ret_size.append(other_size[i]) + else: + ret_size.append(d) + ret_sizes.append(ret_size) + + return torch.nested.nested_tensor( + [torch.randn(*size) for size in ret_sizes], device=other.device + ) + + +# makes naming nice for tests that parametrize over layout. +def layout_name(layout): + # e.g. "torch.jagged" -> "jagged" + return layout.__repr__().split(".")[-1] + + +def get_op_name(layout): + # e.g. "" -> "sum" + return layout.__name__.split(".")[0].split("_")[-1] + + +# Helper function for test_dummy_mha_with_nt +@torch.fx.wrap +def convert_dense_to_nested_tensor_legacy(values): + offsets = torch.arange( + 0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device + ) + metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1} + nt = ViewNestedFromBuffer.apply( + values.view(-1, values.shape[-1]), offsets, metadata_cache + ) + return nt + + +# Helper function for test_dummy_mha_with_nt +@torch.fx.wrap +def convert_jagged_to_nested_tensor_legacy( + values: torch.Tensor, offsets: torch.Tensor, max_length: int +) -> torch.Tensor: + metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1} + nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache) + return nt + + +# Helper function for test_dummy_mha_with_nt +@torch.fx.wrap +def convert_nt_to_jagged_legacy(nt): + return buffer_from_jagged(nt) + + +# Helper function for test_dummy_mha_with_nt +@torch.fx.wrap +def convert_dense_to_nested_tensor(values): + nt = torch.nested.as_nested_tensor(values, layout=torch.jagged) + return nt + + +# Helper function for test_dummy_mha_with_nt +@torch.fx.wrap +def convert_jagged_to_nested_tensor( + values: torch.Tensor, offsets: torch.Tensor, max_length: int +) -> torch.Tensor: + nt = torch.nested.nested_tensor_from_jagged( + values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length + ) + return nt + + +# Helper function for test_dummy_mha_with_nt +def convert_nt_to_jagged(nt): + return nt.values() + + +@markDynamoStrictTest +class TestNestedTensor(NestedTensorTestCase): + @parametrize("batch_size", [2, 4]) + @parametrize("max_seq_len", [3, 5]) + @parametrize("vocab_size", [10, 20]) + def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size): + data = [] + nested_tensor_ref_list = [] + for _ in range(batch_size): + if max_seq_len == 0: + length = 0 + else: + length = np.random.randint(low=1, high=max_seq_len) + row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) + data.append(row) + nested_tensor_ref_list.append(torch.Tensor(row)) + nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) + nested_tensor_list = nested_tensor.unbind() + for id in range(batch_size): + self.assertEqual( + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) + ) + + @parametrize("batch_size", [2, 4]) + @parametrize("max_seq_len", [3, 5]) + @parametrize("vocab_size", [10, 20]) + def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size): + data = [] + nested_tensor_ref_list = [] + for _ in range(batch_size): + if max_seq_len == 0: + length = 0 + else: + length = np.random.randint(low=1, high=max_seq_len) + row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) + row = [list(item * np.arange(max_seq_len)) for item in row] + data.append(row) + nested_tensor_ref_list.append(torch.Tensor(row)) + nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) + nested_tensor_list = nested_tensor.unbind() + for id in range(batch_size): + self.assertEqual( + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64) + ) + + @parametrize("batch_size", [2, 4]) + @parametrize("max_seq_len", [3, 5]) + @parametrize("vocab_size", [10, 20]) + def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size): + data = [] + nested_tensor_ref_list = [] + for _ in range(batch_size): + if max_seq_len == 0: + length = 0 + else: + length = np.random.randint(low=1, high=max_seq_len) + row = list( + np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float) + ) + row = [list(item * np.arange(max_seq_len)) for item in row] + data.append(row) + nested_tensor_ref_list.append(torch.Tensor(row)) + nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float) + nested_tensor_list = nested_tensor.unbind() + for id in range(batch_size): + self.assertEqual( + nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float) + ) + + @torch.inference_mode() + def _test_unbind_case(self, a, b): + nt = torch.nested.nested_tensor([a, b]) + a1, b1 = nt.unbind() + self.assertTrue(a is not a1) + self.assertTrue(b is not b1) + + nt = torch.nested.nested_tensor([a, b], dtype=a.dtype) + a1, b1 = nt.unbind(0) + self.assertEqual(a, a1) + self.assertEqual(b, b1) + + a = torch.randn((2, 3)).add_(1) + nt = torch.nested.nested_tensor([a]) + self.assertEqual(a, nt.unbind(0)[0]) + + @torch.inference_mode() + def test_unbind_0(self): + self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8])) + + @torch.inference_mode() + def test_unbind_1(self): + self._test_unbind_case(torch.tensor([1]), torch.tensor([7])) + + @torch.inference_mode() + def test_unbind_3(self): + self._test_unbind_case(torch.tensor([1.0]), torch.tensor([])) + + @torch.inference_mode() + def test_unbind_4(self): + self._test_unbind_case(torch.tensor([]), torch.tensor([])) + + @torch.inference_mode() + def test_unbind_dim(self): + def _test_fn(unbind_fn): + a = torch.rand(3, 2) + b = torch.rand(2, 3) + nt = torch.nested.nested_tensor([a, b]) + self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1)) + + # Both of these tests are necessary, because we're using + # torch_function. + _test_fn(lambda x, dim: x.unbind(dim)) + # TODO: Re-enable this once using torch_dispatch + # _test_fn(lambda x, dim: torch.unbind(x, dim)) + + @torch.inference_mode() + def test_nested_tensor(self): + self.assertRaises( + TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0])) + ) + self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0)) + + @torch.inference_mode() + def test_nested_tensor_matching_dim(self): + self.assertRaisesRegex( + RuntimeError, + "Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.", + lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]), + ) + self.assertRaisesRegex( + RuntimeError, + "Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.", + lambda: torch.nested.nested_tensor( + [torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])] + ), + ) + + @torch.inference_mode() + def test_default_nested_tensor(self): + self.assertRaises(TypeError, lambda: torch.nested.nested_tensor()) + default_nested_tensor = torch.nested.nested_tensor([]) + default_tensor = torch.tensor([]) + # self.assertEqual(default_nested_tensor.nested_dim(), 1) + # self.assertEqual(default_nested_tensor.nested_size(), ()) + self.assertEqual(default_nested_tensor.dim(), default_tensor.dim()) + self.assertEqual(default_nested_tensor.layout, default_tensor.layout) + self.assertEqual(default_nested_tensor.device, default_tensor.device) + self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype) + self.assertEqual( + default_nested_tensor.requires_grad, default_tensor.requires_grad + ) + self.assertIsNone(default_tensor.grad) + # TODO: Re-enable once we have a performance driven + # use case and implementation. + # self.assertEqual(default_nested_tensor.is_pinned(), + # default_tensor.is_pinned()) + + @torch.inference_mode() + def test_dim(self): + for constructor in _iter_constructors(): + a1 = constructor([]) + self.assertEqual(a1.dim(), 1) + a1 = constructor([torch.tensor(3.0)]) + self.assertEqual(a1.dim(), 1) + a1 = constructor([torch.tensor([1, 2, 3, 4])]) + self.assertEqual(a1.dim(), 2) + + @unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.") + @torch.inference_mode() + def test_numel(self): + for constructor in _iter_constructors(): + a1 = constructor([]) + self.assertEqual(a1.numel(), 0) + a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)]) + self.assertEqual(a1.numel(), 2) + a1 = constructor([torch.randn(2, 2, 2)]) + self.assertEqual(a1.numel(), 8) + a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)]) + self.assertEqual(a1.numel(), 12) + a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)]) + self.assertEqual(a1.numel(), 27) + a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)]) + self.assertEqual(a1.numel(), 341) + + # Interesting edge case + a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)]) + self.assertEqual(a1.numel(), 6) + + @torch.inference_mode() + def test_size(self): + for constructor in _iter_constructors(): + a1 = constructor([]) + self.assertRaisesRegex( + RuntimeError, + "NestedTensorImpl doesn't support sizes", + lambda: a1.size(), + ) + + def test_size_dim(self): + a = torch.nested.nested_tensor([]) + self.assertEqual(a.size(0), 0) + + a = torch.nested.nested_tensor([torch.tensor(1)]) + self.assertEqual(a.size(0), 1) + + a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)]) + self.assertEqual(a.size(0), 2) + + a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)]) + self.assertEqual(a.size(0), 2) + self.assertEqual(a.size(1), 1) + self.assertRaisesRegex( + RuntimeError, + "Given dimension 2 is irregular and does not have a size", + lambda: a.size(2), + ) + + a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)]) + self.assertEqual(a.size(0), 2) + self.assertRaisesRegex( + RuntimeError, + "Given dimension 1 is irregular and does not have a size", + lambda: a.size(1), + ) + self.assertEqual(a.size(2), 4) + + @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.") + @torch.inference_mode() + def test_stride(self): + for constructor in _iter_constructors(): + a1 = constructor([]) + self.assertRaisesRegex( + RuntimeError, + "NestedTensorImpl doesn't support strides", + lambda: a1.stride(), + ) -try: - from xpu_test_utils import XPUPatchForImport -except Exception as e: - from .xpu_test_utils import XPUPatchForImport - -with XPUPatchForImport(False): - from test_nestedtensor import ( - convert_jagged_to_nested_tensor, - get_tolerances, - random_nt, - random_nt_noncontiguous_pair, - TestNestedTensor, - TestNestedTensorAutograd, - TestNestedTensorDeviceType, - TestNestedTensorOpInfo, - TestNestedTensorSubclass, - ) - - def _test_to(self): + @unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.") + @torch.inference_mode() + def test_is_contiguous(self): + # Test empty case + nt_empty = torch.nested.nested_tensor([]) + assert nt_empty.is_contiguous() + self.assertEqual(nt_empty, nt_empty.contiguous()) + + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) + + # Test contiguous case + assert nt_contiguous.is_contiguous() + self.assertEqual(nt_contiguous, nt_contiguous.contiguous()) + + # Test non_contiguous case + assert not nt_noncontiguous.is_contiguous() + self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous()) + + # Test querying by memory_format + self.assertTrue( + nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + + @torch.inference_mode() + def test_repr_string(self): + a = torch.nested.nested_tensor([]) + expected = "nested_tensor([\n\n])" + self.assertEqual(str(a), expected) + self.assertEqual(repr(a), expected) + + a = torch.nested.nested_tensor([torch.tensor(1.0)]) + expected = "nested_tensor([\n tensor(1.)\n])" + self.assertEqual(str(a), expected) + self.assertEqual(repr(a), expected) + + a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])]) + expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])" + self.assertEqual(str(a), expected) + self.assertEqual(repr(a), expected) + + def test_to_padded_tensor_on_empty_tensor(self): + nt = torch.nested.nested_tensor([]) + empty = torch.nested.to_padded_tensor(nt, 4) + self.assertEqual(empty, torch.tensor([])) + + def test_nested_namespace(self): + nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)]) + result = nt.to_padded_tensor(4) + nested_namespace_result = torch.nested.to_padded_tensor(nt, 4) + self.assertEqual(result, nested_namespace_result) + + def test_to(self): ntensors = 4 nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) @@ -53,11 +607,11 @@ def test_copy_behavior(t, non_blocking=False): ) devices = [t.device] - if t.device.type == "xpu": + if t.device.type == "cuda": if t.device.index == -1: - devices.append(f"xpu:{torch.xpu.current_device()}") - elif t.device.index == torch.xpu.current_device(): - devices.append("xpu") + devices.append(f"cuda:{torch.cuda.current_device()}") + elif t.device.index == torch.cuda.current_device(): + devices.append("cuda") for device in devices: self.assertIs(t, t.to(device, non_blocking=non_blocking)) self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) @@ -83,22 +637,22 @@ def test_data_ptr(getter): test_data_ptr(lambda nt: nt.data_ptr()) - if torch.xpu.is_available(): + if torch.cuda.is_available(): for non_blocking in [True, False]: - for xpu in [ - "xpu", - "xpu:0" if torch.xpu.device_count() == 1 else "xpu:1", + for cuda in [ + "cuda", + "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1", ]: - nt2 = random_nt(xpu, torch.float32, ntensors, (4, 4)) + nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4)) test_copy_behavior(nt2, non_blocking) self.assertEqual( - nt2.device, nt2.to(xpu, non_blocking=non_blocking).device + nt2.device, nt2.to(cuda, non_blocking=non_blocking).device ) self.assertEqual( nt.device, nt2.to("cpu", non_blocking=non_blocking).device ) self.assertEqual( - nt2.device, nt.to(xpu, non_blocking=non_blocking).device + nt2.device, nt.to(cuda, non_blocking=non_blocking).device ) self.assertIs( torch.int32, @@ -115,7 +669,7 @@ def test_data_ptr(getter): self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype) self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device) - def _test_copy_(self): + def test_copy_(self): ntensors = 4 nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) nt_copy = torch.empty_like(nt) @@ -131,11 +685,11 @@ def _test_copy_(self): lambda: nt_error.copy_(nt), ) - if torch.xpu.is_available(): - nt = random_nt(torch.device("xpu"), torch.float32, ntensors, (4, 4)) + if torch.cuda.is_available(): + nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4)) nt_copy = torch.empty_like(nt, device=torch.device("cpu")) nt_copy.copy_(nt, non_blocking=True) - torch.xpu.current_stream(torch.xpu.current_device()).synchronize() + torch.cuda.current_stream(torch.cuda.current_device()).synchronize() for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) @@ -144,389 +698,6173 @@ def _test_copy_(self): for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy): self.assertEqual(nt_ub, nt_copy_ub) - @skipMeta - def _test_device_checks(self, device): - nt = torch.nested.nested_tensor([], device=device) - is_xpu = "xpu" in str(device) - self.assertEqual(nt.is_xpu, is_xpu) + def test_fill_(self): + ntensors = 4 + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) + nt.fill_(10.0) + for nt_ub in nt.unbind(): + t = torch.empty_like(nt_ub) + t.fill_(10.0) + self.assertEqual(nt_ub, t) - @dtypes(torch.float, torch.float16, torch.double) - def _test_empty_like(self, device, dtype): + fill_tensor = torch.tensor([11.0]) + self.assertRaisesRegex( + RuntimeError, + "fill_ only supports 0-dimension value tensor", + lambda: nt.fill_(fill_tensor), + ) + + nt.fill_(fill_tensor[0]) + for nt_ub in nt.unbind(): + t = torch.empty_like(nt_ub) + t.fill_(11.0) + self.assertEqual(nt_ub, t) + + def test_zero_(self): ntensors = 4 - nt = random_nt(device, dtype, ntensors, (4, 4)) + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) + nt.zero_() + for nt_ub in nt.unbind(): + t = torch.empty_like(nt_ub) + t.fill_(0.0) + self.assertEqual(nt_ub, t) - # Create empty on same device as original nested tensor - nt_empty = torch.empty_like(nt) - assert nt.is_same_size(nt_empty) - self.assertEqual(nt.dtype, nt_empty.dtype) - self.assertEqual(nt.device, nt_empty.device) - self.assertEqual(nt.layout, nt_empty.layout) + @parametrize( + "func", + [torch.ones_like, torch.zeros_like, torch.randn_like], + name_fn=lambda f: f.__name__, + ) + def test_like_functions(self, func): + ntensors = 4 + nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4)) + torch.manual_seed(1) + nt_like = func(nt) - if torch.xpu.is_available(): - if device == "cpu": - nt_xpu = torch.empty_like(nt, device="xpu") - self.assertEqual(torch.device("xpu").type, nt_xpu.device.type) - else: - nt_cpu = torch.empty_like(nt, device="cpu") - self.assertEqual(torch.device("cpu").type, nt_cpu.device.type) + torch.manual_seed(1) + for nt_ub in nt_like.unbind(): + t_like = func(nt_ub) + self.assertEqual(nt_ub, t_like) - # Check changing dtype of empty_like nested tensor output - dtype_set = {torch.float, torch.float16, torch.double} - for other_dtype in dtype_set - {dtype}: - nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype) - self.assertEqual(nt.dtype, dtype) - self.assertEqual(nt_empty_other_dtype.dtype, other_dtype) - self.assertEqual(nt.device, nt_empty.device) - self.assertEqual(nt.layout, nt_empty.layout) + def test_cat(self): + # dim=0 success case + # No constraints on ragged structures matching. + x = random_nt_from_dims([5, None, 10]) + y = random_nt_from_dims([3, 4, None]) + output = torch.cat([x, y], dim=0) + for out_component, xy_component in zip( + output.unbind(), itertools.chain(x.unbind(), y.unbind()) + ): + self.assertEqual(out_component, xy_component) - # Create tensor for autograd - nt_empty_req_grad = torch.empty_like(nt, requires_grad=True) - self.assertEqual(nt_empty_req_grad.requires_grad, True) + # dim=-1 success case + # shape (B, *, D) + x = random_nt_from_dims([5, None, 10]) + # shape (B, *, D'); same structure as x but dim=-1 differs + y = random_nt_from_similar(x, dims=[-1, -1, 8]) + # should be shape (B, *, D + D') when supported + output = torch.cat([x, y], dim=-1) + for out_component, x_component, y_component in zip( + output.unbind(), x.unbind(), y.unbind() + ): + self.assertEqual( + out_component, torch.cat([x_component, y_component], dim=-1) + ) - # Test noncontiguous tensor does not fail to copy - nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7)) - nt_empty = torch.empty_like(nt_cont) - assert nt_cont.is_same_size(nt_empty) - nt_empty_non_contig = torch.empty_like(nt_noncont) - assert nt_noncont.is_same_size(nt_empty_non_contig) + # dim between 0 and -1 success case + x = random_nt_from_dims([5, None, 2, 3]) + # same structure as x but dim=2 differs + y = random_nt_from_similar(x, dims=[-1, -1, 4, -1]) + output = torch.cat([x, y], dim=2) + for out_component, x_component, y_component in zip( + output.unbind(), x.unbind(), y.unbind() + ): + self.assertEqual( + out_component, torch.cat([x_component, y_component], dim=1) + ) - # Test the contiguous memory format option - nt_empty_contig = torch.empty_like( - nt_cont, memory_format=torch.contiguous_format - ) - assert nt_cont.is_same_size(nt_empty_contig) - assert nt_empty_contig.is_contiguous() + # error case: mixed NT / dense inputs + x = random_nt_from_dims([5, None, 2]) + y = torch.randn(5, 3, 2) + with self.assertRaisesRegex( + RuntimeError, "expected each tensor in given list to be nested" + ): + torch.cat([x, y], dim=-1) - nt_empty_non_contig = torch.empty_like( - nt_noncont, memory_format=torch.contiguous_format - ) - assert nt_noncont.is_same_size(nt_empty_non_contig) - assert nt_empty_non_contig.is_contiguous() + # error case: NTs with different dims + x = random_nt_from_dims([5, None, 2]) + y = random_nt_from_dims([5, None, 2, 3]) + with self.assertRaisesRegex( + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): + torch.cat([x, y], dim=-1) - # Test other memory formats fail - self.assertRaises( + # error case: non-contiguous NT + x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32) + # transpose to put ragged dim next to batch dim + x, y = x.transpose(-2, -1), y.transpose(-2, -1) + with self.assertRaisesRegex( + RuntimeError, "only contiguous nested tensors are supported" + ): + torch.cat([x, y], dim=-1) + + # error case: multiple ragged dims in inputs + x = random_nt_from_dims([5, None, None, 2]) + y = random_nt_from_similar(x) + with self.assertRaisesRegex( RuntimeError, - lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last), - ) - self.assertRaises( + "only nested tensors with a single ragged dim next to the batch dim are supported", + ): + torch.cat([x, y], dim=-1) + + # error case: ragged dim not next to batch dim + x = random_nt_from_dims([5, 2, None]) + y = random_nt_from_similar(x) + with self.assertRaisesRegex( RuntimeError, - lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last), + "only nested tensors with a single ragged dim next to the batch dim are supported", + ): + torch.cat([x, y], dim=1) + + # error case: NTs with different batch sizes + x = random_nt_from_dims([5, None, 2]) + y = random_nt_from_dims([3, None, 2]) + with self.assertRaisesRegex( + RuntimeError, + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): + torch.cat([x, y], dim=-1) + + # error case: NTs with different ragged structures + x = torch.nested.nested_tensor( + [ + torch.randn(2, 6), + torch.randn(4, 6), + torch.randn(5, 6), + ] ) - self.assertRaises( + y = torch.nested.nested_tensor( + [ + torch.randn(5, 6), + torch.randn(4, 6), + torch.randn(2, 6), + ] + ) + with self.assertRaisesRegex( RuntimeError, - lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d), + "expected all nested tensors to have matching ragged structures outside of the concatenated dim", + ): + torch.cat([x, y], dim=-1) + + # https://github.com/pytorch/pytorch/issues/161812 + def test_jagged_with_dim_error(self): + x = torch.nested.nested_tensor( + [torch.ones(3, 2, 3), torch.ones(4, 2, 3)], layout=torch.jagged ) - self.assertRaises( + with self.assertRaisesRegex( RuntimeError, - lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d), + "not supported for NestedTensor on dim=0", + ): + torch.cat([x, x]) + with self.assertRaisesRegex( + RuntimeError, + "not supported for NestedTensor on dim=0", + ): + torch.stack([x, x]) + + def test_nested_view_from_buffer_overflow_errors(self): + buffer = torch.tensor([1]) + sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64) + strides = torch.tensor( + [[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64 ) + offsets = torch.tensor( + [[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64 + ) + with self.assertRaisesRegex( + RuntimeError, + r"Storage size calculation overflowed with sizes=\[9223372036854775807\] and strides=\[1094795585\]", + ): + nt = torch._nested_view_from_buffer(buffer, sizes, strides, offsets) - @dtypes(torch.float32) - def _test_linear_backward_memory_usage(self, device, dtype): - # Verify that linear_backward() doesn't use more memory than it should - # for higher dim input sizes. - # See https://github.com/pytorch/pytorch/issues/141112 - B, D, max_seq_len = 64, 512, 100 - m = torch.nn.Linear(D, D, device=device) - nt = torch.nested.as_nested_tensor( - [ - torch.rand(size=[seq_len, D]) - for seq_len in torch.randint(max_seq_len, size=(B,)) - ], - layout=torch.jagged, - device=device, + +@markDynamoStrictTest +class TestNestedTensorDeviceType(NestedTensorTestCase): + # Helper function to generate a pair of random nested tensors + # the 2 nested tensors have same shapes + def random_nt_pair(self, device, dtype, num_tensors, max_dims): + ts1 = [] + ts2 = [] + for _ in range(num_tensors): + tensor_dims = tuple( + [ + torch.randint(low=0, high=max_dim, size=(1,)).item() + for max_dim in max_dims + ] + ) + t1 = torch.randn(tensor_dims, device=device, dtype=dtype) + t2 = torch.randn(tensor_dims, device=device, dtype=dtype) + ts1.append(t1) + ts2.append(t2) + return ( + torch.nested.nested_tensor(ts1, device=device, dtype=dtype), + torch.nested.nested_tensor(ts2, device=device, dtype=dtype), ) - # (B, j1, D) -> (B, j1, 1, D) for a higher dim input size - nt = nt.unsqueeze(-2) - # linear_backward() should not explode the max memory usage - torch.xpu.reset_max_memory_allocated() - m(nt).sum().backward() - # expect under a GB for max memory allocated - max_after_gb = torch.xpu.max_memory_allocated(0) // (1024**3) - self.assertEqual(max_after_gb, 0) + @dtypes(*floating_types_and_half()) + def test_detach(self, device, dtype): + a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False) + b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False) + x = torch.nested.nested_tensor([a, b], requires_grad=True) - @dtypes(torch.float32) - def _test_record_stream(self, device, dtype): - def _create_nt(): - values = torch.ones(1024, 4 * 1024, device="xpu") - offsets = torch.tensor([0, 500, 1024], device="xpu", dtype=torch.int64) - lengths = offsets.diff() - nt = torch.nested.nested_tensor_from_jagged(values, offsets, lengths) - data_ptrs = { - nt._values.data_ptr(), - nt._offsets.data_ptr(), - nt._lengths.data_ptr(), - } - return nt, data_ptrs - - def fn(record_stream): - nt, data_ptrs = _create_nt() - s = torch.xpu.Stream() - - with torch.xpu.stream(s): - # emulate doing something long via sleep - per_ms = 2e7 - torch.xpu._sleep(int(per_ms * 100)) - if record_stream: - nt.record_stream(s) - return data_ptrs + x_detach = x.detach() - # expect memory reuse when record_stream() is not run - data_ptrs = fn(record_stream=False) - nt, nt_data_ptrs = _create_nt() - self.assertEqual(data_ptrs, nt_data_ptrs) - del nt - torch.xpu.synchronize() - - # expect memory to be preserved (no reuse) when record_stream() is run - data_ptrs = fn(record_stream=True) - nt, nt_data_ptrs = _create_nt() - self.assertEqual(len(data_ptrs.intersection(nt_data_ptrs)), 0) + z = x_detach * 4 + self.assertFalse(x_detach.requires_grad) + self.assertFalse(z.requires_grad) - @dtypes(torch.float32) - def _test_construction_from_list(self, device, dtype): - from torch.fx.experimental.symbolic_shapes import is_nested_int + a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True) + b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True) + x = torch.nested.as_nested_tensor([a, b]) - # success case: single ragged dim anywhere but the batch dim - for nt_dim in [2, 3, 4]: - for ragged_dim in range(1, nt_dim): - B = 6 - shapes = [list(range(3, 3 + nt_dim - 1)) for _ in range(B)] - for b in range(B): - # subtract 1 to convert to component dim space - shapes[b][ragged_dim - 1] = torch.randint( - 2, 9, (1,), device=device, dtype=torch.int64 - ).item() + y = x * 2 + y = y.detach() + self.assertFalse(y.requires_grad) + self.assertIsNone(y.grad_fn) - components = [ - torch.randn(shape, device=device, dtype=dtype) for shape in shapes - ] - nt = torch.nested.nested_tensor(components, layout=torch.jagged) + z = x + y + torch.nested.to_padded_tensor(z, 0).sum().backward() + # This is an incorrect gradient, but we assume that's what the user + # wanted. detach() is an advanced option. + self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) + self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) - self.assertEqual(nt.dim(), nt_dim) - self.assertEqual(nt._ragged_idx, ragged_dim) - for d in range(nt_dim): - self.assertEqual(d == ragged_dim, is_nested_int(nt.shape[d])) + @dtypes(torch.float, torch.double, torch.half) + @parametrize("requires_grad", [False, True]) + @parametrize("weights_only", [False, True]) + def test_serialization(self, device, dtype, requires_grad, weights_only): + def compare_metadata(nt1, nt2): + self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size()) + self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides()) + self.assertEqual( + nt1._nested_tensor_storage_offsets(), + nt2._nested_tensor_storage_offsets(), + ) - # error case: empty list - with self.assertRaisesRegex( - RuntimeError, "Cannot construct a nested tensor from an empty tensor list" - ): - torch.nested.nested_tensor([], layout=torch.jagged) + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) + for a in [nt_contiguous, nt_noncontiguous]: + buffer = io.BytesIO() + serialized = torch.save(a, buffer) + buffer.seek(0) + b = torch.load(buffer, weights_only=weights_only) + # should be both conceptually equal and metadata equivalent + self.assertEqual(a, b) + compare_metadata(a, b) + # should be conceptually equal but not necessarily metadata equivalent + self.assertEqual(b, nt_contiguous) + self.assertEqual(b, nt_noncontiguous) - # error case: list of zero-dim tensors - with self.assertRaisesRegex( - RuntimeError, - "Cannot construct a nested tensor from a list of zero-dim tensors", - ): - torch.nested.nested_tensor( - [ - torch.tensor(3.0, device=device, dtype=dtype), - torch.tensor(4.0, device=device, dtype=dtype), - torch.tensor(5.0, device=device, dtype=dtype), - ], - layout=torch.jagged, - ) + @dtypes(torch.float, torch.float16, torch.double) + def test_unbind_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) + ub_contiguous = nt_contiguous.unbind() + ub_noncontiguous = nt_noncontiguous.unbind() + self.assertEqual(len(ub_contiguous), len(ub_noncontiguous)) + n = len(ub_contiguous) + for i in range(n): + self.assertEqual(ub_contiguous[i], ub_noncontiguous[i]) - # error case: multiple ragged dims - with self.assertRaisesRegex( - RuntimeError, - "Cannot represent given tensor list as a nested tensor with the jagged layout", - ): - torch.nested.nested_tensor( - [ - torch.randn(2, 3, device=device, dtype=dtype), - torch.randn(4, 5, device=device, dtype=dtype), - ], - layout=torch.jagged, - ) + @dtypes(torch.float) + @skipMeta + def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): + t = torch.randn(4, 4, 4, device=device, dtype=dtype) + ts = list(torch.unbind(t)) + ts[0] = ts[0][:-1] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + padded = torch.nested.to_padded_tensor(nt, 0) - # error case: components on multiple devices - if "xpu" in device: - with self.assertRaisesRegex( - RuntimeError, - "When constructing a nested tensor, all tensors in list must be on the same device", - ): - torch.nested.nested_tensor( - [ - torch.randn(2, 3, device=device, dtype=dtype), - torch.randn(2, 4, device="cpu", dtype=dtype), - ], - layout=torch.jagged, - ) + nt_to = torch._nested_from_padded_and_nested_example(padded, nt) - # error case: components with multiple dtypes - with self.assertRaisesRegex( - RuntimeError, - "When constructing a nested tensor, all tensors in list must have the same dtype", - ): - torch.nested.nested_tensor( - [ - torch.randn(2, 3, device=device, dtype=dtype), - torch.randn(2, 4, device=device, dtype=torch.float64), - ], - layout=torch.jagged, - ) + for t1, t2 in zip(nt.unbind(), nt_to.unbind()): + self.assertEqual(t1, t2) + self.assertEqual(nt.device, nt_to.device) - # error case: components with multiple dims - with self.assertRaisesRegex( - RuntimeError, - "When constructing a nested tensor, all tensors in list must have the same dim", - ): - torch.nested.nested_tensor( - [ - torch.randn(2, 3, device=device, dtype=dtype), - torch.randn(2, 3, 4, device=device, dtype=dtype), - ], - layout=torch.jagged, - ) + @dtypes(torch.float) + @dtypesIfCUDA(torch.float, torch.half) + @skipMeta + @torch.inference_mode() + def test_layer_norm(self, device, dtype): + def _test(size): + # Simple shapes test + t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) + t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) + ts = [t0, t1, t0, t1] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) + nt_result = layer_norm(nt) + for nt_subresult, t in zip(nt_result.unbind(), ts): + t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) + self.assertEqual(nt_subresult, t_result) - def _test_index_put_error(self, device): - import subprocess + # More complex nt test with different lengths for each tensor + t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False) + t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False) + t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False) + ts = [t0, t1, t2, t0, t2] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) + nt_result = layer_norm(nt) + for nt_subresult, t in zip(nt_result.unbind(), ts): + t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) + self.assertEqual(nt_subresult, t_result) - with self.subTest(): - r = subprocess.call( - [ - sys.executable, - "-c", - """\ -import torch -offsets = torch.tensor([0, 2, 5, 7], device='xpu') -lengths = torch.tensor([2, 2, 2], device='xpu') -indices = [ - torch.tensor([0, 1, 2], device='xpu'), - torch.tensor([0, 2, 1], device='xpu'), - torch.tensor([0, 0, 0], device='xpu'), -] -a = torch.nested.nested_tensor_from_jagged( - torch.zeros(7, 3, device='xpu'), offsets, lengths -) -a[indices] = 1.0 -torch.xpu.synchronize() -""", - ] - ) - self.assertTrue(r != 0) + if size <= 128: + # Test with multidimensional tensors after irregular dim + # (run only with smaller dimensions to ensure fast execution) + t0 = torch.randn( + 4, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t1 = torch.randn( + 10, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + t2 = torch.randn( + 7, size, size, 4, device=device, dtype=dtype, requires_grad=False + ) + ts = [t0, t1, t2, t0, t2] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + layer_norm = torch.nn.LayerNorm( + (size, size, 4), device=device, dtype=dtype + ) + nt_result = layer_norm(nt) + for nt_subresult, t in zip(nt_result.unbind(), ts): + t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) + self.assertEqual(nt_subresult, t_result) - @dtypes(torch.float16, torch.bfloat16, torch.float32) - def _test_sdpa(self, device, dtype): - batch_size = 1 - emb_dims = 128 - n_heads = 8 - head_dims = emb_dims // n_heads + # Test where the normalizing dimensions are not all + layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype) + nt_result = layer_norm(nt) + for nt_subresult, t in zip(nt_result.unbind(), ts): + t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) + self.assertEqual(nt_subresult, t_result) - sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) - sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) + for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32): + _test(size) - query = torch.nn.Linear( - emb_dims, emb_dims, bias=False, device=device, dtype=dtype - ) - key = torch.nn.Linear( - emb_dims, emb_dims, bias=False, device=device, dtype=dtype + @dtypes(torch.float) + @dtypesIfCUDA(torch.float, torch.half) + @skipMeta + @torch.inference_mode() + def test_layer_norm_breaking(self, device, dtype): + size = 128 + t0 = torch.randn( + 4, size, size, 4, device=device, dtype=dtype, requires_grad=False ) - value = torch.nn.Linear( - emb_dims, emb_dims, bias=False, device=device, dtype=dtype + t1 = torch.randn( + 10, size, size, 4, device=device, dtype=dtype, requires_grad=False ) - - # Simplest case: 1 sentence, no batching - x_d1 = sen1.unsqueeze(0) - x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) - - # See note below for why we detach here. - q_d1 = ( - query(x_d1) - .view(batch_size, -1, n_heads, head_dims) - .detach() - .requires_grad_(True) + t2 = torch.randn( + 7, size, size, 4, device=device, dtype=dtype, requires_grad=False ) - q_d1_t = q_d1.transpose(1, 2) - k_d1 = ( - key(x_d1) - .view(batch_size, -1, n_heads, head_dims) - .detach() - .requires_grad_(True) + ts = [t0, t1, t2, t0, t2] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "normalized_shape extends into irregular dimensions for the nested tensor", + lambda: layer_norm(nt), ) - k_d1_t = k_d1.transpose(1, 2) - v_d1 = ( - value(x_d1) - .view(batch_size, -1, n_heads, head_dims) - .detach() - .requires_grad_(True) + layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "The shape at dimension 0", + lambda: layer_norm(nt), ) - v_d1_t = v_d1.transpose(1, 2) - q_nt = ( - query(x_nt) - .view(*x_nt.size()[0:2], n_heads, head_dims) - .detach() - .requires_grad_(True) + @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) + def test_embedding(self, device, layout): + inputs = [ + torch.randint(100, (L,), device=device, dtype=torch.int64) + for L in torch.randint(5, 50, (8,)) + ] + x = torch.nested.nested_tensor( + inputs, device=device, dtype=torch.int64, layout=layout ) - q_nt_t = q_nt.transpose(1, 2) - k_nt = ( - key(x_nt) - .view(*x_nt.size()[0:2], n_heads, head_dims) - .detach() - .requires_grad_(True) + emb = torch.nn.Embedding(100, 8, device=device) + y = emb(x) + if layout == torch.jagged: + y.backward(torch.randn_like(y)) + + @torch._dynamo.disable + def check(inputs, y): + ys = y.unbind() + for i, inp in enumerate(inputs): + self.assertEqual(emb(inp), ys[i]) + + check(inputs, y) + + @dtypes( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float, + torch.float16, + torch.bfloat16, + torch.double, + ) + def test_jagged_max_dtypes(self, device, dtype): + x = torch.nested.nested_tensor( + [torch.arange(0, n, dtype=dtype, device=device) for n in (10, 20, 30)], + layout=torch.jagged, ) - k_nt_t = k_nt.transpose(1, 2) - v_nt = ( - value(x_nt) - .view(*x_nt.size()[0:2], n_heads, head_dims) - .detach() - .requires_grad_(True) + + result_max = x.max(dim=1) + expected_max = torch.tensor([9, 19, 29], dtype=dtype, device=device) + + self.assertEqual(result_max.values, expected_max) + + @dtypes( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float, + torch.float16, + torch.bfloat16, + torch.double, + ) + def test_jagged_min_dtypes(self, device, dtype): + x = torch.nested.nested_tensor( + [torch.arange(0, n, dtype=dtype, device=device) for n in (10, 20, 30)], + layout=torch.jagged, ) - v_nt_t = v_nt.transpose(1, 2) - # High Precision Math Reference - q_d1_f32 = q_d1.to(torch.float32) - k_d1_f32 = k_d1.to(torch.float32) - v_d1_f32 = v_d1.to(torch.float32) - q_d1_f32_t = q_d1_f32.transpose(1, 2) - k_d1_f32_t = k_d1_f32.transpose(1, 2) - v_d1_f32_t = v_d1_f32.transpose(1, 2) - out_ref = torch.ops.aten._scaled_dot_product_attention_math( - q_d1_f32_t, k_d1_f32_t, v_d1_f32_t - )[0] - grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32)) + result_min = x.min(dim=1) + expected_min = torch.tensor([0, 0, 0], dtype=dtype, device=device) - # Low Precision Math Reference - out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( - q_d1_t, k_d1_t, v_d1_t - )[0] - grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1)) + self.assertEqual(result_min.values, expected_min) - # Compute tolerances - output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) - # fudge factor of 1.7 for smaller GPUs e.g., A2, A16 - grad_q_ref_atol, grad_q_ref_rtol = get_tolerances( - grads_ref[0], grads_lp_ref[0], 1.7 + @dtypes( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float, + torch.float16, + torch.bfloat16, + torch.double, + ) + def test_jagged_amax_dtypes(self, device, dtype): + x = torch.nested.nested_tensor( + [torch.arange(0, n, dtype=dtype, device=device) for n in (10, 20, 30)], + layout=torch.jagged, ) - grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1]) - grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2]) - grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol] - grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol] - attn_d1 = torch.nn.functional.scaled_dot_product_attention( - q_d1_t, k_d1_t, v_d1_t - ).transpose(1, 2) - attn_nt = torch.nn.functional.scaled_dot_product_attention( - q_nt_t, k_nt_t, v_nt_t - ).transpose(1, 2) + result_amax = x.amax(dim=1) + expected_amax = torch.tensor([9, 19, 29], dtype=dtype, device=device) - self.assertEqual( - attn_d1, - attn_nt.unbind()[0].unsqueeze(0), - atol=output_ref_atol, - rtol=output_ref_rtol, + self.assertEqual(result_amax, expected_amax) + + @dtypes( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float, + torch.float16, + torch.bfloat16, + torch.double, + ) + def test_jagged_amin_dtypes(self, device, dtype): + x = torch.nested.nested_tensor( + [torch.arange(0, n, dtype=dtype, device=device) for n in (10, 20, 30)], + layout=torch.jagged, ) - # Simple case: 2 sentences, no extra params - x_d2 = sen2.unsqueeze(0) - x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) + result_amin = x.amin(dim=1) + expected_amin = torch.tensor([0, 0, 0], dtype=dtype, device=device) - # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before - # it is transposed. This is because today we cannot backward through view or unbind a + self.assertEqual(result_amin, expected_amin) + + @dtypes( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float, + torch.float16, + torch.bfloat16, + torch.double, + ) + def test_jagged_argmax_dtypes(self, device, dtype): + x = torch.nested.nested_tensor( + [torch.arange(0, n, dtype=dtype, device=device) for n in (10, 20, 30)], + layout=torch.jagged, + ) + + result_argmax = x.argmax(dim=1) + expected_argmax = torch.tensor([9, 19, 29], dtype=torch.long, device=device) + + self.assertEqual(result_argmax, expected_argmax) + + @dtypes( + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + torch.float, + torch.float16, + torch.bfloat16, + torch.double, + ) + def test_jagged_argmin_dtypes(self, device, dtype): + x = torch.nested.nested_tensor( + [torch.arange(0, n, dtype=dtype, device=device) for n in (10, 20, 30)], + layout=torch.jagged, + ) + + result_argmin = x.argmin(dim=1) + expected_argmin = torch.tensor([0, 0, 0], dtype=torch.long, device=device) + + self.assertEqual(result_argmin, expected_argmin) + + @skipMeta + @torch.inference_mode() + @dtypes(*floating_types_and_half()) + def test_masked_fill(self, device, dtype): + # nested tensor * nested tensor + (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4)) + mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()]) + ref = torch.nested.nested_tensor( + [t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())] + ) + out = nt.masked_fill(mask, 0) + self.assertEqual(ref, out) + + @dtypes(torch.float, torch.float16) + def test_to_padded_tensor_simple(self, device, dtype): + t = torch.randn(4, 4, 4, device=device, dtype=dtype) + ts = list(torch.unbind(t)) + ts[0] = ts[0][:-1] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + for padding_value in (0, 1): + padded = torch.nested.to_padded_tensor(nt, padding_value) + + correct_output = t.clone() + if padding_value == 0: + correct_output[0][-1] = torch.zeros_like(correct_output[0][-1]) + else: + correct_output[0][-1] = torch.ones_like(correct_output[0][-1]) + + self.assertEqual(padded, correct_output) + self.assertEqual(padded.device, torch.device(device)) + self.assertEqual(padded.dtype, dtype) + + @dtypes(torch.float, torch.float16) + def test_to_padded_tensor_output_size(self, device, dtype): + t = torch.randn(4, 4, 4, device=device, dtype=dtype) + output_size = (4, 6, 5) + ts = list(torch.unbind(t)) + ts[0] = ts[0][:-1] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + for padding_value in (0, 1): + padded = torch.nested.to_padded_tensor( + nt, padding_value, output_size=output_size + ) + correct_output = ( + torch.ones(output_size, device=device, dtype=dtype) * padding_value + ) + correct_output[:4:, :4, :4] = t.clone() + if padding_value == 0: + correct_output[0][3] = torch.zeros_like(correct_output[0][3]) + else: + correct_output[0][3] = torch.ones_like(correct_output[0][3]) + + self.assertEqual(padded, correct_output) + self.assertEqual(padded.device, torch.device(device)) + self.assertEqual(padded.dtype, dtype) + + @dtypes(torch.float, torch.float16, torch.double) + def test_to_padded_tensor_dim2(self, device, dtype): + ts = [ + torch.randn(160, device=device, dtype=dtype), + torch.randn(1240, device=device, dtype=dtype), + torch.randn(2400, device=device, dtype=dtype), + ] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + pad = 42 + correct_output = [] + for t in ts: + next_output = torch.ones_like(ts[2]) * pad + correct_output.append(next_output) + next_output[: t.size(0)].copy_(t) + correct_output = torch.stack(correct_output) + padded = torch.nested.to_padded_tensor(nt, pad) + self.assertEqual(padded, correct_output) + + @dtypes(torch.float, torch.float16, torch.double) + def test_to_padded_tensor_dim3(self, device, dtype): + ts = [ + torch.randn(16, 21, device=device, dtype=dtype), + torch.randn(24, 32, device=device, dtype=dtype), + torch.randn(40, 53, device=device, dtype=dtype), + ] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + pad = 42 + correct_output = [] + for t in ts: + next_output = torch.ones_like(ts[2]) * pad + correct_output.append(next_output) + next_output[: t.size(0), : t.size(1)].copy_(t) + correct_output = torch.stack(correct_output) + padded = torch.nested.to_padded_tensor(nt, pad) + self.assertEqual(padded, correct_output) + + @dtypes(torch.float, torch.float16, torch.double) + def test_to_padded_tensor_dim4(self, device, dtype): + ts = [ + torch.randn(16, 21, 13, device=device, dtype=dtype), + torch.randn(24, 32, 14, device=device, dtype=dtype), + torch.randn(40, 53, 16, device=device, dtype=dtype), + ] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + pad = 42 + correct_output = [] + for t in ts: + next_output = torch.ones_like(ts[2]) * pad + correct_output.append(next_output) + next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t) + correct_output = torch.stack(correct_output) + padded = torch.nested.to_padded_tensor(nt, pad) + self.assertEqual(padded, correct_output) + + # TODO: test noncontiguous to_padded_tensor + # For now this tests the functionality of noncontiguous_to_padded_tensor + # and the error message of to_padded_tensor + # since to_padded_tensor does not support noncontiguous buffer yet + @dtypes(torch.float, torch.float16, torch.double) + @torch.inference_mode() + def test_to_padded_tensor_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) + # test noncontiguous_to_padded_tensor functionality + self.assertEqual( + torch.nested.to_padded_tensor(nt_contiguous, 0.0), + noncontiguous_to_padded_tensor(nt_noncontiguous), + ) + # test to_padded_tensor error message + self.assertRaisesRegex( + RuntimeError, + r"for now to_padded_tensor only supports contiguous nested tensor", + lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0), + ) + + @skipMeta + def test_device_checks(self, device): + nt = torch.nested.nested_tensor([], device=device) + is_cuda = "cuda" in str(device) + self.assertEqual(nt.is_cuda, is_cuda) + + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") + def test_share_memory(self, device): + a = torch.randn(3, 4, device=device) + b = torch.randn(5, 4, device=device) + nt = torch.nested.nested_tensor([a, b], layout=torch.jagged) + + # Guard CUDA tensors + if device.split(":")[0] in ["cuda", "xpu"]: + result = nt.share_memory_() + self.assertIs(result, nt) + return + + result = nt.share_memory_() + self.assertIs(result, nt) + + # Verify in shared memory + self.assertTrue(nt.is_shared()) + + @dtypes(torch.float, torch.float16, torch.double) + def test_nested_tensor_indexing(self, device, dtype): + # edge case: empty nested tensor + nt0 = torch.nested.nested_tensor([]) + self.assertRaises(IndexError, lambda: nt0[0]) + # normal case + x0 = torch.randn((2, 5), device=device, dtype=dtype) + x1 = torch.randn((3, 4), device=device, dtype=dtype) + nt = torch.nested.nested_tensor([x0, x1]) + # single index: only support integer in the batch dimension + self.assertEqual(nt[0], x0) + self.assertEqual(nt[-1], x1) + self.assertRaises(IndexError, lambda: nt[2]) + self.assertRaises(IndexError, lambda: nt[-3]) + self.assertRaises(NotImplementedError, lambda: nt[:]) + self.assertEqual(nt[...], nt) + # tuple of indices: only support integer in the batch dimension + # + all possible indexing in the original tensor dimensions + self.assertEqual(nt[0, 0, 0], x0[0, 0]) + self.assertEqual(nt[0, 1, :], x0[1, :]) + self.assertEqual(nt[1, ...], x1) + self.assertRaises(IndexError, lambda: nt[1, 4, 2]) + self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1]) + # test select on non-batch dimensions + self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0)) + self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0)) + self.assertRaises(IndexError, lambda: nt.select(1, 3)) + self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0)) + self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0)) + self.assertRaises(IndexError, lambda: nt.select(2, 5)) + # make sure indexing returns a view + nt[0].fill_(100.0) + answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5)) + self.assertEqual(nt[0], answer) + nt[1, 1, :].fill_(200.0) + answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4) + self.assertEqual(nt[1, 1, :], answer) + + # Test that indexing works when requires_grad_(True) + # previously this was failing because the backward kernel for select.int uses .sizes() + nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True) + self.assertEqual(nt[0], x0) + self.assertEqual(nt[-1], x1) + grad_x0 = torch.randn((2, 5), device=device, dtype=dtype) + nt[0].backward(grad_x0) + expected_grad = torch.nested.nested_tensor( + [grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)] + ) + self.assertEqual(nt.grad, expected_grad) + + @parametrize( + "func", + [ + subtest(torch.nn.functional.relu, name="relu"), + subtest(torch.nn.functional.relu_, name="relu_"), + subtest(torch.nn.functional.gelu, name="gelu"), + subtest(torch._C._nn.gelu_, name="gelu_"), + subtest(torch.tanh, name="tanh"), + subtest(torch.tanh_, name="tanh_"), + subtest(torch.neg, name="neg"), + subtest(torch.nn.functional.silu, name="silu"), + subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"), + subtest(torch.abs, name="abs"), + subtest(torch.abs_, name="abs_"), + subtest(torch.sgn, name="sgn"), + subtest(torch.logical_not, name="logical_not"), + subtest(torch.sin, name="sin"), + subtest(torch.cos, name="cos"), + subtest(torch.isinf, name="isinf"), + subtest(torch.isposinf, name="isposinf"), + subtest(torch.isneginf, name="isneginf"), + subtest(torch.isnan, name="isnan"), + subtest(torch.sqrt, name="sqrt"), + ], + ) + def test_unary_funcs(self, device, func): + nt, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device=device, dtype=torch.float32 + ) + nested_result = func(nt) + self.assertTrue(nested_result.is_nested) + for t, t_res in zip(nt.unbind(), nested_result.unbind()): + self.assertEqual(func(t), t_res) + self.assertRaisesRegex( + RuntimeError, + "NestedTensor must be contiguous to get buffer.", + lambda: func(nt_noncontiguous), + ) + + def test_is_any_true_jagged(self, device): + B, Fin = 2, 6 + start = torch.zeros(B, dtype=torch.int64, device=device) + lengths = torch.tensor([3, 2], dtype=torch.int64, device=device) + + # NestedTensor reduction should operate on same data as .values(). + with self.subTest("dispatch_matches_values_buffer"): + cond = torch.tensor( + [ + [True, False, False, True, True, False], + [False, False, True, False, False, False], + ], + dtype=torch.bool, + device=device, + ) + nt = torch.nested.narrow( + cond, dim=1, start=start, length=lengths, layout=torch.jagged + ) + out_nt = torch.ops.aten._is_any_true.default(nt).item() + out_vals = torch.ops.aten._is_any_true.default(nt.values()).item() + self.assertEqual(out_nt, out_vals) + + # Verify jagged boolean behavior. + with self.subTest("all_false_returns_false"): + cond_false = torch.zeros(B, Fin, dtype=torch.bool, device=device) + nt_false = torch.nested.narrow( + cond_false, dim=1, start=start, length=lengths, layout=torch.jagged + ) + self.assertFalse(torch.ops.aten._is_any_true.default(nt_false).item()) + + with self.subTest("one_true_returns_true"): + cond_mixed = torch.zeros(B, Fin, dtype=torch.bool, device=device) + cond_mixed[0, 0] = True + nt_mixed = torch.nested.narrow( + cond_mixed, dim=1, start=start, length=lengths, layout=torch.jagged + ) + self.assertTrue(torch.ops.aten._is_any_true.default(nt_mixed).item()) + + def test_is_all_true_jagged(self, device): + B, Fin = 2, 6 + start = torch.zeros(B, dtype=torch.int64, device=device) + lengths = torch.tensor([3, 2], dtype=torch.int64, device=device) + + # NestedTensor reduction should operate on same data as .values(). + with self.subTest("dispatch_matches_values_buffer"): + cond = torch.tensor( + [ + [True, True, True, False, False, False], + [True, True, False, False, False, False], + ], + dtype=torch.bool, + device=device, + ) + nt = torch.nested.narrow( + cond, dim=1, start=start, length=lengths, layout=torch.jagged + ) + out_nt = torch.ops.aten._is_all_true.default(nt).item() + out_vals = torch.ops.aten._is_all_true.default(nt.values()).item() + self.assertEqual(out_nt, out_vals) + + # Verify jagged boolean behavior. + with self.subTest("all_true_returns_true"): + cond_true = torch.ones(B, Fin, dtype=torch.bool, device=device) + nt_true = torch.nested.narrow( + cond_true, dim=1, start=start, length=lengths, layout=torch.jagged + ) + self.assertTrue(torch.ops.aten._is_all_true.default(nt_true).item()) + + with self.subTest("any_false_returns_false"): + cond_mixed = torch.ones(B, Fin, dtype=torch.bool, device=device) + cond_mixed[0, 1] = False + nt_mixed = torch.nested.narrow( + cond_mixed, dim=1, start=start, length=lengths, layout=torch.jagged + ) + self.assertFalse(torch.ops.aten._is_all_true.default(nt_mixed).item()) + + @parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")]) + def test_binary_ops_with_scalar(self, device, func): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device=device, dtype=torch.float32 + ) + scalar = 0.0 + + # should work regardless of contiguity + for nt in (nt_contiguous, nt_noncontiguous): + nested_result = func(nt, scalar) + self.assertTrue(nested_result.is_nested) + for t, t_res in zip(nt.unbind(), nested_result.unbind()): + self.assertEqual(func(t, scalar), t_res) + + @dtypes(*floating_types_and_half()) + def test_nested_tensor_chunk(self, device, dtype): + # Transformer use case + a = torch.randn(3, 3 * 4, device=device, dtype=dtype) + b = torch.randn(2, 3 * 4, device=device, dtype=dtype) + c = torch.randn(1, 3 * 4, device=device, dtype=dtype) + a_chunks = a.chunk(3, dim=-1) + b_chunks = b.chunk(3, dim=-1) + c_chunks = c.chunk(3, dim=-1) + + a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]] + b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]] + c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]] + + nt = torch.nested.nested_tensor([a, b, c]) + chunked = nt.chunk(3, dim=-1) + + self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt)) + self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt)) + self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt)) + + for chunk in chunked: + self.assertFalse(chunk.is_contiguous()) + + # Failure chunking on ragged dimensions + self.assertRaisesRegex( + RuntimeError, + "Chunk for nested tensors is currently only supported for the last dimension.", + lambda: torch.chunk(nt, 5, dim=1), + ) + self.assertRaisesRegex( + RuntimeError, + "Chunk for nested tensors is currently only supported for the last dimension.", + lambda: torch.chunk(nt, 5, dim=0), + ) + + # Failure on non-contiguous nt + _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) + self.assertRaisesRegex( + RuntimeError, + "chunk expects `self` to be contiguous.", + lambda: torch.chunk(nt_noncontiguous, 5, dim=-1), + ) + + # Failure when calling non divisible n_chunks + self.assertRaisesRegex( + RuntimeError, + "Chunk for nested tensors is only supported for " + "nested tensors with trailing dimension divisible by chunks.", + lambda: torch.chunk(nt, 5, dim=-1), + ) + + # Failure when calling backward on a chunk + a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True) + b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True) + nt_grad = torch.nested.as_nested_tensor([a, b]) + chunked = torch.chunk(nt_grad, 2, dim=-1) + self.assertRaisesRegex( + RuntimeError, + "Nested Strided Tensor doesn't support chunk backward.", + lambda: chunked[0].backward(chunked[0].clone()), + ) + + @dtypes(*floating_types_and_half()) + def test_nested_tensor_split_with_sizes(self, device, dtype): + a = torch.randn(3, 20, device=device, dtype=dtype) + b = torch.randn(2, 20, device=device, dtype=dtype) + c = torch.randn(1, 20, device=device, dtype=dtype) + + split_sizes = [4, 6, 10] + a_splits = a.split_with_sizes(split_sizes, dim=-1) + b_splits = b.split_with_sizes(split_sizes, dim=-1) + c_splits = c.split_with_sizes(split_sizes, dim=-1) + + nt = torch.nested.nested_tensor([a, b, c]) + nt_splits = nt.split_with_sizes(split_sizes, dim=-1) + + for i, nt_split in enumerate(nt_splits): + self.assertEqual( + nt_split, + torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]), + ) + dense_strides = torch.stack( + [ + torch.tensor(a_splits[i].stride()), + torch.tensor(b_splits[i].stride()), + torch.tensor(c_splits[i].stride()), + ] + ) + self.assertEqual(nt_split._nested_tensor_strides(), dense_strides) + self.assertFalse(nt_split.is_contiguous()) + + # Failure calling on ragged dimensions + self.assertRaisesRegex( + RuntimeError, + "split_with_sizes for nested tensors is currently only supported for the last dimension.", + lambda: torch.split_with_sizes(nt, split_sizes, dim=1), + ) + + # Failure calling on non-last dimension + self.assertRaisesRegex( + RuntimeError, + "split_with_sizes for nested tensors is currently only supported for the last dimension.", + lambda: torch.split_with_sizes(nt, split_sizes, dim=0), + ) + + # Failure on non-contiguous nt + _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) + self.assertRaisesRegex( + RuntimeError, + "split_with_sizes expects `self` to be contiguous.", + lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1), + ) + + # Failure when calling with split_sizes that don't cover the full dim size + bad_split_sizes = [4, 6, 9] # don't add up to 20 + self.assertRaisesRegex( + RuntimeError, + "split_with_sizes expects split_sizes to sum exactly to 20", + lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1), + ) + + @dtypes(torch.float, torch.float16, torch.double) + @torch.inference_mode() + def test_nested_tensor_indexing_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) + self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0)) + n = nt_contiguous.size(0) + for i in range(n): + self.assertEqual(nt_contiguous[i], nt_noncontiguous[i]) + + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + @parametrize("transpose", [True, False]) + def test_nested_tensor_add(self, device, dtype, transpose): + if transpose: + a = torch.randn(2, 2, 2, device=device, dtype=dtype) + b = torch.rand(2, 2, 2, device=device, dtype=dtype) + c = a.transpose(-1, -2).contiguous() + d = b.transpose(-1, -2).contiguous() + nt1 = torch.nested.nested_tensor([a, b, a, b]) + nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) + else: + (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) + ref = torch.nested.nested_tensor( + [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) + out = nt1 + nt2 + self.assertEqual(ref, out) + + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + @parametrize("transpose", [True, False]) + def test_nested_tensor_sub(self, device, dtype, transpose): + if transpose: + a = torch.randn(2, 2, 2, device=device, dtype=dtype) + b = torch.rand(2, 2, 2, device=device, dtype=dtype) + c = a.transpose(-1, -2).contiguous() + d = b.transpose(-1, -2).contiguous() + nt1 = torch.nested.nested_tensor([a, b, a, b]) + nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) + else: + (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) + ref = torch.nested.nested_tensor( + [t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) + out = nt1 - nt2 + self.assertEqual(ref, out) + + @onlyOn(["cuda", "xpu"]) + @dtypes(torch.float, torch.float16) + @torch.inference_mode() + @parametrize("embedding_dim", [8, 128, 256, 384]) + def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim): + def _test_add_mul(nt, t): + ref_add = torch.nested.nested_tensor( + [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] + ) + ref_mul = torch.nested.nested_tensor( + [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())] + ) + self.assertEqual(nt.add(t), ref_add) + self.assertEqual(nt.mul(t), ref_mul) + + batch_size = 32 + seq_lens = torch.randint(low=0, high=10, size=(batch_size,)) + + # [B, *, D], [B, 1, D] case + ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype) + _test_add_mul(nt, t) + + # [B, *], [B, 1] case + ts = [torch.randn(seq_len) for seq_len in seq_lens] + nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) + t = torch.randn((batch_size, 1), device=device, dtype=dtype) + _test_add_mul(nt, t) + + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + def test_nested_tensor_mul(self, device, dtype): + # nested tensor * nested tensor + (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) + ref = torch.nested.nested_tensor( + [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) + out = nt1 * nt2 + self.assertEqual(ref, out) + # nested tensor * scalar + number = 10.0 + scalar = torch.tensor(number).to(dtype).to(device) + ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) + out_number0 = nt1 * number + out_number1 = number * nt1 + out_scalar0 = nt1 * scalar + out_scalar1 = scalar * nt1 + self.assertEqual(out_number0, ref) + self.assertEqual(out_number1, ref) + self.assertEqual(out_scalar0, ref) + self.assertEqual(out_scalar1, ref) + # error case: numel == 1 but dim > 0 + vector = torch.tensor([number]).to(dtype).to(device) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a nested self and non-nested other", + lambda: nt1.mul(vector), + ) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a non-nested self and nested other", + lambda: vector.mul(nt1), + ) + + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + def test_nested_tensor_div(self, device, dtype): + nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4)) + scale = 4.0 + ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()]) + out = nt / 4.0 + self.assertEqual(ref, out) + ref_transposed = ref.transpose(1, 2) + out = nt.transpose(1, 2) / 4.0 + self.assertEqual(ref_transposed, out) + + ref = torch.nested.nested_tensor( + [t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())] + ) + out = nt / nt2 + self.assertEqual(ref, out) + + out = nt.transpose(1, 2) / nt2.transpose(1, 2) + self.assertEqual(ref.transpose(1, 2), out) + + nt_transpose_copy = torch.nested.nested_tensor( + [t.transpose(0, 1) for t in nt.unbind()] + ) + + self.assertRaisesRegex( + RuntimeError, + "div requires strides to match when given NestedTensors", + lambda: nt_transpose_copy.transpose(1, 2) / nt2, + ) + + nt = torch.nested.nested_tensor( + [torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype + ) + nt_chunks = nt.chunk(2, -1) + self.assertRaisesRegex( + RuntimeError, + "div requires offsets to match when given NestedTensors", + lambda: nt_chunks[0] / nt_chunks[1], + ) + + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + def test_nested_tensor_add_in_place(self, device, dtype): + (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) + ref = torch.nested.nested_tensor( + [t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) + nt1 += nt2 + self.assertEqual(ref, nt1) + + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + def test_nested_tensor_mul_in_place(self, device, dtype): + # nested tensor * nested tensor + (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) + ref = torch.nested.nested_tensor( + [t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())] + ) + nt1 *= nt2 + self.assertEqual(ref, nt1) + # nested tensor * scalar + number = 10.0 + scalar = torch.tensor(number).to(dtype).to(device) + ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) + out_number = nt1.clone() + out_number *= number + out_scalar = nt1.clone() + out_scalar *= scalar + self.assertEqual(out_number, ref) + self.assertEqual(out_scalar, ref) + self.assertRaisesRegex( + RuntimeError, + r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]", + lambda: scalar.mul_(nt1), + ) + # error case: numel == 1 but dim > 0 + vector = torch.tensor([number]).to(dtype).to(device) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a nested self and non-nested other", + lambda: nt1.mul_(vector), + ) + self.assertRaisesRegex( + RuntimeError, + "Expected both self and other to be nested, but got a non-nested self and nested other", + lambda: vector.mul_(nt1), + ) + + @onlyCPU + @skipMeta + @dtypes(torch.float) + def test_nested_tensor_sum_dim(self, device, dtype): + params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7))) + + def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True): + nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False) + nt2 = nt.clone() + ub2 = nt2.unbind() + nt.requires_grad_(True) + [t.requires_grad_(True) for t in ub2] + nt_sum = nt.sum(dim=dim, keepdim=keepdim) + ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2] + self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum)) + + # test backward + # generate gradient tensor that has the same size as the output + size = nt_sum._nested_tensor_size() + gt2 = [] + for i in range(ntensors): + gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype)) + gt = torch.nested.nested_tensor(gt2).clone() + nt_sum.backward(gt) + for t2, g2 in zip(ub2_sum, gt2): + t2.backward(g2) + self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2])) + return + + for ntensors, max_sizes in params: + test_sum(device, dtype, ntensors, max_sizes, len(max_sizes)) + + # Test error inputs + with self.assertRaisesRegex( + RuntimeError, "NestedTensor can only be reduced across the last" + ): + torch.nested.nested_tensor( + [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] + ).sum(0, keepdim=True) + + with self.assertRaisesRegex( + RuntimeError, "NestedTensor only allows reduction of a single" + ): + torch.nested.nested_tensor( + [torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])] + ).sum([0, 1], keepdim=True) + + with self.assertRaisesRegex( + RuntimeError, "NestedTensor always requires keepdim=True for now." + ): + torch.nested.nested_tensor( + [torch.tensor([3, 4, 5]), torch.tensor([1, 2])] + ).sum(-1) + + @dtypes(torch.float, torch.float16) + def test_contiguous(self, device, dtype): + # Since we don't have access to the buffer in python this is harder to show what + # we are testing for. When we call chunk on a consistent dim of a NT + # for chunk_size > 1 the resulting tensors are views of the original NT + # whose numels is now less than the size of the buffer. Clone was + # previously creating a new NT with a buffer that was the same size as the + # original. + nt_contiguous = torch.nested.nested_tensor( + [ + torch.randn(2, 20, device=device, dtype=dtype), + torch.randn(4, 20, device=device, dtype=dtype), + ] + ) + # Split up the last dimension which has a consistent size of 20 into 5 chunks + chunks = nt_contiguous.chunk(5, dim=-1) + + # # Check chunks are contiguous after calling contiguous + for chunk in chunks: + self.assertFalse(chunk.is_contiguous()) + self.assertTrue(chunk.contiguous().is_contiguous()) + + @dtypes(torch.float, torch.float16) + @skipMeta + def test_clone(self, device, dtype): + nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1)) + nt2 = nt1.clone() + # Verify the values match + self.assertEqual(nt1, nt2) + # Verify modifying nt2 doesn't affect nt1 + nt2.mul_(nt1) + ub1 = nt1.unbind() + ub2 = nt2.unbind() + for i in range(len(ub1)): + self.assertNotEqual(ub1[i], ub2[i]) + + nt1.clone(memory_format=torch.preserve_format) + msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast" + with self.assertRaisesRegex(RuntimeError, msg): + nt1.clone(memory_format=torch.channels_last) + + # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half' + @decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged) + @dtypes(torch.float, torch.double) + @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) + def test_dropout(self, device, dtype, layout): + # edge case: empty nested tensor + # TODO: support empty NT in jagged layout + if layout == torch.strided: + nt0 = torch.nested.nested_tensor([], layout=layout) + y = torch.nn.functional.dropout(nt0, 0.5) + self.assertEqual(nt0, y) + # normal nested tensor + ntensors = 4 + if layout == torch.jagged: + nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout) + else: + nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout) + # edge case: invalid dropout + self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1)) + self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1)) + self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1)) + self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1)) + # edge case: no dropout + dropouter = torch.nn.Dropout(0.0) + y0 = dropouter(nt) + y1 = torch.nn.functional.dropout(nt, 0.0) + self.assertEqual(nt, y0) + self.assertEqual(nt, y1) + # edge case: all dropout + dropouter = torch.nn.Dropout(1.0) + y0 = dropouter(nt) + y1 = torch.nn.functional.dropout(nt, 1.0) + nt0 = torch.zeros_like(nt) + self.assertEqual(nt0, y0) + self.assertEqual(nt0, y1) + # normal case: normal dropout + p = 0.2 + y = torch.nn.functional.dropout(nt, p) + expect = nt.clone() + if layout == torch.jagged: + expect = torch.where(y == 0.0, y, nt) + expect /= 1.0 - p + self.assertEqual(y, expect) + else: + expect = nt.clone() + for i in range(ntensors): + actual_tensor = y[i].view(-1) + expect_tensor = expect[i].view(-1) + for j in range(actual_tensor.shape[0]): + if actual_tensor[j].item() == 0.0: + expect_tensor[j] = 0.0 + else: + expect_tensor[j] /= 1.0 - p + self.assertEqual(y, expect) + with freeze_rng_state(): + dropouter = torch.nn.Dropout(p) + y0 = dropouter(nt) + with freeze_rng_state(): + y1 = torch.nn.functional.dropout(nt, p) + self.assertEqual(y0, y1) + + @dtypes(torch.float, torch.double) + def test_dropout_noncontiguous(self, device, dtype): + ntensors = 4 + nt0 = random_nt(device, dtype, ntensors, (4, 4)) + nt1 = nt0.transpose(-1, -2) + p = 0.3 + with freeze_rng_state(): + dropouter = torch.nn.Dropout(p) + y0 = dropouter(nt0) + with freeze_rng_state(): + y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2) + self.assertEqual(y0, y1) + + # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_softmax(self, device, dtype): + # normal nested tensor + ntensors = 4 + nt = random_nt(device, dtype, ntensors, (4, 4)) + # error case: softmax across nested dimension + self.assertRaisesRegex( + RuntimeError, + "Cannot apply softmax across nested dimension 0", + lambda: torch.nn.functional.softmax(nt, 0), + ) + self.assertRaisesRegex( + RuntimeError, + "Cannot apply softmax across nested dimension 0", + lambda: torch.nn.functional.softmax(nt, -3), + ) + # error case: dimension out of range + self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) + self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4)) + # normal case: should equal to padding -inf + softmaxer = torch.nn.Softmax(1) + y0 = softmaxer(nt) + y1 = torch.nn.functional.softmax(nt, 1) + self.assertEqual(y0, y1) + pt = torch.nested.to_padded_tensor(nt, float("-inf")) + # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan + # however, physically speaking that should be 0.0 + expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0) + self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect) + # edge case: empty nested tensor + nt0 = torch.nested.nested_tensor([]) + y = torch.nn.functional.softmax(nt0, 1) + self.assertEqual(nt0, y) + # edge case: nesting scalars + nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)]) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0)) + self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1)) + + @dtypes(torch.float, torch.double) + @torch.inference_mode() + def test_softmax_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) + self.assertEqual( + torch.nn.functional.softmax(nt_contiguous, -1), + torch.nn.functional.softmax(nt_noncontiguous, -1), + ) + + def _test_bmm(self, device, dtype): + # error case: not 3D tensors + nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) + nt2 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + self.assertRaisesRegex( + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0) + ) + self.assertRaisesRegex( + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1) + ) + self.assertRaisesRegex( + RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2) + ) + self.assertRaisesRegex( + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0) + ) + self.assertRaisesRegex( + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1) + ) + self.assertRaisesRegex( + RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2) + ) + self.assertRaisesRegex( + RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0) + ) + self.assertRaisesRegex( + RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1) + ) + # error case: incompatible batch size + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], + device=device, + dtype=dtype, + ) + self.assertRaisesRegex( + RuntimeError, + "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", + lambda: nt0.bmm(nt1), + ) + self.assertRaisesRegex( + RuntimeError, + "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", + lambda: nt1.bmm(nt0), + ) + # error case: underlying matrices cannot be multiplied + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + self.assertRaisesRegex( + RuntimeError, + r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", + lambda: nt0.bmm(nt0), + ) + # normal nested tensor + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) + actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( + torch.nested.to_padded_tensor(nt1, 0.0) + ) + if dtype == torch.float16: + self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) + else: + self.assertEqual(actual, expect) + + # nested tensor bmm normal tensor + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype + ) + nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device) + actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) + if dtype == torch.float16: + self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) + else: + self.assertEqual(actual, expect) + + # nested tensor bmm normal tensor with non-contiguous view + nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device) + nt1 = nt1.transpose(1, 2) + actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1) + if dtype == torch.float16: + self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) + else: + self.assertEqual(actual, expect) + + # normal tensor bmm nested tensor + nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device) + nt1 = torch.nested.nested_tensor( + [torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) + actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) + expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0)) + if dtype == torch.float16: + self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) + else: + self.assertEqual(actual, expect) + + # test tensorcore path + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype + ) + actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) + expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm( + torch.nested.to_padded_tensor(nt1, 0.0) + ) + if dtype == torch.float16: + self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) + else: + self.assertEqual(actual, expect) + + @onlyOn(["cuda", "xpu"]) + @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16) + @tf32_on_and_off(0.005) + def test_bmm_cuda(self, device, dtype): + self._test_bmm(device, dtype) + + @onlyCPU + # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_bmm_cpu(self, device, dtype): + self._test_bmm(device, dtype) + + # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_bmm_noncontiguous(self, device, dtype): + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( + (2, 3), device, dtype + ) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( + (6, 7), device, dtype + ) + self.assertEqual( + nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), + nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous), + ) + + @dtypes(torch.float, torch.double) + @tf32_on_and_off(0.005) + def test_matmul_with_bmm_path(self, device, dtype): + def unbind_rebind_matmul(nt1, nt2): + t1s = nt1.unbind() + t2s = nt2.unbind() + out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)] + return torch.nested.nested_tensor(out_ts) + + # [N, n_head, *, head_dim], [N, n_head, head_dim, *] + Ns = [1, 2, 5] + n_heads = np.random.randint(2, 5) + head_dim = 3 + t1s = [] + t2s = [] + for N in Ns: + for _ in range(N): + seq_len1 = np.random.randint(2, 5) + seq_len2 = np.random.randint(2, 5) + t1s.append(torch.randn(n_heads, seq_len1, head_dim)) + t2s.append(torch.randn(n_heads, head_dim, seq_len2)) + nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype) + nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype) + self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2)) + + # test with noncontiguous + t3s = [] + t4s = [] + for _ in range(N): + seq_len = np.random.randint(2, 5) + t3s.append(torch.randn(seq_len, n_heads, head_dim)) + t4s.append(torch.randn(seq_len, n_heads, head_dim)) + nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose( + 1, 2 + ) + nt4 = ( + torch.nested.nested_tensor(t4s, device=device, dtype=dtype) + .transpose(1, 2) + .transpose(2, 3) + ) + self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4)) + + # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_matmul(self, device, dtype): + # error case: one is nested but the other is not + nt = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) + t = torch.randn(4, device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "Expected both to be nested, but got a nested self and non-nested other", + lambda: torch.matmul(nt, t), + ) + self.assertRaisesRegex( + RuntimeError, + "Expected both to be nested, but got a non-nested self and nested other", + lambda: torch.matmul(t, nt), + ) + # error case: not 3+D tensors + nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) + nt1 = torch.nested.nested_tensor( + [torch.randn(2), torch.randn(3)], device=device, dtype=dtype + ) + nt2 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt0, nt0), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt0, nt1), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt0, nt2), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt1, nt0), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt1, nt1), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", + lambda: torch.matmul(nt1, nt2), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", + lambda: torch.matmul(nt2, nt0), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", + lambda: torch.matmul(nt2, nt1), + ) + # error case: incompatible batch size + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))], + device=device, + dtype=dtype, + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", + lambda: torch.matmul(nt0, nt1), + ) + self.assertRaisesRegex( + RuntimeError, + r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", + lambda: torch.matmul(nt1, nt0), + ) + # error case: incompatible (wrong) batch sizes that shouldn't even broadcast? + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype + ) + self.assertRaisesRegex( + RuntimeError, + "matmul(): For nested tensors, batch dimensions must have the same sizes,", + lambda: torch.matmul(nt0, nt1), + ) + # error case: incompatible batch sizes that should technically broadcast + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype + ) + self.assertRaisesRegex( + RuntimeError, + "matmul(): For nested tensors, batch dimensions must have the same sizes,", + lambda: torch.matmul(nt0, nt1), + ) + # error case: underlying matrices cannot be multiplied + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype + ) + self.assertRaisesRegex( + RuntimeError, + "matmul(): Nested tensors cannot be matrix multiplied", + lambda: torch.matmul(nt0, nt0), + ) + # normal nested tensor: 3D + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype + ) + actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) + self.assertEqual(actual, expect) + # normal nested tensor: 4D (with testing for batch_size=1) + nt0 = torch.nested.nested_tensor( + [torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype + ) + actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) + self.assertEqual(actual, expect) + # normal nested tensor: 5D + nt0 = torch.nested.nested_tensor( + [torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))], + device=device, + dtype=dtype, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))], + device=device, + dtype=dtype, + ) + actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) + expect = torch.matmul( + torch.nested.to_padded_tensor(nt0, 0.0), + torch.nested.to_padded_tensor(nt1, 0.0), + ) + self.assertEqual(actual, expect) + + # only supported on CUDA for now + @dtypes(torch.float, torch.double) + def test_matmul_nt_with_broadcasted_t(self, device, dtype): + # NT (B, *, C, D) with T (D, E) broadcasting case + nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype) + t = torch.randn(5, 6, device=device, dtype=dtype) + output = torch.matmul(nt, t) + + # should be equivalent to matmul-ing each component with the dense tensor + self.assertEqual(nt.size(0), output.size(0)) + for component, out_component in zip(nt, output): + self.assertEqual(out_component, torch.matmul(component, t)) + + # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_matmul_noncontiguous(self, device, dtype): + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair( + (2, 3), device, dtype + ) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair( + (6, 7), device, dtype + ) + self.assertEqual( + torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous), + torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous), + ) + + @dtypes(torch.float, torch.double) + def test_linear(self, device, dtype): + a = torch.randn(1, 2, device=device, dtype=dtype) + b = torch.randn(2, 2, device=device, dtype=dtype) + c = torch.randn(3, 2, device=device, dtype=dtype) + nt = torch.nested.nested_tensor([a, b, c]) + + weight = torch.randn(2, 2, device=device, dtype=dtype) + bias = torch.randn(2, device=device, dtype=dtype) + # success case + torch.functional.F.linear(nt, weight, bias) + + # invalid nested tensor dimension + msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2" + nt1 = torch.nested.nested_tensor( + [ + torch.randn(1, device=device, dtype=dtype), + torch.randn(2, device=device, dtype=dtype), + ] + ) + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt1, weight, bias) + + # invalid weight shape + msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3" + weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype) + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt, weight1, bias) + + # inconsistent last dim of nested tensor + msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:" + nt2 = torch.nested.nested_tensor( + [ + torch.randn(1, 2, device=device, dtype=dtype), + torch.randn(2, 3, device=device, dtype=dtype), + ] + ) + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt2, weight, bias) + + # Mismatch of nested tensor last dim and weight dimension + weight2 = torch.randn(2, 4, device=device, dtype=dtype) + msg = ( + r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" + r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4" + ) + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt, weight2, bias) + + # Nested tensor input and nested weight + nt_weight = nt.clone() + msg = r"Linear does not support nested weight when input is a nested tensor." + with self.assertRaisesRegex(RuntimeError, msg): + torch.functional.F.linear(nt, nt_weight, bias) + + # TODO: test noncontiguous linear + # For now this tests the error message of linear + # since linear does not support noncontiguous buffer yet + @dtypes(torch.float, torch.double) + def test_linear_noncontiguous(self, device, dtype): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( + (2, 3, 6, 7), device, dtype + ) + weight = torch.randn((8, 5), device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + r"for now linear only supports contiguous nested tensor", + lambda: torch.nn.functional.linear(nt_noncontiguous, weight), + ) + + @dtypes(torch.float, torch.float16, torch.double) + def test_to_padded_tensor_zero_numel_errors(self, device, dtype): + ts = [torch.ones(1, 0), torch.ones(0, 0)] + nt = torch.nested.nested_tensor( + ts, device=device, dtype=dtype, layout=torch.strided + ) + self.assertRaisesRegex( + RuntimeError, + r"at least one constituent tensor should have non-zero numel", + lambda: torch.nested.to_padded_tensor(nt, 0.0), + ) + + @dtypes(torch.float, torch.float16, torch.double) + def test_transpose(self, device, dtype): + nt = random_nt(device, dtype, 4, (4, 4)) + # error case: transpose nested dimension + self.assertRaisesRegex( + RuntimeError, + "Nested tensor dimension 0 cannot be transposed", + lambda: nt.transpose(0, 1), + ) + self.assertRaisesRegex( + RuntimeError, + "Nested tensor dimension 0 cannot be transposed", + lambda: nt.transpose(1, -3), + ) + # error case: dimension out of range + self.assertRaises(IndexError, lambda: nt.transpose(1, 3)) + self.assertRaises(IndexError, lambda: nt.transpose(-4, -1)) + # normal case + ntT = nt.transpose(-1, -2) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = torch.nested.to_padded_tensor(nt, 0.0) + ptT = pt.transpose(-1, -2) + self.assertEqual(ptT, ptT_from_ntT) + + @dtypes(torch.float, torch.float16, torch.double) + def test_squeeze_unsqueeze(self, device, dtype): + a = torch.arange(6).reshape(2, 3) + b = torch.arange(15).reshape(5, 3) + nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype) + # error case: squeeze no dimension + self.assertRaisesRegex( + RuntimeError, + "For nested tensors, squeeze without the dim argument", + lambda: nt.squeeze(), + ) + # error case: squeeze nested dimension + self.assertRaisesRegex( + RuntimeError, + "For nested tensors, squeezing dimension 0", + lambda: nt.squeeze(0), + ) + # error case: dimension out of range + self.assertRaises(IndexError, lambda: nt.squeeze(3)) + # error case: squeeze nested tensor of singleton tensors + c = torch.ones(1) + nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype) + self.assertRaisesRegex( + RuntimeError, + "For nested tensors, squeezing a nested tensor of singleton", + lambda: nt_singleton.squeeze(1), + ) + + # squeezing a dim which does not have size 1 should be a no-op + nt2 = nt.squeeze(-1) + self.assertEqual(nt, nt2) + + # test cases that should work + nt_sizes = nt._nested_tensor_size() + nt_strides = nt._nested_tensor_strides() + for i in range(-2, 4): + if i == 0: + # cannot unsqueeze batch dim + continue + nt_unsqueezed = nt.unsqueeze(i) + # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1 + wrapped_i = i + nt.dim() + 1 if i < 0 else i + # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim + size_idx = wrapped_i - 1 + self.assertEqual( + nt_unsqueezed._nested_tensor_size()[:, size_idx], + torch.ones(2, dtype=torch.long), + ) + unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx] + if i == nt.ndim or i == -1: + self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long)) + else: + stride_col_after = nt_strides[:, size_idx] + size_col_after = nt_sizes[:, size_idx] + self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after) + nt_squeezed = nt_unsqueezed.squeeze(i) + self.assertEqual(nt_squeezed, nt) + self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes) + self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides) + + @dtypes(torch.float, torch.float16, torch.double) + def test_transpose_inference_mode_interaction(self, device, dtype): + nt = random_nt(device, dtype, 4, (4, 4)) + # Construct in default mode and transpose while in inference mode + with torch.inference_mode(): + ntT = nt.transpose(-1, -2) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = torch.nested.to_padded_tensor(nt, 0.0) + ptT = pt.transpose(-1, -2) + self.assertEqual(ptT, ptT_from_ntT) + + # Construct and transpose while in inference mode + with torch.inference_mode(): + nt = random_nt(device, dtype, 4, (4, 4)) + ntT = nt.transpose(-1, -2) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = torch.nested.to_padded_tensor(nt, 0.0) + ptT = pt.transpose(-1, -2) + self.assertEqual(ptT, ptT_from_ntT) + + @dtypes(torch.float, torch.float16, torch.double) + def test_view(self, device, dtype): + nt = random_nt(device, dtype, 4, (4, 4)) + # error case: empty shape + self.assertRaisesRegex( + RuntimeError, + r"shape '\[\]' is invalid for a nested tensor", + lambda: nt.view(()), + ) + # error case: empty nested tensor + nt_empty = torch.nested.nested_tensor([]) + self.assertRaisesRegex( + RuntimeError, + "empty nested tensor cannot be reshaped", + lambda: nt_empty.view(-1), + ) + # error case: -1 for batch size + self.assertRaisesRegex( + RuntimeError, + r"view: For now nested view cannot change or infer the implicit batch dimension", + lambda: nt.view(-1, 2, 3), + ) + self.assertRaisesRegex( + RuntimeError, + r"shape '\[.*\]' is invalid for input of size [0-9]+", + lambda: nt.view(4, 2, 3), + ) + # normal case + x0 = torch.randn((2, 20), device=device, dtype=dtype) + x1 = torch.randn((3, 20), device=device, dtype=dtype) + nt = torch.nested.nested_tensor([x0, x1]) + pt = torch.nested.to_padded_tensor(nt, 0.0) + # error case, trying to reshape batch dim to a legit shape + self.assertRaisesRegex( + RuntimeError, + r"For now nested view cannot change or infer the implicit batch dimension", + lambda: nt.transpose(-1, -2).view(40, -1), + ) + # inherit only the ragged dimension + # (2, 20) -> (2, 5, 4) + # (3, 20) -> (3, 5, 4) + nt1 = nt.view(2, -1, 5, 4) + # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) + pt1 = pt.view(2, -1, 5, 4) + self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) + + # more than one -1 (even for "old" dims), should fail + # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) + # but we ban "inherit old behavior" for >1 dimension + self.assertRaisesRegex( + RuntimeError, + r"only one dimension can be inferred", + lambda: nt1.view(2, -1, -1, 2, 2), + ) + + @dtypes(torch.float, torch.float16, torch.double) + def test_view_inference_mode_interaction(self, device, dtype): + # Construct in default mode and view while in inference mode + nt = torch.nested.nested_tensor( + [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype + ) + with torch.inference_mode(): + ntT = nt.view(2, -1, 4, 5) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = torch.nested.to_padded_tensor(nt, 0.0) + ptT = pt.view(2, -1, 4, 5) + self.assertEqual(ptT, ptT_from_ntT) + # Construct and view while in inference mode + with torch.inference_mode(): + nt = torch.nested.nested_tensor( + [torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype + ) + ntT = nt.view(2, -1, 4, 5) + ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) + pt = torch.nested.to_padded_tensor(nt, 0.0) + ptT = pt.view(2, -1, 4, 5) + self.assertEqual(ptT, ptT_from_ntT) + + @dtypes(torch.float, torch.float16, torch.double) + def test_reshape(self, device, dtype): + nt = random_nt(device, dtype, 4, (4, 4)) + # error case: empty shape + self.assertRaisesRegex( + RuntimeError, + r"shape '\[\]' is invalid for a nested tensor", + lambda: nt.reshape(()), + ) + # error case: empty nested tensor + nt_empty = torch.nested.nested_tensor([]) + self.assertRaisesRegex( + RuntimeError, + "empty nested tensor cannot be reshaped", + lambda: nt_empty.reshape(-1), + ) + # error case: -1 for batch size + self.assertRaisesRegex( + RuntimeError, + r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", + lambda: nt.reshape(-1, 2, 3), + ) + self.assertRaisesRegex( + RuntimeError, + r"shape '\[.*\]' is invalid for input of size [0-9]+", + lambda: nt.reshape(4, 2, 3), + ) + # normal case + x0 = torch.randn((2, 20), device=device, dtype=dtype) + x1 = torch.randn((3, 20), device=device, dtype=dtype) + nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20) + pt = torch.nested.to_padded_tensor(nt, 0.0) + # error case, trying to reshape batch dim to a legit shape + self.assertRaisesRegex( + RuntimeError, + r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", + lambda: nt.transpose(-1, -2).reshape(40, -1), + ) + # inherit only the ragged dimension + # (2, 20) -> (2, 5, 4) + # (3, 20) -> (3, 5, 4) + nt1 = nt.reshape(2, -1, 5, 4) + # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) + pt1 = pt.reshape(2, -1, 5, 4) + self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) + + # more than one -1 (even for "old" dims), should fail + # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) + # but we ban "inherit old behavior" for >1 dimension + self.assertRaisesRegex( + RuntimeError, + r"only one dimension can be inferred", + lambda: nt1.reshape(2, -1, -1, 2, 2), + ) + + def test_nested_masked_select(self, device): + t = torch.randn([3, 3], device=device) + mask = torch.tensor([False], device=device) + + njt = torch.nested.masked_select(t, mask) + self.assertEqual(njt.values(), torch.tensor([], device=device)) + self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device)) + + mask = torch.tensor([[False], [False], [True]], device=device) + njt = torch.nested.masked_select(t, mask) + self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1) + self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device)) + + mask = torch.tensor( + [[False, False, True], [True, False, True], [False, False, True]], + device=device, + ) + njt = torch.nested.masked_select(t, mask) + self.assertEqual(njt.values(), t.masked_select(mask)) + self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device)) + + t = torch.randn([2, 3, 3, 1], device=device) + mask = torch.tensor( + [ + [ + [[True], [False], [True]], + [[True], [False], [True]], + [[True], [False], [True]], + ], + [ + [[False], [True], [True]], + [[False], [True], [True]], + [[True], [True], [True]], + ], + ], + device=device, + ) + njt = torch.nested.masked_select(t, mask) + self.assertEqual(njt.values(), t.masked_select(mask)) + self.assertEqual( + njt.offsets(), + torch.tensor( + [0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13], + device=device, + ), + ) + + @dtypes(torch.float, torch.float16, torch.double) + def test_narrow(self, device, dtype): + nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype) + + # narrow on dim=0 from start to end + bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)] + for start, end in bounds: + length = end - start + narrowed = nt.narrow(dim=0, start=start, length=length) + # ensure output is a view + self.assertTrue(narrowed._base is nt) + for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]): + self.assertEqual(nc, c) + + # dim != 0 is not supported + for dim in range(1, nt.dim()): + with self.assertRaisesRegex( + RuntimeError, "only dim=0 supported for nested tensors" + ): + nt.narrow(dim=dim, start=0, length=1) + + # error case: non-contiguous NT + _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4)) + with self.assertRaisesRegex( + RuntimeError, "only contiguous nested tensors supported" + ): + nt_noncont.narrow(dim=0, start=0, length=1) + + @parametrize("input_dim", [3, 4]) + @tf32_on_and_off(0.005) + def test_scaled_dot_product_attention(self, device, input_dim): + def rand_tensor(*shape): + return torch.randn(shape, device=device) + + E = 8 + if input_dim == 3: + # Shape: (N, L, E); ragged L + query = torch.nested.nested_tensor( + [rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)] + ) + + # Shape: (N, S, E); ragged S + key = torch.nested.nested_tensor( + [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] + ) + value = torch.nested.nested_tensor( + [rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)] + ) + elif input_dim == 4: + # In the 4D case the L and S is ragged + # Shape: (N, N', L, E); ragged N' and L + query = torch.nested.nested_tensor( + [rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)] + ) + # Shape: (N, N', S, E); ragged N' and S + key = torch.nested.nested_tensor( + [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] + ) + value = torch.nested.nested_tensor( + [rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)] + ) + else: + self.fail(f"Invalid input_dim {input_dim} encountered in SDP test") + + def rand_mask(size): + return torch.randint(0, 2, size=size, dtype=torch.bool, device=device) + + # Shape: (N, L, S); ragged L and S matching above + attn_mask = torch.nested.nested_tensor( + [rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))] + ) + + dropout_p = 0.0 # no dropout for reproducibility + + # Success case: no attn_mask set and is_causal=False. + actual = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p + ) + + expected_outputs = [] + for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()): + output = torch.nn.functional.scaled_dot_product_attention( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + attn_mask=None, + dropout_p=dropout_p, + ) + expected_outputs.append(output.squeeze(0)) + expected_output_nested = torch.nested.nested_tensor(expected_outputs) + self.assertEqual(actual, expected_output_nested) + + # Error case: explicit attn_mask set. + with self.assertRaisesRegex( + RuntimeError, "not supported when an explicit attn_mask is set" + ): + torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=dropout_p + ) + + # Error case: is_causal=True. + with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"): + torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=dropout_p, is_causal=True + ) + + @dtypes(torch.float, torch.float16, torch.double) + def test_empty_like(self, device, dtype): + ntensors = 4 + nt = random_nt(device, dtype, ntensors, (4, 4)) + + # Create empty on same device as original nested tensor + nt_empty = torch.empty_like(nt) + assert nt.is_same_size(nt_empty) + self.assertEqual(nt.dtype, nt_empty.dtype) + self.assertEqual(nt.device, nt_empty.device) + self.assertEqual(nt.layout, nt_empty.layout) + + if torch.cuda.is_available(): + if device == "cpu": + nt_cuda = torch.empty_like(nt, device="cuda") + self.assertEqual(torch.device("cuda").type, nt_cuda.device.type) + else: + nt_cpu = torch.empty_like(nt, device="cpu") + self.assertEqual(torch.device("cpu").type, nt_cpu.device.type) + + # Check changing dtype of empty_like nested tensor output + dtype_set = {torch.float, torch.float16, torch.double} + for other_dtype in dtype_set - {dtype}: + nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype) + self.assertEqual(nt.dtype, dtype) + self.assertEqual(nt_empty_other_dtype.dtype, other_dtype) + self.assertEqual(nt.device, nt_empty.device) + self.assertEqual(nt.layout, nt_empty.layout) + + # Create tensor for autograd + nt_empty_req_grad = torch.empty_like(nt, requires_grad=True) + self.assertEqual(nt_empty_req_grad.requires_grad, True) + + # Test noncontiguous tensor does not fail to copy + nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7)) + nt_empty = torch.empty_like(nt_cont) + assert nt_cont.is_same_size(nt_empty) + nt_empty_non_contig = torch.empty_like(nt_noncont) + assert nt_noncont.is_same_size(nt_empty_non_contig) + + # Test the contiguous memory format option + nt_empty_contig = torch.empty_like( + nt_cont, memory_format=torch.contiguous_format + ) + assert nt_cont.is_same_size(nt_empty_contig) + assert nt_empty_contig.is_contiguous() + + nt_empty_non_contig = torch.empty_like( + nt_noncont, memory_format=torch.contiguous_format + ) + assert nt_noncont.is_same_size(nt_empty_non_contig) + assert nt_empty_non_contig.is_contiguous() + + # Test other memory formats fail + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d), + ) + self.assertRaises( + RuntimeError, + lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d), + ) + + +@markDynamoStrictTest +class TestNestedTensorAutograd(NestedTensorTestCase): + # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck + # includes the default parameters used for testing ops with gradcheck. However nested tensor + # does not support the stack op therefore we turn it off for these tests + def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False): + return torch.nested.nested_tensor( + [torch.randn(1, 2), torch.randn(7, 8)], + requires_grad=requires_grad, + device=tensor_device, + ) + + def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False): + return torch.nested.as_nested_tensor( + [ + torch.randn(1, 2, requires_grad=requires_grad), + torch.randn(7, 8, requires_grad=requires_grad), + ], + device=tensor_device, + ) + + def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False): + data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device) + mask = torch.ones_like(data[:, :, 0]).bool() + return torch._nested_tensor_from_mask(data, mask) + + def test_as_nested_tensor_propagates_gradients(self, device): + a = torch.arange(3, dtype=torch.float, device=device) + b = torch.arange(5, dtype=torch.float, device=device) + nt = torch.nested.as_nested_tensor([a, b]) + # tensors with requires_grad=False are leaves + self.assertTrue(nt.is_leaf) + self.assertTrue(not nt.requires_grad) + + a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) + b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) + nt2 = torch.nested.as_nested_tensor([a, b]) + fake_grad = torch.nested.nested_tensor( + [torch.ones_like(a), torch.zeros_like(b)], device=device + ) + nt2.backward(fake_grad) + self.assertEqual(a.grad, fake_grad[0]) + self.assertEqual(b.grad, fake_grad[1]) + + def test_nested_tensor_generates_leaf(self, device): + a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) + b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) + + nt = torch.nested.nested_tensor([a, b], requires_grad=False) + self.assertTrue(nt.is_leaf) + self.assertTrue(not nt.requires_grad) + + nt2 = torch.nested.nested_tensor([a, b], requires_grad=True) + self.assertTrue(nt2.is_leaf) + self.assertTrue(nt2.requires_grad) + + fake_grad = torch.nested.nested_tensor( + [torch.ones_like(a), torch.zeros_like(b)], device=device + ) + nt2.backward(fake_grad) + self.assertEqual(nt2.grad, fake_grad) + self.assertEqual(a.grad, None) + self.assertEqual(b.grad, None) + + def test_set_requires_grad_from_list(self, device): + nt = self._create_nested_tensor_from_list(device) + nt.requires_grad_() + assert nt.requires_grad + + def test_set_requires_grad_from_mask(self, device): + nt = self._create_nested_tensor_from_mask(device) + nt.requires_grad_() + assert nt.requires_grad + + def test_backward_for_add_op(self, device): + nt_1 = self._create_nested_tensor_from_mask(device) + nt_2 = self._create_nested_tensor_from_mask(device) + + nt_1.requires_grad_() + c = nt_1 + nt_2 + + assert nt_1.requires_grad + assert c.requires_grad + grad_output = self._create_nested_tensor_from_mask(device) + c.backward(grad_output) + + # Grad check doesn't work with nested yet. + # d/dnt_1 (nt + nt_1) = 1*grad_output + self.assertEqual(nt_1.grad, grad_output) + + def test_backward_for_sub_op(self, device): + nt_1 = self._create_nested_tensor_from_mask(device) + nt_2 = self._create_nested_tensor_from_mask(device) + + nt_1.requires_grad_() + nt_2.requires_grad_() + c = nt_1 - nt_2 + + assert nt_1.requires_grad + assert nt_2.requires_grad + assert c.requires_grad + grad_output = self._create_nested_tensor_from_mask(device) + c.backward(grad_output) + + self.assertEqual(nt_1.grad, grad_output) + self.assertEqual(nt_2.grad, -1 * grad_output) + + def test_backward_sub_strided(self, device): + a = torch.nested.nested_tensor( + [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], + requires_grad=True, + device=device, + ) + b = torch.nested.nested_tensor( + [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], + requires_grad=True, + device=device, + ) + c = a - b.transpose(-1, -2) + grad_output = c.clone() + c.backward(grad_output) + self.assertEqual(a.grad, grad_output) + self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2)) + + def test_backward_add_strided(self, device): + a = torch.nested.nested_tensor( + [torch.randn(9, 2, 4), torch.randn(12, 2, 4)], + requires_grad=True, + device=device, + ) + b = torch.nested.nested_tensor( + [torch.randn(9, 4, 2), torch.randn(12, 4, 2)], + requires_grad=True, + device=device, + ) + c = a + b.transpose(-1, -2) + grad_output = c.clone() + c.backward(grad_output) + self.assertEqual(a.grad, grad_output) + self.assertEqual(b.grad, grad_output.transpose(-1, -2)) + + # Test Factory Functions + def test_nested_tensor_to_padded_tensor(self, device): + for padding_val in [0, 1]: + nt = self._create_leaf_nested_tensor_from_list( + tensor_device=device, requires_grad=True + ) + + out = torch.nested.to_padded_tensor(nt, padding_val) + grad_output = torch.ones(out.shape, device=device) + out.backward(grad_output) + + self.assertEqual( + nt.grad, + torch.nested.nested_tensor( + [torch.ones(1, 2), torch.ones(7, 8)], device=device + ), + ) + + def test_nested_tensor_from_mask_and_to_padded(self, device): + N, L, D = 2, 4, 4 + mask = torch.ones(N, L, device=device) + for i in range(1, N): + end = torch.randint(1, L - 1, (1,), device=device) + mask[i, end:] = 0 + + mask[0, :] = 1 + mask = mask.bool() + + data = torch.randn( + N, L, D, requires_grad=True, dtype=torch.float64, device=device + ) + + def grad_test_func(inpt): + nt = torch._nested_tensor_from_mask(inpt, mask) + # This implicitly tests to_padded_tensor grads + return torch.nested.to_padded_tensor(nt, 0) + + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_nested_tensor_from_padded(self, device): + nested_size = torch.tensor([[1, 2], [2, 2]]) + padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device) + padded_tensor[0, 1, :] = 0 + padded_tensor.requires_grad_() + + def grad_test_func(tensor, nested_size): + nt = torch._nested_from_padded( + tensor, nested_size, fuse_transform_0213=False + ) + # This implicitly tests to_padded_tensor grads + return torch.nested.to_padded_tensor(nt, 0) + + data = (padded_tensor, nested_size) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_nested_tensor_from_padded_fused(self, device): + nested_size = torch.tensor([[1, 8], [2, 8]]) + padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device) + padded_tensor[0, 1, :] = 0 + padded_tensor.requires_grad_() + + def grad_test_func(tensor, nested_size): + nt = torch._nested_from_padded( + tensor, nested_size, fuse_transform_0213=True + ) + # This implicitly tests to_padded_tensor grads + return torch.nested.to_padded_tensor(nt, 0) + + data = (padded_tensor, nested_size) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_nested_tensor_from_list(self, device): + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + c = torch.nested.as_nested_tensor([a, b, c]) + # This implictily tests to_padded_tensor grads + return torch.nested.to_padded_tensor(c, 0) + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) + def test_dropout_backward(self, layout): + if layout == torch.jagged: + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 5))], + requires_grad=True, + layout=layout, + ) + else: + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 4))], + requires_grad=True, + layout=layout, + ) + p = 0.2 + y = torch.nn.functional.dropout(nt, p) + y.backward(nt.detach().clone()) + self.assertEqual(nt.grad, y) + + def test_nested_tensor_bmm_gradcheck(self, device): + a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) + d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c, d): + nt0 = torch.nested.as_nested_tensor([a, b]) + nt1 = torch.nested.as_nested_tensor([c, d]) + result = nt0.bmm(nt1) + return torch.nested.to_padded_tensor(result, 0.0) + + data = (a, b, c, d) + assert torch.autograd.gradcheck(grad_test_func, inputs=data) + + @tf32_on_and_off(0.008) + def test_nested_tensor_bmm_backward(self, device): + nt0 = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], + requires_grad=True, + device=device, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((6, 4)), torch.randn((6, 5))], + requires_grad=True, + device=device, + ) + with torch.no_grad(): + pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) + pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) + + ynt = nt0.bmm(nt1) + ypt = pt0.bmm(pt1) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) + self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) + + def test_nested_tensor_matmul_gradcheck(self, device): + a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) + d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c, d): + nt0 = torch.nested.as_nested_tensor([a, b]) + nt1 = torch.nested.as_nested_tensor([c, d]) + result = torch.matmul(nt0, nt1) + return torch.nested.to_padded_tensor(result, 0.0) + + data = (a, b, c, d) + assert torch.autograd.gradcheck(grad_test_func, inputs=data) + + def test_nested_tensor_matmul_backward(self, device): + nt0 = torch.nested.nested_tensor( + [torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], + requires_grad=True, + device=device, + ) + nt1 = torch.nested.nested_tensor( + [torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], + requires_grad=True, + device=device, + ) + with torch.no_grad(): + pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) + pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) + + ynt = torch.matmul(nt0, nt1) + ypt = torch.matmul(pt0, pt1) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) + self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) + + def test_nested_tensor_transpose_gradcheck(self, device): + a = torch.randn(2, 5, requires_grad=True, device=device) + b = torch.randn(3, 4, requires_grad=True, device=device) + + def grad_test_func(a, b): + nt = torch.nested.as_nested_tensor([a, b]) + result = nt.transpose(-2, -1).transpose(-2, -1) + return torch.nested.to_padded_tensor(result, 0.0) + + data = (a, b) + assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) + + def test_nested_tensor_transpose_backward(self, device): + nt = torch.nested.nested_tensor( + [torch.randn((2, 5)), torch.randn((3, 4))], + requires_grad=True, + device=device, + ) + with torch.no_grad(): + pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) + + ynt = nt.transpose(-2, -1) + ypt = pt.transpose(-2, -1) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) + + def test_nested_tensor_reshape_gradcheck(self, device): + a = torch.randn(2, 6, requires_grad=True, device=device) + b = torch.randn(3, 6, requires_grad=True, device=device) + + def grad_test_func(a, b): + nt = torch.nested.as_nested_tensor([a, b]) + result = nt.reshape(2, -1, 2, 3) + return torch.nested.to_padded_tensor(result, 0.0) + + data = (a, b) + assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) + + def test_nested_tensor_reshape_backward(self): + nt = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True + ) + with torch.no_grad(): + pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) + + ynt = nt.reshape(2, -1, 2, 3) + ypt = pt.reshape(2, -1, 2, 3) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) + + def test_nested_tensor_squeeze_backward(self, device): + nt = torch.nested.nested_tensor( + [torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], + requires_grad=True, + device=device, + ) + with torch.no_grad(): + pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) + + ynt = nt.squeeze(-1) + ypt = pt.squeeze(-1) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) + + def test_nested_tensor_squeeze_gradcheck(self, device): + a = torch.randn( + (2, 6, 1), dtype=torch.float64, requires_grad=True, device=device + ) + b = torch.randn( + (3, 6, 1), dtype=torch.float64, requires_grad=True, device=device + ) + + def grad_test_func(a, b): + nt = torch.nested.as_nested_tensor([a, b]) + result = nt.squeeze(-1) + return torch.nested.to_padded_tensor(result, 0.0) + + assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) + + def test_nested_tensor_unsqueeze_backward(self, device): + nt = torch.nested.nested_tensor( + [torch.randn((2, 6)), torch.randn((3, 6))], + requires_grad=True, + device=device, + ) + with torch.no_grad(): + pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) + + ynt = nt.unsqueeze(2) + ypt = pt.unsqueeze(2) + ynt.backward(ynt.clone()) + ypt.backward(ypt.clone()) + + self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) + + def test_nested_tensor_unsqueeze_gradcheck(self, device): + a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device) + b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device) + + def grad_test_func(a, b): + nt = torch.nested.as_nested_tensor([a, b]) + result = nt.unsqueeze(-1) + return torch.nested.to_padded_tensor(result, 0.0) + + assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) + + def test_nested_tensor_linear(self, device): + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) + + weight = torch.randn( + 2, 2, requires_grad=True, dtype=torch.float64, device=device + ) + bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c, weight, bias=None): + nt = torch.nested.as_nested_tensor([a, b, c]) + # This implicitly tests to_padded_tensor grads + d = torch.functional.F.linear(nt, weight, bias) + return torch.nested.to_padded_tensor(d, 0) + + data = (a, b, c, weight, bias) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + # Test linear with no bias added + data = (a, b, c, weight) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_nested_tensor_linear_plus_transpose(self, device): + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) + + weight = torch.randn( + 2, 2, requires_grad=True, dtype=torch.float64, device=device + ) + bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c, weight, bias=None): + nt = torch.nested.as_nested_tensor([a, b, c]) + # This implicitly tests to_padded_tensor grads + d = torch.functional.F.linear(nt, weight, bias) + d = d.transpose(-1, -2).contiguous() + return torch.nested.to_padded_tensor(d, 0) + + data = (a, b, c, weight, bias) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + # Test linear with no bias added + data = (a, b, c, weight) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_nested_tensor_softmax(self, device): + a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c, dim): + nt = torch.nested.as_nested_tensor([a, b, c]) + # This implicitly tests to_padded_tensor grads + d = torch.functional.F.softmax(nt, dim=dim) + return torch.nested.to_padded_tensor(d, 0) + + # softmax over last dim + data = (a, b, c, -1) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_nested_tensor_linear_backward(self, device): + a = torch.randn(1, 2, requires_grad=False, device=device) + b = torch.randn(2, 2, requires_grad=False, device=device) + c = torch.randn(3, 2, requires_grad=False, device=device) + + weight = torch.randn(2, 2, requires_grad=True, device=device) + bias = torch.randn(2, requires_grad=True, device=device) + nt = torch.nested.as_nested_tensor([a, b, c], device=device) + + out = torch.functional.F.linear(nt, weight, bias) + + out.backward(out.clone()) + + assert weight.grad is not None + assert bias.grad is not None + + assert a.grad is None + assert b.grad is None + assert c.grad is None + + def test_values_grad_with_broadcast(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + buffer = nt.values() + return buffer.sum() + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_to_buffer_series_ops_grad_with_broadcast(self, device): + a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + buffer = nt.values() + buffer = buffer * 2 + return buffer.exp() + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_unbind_flow_through(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + ntT = nt.transpose(-1, -2) + unbound = ntT.unbind() + d = unbound[0] + d = torch.pow(d, 2) + return d + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_split_with_sizes_flow_through(self, device): + a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + splits = nt.split_with_sizes([2, 3], dim=-1) + unbound = splits[1].unbind() + d = unbound[0] + d = torch.pow(d, 2) + return d + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_indexing_backward(self, device): + x0 = torch.randn((2, 5)) + x1 = torch.randn((3, 4)) + nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True) + self.assertEqual(nt[0], x0) + self.assertEqual(nt[-1], x1) + grad_x0 = torch.randn((2, 5), device=device) + nt[0].backward(grad_x0) + expected_grad = torch.nested.nested_tensor( + [grad_x0, torch.zeros((3, 4), device=device)] + ) + self.assertEqual(nt.grad, expected_grad) + + def test_masked_fill_backward(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + mask = nt.detach().clone().to(bool) + out = nt.masked_fill(mask, 0) + out = torch.nested.to_padded_tensor(out, 0) + return out + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_gelu_backward(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + nt_gelu = torch.nn.functional.gelu(nt) + return torch.nested.to_padded_tensor(nt_gelu, 0) + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_relu_backward(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + nt_relu = torch.nn.functional.relu(nt) + return torch.nested.to_padded_tensor(nt_relu, 0) + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_selu_backward(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + nt_relu = torch.nn.functional.silu(nt) + return torch.nested.to_padded_tensor(nt_relu, 0) + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + def test_abs_backward(self, device): + a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + nt_abs = torch.abs(nt) + return torch.nested.to_padded_tensor(nt_abs, 0) + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + # Previously would error when input NT doesn't require grad + # NotImplementedError: Cannot access storage of UndefinedTensorImpl + def test_layer_norm_backward_edge_case(self, device): + size = 4 + a = torch.randn( + 1, 2, size, requires_grad=False, dtype=torch.float64, device=device + ) + nt = torch.nested.nested_tensor([a]) + nt_layer_norm = torch.nn.LayerNorm( + nt.size(-1), device=device, dtype=torch.float64 + ) + out = nt_layer_norm(nt) + out.backward(out.clone()) + + def test_accumulate_grad_different_strides(self, device): + a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device) + b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b): + nt_1 = torch.nested.as_nested_tensor([a, b]) + nt_2 = nt_1.clone() + out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2) + return torch.nested.to_padded_tensor(out, 0) + + data = (a, b) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + # https://github.com/pytorch/pytorch/issues/95562 + @skipIfSlowGradcheckEnv + @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2]) + def test_layer_norm_backward(self, device, size): + a = torch.randn( + 1, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) + b = torch.randn( + 2, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) + c = torch.randn( + 3, 2, size, requires_grad=True, dtype=torch.float64, device=device + ) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + layer_norm = torch.nn.LayerNorm( + nt.size(-1), device=device, dtype=torch.float64 + ) + nt_layer_norm = layer_norm(nt) + return torch.nested.to_padded_tensor(nt_layer_norm, 0) + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + # https://github.com/pytorch/pytorch/issues/95562 + @skipIfSlowGradcheckEnv + # Could either mark slow or reduce size + @parametrize("size", [128, 32, 4, 2]) + def test_layer_norm_backward_5d(self, device, size): + a = torch.randn( + 4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) + b = torch.randn( + 7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) + c = torch.randn( + 10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device + ) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c]) + layer_norm = torch.nn.LayerNorm( + (size, size, nt.size(-1)), device=device, dtype=torch.float64 + ) + nt_layer_norm = layer_norm(nt) + return torch.nested.to_padded_tensor(nt_layer_norm, 0) + + data = (a, b, c) + assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) + + +# Found in torch/testing/_comparison.py +default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} +default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} + + +def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: + deviation = true_value - computed_value + deviation = torch.abs(deviation / true_value) + # Fill in the nans with the default rtol + torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) + return deviation.max().item() + + +def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: + deviation = true_value - computed_value + atol = torch.abs(deviation).max().item() + return atol + + +def get_tolerances( + true_value: torch.Tensor, + computed_value: torch.Tensor, + fudge_factor: Optional[float] = None, +) -> tuple[float, float]: + """Returns the absolute and relative tolerances for comparing two tensors.""" + fudge_factor = fudge_factor if fudge_factor is not None else 1.0 + atol = get_atol(true_value, computed_value) + rtol = get_rtol(true_value, computed_value) + + atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) + rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) + # torch.isclose() has weird behavior around see: + # https://github.com/pytorch/pytorch/issues/102400 + if rtol > 1e30: + rtol = default_rtol[computed_value.dtype] + return atol, rtol + + +# We can probably parametrizing existing tests instead of having a separate +# test class as we begin to support more ops. Also maybe rewrite with OpInfos. +@markDynamoStrictTest +class TestNestedTensorSubclass(NestedTensorTestCase): + # TODO: consolidate with the below + def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): + Ds = nested_size[1:] + out = [] + for s in nested_size[0]: + out.append( + torch.randn( + s, + *Ds, + requires_grad=requires_grad, + device=device, + dtype=torch.float64, + ) + ) + return out + + def _get_example_tensor_lists( + self, + include_list_of_lists=True, + include_requires_grad=True, + include_inner_dim_size_1=False, + include_2d_tensor=False, + ): + def _make_tensor( + *shape, include_requires_grad=include_requires_grad, requires_grad=True + ): + return torch.randn( + *shape, + requires_grad=(requires_grad if include_requires_grad else False), + ) + + # Purposefully introduce mixed requires_grad settings for the components + # when include_requires_grad=True. + example_lists = [ + # (B, *, D) with B=4 + [ + _make_tensor(2, 5), + _make_tensor(3, 5, requires_grad=False), + _make_tensor(4, 5, requires_grad=False), + _make_tensor(6, 5), + ], + # (B, *, D_0, D_1) with B=5 + [ + _make_tensor(2, 5, 6), + _make_tensor(3, 5, 6), + _make_tensor(4, 5, 6, requires_grad=False), + _make_tensor(5, 5, 6), + _make_tensor(6, 5, 6), + ], + # (B, *, D_0, D_1, D_2) with B=6 + [ + _make_tensor(2, 5, 6, 7), + _make_tensor(3, 5, 6, 7), + _make_tensor(4, 5, 6, 7, requires_grad=False), + _make_tensor(5, 5, 6, 7), + _make_tensor(6, 5, 6, 7), + _make_tensor(7, 5, 6, 7), + ], + ] + + if include_list_of_lists: + example_lists.append( + # (B, *, D) with B=3 in list form + [ + _make_tensor(2, 5, requires_grad=False).tolist(), + _make_tensor(3, 5).tolist(), + _make_tensor(4, 5).tolist(), + ] + ) + + if include_inner_dim_size_1: + example_lists.append( + [ + _make_tensor(2, 1), + _make_tensor(3, 1, requires_grad=False), + _make_tensor(4, 1, requires_grad=False), + _make_tensor(6, 1), + ] # (B, *, 1) + ) + example_lists.append( + [ + _make_tensor(2, 5, 1), + _make_tensor(3, 5, 1, requires_grad=False), + _make_tensor(4, 5, 1, requires_grad=False), + _make_tensor(6, 5, 1), + ] # (B, *, 5, 1) + ) + + if include_2d_tensor: + example_lists.append( + [ + _make_tensor(2), + _make_tensor(3, requires_grad=False), + _make_tensor(4, requires_grad=False), + _make_tensor(6), + ] # (B, *) + ) + + return example_lists + + @dtypes(torch.float32) + @parametrize( + "contiguity", + ["contig", "noncontig_transposed", "noncontig_with_holes"], + name_fn=lambda c: c, + ) + @parametrize("weights_only", [True, False]) + def test_serialization(self, device, dtype, contiguity, weights_only): + # Test with 3 cases: + # 1. contiguous + # 2. non-contiguous transposed + # 3. non-contiguous with holes + if contiguity == "contig": + nt = random_nt_from_dims( + [4, None, 10], + device=device, + dtype=dtype, + layout=torch.jagged, + ) + elif contiguity == "noncontig_transposed": + nt = random_nt_from_dims( + [3, None, 5, 2], + device=device, + dtype=dtype, + layout=torch.jagged, + ).transpose(-3, -2) + elif contiguity == "noncontig_with_holes": + nt = torch.nested.nested_tensor_from_jagged( + values=torch.randn(10, 3, device=device, dtype=dtype), + offsets=torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int64), + # these lengths specify holes + lengths=torch.tensor([1, 2, 3], device=device, dtype=torch.int64), + ) + else: + raise ValueError("invalid contiguity specified for test_serialization()") + + # Access sizes / strides to ensure cache doesn't break serialization. + # See https://github.com/pytorch/pytorch/issues/129366 + nt.size() + nt.stride() + + with tempfile.TemporaryFile() as f: + torch.save(nt, f) + f.seek(0) + nt_loaded = torch.load(f, weights_only=weights_only) + + self.assertIsNot(nt, nt_loaded) + # we expect a new offsets tensor -> different nested int upon load + self.assertEqualIgnoringNestedInts(nt, nt_loaded) + self.assertEqual(nt._ragged_idx, nt_loaded._ragged_idx) + # ensure shapes are equal except nested int + nt_rest_of_shape = ( + *nt.shape[: nt._ragged_idx], + *nt.shape[nt._ragged_idx + 1 :], + ) + nt_loaded_rest_of_shape = ( + *nt_loaded.shape[: nt_loaded._ragged_idx], + *nt_loaded.shape[nt_loaded._ragged_idx + 1 :], + ) + self.assertEqual(nt_rest_of_shape, nt_loaded_rest_of_shape) + # ensure metadata cache is carried through serialization + self.assertEqual(nt._metadata_cache, nt_loaded._metadata_cache) + # ensure lengths are carried through if present + self.assertEqual(nt._lengths, nt_loaded._lengths) + + def test_tensor_attributes(self, device): + a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + _offsets = nt.offsets() + + for op in ( + torch.ops.aten.is_non_overlapping_and_dense.default, + torch.ops.aten.sym_size.default, + torch.ops.aten.dim.default, + torch.ops.aten.numel.default, + torch.ops.aten.sym_numel.default, + torch.ops.aten.sym_stride.default, + torch.ops.aten.sym_storage_offset.default, + ): + op(nt) + + with self.assertRaisesRegex( + RuntimeError, "directly calling torch.ops.aten.size" + ): + torch.ops.aten.size.default(nt) + + nested_int = torch.nested._internal.nested_tensor.get_tensor_symint( + _offsets, coeff=1 + ) + self.assertEqual(nt.size(), (3, nested_int, 3)) + self.assertEqual(nt.shape, (3, nested_int, 3)) + self.assertEqual(nt.dim(), 3) + self.assertEqual(nt.numel(), 27) + + @parametrize("nt_dim", [3, 4, 5]) + def test_linear(self, device, nt_dim): + if nt_dim == 3: + fixed_shape = (3,) + elif nt_dim == 4: + fixed_shape = (4, 3) + elif nt_dim == 5: + fixed_shape = (5, 4, 3) + + a = torch.randn( + 2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device + ) + b = torch.randn( + 3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device + ) + c = torch.randn( + 4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device + ) + weight = torch.randn( + 4, 3, requires_grad=True, dtype=torch.float64, device=device + ) + bias = torch.randn(4, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c, weight, bias): + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + out = torch.nn.functional.linear(nt, weight, bias) + return out.values() + + gradcheck( + grad_test_func, inputs=(a, b, c, weight, bias), check_batched_grad=False + ) + + @onlyOn(["cuda", "xpu"]) + @dtypes(torch.float32) + @serialTest() + def test_linear_backward_memory_usage(self, device, dtype): + # Verify that linear_backward() doesn't use more memory than it should + # for higher dim input sizes. + # See https://github.com/pytorch/pytorch/issues/141112 + B, D, max_seq_len = 64, 512, 100 + if device == "cuda": + torch._C._cuda_clearCublasWorkspaces() + m = torch.nn.Linear(D, D, device=device) + nt = torch.nested.as_nested_tensor( + [ + torch.rand(size=[seq_len, D]) + for seq_len in torch.randint(max_seq_len, size=(B,)) + ], + layout=torch.jagged, + device=device, + ) + + # (B, j1, D) -> (B, j1, 1, D) for a higher dim input size + nt = nt.unsqueeze(-2) + # linear_backward() should not explode the max memory usage + if device == "cuda": + torch.cuda.reset_max_memory_allocated() + m(nt).sum().backward() + # expect under a GB for max memory allocated + max_after_gb = torch.cuda.max_memory_allocated(0) // (1024**3) + self.assertEqual(max_after_gb, 0) + + def test_unary_pointwise(self, device): + a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + out = torch.nn.functional.silu(nt.sin().cos()) + return out.values() + + gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) + + def test_unary_pointwise_transposed_inputs(self, device): + a, b, c = ( + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ) + + nt = torch.nested.nested_tensor( + [a.detach(), b.detach(), c.detach()], layout=torch.jagged + ) + nt_t = nt.transpose(1, 2) + self.assertFalse(nt_t.is_contiguous()) + out = torch.nn.functional.silu(nt_t.sin().cos()) + self.assertEqual( + out.is_contiguous(), + torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(), + ) + + self.assertEqual(nt_t.shape, out.shape) + + a, b, c = ( + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ) + + def grad_test_func(a, b, c): + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + nt_t = nt.transpose(1, 2) + out = torch.nn.functional.silu(nt_t.sin().cos()) + return out.values() + + gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) + + def test_binary_pointwise(self, device): + a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + + # Incorrect usage: shape check will fail if the offsets tensor are not + # the same exact tensor object + nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + + self.assertRaisesRegex( + RuntimeError, + "cannot call binary pointwise function .* with inputs of shapes", + lambda: nt1 * nt2, + ) + + # Correct usage: chain the calls using the same offsets tensor object + def grad_test_func(a, b, c): + nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + # TODO: Switch to public API that takes in (values, offsets) once it exists + nt2, offsets = jagged_from_list([a, b, c], nt1.offsets()) + out = nt1 * nt2 + return out.values() + + gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) + + def test_binary_pointwise_transposed(self, device): + a, b, c = ( + torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3) + ) + + nt1, offsets = jagged_from_list([a, b, c], None) + nt2, offsets = jagged_from_list([a, b, c], offsets) + + nt1_t = nt1.transpose(1, 2) + nt2_t = nt2.transpose(1, 2) + + # out = nt1_t * nt2_t + # self.assertFalse(nt1_t.is_contiguous()) + # self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous()) + # self.assertEqual(out.shape, nt1_t.shape) + + self.assertRaisesRegex( + RuntimeError, + "cannot call binary pointwise function mul.Tensor with inputs of shapes", + lambda: nt1 * nt2_t, + ) + + a, b, c = ( + torch.randn( + i + 2, 5, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ) + + # Correct usage: chain the calls using the same offsets tensor object + def grad_test_func(a, b, c): + nt1, offsets = jagged_from_list([a, b, c], None) + nt2, offsets = jagged_from_list([a, b, c], offsets) + nt1_t = nt1.transpose(1, 2) + nt2_t = nt2.transpose(1, 2) + out = nt1_t * nt2_t + return out.values() + + gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) + + def test_binary_pointwise_with_nested_int_second_arg(self, device): + # See https://github.com/pytorch/pytorch/issues/138496 + nt = random_nt_from_dims( + [3, None, 5], + device=device, + dtype=torch.float32, + layout=torch.jagged, + ) + + with self.assertRaisesRegex(RuntimeError, "invalid argument"): + nt * nt.size(1) + + with self.assertRaisesRegex(RuntimeError, "invalid argument"): + nt + nt.size(1) + + def test_split(self, device): + a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + out = torch.split(nt, 2, -1) + self.assertEqual(len(out), 2) + self.assertEqualIgnoringNestedInts( + out[0], + torch.nested.as_nested_tensor( + [a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged + ), + ) + self.assertEqualIgnoringNestedInts( + out[1], + torch.nested.as_nested_tensor( + [a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged + ), + ) + + with self.assertRaisesRegex( + RuntimeError, + r"split\(\): not supported for NestedTensor on ragged dim", + ): + torch.split(nt, 2, 1) + + def test_split_with_sizes(self, device): + a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + + nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + out = torch.split(nt, [1, 2], -1) + self.assertEqual(len(out), 2) + self.assertEqualIgnoringNestedInts( + out[0], + torch.nested.as_nested_tensor( + [a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged + ), + ) + self.assertEqualIgnoringNestedInts( + out[1], + torch.nested.as_nested_tensor( + [a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged + ), + ) + with self.assertRaisesRegex( + RuntimeError, + r"split_with_sizes\(\): not supported for NestedTensor on ragged dim", + ): + torch.split(nt, [1, 2], 1) + + def test_softmax(self, device): + nt = random_nt_from_dims( + [3, None, 5], + device=device, + dtype=torch.float32, + layout=torch.jagged, + requires_grad=True, + ) + + # operate on dim=2 + output = nt.softmax(dim=2) + + @torch._dynamo.disable + def _compare_to_ref(nt, output, dim): + for in_component, out_component in zip(nt.unbind(), output.unbind()): + self.assertEqual(in_component.softmax(dim=dim), out_component) + + # dim=2 -> dim=1 after unbind + _compare_to_ref(nt, output, dim=1) + + # operate on dim=-1 + output2 = nt.softmax(dim=-1) + torch._dynamo.disable(self.assertEqual)(output, output2) + _compare_to_ref(nt, output2, dim=-1) + + def grad_test_func(a, b): + nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged) + out = nt.softmax(dim=-1) + return out.values() + + a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device) + b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device) + gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False) + + def test_views_inherit_ragged_dim(self, device): + # view + nt = random_nt_from_dims( + [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged + ) + # inherit ragged dim via -1 + view = nt.view(4, -1, 80) + self.assertEqual(nt.shape[1], view.shape[1]) + # inherit batch and ragged dims via -1 + view2 = nt.view(-1, -1, 80) + self.assertEqual(nt.shape[:2], view2.shape[:2]) + + # expand + nt = random_nt_from_dims( + [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged + ) + # inherit batch and ragged dims via -1 + view = nt.expand(-1, -1, 5) + self.assertEqual(nt.shape[:2], view.shape[:2]) + + def test_view_ragged_idx_not_one(self, device): + nt = random_nt_from_dims( + [2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged + ) + + view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1)) + self.assertEqual((2, 20, nt.size(1)), (view_transposed.size())) + self.assertEqual(view_transposed._base, nt._base) + + def test_unsafe_view(self, device): + nt = random_nt_from_dims( + [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged + ) + # basic view + view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80)) + self.assertEqual((4, nt.size(1), 80), tuple(view1.size())) + # _unsafe_view differs from view in that the view information is not tracked + self.assertTrue(view1._base is None) + + # test an unsafe_view when ragged_idx != 1, currently only supports identity view + nt_t = nt.transpose(1, 2) + view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10)) + self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size())) + self.assertTrue(view2._base is None) + + @xfailIfTorchDynamo + @parametrize("requires_grad", [False, True]) + def test_reshape_decomp(self, device, requires_grad): + # contiguous NT should result in view. + nt = ( + random_nt_from_dims( + [3, None, 10], + device=device, + dtype=torch.float32, + layout=torch.jagged, + ) + .detach() + .requires_grad_(requires_grad) + ) + view = nt.reshape(-1, -1, 5, 2) + self.assertEqual(view.shape[:2], nt.shape[:2]) + self.assertTrue(view._is_view() and view._base is nt) + # make sure gradients flow back + if requires_grad: + view.backward(torch.ones_like(view)) + self.assertEqual(nt.grad, torch.ones_like(nt)) + + # non-contiguous NT should result in contiguous copy + nt = random_nt_from_dims( + [3, None, 5, 2], + device=device, + dtype=torch.float32, + layout=torch.jagged, + requires_grad=requires_grad, + ) + nt_noncontig = nt.transpose(-1, -2) + self.assertFalse(nt_noncontig.is_contiguous()) + copy = nt_noncontig.reshape(-1, -1, 10) + self.assertTrue(copy.is_contiguous()) + self.assertEqual(copy.shape[:2], nt.shape[:2]) + # make sure gradients flow back + if requires_grad: + copy.backward(torch.ones_like(copy)) + self.assertEqual(nt.grad, torch.ones_like(nt)) + + def test_flatten_decomp(self, device): + nt = random_nt_from_dims( + [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged + ) + flattened = nt.flatten(-2, -1) + self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape) + + nt = random_nt_from_dims( + [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged + ) + flattened = nt.flatten(-3, -2) + self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape) + + def test_chunk(self, device): + # none NJT case + t = torch.randn(10, 4, 5, requires_grad=True) + t_list = t.chunk(3, dim=0) + loss = t_list[0].sum() + t_list[2].sum() + loss.backward() + + # normal case + D = 30 + B = 8 + nt = random_nt_from_dims( + [B, None, D], + device=device, + dtype=torch.float32, + layout=torch.jagged, + requires_grad=True, + ) + NUM_CHUNKS = 3 + chunks = nt.chunk(NUM_CHUNKS, dim=-1) + self.assertEqual(len(chunks), NUM_CHUNKS) + for i in range(NUM_CHUNKS): + self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS) + + # test chunk_backward + values = torch.randn( + 5, 11, dtype=torch.float64, device=device, requires_grad=True + ) + offsets = torch.tensor([0, 2, 3, 5], device=device) + + def grad_test_func(values, offsets): + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + chunks = nt.chunk(3, dim=-1) + return chunks[0].values().sum() + + assert gradcheck( + grad_test_func, + inputs=(values, offsets), + check_batched_grad=False, + ) + + # chunk on batch dim + chunks = nt.chunk(NUM_CHUNKS, dim=0) + self.assertEqual(len(chunks), NUM_CHUNKS) + chunk_size = math.ceil(B / NUM_CHUNKS) + for i in range(NUM_CHUNKS): + if i < NUM_CHUNKS - 1: + self.assertEqual(chunks[i].shape[0], chunk_size) + else: + self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1)) + offsets_expected = ( + nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1] + - nt._offsets[i * chunk_size] + ) + self.assertEqual(chunks[i]._offsets[1:], offsets_expected) + self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0)) + + # doesn't support backward for chunk (dim=0) yet + loss = ( + chunks[0].values().sum() + + chunks[1].values().sum() + + chunks[2].values().sum() + ) + loss.backward() + + # chunk on ragged dim not supported + with self.assertRaisesRegex( + RuntimeError, "chunk.* not supported for NestedTensor on ragged dim" + ): + nt.chunk(2, dim=1) + + def test_squeeze(self, device): + B = 4 + D = 6 + # squeeze middle dim + nt = random_nt_from_dims( + [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged + ) + j0 = nt.shape[1] + + for dim_arg in [-2, 2]: + out = nt.squeeze(dim_arg) + self.assertEqual(out.shape, (B, j0, D)) + self.assertEqual(out.unsqueeze(-2), nt) + + # squeeze last dim + nt = random_nt_from_dims( + [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged + ) + j1 = nt.shape[1] + + for dim_arg in [-1, 2]: + out = nt.squeeze(dim_arg) + self.assertEqual(out.shape, (B, j1)) + self.assertEqual(out.unsqueeze(-1), nt) + + # squeeze on batch dim not supported + with self.assertRaisesRegex( + RuntimeError, "squeeze.* not supported for NestedTensor on dim=0" + ): + nt.squeeze(0) + + # squeeze on ragged dim not supported + with self.assertRaisesRegex( + RuntimeError, "squeeze.* not supported for NestedTensor on ragged dim" + ): + nt.squeeze(1) + + def test_binary_pointwise_broadcasting(self, device): + # (B, j0, 3, 4) + ts = self._get_list_for_jagged_tensor( + ((2, 3, 4), 3, 4), device, requires_grad=True + ) + # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) + # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) + # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?) + # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) + t_sizes = ( + (4,), + (1, 4), + (3, 1), + (1, 3, 1), + (1, 1, 1, 4), + # (1, 1, 1, 1, 4), (unsupported today) + ) + + def grad_test_func(t, *ts): + nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged) + out = nt + t + return out.values() + + for t_size in t_sizes: + t = torch.rand( + t_size, requires_grad=True, device=device, dtype=torch.float64 + ) + gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) + + def test_threshold_backward(self, device): + ts1 = self._get_list_for_jagged_tensor( + ((2, 3, 4), 16), device=device, requires_grad=False + ) + ts2 = self._get_list_for_jagged_tensor( + ((2, 3, 4), 16), device=device, requires_grad=False + ) + + nt1, offsets = jagged_from_list(ts1, None) + nt2, offsets = jagged_from_list(ts2, offsets) + buf1 = nt1.values().detach().clone() + buf2 = nt2.values().detach().clone() + + res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0) + res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0) + + self.assertEqual(res_dense, res_nt.values()) + + @onlyCUDA + @dtypes(torch.float32) + def test_record_stream(self, device, dtype): + def _create_nt(): + values = torch.ones(1024, 4 * 1024, device="cuda") + offsets = torch.tensor([0, 500, 1024], device="cuda", dtype=torch.int64) + lengths = offsets.diff() + nt = torch.nested.nested_tensor_from_jagged(values, offsets, lengths) + data_ptrs = { + nt._values.data_ptr(), + nt._offsets.data_ptr(), + nt._lengths.data_ptr(), + } + return nt, data_ptrs + + def fn(record_stream): + nt, data_ptrs = _create_nt() + s = torch.cuda.Stream() + + with torch.cuda.stream(s): + # emulate doing something long via sleep + per_ms = 2e7 + torch.cuda._sleep(int(per_ms * 100)) + if record_stream: + nt.record_stream(s) + return data_ptrs + + # expect memory reuse when record_stream() is not run + data_ptrs = fn(record_stream=False) + nt, nt_data_ptrs = _create_nt() + self.assertEqual(data_ptrs, nt_data_ptrs) + del nt + torch.cuda.synchronize() + + # expect memory to be preserved (no reuse) when record_stream() is run + data_ptrs = fn(record_stream=True) + nt, nt_data_ptrs = _create_nt() + self.assertEqual(len(data_ptrs.intersection(nt_data_ptrs)), 0) + + @dtypes(torch.float32) + @parametrize( + "func", + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], + name_fn=get_op_name, + ) + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_jagged_op_different_output_shape_dim( + self, device, dtype, keepdim, requires_grad, components_require_grad, func + ): + """ + Operator passes when reducing on valid reduction dimensions. + This test is for operators which return an output tensor with a shape different from the input tensor. + """ + if get_op_name(func) == "mean" and not keepdim: + return + + op_name = get_op_name(func) + + ts = self._get_list_for_jagged_tensor( + ((2, 3, 4), 3, 4), device=device, requires_grad=True + ) # (B, j0, 3, 4) + + # verify correctness of shapes (assuming that ragged_idx == 1) + if op_name == "sum": + reduce_dims = ( + ((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged + ((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch + ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch + ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch + ( + (0, 1, 2, 3), + (), + (1, 1, 1, 1), + (0, 1, 2), + ), # batch, ragged, non-batch, non-batch + ((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch + ) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None + elif op_name == "mean": + reduce_dims = ( + ((2,), (3, None, 4), (3, None, 1, 4), (1,)), + ((3,), (3, None, 3), (3, None, 3, 1), (2,)), + ) + + for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims: + nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) + out = func(nt, dim=rd, keepdim=keepdim) + ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim + if not torch.compiler.is_compiling(): # if not using torch dynamo + self.assertEqual(len(out.shape), len(ref_shape)) + for o, r in zip(out.shape, ref_shape): + if r is not None: + self.assertEqual(o, r) + else: + self.assertTrue(isinstance(o, torch.SymInt)) + + # verify correctness of values + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, + ) + for tensor_list, reduce_dim_tuple in itertools.product( + tensor_lists, reduce_dims + ): + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple + + if nt.dim() > reduce_dim[-1]: + out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) + if nt._ragged_idx in reduce_dim: # raggedness reduced away + out_expected = func( + nt.values(), dim=reduce_dim_expected, keepdim=keepdim + ) + self.assertTrue(torch.allclose(out_actual, out_expected)) + else: # raggedness preserved + out_expected = func(nt.values(), dim=reduce_dim_expected) + self.assertTrue( + torch.allclose( + out_actual.values().view(-1), out_expected.view(-1) + ) + ) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + @parametrize( + "func", + [torch.nn.functional.softmax, torch.nn.functional.log_softmax], + name_fn=lambda func: func.__name__, + ) + def test_softmax_dim( + self, + device, + dtype, + requires_grad, + components_require_grad, + func, + ): + """ + Softmax passes when reducing on valid reduction dimensions. + """ + ts = self._get_list_for_jagged_tensor( + ((2, 3, 4), 3, 4), device=device, requires_grad=True + ) # (B, j0, 3, 4) + + output_shape = (3, None, 3, 4) + + # verify correctness of shapes (assuming that ragged_idx == 1) + reduce_dims = ( + (2, 1), + (3, 2), + ) # (reduction dimension, effective reduction dimension for baseline) + + for reduce_dim, _ in reduce_dims: + nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged) + out_actual = func(nt, dim=reduce_dim) + torch._dynamo.disable(self.assertEqual)( + len(out_actual.shape), len(output_shape) + ) # disable if running on dynamo + for dim_actual, dim_expected in zip(out_actual.shape, output_shape): + if dim_expected is not None: + self.assertEqual(dim_actual, dim_expected) + else: + self.assertTrue(isinstance(dim_actual, torch.SymInt)) + + # verify correctness of values + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, + ) + for tensor_list, reduce_dim_tuple in itertools.product( + tensor_lists, reduce_dims + ): + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + reduce_dim, reduce_dim_expected = reduce_dim_tuple + + if nt.dim() > reduce_dim: + # nested tensor + out_actual = func(nt, dim=reduce_dim) + # dense tensor of dimensions 1 less than out_actual + out_expected = func(nt.values(), dim=reduce_dim_expected) + self.assertTrue( + torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) + ) + + @dtypes(torch.float32) + @parametrize( + "func", + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], + name_fn=get_op_name, + ) + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_op_dim_reduce_ragged_idx_1_different_output_shape( + self, device, dtype, keepdim, requires_grad, components_require_grad, func + ): + """ + Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1. + This test is for operators which return an output tensor with a shape different from the input tensor. + """ + if get_op_name(func) == "mean" and not keepdim: + return + + op_name = get_op_name(func) + + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + ) + reduce_dim = (1,) # ragged + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) + out_expected = torch.cat( + [func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()] + ) + if keepdim: + out_expected = out_expected.unsqueeze(reduce_dim[0]) + + self.assertFalse( + out_actual.is_nested, + f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", + ) # output is a dense tensor + self.assertEqual(out_actual, out_expected) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_softmax_dim_reduce_ragged_idx_1( + self, device, dtype, requires_grad, components_require_grad + ): + """ + Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1. + """ + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + include_2d_tensor=True, # (B, *) + ) + reduce_dim = 1 # ragged + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim) + out_expected = torch.cat( + [ + torch.nn.functional.softmax(t, dim=reduce_dim - 1) + for t in nt.unbind() + ] + ) + + self.assertTrue( + out_actual.is_nested, + "softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor", + ) # output is a nested tensor + self.assertTrue(torch.allclose(out_actual.values(), out_expected)) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + @parametrize( + "func", + [torch.nn.functional.softmax, torch.nn.functional.log_softmax], + name_fn=lambda func: func.__name__, + ) + def test_softmax_reduce_batch_dim( + self, device, dtype, requires_grad, components_require_grad, func + ): + """ + Softmax on NestedTensor fails when trying to reduce across batch dimension. + """ + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + ) + reduce_dim = 0 # batch + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + with self.assertRaisesRegex( + RuntimeError, + "not supported when reducing across the batch dimension for NestedTensor", + ): + out = func(nt, dim=reduce_dim) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_layer_norm_reduce_ragged_idx_1( + self, device, dtype, requires_grad, components_require_grad + ): + """ + Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1. + """ + + # requires_grad = False does not currently work with dynamo tests and throws this error: + # AssertionError: SymInts must use SymNodeVariable. + # If the underlying value is static, we will create a ConstantVariable and specialize. + if torch._dynamo.is_compiling() and not requires_grad: + return + + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + ) + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + if ( + nt.dim() >= 3 + ): # layer norm only works for tensors with 3 or more dimensions + normalized_shape = nt.shape[nt._ragged_idx :] + + out_actual = torch.nn.functional.layer_norm( + nt, normalized_shape=normalized_shape + ) + out_expected = torch.cat( + [ + torch.nn.functional.layer_norm(t, normalized_shape=t.shape) + for t in nt.unbind() + ] + ) # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M) + + self.assertTrue( + out_actual.is_nested, + "layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor", + ) # output is a nested tensor + self.assertEqual(out_actual._values.shape, out_expected.shape) + self.assertTrue(torch.allclose(out_actual.values(), out_expected)) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_layer_norm_2d_input( + self, + device, + dtype, + requires_grad, + components_require_grad, + ): + """ + Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor + """ + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + include_2d_tensor=True, # (B, *) + ) + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + if nt.dim() <= 2: + with self.assertRaisesRegex( + RuntimeError, + "not supported for NestedTensor objects with 2 or fewer dimensions", + ): + out = torch.nn.functional.layer_norm( + nt, normalized_shape=(nt.shape[nt._ragged_idx],) + ) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_layer_norm_operate_on_batch_dim( + self, + device, + dtype, + requires_grad, + components_require_grad, + ): + """ + Layer normalization on NestedTensor fails when trying to operate on the batch dimension + """ + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + include_2d_tensor=True, # (B, *) + ) + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + if nt.dim() > 2: # cannot perform layer normalization on 2D tensors + with self.assertRaisesRegex( + RuntimeError, + "not supported when normalizing over the batch dimension for NestedTensor", + ): + out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape) + + @dtypes(torch.float32) + @parametrize( + "func", + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], + name_fn=get_op_name, + ) + @parametrize( + "transpose_offset", [1, 2] + ) # [transpose consecutive dimensions, transpose nonconsecutive dimensions] + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape( + self, + device, + dtype, + keepdim, + requires_grad, + components_require_grad, + func, + transpose_offset, + ): + """ + Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1 + This test is for operators which return an output tensor with a shape different from the input tensor. + """ + if get_op_name(func) == "mean" and not keepdim: + return + + op_name = get_op_name(func) + + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + include_2d_tensor=True, # (B, *) + ) + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + if nt.dim() > nt._ragged_idx + transpose_offset: + nt_transposed = nt.transpose( + nt._ragged_idx, nt._ragged_idx + transpose_offset + ) + reduce_dim = (nt_transposed._ragged_idx,) # ragged + + out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim) + out_expected = torch.cat( + [ + func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) + for t in nt_transposed.unbind() + ] + ) + if keepdim: + out_expected = out_expected.unsqueeze(reduce_dim[0]) + + self.assertFalse( + out_actual.is_nested, + f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor", + ) # output is a dense tensor + self.assertEqual(out_actual, out_expected) + + @dtypes(torch.float32) + @parametrize( + "transpose_offset", [1, 2] + ) # [transpose consecutive dimensions, transpose nonconsecutive dimensions] + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape( + self, + device, + dtype, + requires_grad, + components_require_grad, + transpose_offset, + ): + """ + Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1 + This test is for operators which return an output tensor with the same shape as the input tensor. + """ + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + ) + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + if nt.dim() > nt._ragged_idx + transpose_offset: + nt_transposed = nt.transpose( + nt._ragged_idx, nt._ragged_idx + transpose_offset + ) + reduce_dim = nt_transposed._ragged_idx # ragged + + with self.assertRaisesRegex( + RuntimeError, + "not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor", + ): + out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim) + + @dtypes(torch.float32) + @parametrize( + "func", + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], + name_fn=get_op_name, + ) + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_op_dim_transpose_non_ragged_dim_different_output_shape( + self, device, dtype, keepdim, requires_grad, components_require_grad, func + ): + """ + Operator passes when reducing transposed nested tensors on valid reduction dimensions. + This test is for operators which return an output tensor with a shape different from the input tensor. + """ + if get_op_name(func) == "mean" and not keepdim: + return + + # verify correctness of shapes (assuming that ragged_idx == 1) + if get_op_name(func) == "sum": + reduce_dims = ( + ((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged + ((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch + ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch + ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch + ( + (0, 1, 2, 3), + (), + (1, 1, 1, 1), + (0, 1, 2), + ), # batch, ragged, non-batch, non-batch + ((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch + ) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None + elif get_op_name(func) == "mean": + reduce_dims = ( + ((2,), (3, None, 4), (3, None, 1, 4), (1,)), + ((3,), (3, None, 3), (3, None, 3, 1), (2,)), + ) + + # verify correctness of values + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + ) + for tensor_list, reduce_dim_tuple in itertools.product( + tensor_lists, reduce_dims + ): + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ).transpose(-1, -2) + + reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple + + if nt.dim() > max( + reduce_dim[-1], nt._ragged_idx + 2 + ): # ensure that transposed dimensions are non-batch, non-ragged dimensions + out_actual = func(nt, dim=reduce_dim, keepdim=keepdim) + if nt._ragged_idx in reduce_dim: # raggedness reduced away + out_expected = func( + nt.values(), dim=reduce_dim_expected, keepdim=keepdim + ) + self.assertTrue(torch.allclose(out_actual, out_expected)) + else: # raggedness preserved + out_expected = func(nt.values(), dim=reduce_dim_expected) + self.assertTrue( + torch.allclose( + out_actual.values().view(-1), out_expected.view(-1) + ) + ) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_softmax_dim_transpose_non_ragged_dim( + self, + device, + dtype, + requires_grad, + components_require_grad, + ): + """ + Softmax passes when reducing transposed nested tensors on valid reduction dimensions. + This test is for operators which return an output tensor with the same shape as the input tensor. + """ + # verify correctness of shapes (assuming that ragged_idx == 1) + reduce_dims = ( + (2, 1), + (3, 2), + ) # (reduction dimension, effective reduction dimension for baseline) + + # verify correctness of values + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, + include_requires_grad=components_require_grad, + include_inner_dim_size_1=True, # (B, *, 1) + ) + for tensor_list, reduce_dim_tuple in itertools.product( + tensor_lists, reduce_dims + ): + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ).transpose(-1, -2) + + reduce_dim, reduce_dim_expected = reduce_dim_tuple + + if nt.dim() > max(reduce_dim, nt._ragged_idx + 2): + out_actual = torch.nn.functional.softmax( + nt, dim=reduce_dim + ) # nested tensor + out_expected = torch.nn.functional.softmax( + nt.values(), dim=reduce_dim_expected + ) # dense tensor of dimensions 1 less than out_actual + + self.assertTrue( + torch.allclose(out_actual.values().view(-1), out_expected.view(-1)) + ) + + @dtypes(torch.float32) + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_sum_dim_reduce_ragged_and_non_batch( + self, + device, + dtype, + keepdim, + requires_grad, + components_require_grad, + ): + """ + Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions + """ + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, include_requires_grad=components_require_grad + ) + reduce_dims = ( + (1, 2), # ragged, non-batch + (1, 3), # ragged, non-batch + ) + + for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + if nt.dim() > reduce_dim[-1]: + with self.assertRaisesRegex( + RuntimeError, + "reducing along a ragged and non-batch dimension is not supported", + ): + out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim) + + @dtypes(torch.float32) + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_sum_dim_reduce_batch_and_non_batch( + self, + device, + dtype, + keepdim, + requires_grad, + components_require_grad, + ): + """ + Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions + """ + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, include_requires_grad=components_require_grad + ) + reduce_dims = ( + (0, 2), # batch, non-batch + (0, 3), # batch, non-batch + ) + + for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims): + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + if nt.dim() > reduce_dim[-1]: + with self.assertRaisesRegex( + RuntimeError, + "reducing along the batch dimension but not the ragged dimension " + + "is not supported", + ): + out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim) + + @dtypes(torch.float32) + @parametrize( + "func", + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], + name_fn=get_op_name, + ) + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_op_dim_reduce_batch_only_different_output_shape( + self, device, dtype, keepdim, requires_grad, components_require_grad, func + ): + """ + Operator on NestedTensor fails when trying to reduce across batch dimension + """ + if get_op_name(func) == "mean" and not keepdim: + return + + tensor_lists = self._get_example_tensor_lists( + include_list_of_lists=False, include_requires_grad=components_require_grad + ) + reduce_dim = (0,) # batch + + for tensor_list in tensor_lists: + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + with self.assertRaisesRegex( + RuntimeError, + "reducing along the batch dimension but not the ragged dimension " + + "is not supported", + ): + out = func(nt, dim=reduce_dim, keepdim=keepdim) + + @dtypes(torch.float32) + @parametrize( + "func", + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim], + name_fn=get_op_name, + ) + @parametrize("keepdim", [False, True]) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_op_dim_with_lengths_different_output_shape( + self, + device, + dtype, + keepdim, + requires_grad, + components_require_grad, + func, + ): + """ + Operator on NestedTensor fails when trying to reduce a nested tensor with lengths, + i.e. a nested tensor with holes, if reducing on the ragged dimension. + This test is for operators which return an output tensor with different shape than the input tensor. + """ + if get_op_name(func) == "mean" and not keepdim: + return + + reduce_dims = ((1,), (2,), (2, 3)) + + lengths = torch.randint(5, 10, (20,), device=device) + offsets = torch.zeros((21,), device=device, dtype=torch.int) + torch.cumsum(lengths, dim=0, out=offsets[1:]) + + values = torch.randn( + (offsets[-1].item(), 20), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + nt_with_holes = torch.nested.nested_tensor_from_jagged( + values, + offsets, + lengths=offsets.diff() - 2, # arbitrary subtraction to create holes + ) + + for reduce_dim in reduce_dims: + if nt_with_holes.dim() > reduce_dim[-1]: + if nt_with_holes._ragged_idx in reduce_dim: + with self.assertRaisesRegex( + RuntimeError, + "reducing across the ragged dimension is not supported for " + + "non-contiguous nested tensors with holes", + ): + out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim) + else: + out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim) + + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_softmax_dim_with_lengths( + self, + device, + dtype, + requires_grad, + components_require_grad, + ): + """ + Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths, + i.e. a nested tensor with holes, if reducing on the ragged dimension. + """ + reduce_dims = (1, 2, 3) + + lengths = torch.randint(5, 10, (20,), device=device) + offsets = torch.zeros((21,), device=device, dtype=torch.int) + torch.cumsum(lengths, dim=0, out=offsets[1:]) + + values = torch.randn( + (offsets[-1].item(), 20), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + nt_with_holes = torch.nested.nested_tensor_from_jagged( + values, + offsets, + lengths=offsets.diff() - 2, # arbitrary subtraction to create holes + ) + + for reduce_dim in reduce_dims: + if nt_with_holes.dim() > reduce_dim: + if nt_with_holes._ragged_idx == reduce_dim: + with self.assertRaisesRegex( + RuntimeError, + "not supported where lengths is not None " + + "if reducing across the ragged dimension for NestedTensor", + ): + out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim) + else: + out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim) + + @skipIfTorchDynamo( + "ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work " + + "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. " + + "If the underlying value is static, we will create a ConstantVariable and specialize.`" + ) + @dtypes(torch.float32) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_layer_norm_with_lengths( + self, + device, + dtype, + requires_grad, + components_require_grad, + ): + """ + Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths, + i.e. a nested tensor with holes, if operating on the ragged dimension. + """ + + # create components for nested tensor + lengths = torch.randint(5, 10, (20,), device=device) + offsets = torch.zeros((21,), device=device, dtype=torch.int) + torch.cumsum(lengths, dim=0, out=offsets[1:]) + values = torch.randn( + (offsets[-1].item(), 10, 30), + device=device, + dtype=dtype, + requires_grad=requires_grad, + ) + + nt_with_holes = torch.nested.nested_tensor_from_jagged( + values, + offsets, + lengths=offsets.diff() - 2, # arbitrary subtraction to create holes + ) + + ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] + + normalized_shapes = ( + (10, 30), # normalization on non-ragged dimension passes + (ragged_size, 10, 30), # normalization on ragged dimension fails + ) + + for normalized_shape in normalized_shapes: + if ragged_size in normalized_shape: + with self.assertRaisesRegex( + RuntimeError, + "not supported where lengths is not None if operating on the ragged dimension for NestedTensor", + ): + out = torch.nn.functional.layer_norm( + nt_with_holes, normalized_shape=normalized_shape + ) + else: + out = torch.nn.functional.layer_norm( + nt_with_holes, normalized_shape=normalized_shape + ) + + @unittest.skipIf( + PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" + ) + @onlyOn(["cuda", "xpu"]) + def test_pin_memory(self, device): + nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) + for nt in [nt_contiguous, nt_noncontiguous]: + self.assertFalse(nt.is_pinned()) + pinned = nt.pin_memory() + self.assertTrue(pinned.is_pinned()) + self.assertEqual(nt, pinned) + self.assertNotEqual(nt.data_ptr(), pinned.data_ptr()) + # test that pin_memory on already pinned tensor has no effect + self.assertIs(pinned, pinned.pin_memory()) + self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) + + @torch.compiler.disable + def _validate_nt( + self, + nt, + device, + dtype, + layout, + requires_grad, + dim, + batch_size, + contiguous, + cached_min_seqlen=None, + cached_max_seqlen=None, + base=None, + ref_nt=None, + ): + # Validate a bunch of properties after NT construction. + device = torch.device(device) + self.assertEqual(nt.dim(), dim) + self.assertEqual(nt.device, device) + self.assertEqual(nt.dtype, dtype) + self.assertEqual(nt.layout, layout) + self.assertEqual(nt.requires_grad, requires_grad) + self.assertEqual(nt.is_contiguous(), contiguous) + + if layout == torch.jagged: + self.assertEqual(nt._values.device, device) + self.assertEqual(nt._offsets.device, device) + self.assertEqual(nt.shape[0], batch_size) + self.assertTrue(isinstance(nt.shape[1], torch.SymInt)) + + if base is not None: + self.assertTrue(nt._is_view() and nt._base is base) + replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache + self.assertEqual( + "min_seqlen" in replay_cache, cached_min_seqlen is not None + ) + self.assertEqual( + "max_seqlen" in replay_cache, cached_max_seqlen is not None + ) + + self.assertEqual( + "min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None + ) + self.assertEqual( + "max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None + ) + + if cached_min_seqlen is not None: + self.assertEqual(nt._min_seqlen, cached_min_seqlen) + + if cached_max_seqlen is not None: + self.assertEqual(nt._max_seqlen, cached_max_seqlen) + + if ref_nt is not None: + self.assertEqual(nt.size(0), ref_nt.size(0)) + for n1, n2 in zip(nt.unbind(), ref_nt.unbind()): + self.assertEqual(n1, n2) + + @dtypes(torch.float, torch.double, torch.half) + @parametrize("requires_grad", [False, True]) + @parametrize("components_require_grad", [False, True]) + def test_jagged_layout_construction_nested_tensor( + self, device, dtype, requires_grad, components_require_grad + ): + for tensor_list in self._get_example_tensor_lists( + include_list_of_lists=True, include_requires_grad=components_require_grad + ): + nt = torch.nested.nested_tensor( + tensor_list, + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=requires_grad, + ) + + expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 + expected_batch_size = len(tensor_list) + expected_contiguous = True + expected_min_seqlen = min( + (torch.tensor(t) if isinstance(t, list) else t).shape[0] + for t in tensor_list + ) + expected_max_seqlen = max( + (torch.tensor(t) if isinstance(t, list) else t).shape[0] + for t in tensor_list + ) + self._validate_nt( + nt, + device, + dtype, + torch.jagged, + requires_grad, + expected_dim, + expected_batch_size, + expected_contiguous, + expected_min_seqlen, + expected_max_seqlen, + ) + + # Make sure grads -don't- flow back into original tensors for nested_tensor() + if requires_grad: + (nt * 2).backward(torch.ones_like(nt)) + for t in tensor_list: + t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t) + self.assertTrue(t.grad is None) + + @dtypes(torch.float, torch.double, torch.half) + @parametrize("components_require_grad", [False, True]) + def test_jagged_layout_construction_as_nested_tensor( + self, device, dtype, components_require_grad + ): + # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list + for tensor_list in self._get_example_tensor_lists( + include_list_of_lists=False, include_requires_grad=components_require_grad + ): + nt = torch.nested.as_nested_tensor( + tensor_list, device=device, dtype=dtype, layout=torch.jagged + ) + + # nt.requires_grad=True should be set if at least one component requires grad + expected_dim = tensor_list[0].dim() + 1 + expected_batch_size = len(tensor_list) + expected_contiguous = True + expected_min_seqlen = min( + (torch.tensor(t) if isinstance(t, list) else t).shape[0] + for t in tensor_list + ) + expected_max_seqlen = max( + (torch.tensor(t) if isinstance(t, list) else t).shape[0] + for t in tensor_list + ) + self._validate_nt( + nt, + device, + dtype, + torch.jagged, + components_require_grad, + expected_dim, + expected_batch_size, + expected_contiguous, + expected_min_seqlen, + expected_max_seqlen, + ) + + # Make sure grads flow back into original tensors for as_nested_tensor() + if components_require_grad: + (nt * 2).backward(torch.ones_like(nt)) + for t in tensor_list: + if t.requires_grad: + self.assertEqual(t.grad, torch.ones_like(t) * 2) + else: + self.assertTrue(t.grad is None) + + @xfailIfTorchDynamo + @unittest.skipIf( + PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property" + ) + @onlyOn(["cuda", "xpu"]) + def test_jagged_layout_construction_with_pinned_memory(self, device): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, layout=torch.jagged, device="cpu", pin_memory=True + ) + + expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1 + expected_batch_size = len(tensor_list) + expected_min_seqlen = min( + (torch.tensor(t) if isinstance(t, list) else t).shape[0] + for t in tensor_list + ) + expected_max_seqlen = max( + (torch.tensor(t) if isinstance(t, list) else t).shape[0] + for t in tensor_list + ) + self._validate_nt( + nt, + device="cpu", + dtype=torch.float32, + layout=torch.jagged, + requires_grad=False, + dim=expected_dim, + batch_size=expected_batch_size, + contiguous=True, + cached_min_seqlen=expected_min_seqlen, + cached_max_seqlen=expected_max_seqlen, + ) + self.assertTrue(nt.is_pinned()) + + @dtypes(torch.float, torch.double, torch.half) + @parametrize("requires_grad", [False, True]) + @parametrize("values_is_view", [False, True]) + def test_jagged_view_from_values_offsets( + self, device, dtype, requires_grad, values_is_view + ): + if values_is_view: + # make values a view of base + base = torch.randn( + 2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad + ) + values = base.flatten(0, -2) + else: + values = torch.randn( + 10, 5, device=device, dtype=dtype, requires_grad=requires_grad + ) + offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) + + nt = nested_view_from_values_offsets(values, offsets) + + expected_dim = values.dim() + 1 + expected_batch_size = offsets.shape[0] - 1 + expected_base = base if values_is_view else values + lengths = offsets.diff() + self._validate_nt( + nt, + device, + dtype, + torch.jagged, + requires_grad, + expected_dim, + expected_batch_size, + # ensure NT is a proper view + base=expected_base, + contiguous=True, + # if no min / max are passed, expect the metadata cache to be empty + cached_min_seqlen=None, + cached_max_seqlen=None, + ) + + if requires_grad: + # Make sure grads flow back + (nt * 2).backward(torch.ones_like(nt)) + + @torch.compiler.disable + def _check_grad(t): + self.assertTrue(t.grad is not None) + self.assertEqual(t.grad, torch.ones_like(t) * 2) + + _check_grad(base if values_is_view else values) + + @dtypes(torch.float) + @parametrize("pass_min_max", [False, True]) + def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max): + # === construct from (values, offsets) === + values = torch.randn(10, 5, device=device, dtype=dtype) + offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64) + + # compute min / max seqlen + lengths = offsets.diff() + min_seqlen = lengths.min().item() + max_seqlen = lengths.max().item() + + if pass_min_max: + nt = torch.nested.nested_tensor_from_jagged( + values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen + ) + else: + nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets) + self._validate_nt( + nt, + device, + dtype, + torch.jagged, + requires_grad=False, + dim=3, + batch_size=4, + contiguous=True, + cached_min_seqlen=(min_seqlen if pass_min_max else None), + cached_max_seqlen=(max_seqlen if pass_min_max else None), + base=values, + ) + + # === construct from (values, offsets, lengths) === + lengths = torch.tensor([2, 1, 1, 2], device=device) + + # compute min / max seqlen + min_seqlen = lengths.min().item() + max_seqlen = lengths.max().item() + + if pass_min_max: + nt = torch.nested.nested_tensor_from_jagged( + values, + offsets=offsets, + lengths=lengths, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ) + else: + nt = torch.nested.nested_tensor_from_jagged( + values, offsets=offsets, lengths=lengths + ) + + # when both offsets / lengths are specified, expect non-contiguous + self._validate_nt( + nt, + device, + dtype, + torch.jagged, + requires_grad=False, + dim=3, + batch_size=4, + contiguous=False, + cached_min_seqlen=(min_seqlen if pass_min_max else None), + cached_max_seqlen=(max_seqlen if pass_min_max else None), + base=values, + ) + self.assertIs(nt.lengths(), lengths) + + # === construct from (values, lengths) === + values = torch.randn(14, 5, device=device, dtype=dtype) + lengths = torch.tensor([2, 3, 4, 5], device=device) + + # compute min / max seqlen + min_seqlen = lengths.min().item() + max_seqlen = lengths.max().item() + + if pass_min_max: + nt = torch.nested.nested_tensor_from_jagged( + values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen + ) + else: + nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths) + + # for now, if only lengths is specified, convert to offsets to integrate best with the + # existing kernels + expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device) + expected_nt = torch.nested.nested_tensor_from_jagged( + values, offsets=expected_offsets + ) + self._validate_nt( + nt, + device, + dtype, + torch.jagged, + requires_grad=False, + dim=3, + batch_size=4, + contiguous=True, + cached_min_seqlen=(min_seqlen if pass_min_max else None), + cached_max_seqlen=(max_seqlen if pass_min_max else None), + base=values, + ref_nt=expected_nt, + ) + + # error case: no offsets or lengths + with self.assertRaisesRegex( + RuntimeError, "At least one of offsets or lengths is required" + ): + torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None) + + with self.assertRaisesRegex(ValueError, "Expected jagged_dim >=1, but got 0."): + torch.nested.nested_tensor_from_jagged( + values, lengths=lengths, jagged_dim=0 + ) + + @onlyCPU + def test_nested_tensor_from_jagged_fx_trace(self, device): + def fn(x, y): + return torch.nested.nested_tensor_from_jagged(x, y) + + def user_unwrapped(x, y): + return fn(x, y) + + with self.assertRaisesRegex( + RuntimeError, + "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace", + ): + torch.fx.symbolic_trace(user_unwrapped) + + @dtypes(torch.float, torch.double, torch.half) + @parametrize("dim", range(5)) + @parametrize( + "layout", + [torch.strided, torch.jagged], + name_fn=lambda l: f"layout_{str(l).split('.')[1]}", + ) + @parametrize("requires_grad", [False, True]) + @parametrize("contiguous", [False, True]) + def test_as_nested_tensor_from_tensor( + self, device, dtype, dim, layout, requires_grad, contiguous + ): + if dim == 0: + t = torch.tensor(3.0, requires_grad=requires_grad) + else: + t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad) + assert t.dim() == dim + + if dim < 2: + # 0-1 dim tensors can't be converted to NTs + with self.assertRaisesRegex( + RuntimeError, "Expected tensor argument to have dim" + ): + nt = torch.nested.as_nested_tensor( + t, device=device, dtype=dtype, layout=layout + ) + return + + orig_t = t + if not contiguous: + t = t.transpose(0, 1) + + nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout) + expected_dim = t.dim() + expected_batch_size = t.size(0) + expected_seqlen = t.size(1) if layout == torch.jagged else None + self._validate_nt( + nt, + device, + dtype, + layout, + requires_grad=requires_grad, + dim=dim, + batch_size=expected_batch_size, + contiguous=True, + cached_min_seqlen=expected_seqlen, + cached_max_seqlen=expected_seqlen, + ) + + if torch.device(device) == t.device and dtype == t.dtype and contiguous: + # should be the non-copying (view) case + self.assertTrue(nt._is_view() and nt._base is t) + + # should have equivalent components to construction from unbound tensor list + nt_from_unbind = torch.nested.as_nested_tensor( + list(t.unbind(0)), device=device, dtype=dtype, layout=layout + ) + self.assertEqualIgnoringNestedInts(nt, nt_from_unbind) + + # ensure call on a NT with the same properties returns the NT directly + nt2 = torch.nested.as_nested_tensor( + nt, device=device, dtype=dtype, layout=layout + ) + self.assertTrue(nt is nt2) + + # ensure call with device=None uses input tensor device + nt3 = torch.nested.as_nested_tensor( + t.to(device=device, dtype=dtype), + device=None, + dtype=None, + layout=layout, + ) + self._validate_nt( + nt3, + device, + dtype, + layout, + requires_grad=requires_grad, + dim=dim, + batch_size=expected_batch_size, + contiguous=True, + cached_min_seqlen=expected_seqlen, + cached_max_seqlen=expected_seqlen, + ) + + # we don't support conversion between layouts this way atm + other_layout = torch.strided if layout == torch.jagged else torch.jagged + with self.assertRaisesRegex( + RuntimeError, "Converting between nested tensor layouts is not supported" + ): + torch.nested.as_nested_tensor( + nt, device=device, dtype=dtype, layout=other_layout + ) + + if requires_grad: + # make sure gradients flow back into inputs + (nt * 2).backward(torch.ones_like(nt)) + self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2) + + @dtypes(torch.float32) + def test_construction_from_list(self, device, dtype): + from torch.fx.experimental.symbolic_shapes import is_nested_int + + # success case: single ragged dim anywhere but the batch dim + for nt_dim in [2, 3, 4]: + for ragged_dim in range(1, nt_dim): + B = 6 + shapes = [list(range(3, 3 + nt_dim - 1)) for _ in range(B)] + for b in range(B): + # subtract 1 to convert to component dim space + shapes[b][ragged_dim - 1] = torch.randint( + 2, 9, (1,), device=device, dtype=torch.int64 + ).item() + + components = [ + torch.randn(shape, device=device, dtype=dtype) for shape in shapes + ] + nt = torch.nested.nested_tensor(components, layout=torch.jagged) + + self.assertEqual(nt.dim(), nt_dim) + self.assertEqual(nt._ragged_idx, ragged_dim) + for d in range(nt_dim): + self.assertEqual(d == ragged_dim, is_nested_int(nt.shape[d])) + + # error case: empty list + with self.assertRaisesRegex( + RuntimeError, "Cannot construct a nested tensor from an empty tensor list" + ): + torch.nested.nested_tensor([], layout=torch.jagged) + + # error case: list of zero-dim tensors + with self.assertRaisesRegex( + RuntimeError, + "Cannot construct a nested tensor from a list of zero-dim tensors", + ): + torch.nested.nested_tensor( + [ + torch.tensor(3.0, device=device, dtype=dtype), + torch.tensor(4.0, device=device, dtype=dtype), + torch.tensor(5.0, device=device, dtype=dtype), + ], + layout=torch.jagged, + ) + + # error case: multiple ragged dims + with self.assertRaisesRegex( + RuntimeError, + "Cannot represent given tensor list as a nested tensor with the jagged layout", + ): + torch.nested.nested_tensor( + [ + torch.randn(2, 3, device=device, dtype=dtype), + torch.randn(4, 5, device=device, dtype=dtype), + ], + layout=torch.jagged, + ) + + # error case: components on multiple devices + if "cuda" in device: + with self.assertRaisesRegex( + RuntimeError, + "When constructing a nested tensor, all tensors in list must be on the same device", + ): + torch.nested.nested_tensor( + [ + torch.randn(2, 3, device=device, dtype=dtype), + torch.randn(2, 4, device="cpu", dtype=dtype), + ], + layout=torch.jagged, + ) + + # error case: components with multiple dtypes + with self.assertRaisesRegex( + RuntimeError, + "When constructing a nested tensor, all tensors in list must have the same dtype", + ): + torch.nested.nested_tensor( + [ + torch.randn(2, 3, device=device, dtype=dtype), + torch.randn(2, 4, device=device, dtype=torch.float64), + ], + layout=torch.jagged, + ) + + # error case: components with multiple dims + with self.assertRaisesRegex( + RuntimeError, + "When constructing a nested tensor, all tensors in list must have the same dim", + ): + torch.nested.nested_tensor( + [ + torch.randn(2, 3, device=device, dtype=dtype), + torch.randn(2, 3, 4, device=device, dtype=dtype), + ], + layout=torch.jagged, + ) + + @dtypes(torch.double, torch.half) + @onlyOn(["cuda", "xpu"]) + def test_device_dtype_transfer_updates_offsets(self, device, dtype): + for tensor_list in self._get_example_tensor_lists(): + orig_device = torch.device("cpu") + orig_dtype = torch.float32 + nt = torch.nested.nested_tensor( + tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype + ) + + self.assertEqual(torch.int64, nt.offsets().dtype) + nt = nt.to(device=device).to(dtype=dtype) + + # offsets should still be int64 on the new device + self.assertEqual(nt.values().device, nt.offsets().device) + self.assertEqual(torch.int64, nt.offsets().dtype) + + def test_unbind(self, device): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, layout=torch.jagged, device=device + ) # ragged_idx = 1 + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + @parametrize("ragged_idx", [2, 3]) + def test_unbind_transpose(self, device, ragged_idx): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, layout=torch.jagged, device=device + ) + if ragged_idx < nt.dim(): + nt = nt.transpose(1, ragged_idx) # set ragged_idx + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual( + t.transpose(0, ragged_idx - 1), tensor_list[i] + ) # transpose back each element of result + + def test_unbind_transpose_ragged_idx_last_dim(self, device): + for tensor_list in self._get_example_tensor_lists(): + nt = torch.nested.nested_tensor( + tensor_list, layout=torch.jagged, device=device + ).transpose( + 1, -1 + ) # set ragged_idx = last dimension + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual( + t.transpose(0, -1), tensor_list[i] + ) # transpose back each element of result + + def test_unbind_lengths(self, device): + values = torch.randn(16, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + nt = torch.nested.nested_tensor_from_jagged( + values, offsets=offsets, lengths=lengths + ) # 3D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])]) + + out = nt.unbind() + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_1(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + ragged_idx = 1 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 8, 12, 13, 16], device=device) + lengths = torch.tensor([6, 2, 1, 2], device=device) + ragged_idx = 2 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + self.assertRaisesRegex( + RuntimeError, + r"unbind\(\): nested tensor offsets and lengths.*", + lambda: nt.unbind(), + ) + + def test_unbind_lengths_ragged_idx_2(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 2, 4, 8], device=device) + lengths = torch.tensor([2, 1, 3], device=device) + ragged_idx = 2 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + def test_unbind_lengths_ragged_idx_3(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 100, 128], device=device) + lengths = torch.tensor([50, 28], device=device) + ragged_idx = 3 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) + + out = nt.unbind() + + self.assertEqual(len(out), len(tensor_list)) + for i, t in enumerate(out): + self.assertEqual(t, tensor_list[i]) + + @skipIfTorchDynamo( + "TorchDynamo raises an error for ragged_idx == 0 earlier than Torch" + ) + def test_unbind_lengths_ragged_idx_0(self, device): + values = torch.randn(16, 8, 128, device=device) + offsets = torch.tensor([0, 100, 128], device=device) + lengths = torch.tensor([50, 28], device=device) + ragged_idx = 0 + nt = torch.nested._internal.nested_tensor.NestedTensor( + values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx + ) # 4D nested tensor + + tensor_list = [] + for i in range(offsets.shape[0] - 1): + tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])]) + + self.assertRaisesRegex( + RuntimeError, + r"unbind\(\): nested tensor.*out of bounds", + lambda: nt.unbind(), + ) + + def test_narrow(self, device): + starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) + lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) + buffer = ( + torch.arange(0, 10, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(5, -1) + .clone() + .detach() + ) + nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged) + + self.assertTrue(nt._is_view() and nt._base is buffer) + + # TODO: Use this approach when unbind is functional + # unbinded_nt = nt.unbind() + # for i in range(starts.shape[0]): + # self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i]) + for i in range(starts.shape[0]): + self.assertEqual( + torch.arange( + starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64 + ), + nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])], + ) + + def test_njt_cat(self, device): + offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64) + values_1 = torch.randn( + 3, 2, dtype=torch.float64, device=device, requires_grad=True + ) + values_2 = torch.randn( + 3, 4, dtype=torch.float64, device=device, requires_grad=True + ) + + def grad_test_func(values_1, values_2, offsets): + nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets) + nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets) + nt_3 = torch.cat([nt_1, nt_2], dim=-1) + return nt_3.values() + + assert gradcheck( + grad_test_func, + inputs=(values_1, values_2, offsets), + check_batched_grad=False, + ) + + def test_is_contiguous(self, device): + a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) + nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) + + starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) + lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) + narrow_base = ( + torch.arange(0, 10, device=device, dtype=torch.int64) + .unsqueeze(0) + .expand(5, -1) + .clone() + ) + nt_noncontiguous = torch.nested.narrow( + narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged + ) + + starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64) + lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64) + nt_contiguous_narrow = torch.nested.narrow( + narrow_base, 1, starts_c, lengths_c, layout=torch.jagged + ) + + # Test contiguous case + assert nt_contiguous.is_contiguous() + + # Test narrow case + assert not nt_noncontiguous.is_contiguous() + assert nt_contiguous_narrow.is_contiguous() + + # Test querying by memory_format + self.assertTrue( + nt_contiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format) + ) + self.assertTrue( + nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format) + ) + + def test_layout_under_torch_dispatch_mode(self): + from torch.testing._internal.logging_tensor import ( + capture_logs_with_logging_tensor_mode, + ) + + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) + + with capture_logs_with_logging_tensor_mode(): + self.assertEqual(nt.layout, torch.jagged) + + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") + @parametrize( + "func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__ + ) + def test_like_shape(self, func): + nt = random_nt_from_dims( + [2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged + ) + nt_like = func(nt) + + for nt_ub in nt_like.unbind(): + t_like = func(nt_ub) + self.assertEqual(nt_ub.shape, t_like.shape) + + @skipIfTorchDynamo("Not a suitable test for TorchDynamo") + @parametrize( + "func", + [ + torch.empty_like, + torch.full_like, + torch.ones_like, + torch.rand_like, + torch.randint_like, + torch.randn_like, + torch.zeros_like, + ], + name_fn=lambda f: f.__name__, + ) + def test_like_value(self, func, device): + dtype = torch.float32 if func is not torch.randint_like else torch.int32 + for nt in _sample_njts(device=device, dtype=dtype): + extra_kwarg_sets = [{}] + if func is torch.full_like: + extra_kwarg_sets = [{"fill_value": 4.2}] + elif func is torch.randint_like: + extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}] + + # only test changing dtype / device from CUDA -> CPU because CUDA might not be + # available when running this test for CPU + change_dtype_device_settings = ( + [False, True] if "cuda" in device else [False] + ) + for change_dtype_device in change_dtype_device_settings: + if change_dtype_device: + new_dtype = ( + torch.float64 if func is not torch.randint_like else torch.int64 + ) + new_device = "cpu" if "cuda" in device else device + new_layout = torch.strided + for extra_kwargs in extra_kwarg_sets: + extra_kwargs.update( + { + "dtype": new_dtype, + "device": new_device, + "layout": new_layout, + } + ) + + for extra_kwargs in extra_kwarg_sets: + nt_like = func(nt, **extra_kwargs) + self.assertEqual(nt.shape, nt_like.shape) + if change_dtype_device: + self.assertNotEqual(nt.device, nt_like.device) + self.assertNotEqual(nt.device, nt_like.dtype) + # layout should be ignored since only torch.jagged is supported + self.assertEqual(torch.jagged, nt_like.layout) + else: + self.assertEqual(nt.device, nt_like.device) + self.assertEqual(nt.dtype, nt_like.dtype) + self.assertEqual(nt.layout, nt_like.layout) + self.assertEqual(nt.layout, torch.jagged) + + # don't bother trying to compare random or empty values + if func not in [ + torch.empty_like, + torch.rand_like, + torch.randn_like, + torch.randint_like, + ]: + for nt_ub in nt_like.unbind(): + t_like = func(nt_ub, **extra_kwargs) + self.assertEqual(nt_ub, t_like) + + def test_noncontiguous_pointwise(self, device): + a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) + b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device) + c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device) + nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged) + # transpose ragged dim + transposed = nt.transpose(1, 2) + self.assertFalse(transposed.is_contiguous()) + clone = transposed.clone() + + def check_nt_equality(x, y): + self.assertEqual(x.values(), y.values()) + self.assertEqual(x.offsets(), y.offsets()) + self.assertEqual(x._ragged_idx, y._ragged_idx) + self.assertEqual(x.shape, y.shape) + + self.assertFalse(clone.is_contiguous()) + check_nt_equality(clone, transposed) + + clone_contig = transposed.clone(memory_format=torch.contiguous_format) + self.assertTrue(clone_contig.is_contiguous()) + check_nt_equality(clone_contig, transposed) + + detached = transposed.detach() + self.assertFalse(clone.is_contiguous()) + check_nt_equality(detached, transposed) + + def test_permute(self, device): + nt = random_nt_from_dims( + [2, None, 3, 5], device, torch.float32, layout=torch.jagged + ) + nt_shape = nt.shape + nt_inner_shape = nt.values().shape + with self.assertRaisesRegex( + ValueError, + r"permute\(\): number of dimensions in the tensor input \(4\) " + + r"does not match the length of the desired ordering of dimensions \(3\).", + ): + nt.permute(0, 2, 1) + with self.assertRaisesRegex( + ValueError, r"permute\(\): duplicate dims are not allowed." + ): + nt.permute(0, 2, -2, 3) + with self.assertRaisesRegex( + ValueError, "Permute is not supported on the batch dimension for jagged NT" + ): + nt.permute(1, 0, 2, 3) + nt_permute = nt.permute(0, 2, 1, -1) + self.assertEqual( + nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3]) + ) + self.assertEqual( + nt_permute.values().shape, + (nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]), + ) + self.assertEqual(nt_permute._ragged_idx, 2) + self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt) + + def test_to_dtype(self, device): + nt = random_nt_from_dims( + [2, None, 3], device, torch.float32, layout=torch.jagged + ) + nt_after = nt.to(torch.float64) + self.assertEqual(torch.float32, nt.dtype) + self.assertEqual(torch.float64, nt_after.dtype) + self.assertEqual(torch.float64, nt_after.values().dtype) + self.assertEqual(torch.int64, nt_after.offsets().dtype) + + noncontiguous_nt = nt.transpose(1, 2) + noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16) + self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype) + self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype) + self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype) + + def test_to_copy(self, device): + nt = torch.nested.nested_tensor( + [ + torch.randn( + i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ], + layout=torch.jagged, + ) + + nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16) + self.assertEqual(torch.float16, nt_copy_dtype.dtype) + + nt_t = nt.transpose(1, 2) + nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16) + self.assertEqual(torch.float16, nt_t_copy_dtype.dtype) + + def test_copy_(self, device): + offsets = torch.tensor([0, 2, 4], device=device) + a = torch.nested.nested_tensor_from_jagged( + torch.zeros(4, 3, device=device), offsets + ) + b = torch.nested.nested_tensor_from_jagged( + torch.ones(4, 3, device=device), offsets + ) + a.copy_(b) + torch._dynamo.disable(self.assertEqual)(a, b) + + offsets_2 = torch.tensor([0, 2, 4], device=device) + c = torch.nested.nested_tensor_from_jagged( + torch.ones(4, 3, device=device), offsets_2 + ) + # should work even though the nested ints are different due to unbound-based copy + a.copy_(c) + + # fail when tensors have different sizes + a = a.transpose(1, 2) + with self.assertRaisesRegex( + RuntimeError, + "expected compatible input and src shapes, but got", + ): + a.copy_(b) + + # This can't happen in the opinfo tests due to subprocess creation + @unittest.skipIf( + TEST_WITH_ROCM, + "In ROCm, kernel asserts are disabled due to performance overhead", + ) + def test_index_put_error(self, device): + import subprocess + + with self.subTest(): + r = subprocess.call( + [ + sys.executable, + "-c", + """\ +import torch +offsets = torch.tensor([0, 2, 5, 7], device='cuda') +lengths = torch.tensor([2, 2, 2], device='cuda') +indices = [ + torch.tensor([0, 1, 2], device='cuda'), + torch.tensor([0, 2, 1], device='cuda'), + torch.tensor([0, 0, 0], device='cuda'), +] +a = torch.nested.nested_tensor_from_jagged( + torch.zeros(7, 3, device='cuda'), offsets, lengths +) +a[indices] = 1.0 +torch.cuda.synchronize() +""", + ] + ) + self.assertTrue(r != 0) + + @skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()") + def test_profiler_sequence_nr(self): + with torch.profiler.profile() as prof: + values = torch.randn(4, 6, requires_grad=True) + offsets = torch.tensor([0, 2, 4]) + values = values * 2 + l = torch.nn.Linear(6, 8) + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + nt = l(nt) + val = nt.values() + + loss = val.sum() + loss.backward() + + fwd_seq_nrs = [] + for evt in prof.events(): + if ( + "linear" in evt.name.lower() + and "backward" not in evt.name.lower() + and evt.sequence_nr != -1 + ): + fwd_seq_nrs.append(evt.sequence_nr) + + bwd_seq_nrs = [] + for evt in prof.events(): + if ( + "linear" in evt.name.lower() + and "backward" in evt.name.lower() + and "evaluate_function" not in evt.name.lower() + and evt.sequence_nr != -1 + ): + bwd_seq_nrs.append(evt.sequence_nr) + + # There should only be one such event with a sequence number: + # the PythonTLSSnapshot event - but, note that it's not terrible if + # we end up with multiple events with the same sequence number - so we + # could relax this check if it becomes inconvenient to maintain this + # property. + self.assertEqual(len(fwd_seq_nrs), 1) + self.assertEqual(len(bwd_seq_nrs), 1) + self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0]) + + def test_is_same_size(self, device): + def get_3_tensors(): + return [ + torch.randn( + i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device + ) + for i in range(3) + ] + + nt1, offsets1 = jagged_from_list(get_3_tensors(), None) + nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1) + + nt3, offsets2 = jagged_from_list(get_3_tensors(), None) + nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2) + + def check_size(nt1, nt2, nt3, nt4): + self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2)) + self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4)) + self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3)) + + check_size(nt1, nt2, nt3, nt4) + + nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4)) + check_size(nt1_t, nt2_t, nt3_t, nt4_t) + + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape(self, device): + values = torch.randn((18, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device) + like_values = torch.randn_like(values) + + # this marks values as dynamic + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + def fn(values, same_size): + # here, the dynamic shape is specialized by same_size's shape + # https://github.com/pytorch/pytorch/issues/127097 + # make sure this doesn't error out in torch.compile + return values + same_size + + self.assertEqual( + fn(values, like_values), + torch.compile(fn)(values, like_values), + ) + + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape_recompile(self, device): + def generate_inp(total_len): + values = torch.randn((total_len, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device) + like_values = torch.randn_like(values) + return values, offsets, like_values + + def check_results(ref_fn, res_fn, args): + values, offsets, like_values = args + # this may add dynamic shape markings + # goal of this test is to make sure that whatever markings are there, + # we eventually stop recompiling as shape changes. + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values)) + + def fn(values, same_size): + return values + same_size + + compile_counter = torch._dynamo.testing.CompileCounter() + + compiled_fn = torch.compile(fn, backend=compile_counter, fullgraph=True) + check_results(fn, compiled_fn, generate_inp(18)) + self.assertEqual(compile_counter.frame_count, 1) + + check_results(fn, compiled_fn, generate_inp(19)) + # we'll probably recompile here with dynamic shapes - it's okay if not though. + frame_count_2 = compile_counter.frame_count + self.assertIn(frame_count_2, [1, 2]) + + # make sure that by now we've already compiled with dynamic shapes, so additional + # shapes should not trigger additional recompiles. + check_results(fn, compiled_fn, generate_inp(20)) + self.assertEqual(compile_counter.frame_count, frame_count_2) + + # Note 1: Math fallback doesn't work with bfloat16 on CUDA + # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT + @unittest.skipIf( + TEST_WITH_ROCM, + "ROCm doesn't support flash attention or mem_efficient attention for NT", + ) + @tf32_on_and_off(0.005) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) + def test_sdpa(self, device, dtype): + batch_size = 1 + emb_dims = 128 + n_heads = 8 + head_dims = emb_dims // n_heads + + sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) + sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) + + query = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + key = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + value = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + + # Simplest case: 1 sentence, no batching + x_d1 = sen1.unsqueeze(0) + x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) + + # See note below for why we detach here. + q_d1 = ( + query(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + q_d1_t = q_d1.transpose(1, 2) + k_d1 = ( + key(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + k_d1_t = k_d1.transpose(1, 2) + v_d1 = ( + value(x_d1) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + v_d1_t = v_d1.transpose(1, 2) + + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + q_nt_t = q_nt.transpose(1, 2) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + k_nt_t = k_nt.transpose(1, 2) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + v_nt_t = v_nt.transpose(1, 2) + + # High Precision Math Reference + q_d1_f32 = q_d1.to(torch.float32) + k_d1_f32 = k_d1.to(torch.float32) + v_d1_f32 = v_d1.to(torch.float32) + q_d1_f32_t = q_d1_f32.transpose(1, 2) + k_d1_f32_t = k_d1_f32.transpose(1, 2) + v_d1_f32_t = v_d1_f32.transpose(1, 2) + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_f32_t, k_d1_f32_t, v_d1_f32_t + )[0] + grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32)) + + # Low Precision Math Reference + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_t, k_d1_t, v_d1_t + )[0] + grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1)) + + # Compute tolerances + output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) + # fudge factor of 1.7 for smaller GPUs e.g., A2, A16 + grad_q_ref_atol, grad_q_ref_rtol = get_tolerances( + grads_ref[0], grads_lp_ref[0], 1.7 + ) + grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1]) + grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2]) + grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol] + grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol] + + attn_d1 = torch.nn.functional.scaled_dot_product_attention( + q_d1_t, k_d1_t, v_d1_t + ).transpose(1, 2) + attn_nt = torch.nn.functional.scaled_dot_product_attention( + q_nt_t, k_nt_t, v_nt_t + ).transpose(1, 2) + + self.assertEqual( + attn_d1, + attn_nt.unbind()[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + + # Simple case: 2 sentences, no extra params + x_d2 = sen2.unsqueeze(0) + x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) + + # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before + # it is transposed. This is because today we cannot backward through view or unbind a # transposed tensor. q_d2 = ( query(x_d2) @@ -534,203 +6872,1114 @@ def _test_sdpa(self, device, dtype): .detach() .requires_grad_(True) ) - q_d2_t = q_d2.transpose(1, 2) - k_d2 = ( - key(x_d2) - .view(batch_size, -1, n_heads, head_dims) - .detach() - .requires_grad_(True) + q_d2_t = q_d2.transpose(1, 2) + k_d2 = ( + key(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + k_d2_t = k_d2.transpose(1, 2) + v_d2 = ( + value(x_d2) + .view(batch_size, -1, n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + v_d2_t = v_d2.transpose(1, 2) + + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + q_nt_t = q_nt.transpose(1, 2) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + k_nt_t = k_nt.transpose(1, 2) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .requires_grad_(True) + ) + v_nt_t = v_nt.transpose(1, 2) + + attn_d2 = torch.nn.functional.scaled_dot_product_attention( + q_d2_t, k_d2_t, v_d2_t + ).transpose(1, 2) + d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1)) + d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2)) + + # Simple case 3: batch_size = 1, seq_len = 1 + q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device) + q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged) + q_nt_3 = q_nt_3.transpose(1, 2) + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_nt_3, q_nt_3, q_nt_3 + ) + self.assertEqual(attn_out.shape, q_nt_3.shape) + + @parametrize("skip_backward", [True, False]) + def check_forward_backward(skip_backward=False): + if not skip_backward: + attn_nt = torch.nn.functional.scaled_dot_product_attention( + q_nt_t, k_nt_t, v_nt_t + ).transpose(1, 2) + else: + x_nt.requires_grad = False + q_nt.requires_grad = False + k_nt.requires_grad = False + v_nt.requires_grad = False + tq = q_nt_t.detach() + tk = k_nt_t.detach() + tv = v_nt_t.detach() + with torch.no_grad(): + attn_nt = torch.nn.functional.scaled_dot_product_attention( + tq, tk, tv + ).transpose(1, 2) + + attn_nts = attn_nt.unbind() + self.assertEqual( + attn_d1, + attn_nts[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + self.assertEqual( + attn_d2, + attn_nts[1].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + + if not skip_backward: + nt_grads = torch.autograd.grad( + attn_nt.values().sum(), (q_nt, k_nt, v_nt) + ) + for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( + nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols + ): + unbound_nt_grads = nt_grad.unbind() + self.assertEqual( + d1_grad, + unbound_nt_grads[0].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, + ) + self.assertEqual( + d2_grad, + unbound_nt_grads[1].unsqueeze(0), + atol=grad_atol, + rtol=grad_rtol, + ) + + # Default + check_forward_backward() + + # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=True, enable_math=True + ): + check_forward_backward() + + # Test math fallback + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): + # Math fallback doesn't work with bfloat16 on CUDA because + # "group_gemm_dispatch" not implemented for 'BFloat16' + if not (str(device).startswith("cuda") and dtype == torch.bfloat16): + check_forward_backward() + check_cudnn = os.getenv("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED", "0") == "1" + if ( + "cuda" in str(device) + and check_cudnn + and (dtype == torch.float16 or dtype == torch.bfloat16) + ): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ): + check_forward_backward() + + @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + # Guarding with sqrt() doesn't work on ROCm? + @skipCUDAIfRocm + @onlyOn(["cuda", "xpu"]) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) + def test_sdpa_compile(self, device, dtype): + batch_size = 1 + emb_dims = 1024 + n_heads = 8 + head_dims = emb_dims // n_heads + + sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) + sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) + + query = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + key = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + value = torch.nn.Linear( + emb_dims, emb_dims, bias=False, device=device, dtype=dtype + ) + + # Simplest case: 1 sentence, no batching + x_d1 = sen1.unsqueeze(0) + x_d2 = sen2.unsqueeze(0) + x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) + + q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) + + q_nt = ( + query(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) + k_nt = ( + key(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) + v_nt = ( + value(x_nt) + .view(*x_nt.size()[0:2], n_heads, head_dims) + .detach() + .transpose(1, 2) + ) + + # High Precision Math Reference + q_d1_f32 = q_d1.to(torch.float32) + k_d1_f32 = k_d1.to(torch.float32) + v_d1_f32 = v_d1.to(torch.float32) + out_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1_f32, k_d1_f32, v_d1_f32 + )[0] + # Low Precision Math Reference + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( + q_d1, k_d1, v_d1 + )[0] + output_ref_atol, output_ref_rtol = get_tolerances( + out_ref, out_lp_ref, fudge_factor=2 + ) + + attn_d1 = torch.nn.functional.scaled_dot_product_attention( + q_d1, k_d1, v_d1 + ).transpose(1, 2) + attn_d2 = torch.nn.functional.scaled_dot_product_attention( + q_d2, k_d2, v_d2 + ).transpose(1, 2) + + compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention) + attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2) + + attn_nts = attn_nt.unbind() + self.assertEqual( + attn_d1, + attn_nts[0].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + self.assertEqual( + attn_d2, + attn_nts[1].unsqueeze(0), + atol=output_ref_atol, + rtol=output_ref_rtol, + ) + + @dtypes(torch.float32, torch.double, torch.half) + def test_sdpa_with_constant_sequence_length(self, device, dtype): + # shape (B, P*, S, D) + # B: batch size + # P*: ragged number of prompts + # S: (constant) sequence length + # D: embedding size + query = random_nt_from_dims( + [4, None, 8, 10], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + key = random_nt_from_similar(query) + value = random_nt_from_similar(query) + output = F.scaled_dot_product_attention(query, key, value) + self.assertTrue(isinstance(output, NestedTensor)) + output.values().sum().backward() + + query_dense = query.detach().clone().requires_grad_(True) + # should be equivalent to just running the buffers through + output_dense = F.scaled_dot_product_attention( + query_dense.values(), key.values(), value.values() + ) + torch._dynamo.disable(self.assertEqual)(output._values, output_dense) + output_dense.sum().backward() + torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad) + + @onlyOn(["cuda", "xpu"]) + @unittest.skipIf( + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Platform doesn't support flash or mem-efficient attention", + ) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) + def test_sdpa_with_packed_in_proj(self, device, dtype): + # shape (B, *, D) + input_packed = random_nt_from_dims( + [5, None, 10], device=device, dtype=dtype, layout=torch.jagged + ) + + # Do input projection. + num_heads = 2 + # should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient) + head_dim = 8 + qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to( + device=device, dtype=dtype + ) + + def in_proj(input_packed, qkv_linear=qkv_linear): + qkv_post_proj = qkv_linear(input_packed) + # these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor() + q, k, v = qkv_post_proj.chunk(3, dim=-1) + q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) + k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) + v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3) + return q, k, v + + q, k, v = in_proj(input_packed) + output = F.scaled_dot_product_attention(q, k, v, attn_mask=None) + + # compare to individually running unbound components through + for in_component, out_component in zip( + input_packed.unbind(), output.transpose(-2, -3).unbind() + ): + q, k, v = in_proj(in_component) + out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3) + + # Low Precision Math Reference + out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[ + 0 + ].transpose(-2, -3) + output_ref_atol, output_ref_rtol = get_tolerances( + out, out_lp_ref, fudge_factor=2 + ) + + self.assertEqual( + out, out_component, atol=output_ref_atol, rtol=output_ref_rtol + ) + + @skipIfTorchDynamo("SDPA test compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + # mha_varlen_fwd not supported on ROCm + @skipCUDAIfRocm + @onlyOn(["cuda", "xpu"]) + @dtypes( + *( + [torch.float16, torch.bfloat16, torch.float32] + if SM80OrLater + else [torch.float16, torch.float32] + ) + ) + def test_sdpa_backwards(self, device, dtype): + values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype) + offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64) + + @torch.compile + def f(values, offsets): + nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) + nt = nt.transpose(-2, -3) + # purposefully graph break to trigger view replay for subclass view input + torch.tensor(1).item() + output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3) + return convert_nt_to_jagged(output) + + output = f(values, offsets) + output.sum().backward() + self.assertEqual(values.grad, torch.ones_like(values)) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Platform doesn't support flash or mem-efficient attention", + ) + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + @onlyOn(["cuda", "xpu"]) + @skipIfTorchDynamo() + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + def test_sdpa_autocast(self, device): + def fn_nt(values32, values16, offsets): + nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16) + nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16) + nt32 = nt32.transpose(1, 2) + nt16 = nt16.transpose(1, 2) + return F.scaled_dot_product_attention(nt32, nt16, nt32) + + def fn_dense(x32, x16): + x32 = x32.view(8, 16, 4, 16).transpose(1, 2) + x16 = x16.view(8, 16, 4, 16).transpose(1, 2) + return F.scaled_dot_product_attention(x32, x16, x32) + + values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32) + values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16) + offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) + + x32 = values32.clone() + x16 = values16.clone() + + with torch.autocast(device_type="cuda", dtype=torch.float16): + out_dense_eager = fn_dense(x32, x16) + out_dense_compiled = torch.compile(fn_dense)(x32, x16) + out_nt_eager = fn_nt(values32, values16, offsets) + out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets) + + self.assertEqual(out_dense_eager, out_dense_compiled) + self.assertEqual( + out_dense_eager.transpose(1, 2), + out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16), + ) + self.assertEqual( + out_dense_eager.transpose(1, 2), + out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16), ) - k_d2_t = k_d2.transpose(1, 2) - v_d2 = ( - value(x_d2) - .view(batch_size, -1, n_heads, head_dims) - .detach() - .requires_grad_(True) + + def get_values(): + return tuple( + x.detach().clone().requires_grad_(True) for x in (values32, values16) + ) + + v32_dense_eager, v16_dense_eager = get_values() + v32_dense_compile, v16_dense_compile = get_values() + v32_nt_eager, v16_nt_eager = get_values() + v32_nt_compile, v16_nt_compile = get_values() + + with torch.autocast(device_type="cuda", dtype=torch.float16): + loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum() + loss_dense_compile = torch.compile(fn_dense)( + v32_dense_compile, v16_dense_compile + ).sum() + loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum() + loss_nt_compile = ( + torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets) + .values() + .sum() + ) + + loss_dense_eager.backward() + loss_dense_compile.backward() + loss_nt_eager.backward() + loss_nt_compile.backward() + + self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad) + self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad, atol=1e-4, rtol=1e-4) + self.assertEqual( + v32_dense_eager.grad, v32_nt_compile.grad, atol=1e-4, rtol=1e-4 ) - v_d2_t = v_d2.transpose(1, 2) - q_nt = ( - query(x_nt) - .view(*x_nt.size()[0:2], n_heads, head_dims) - .detach() - .requires_grad_(True) + self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad) + self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad, atol=1e-5, rtol=5e-3) + self.assertEqual( + v16_dense_eager.grad, v16_nt_compile.grad, atol=1e-5, rtol=5e-3 ) - q_nt_t = q_nt.transpose(1, 2) - k_nt = ( - key(x_nt) - .view(*x_nt.size()[0:2], n_heads, head_dims) - .detach() - .requires_grad_(True) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FUSED_ATTENTION, + "Platform doesn't support flash or mem-efficient attention", + ) + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + @onlyOn(["cuda", "xpu"]) + @skipIfTorchDynamo() + def test_sdpa_flop_counter(self, device): + from torch.utils.flop_counter import FlopCounterMode + + def get_flops(nt): + flop_counter = FlopCounterMode(display=False) + with flop_counter: + ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt) + ret.values().sum().backward() + return flop_counter.get_total_flops() + + values = torch.randn( + (8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16 ) - k_nt_t = k_nt.transpose(1, 2) - v_nt = ( - value(x_nt) - .view(*x_nt.size()[0:2], n_heads, head_dims) - .detach() - .requires_grad_(True) + offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) + nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16).transpose( + 1, 2 ) - v_nt_t = v_nt.transpose(1, 2) - attn_d2 = torch.nn.functional.scaled_dot_product_attention( - q_d2_t, k_d2_t, v_d2_t + values_meta = torch.randn( + (8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16 + ) + offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32) + nt_meta = convert_jagged_to_nested_tensor( + values_meta, offsets_meta, max_length=16 ).transpose(1, 2) - d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1)) - d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2)) - # Simple case 3: batch_size = 1, seq_len = 1 - q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device) - q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged) - q_nt_3 = q_nt_3.transpose(1, 2) - attn_out = torch.nn.functional.scaled_dot_product_attention( - q_nt_3, q_nt_3, q_nt_3 + self.assertEqual(get_flops(nt), get_flops(nt_meta)) + + @skipIfTorchDynamo() + def test_nested_tensor_activation_checkpoint(self, device): + values = torch.randn( + 9, 3, 256, requires_grad=True, device=device, dtype=torch.float32 ) - self.assertEqual(attn_out.shape, q_nt_3.shape) + lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64) + offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0) - def check_forward_backward(): - attn_nt = torch.nn.functional.scaled_dot_product_attention( - q_nt_t, k_nt_t, v_nt_t - ).transpose(1, 2) + def fn(values, offsets): + nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) + return convert_nt_to_jagged(nt).sum() - attn_nts = attn_nt.unbind() - self.assertEqual( - attn_d1, - attn_nts[0].unsqueeze(0), - atol=output_ref_atol, - rtol=output_ref_rtol, + checkpoint(fn, values, offsets, use_reentrant=False).backward() + self.assertIsNotNone(values.grad) + + context_fn = partial( + create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default] + ) + + values.grad = None + + def fn(values, lengths): + offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0) + nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4) + return convert_nt_to_jagged(nt).sum() + + checkpoint( + fn, values, lengths, use_reentrant=False, context_fn=context_fn + ).backward() + self.assertIsNotNone(values.grad) + + # Internally-defined NT use cases are lifted to here for maximum test realism. + # TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated. + @skipCUDAIfRocm # not needed + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @parametrize("use_legacy_api", [True, False]) + @skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644") + @unittest.skipIf( + "RelWithAssert" in torch.__config__.show(), + "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context", + ) + def test_dummy_mha_with_nt(self, device, use_legacy_api): + bs = 3 + d1 = 2 + d2 = 4 + d3 = 16 + n_heads = 2 + d_head = d3 // n_heads + max_length_1 = 10 + max_length_2 = 20 + torch.manual_seed(0) + + class mha(torch.nn.Module): + def __init__(self, use_legacy_api) -> None: + super().__init__() + torch.manual_seed(0) + self.linear = torch.nn.Linear(d2, d3, device=device) + self.use_legacy_api = use_legacy_api + + def forward(self, query, value, offsets): + value = self.linear(value) + if self.use_legacy_api: + key = convert_jagged_to_nested_tensor_legacy( + value, offsets, max_length_1 + ) + value = convert_jagged_to_nested_tensor_legacy( + value, offsets, max_length_2 + ) + query = convert_dense_to_nested_tensor_legacy(query) + else: + key = convert_jagged_to_nested_tensor(value, offsets, max_length_1) + value = convert_jagged_to_nested_tensor( + value, offsets, max_length_2 + ) + query = convert_dense_to_nested_tensor(query) + q = query.view(bs, -1, n_heads, d_head).transpose(1, 2) + k = key.view(bs, -1, n_heads, d_head).transpose(1, 2) + v = value.view(bs, -1, n_heads, d_head).transpose(1, 2) + + with torch.nn.attention.sdpa_kernel( + [ + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + ] + ): + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + attn_output = attn_output.transpose(1, 2) + if self.use_legacy_api: + attn_output = convert_nt_to_jagged_legacy(attn_output) + else: + attn_output = convert_nt_to_jagged(attn_output) + return attn_output, key._max_seqlen, value._max_seqlen + + query = torch.rand(bs, d1, d3, device=device) + value = torch.rand(30, d2, requires_grad=True, device=device) + # total_length must > than max_length otherwise flash_attn backward will fail + offsets = torch.tensor([0, 2, 3, 30], device=device) + + m = mha(use_legacy_api) + symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m) + m = torch.compile(symbolic_traced) + attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m( + query, value, offsets + ) + loss = attn_output.sum() + # Check that NT can be fx traced and torch.compile, and backward works + loss.backward() + + # Check that value.requires_grad is not lost after tracing and compiling + value_grad = value.grad # save for comparison later + self.assertIsNotNone(value_grad) + # check that max_seqlen is cached properly + self.assertEqual(cached_key_max_seqlen, max_length_1) + self.assertEqual(cached_value_max_seqlen, max_length_2) + + # check if the output is numerically equivalent with the eager mode + m_eager = mha(use_legacy_api) + + value.grad = None + attn_output_eager, _, _ = m_eager(query, value, offsets) + attn_output_eager.sum().backward() + self.assertTrue(torch.allclose(attn_output_eager, attn_output)) + self.assertTrue(torch.allclose(value_grad, value.grad)) + + # Helper function to generate random query, key, value NJTs in (B, n_heads, *, D) format. + # If noncontig_with_holes is True, the results will be non-contiguous with holes (i.e. have + # both offsets and lengths specified). + def _rand_qkv(self, device, dtype, noncontig_with_holes=False, q_and_kv_match=True): + batch_size = 8 + n_heads = 8 + D = 16 + + def _rand_nt(noncontig_with_holes=noncontig_with_holes): + sentence_lengths = [random.randint(2, 1023) for _ in range(batch_size - 1)] + total = sum(sentence_lengths) + + # shape (B, *, D_total) where D_total = n_heads * D + nt = torch.nested.nested_tensor( + [ + torch.randn(l, n_heads * D, device=device, dtype=dtype) + for l in sentence_lengths + ], + layout=torch.jagged, ) - self.assertEqual( - attn_d2, - attn_nts[1].unsqueeze(0), - atol=output_ref_atol, - rtol=output_ref_rtol, + + if noncontig_with_holes: + nt = torch.nested.nested_tensor_from_jagged( + nt._values, + nt._offsets, + # -1 to introduce holes + lengths=nt._offsets.diff() - 1, + jagged_dim=nt._ragged_idx, + min_seqlen=nt._min_seqlen, + max_seqlen=nt._max_seqlen, + ) + + return nt + + query = _rand_nt() + if q_and_kv_match: + key = torch.randn_like(query) + value = torch.randn_like(query) + else: + key = _rand_nt() + value = torch.randn_like(key) + + # shape (B, *, D_total) -> (B, n_heads, *, D) + query = ( + query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() + ) + key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() + value = ( + value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() + ) + + return query, key, value + + @dtypes(torch.float32) + def test_apply_(self, device, dtype): + nt = random_nt_from_dims( + [5, None, 10], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + def f(x): + return x * 2 + + if device != "cpu": + with self.assertRaisesRegex( + TypeError, "apply_ is only implemented on CPU tensors" + ): + nt.apply_(f) + return + + before = nt._values.detach().clone() + + nt.apply_(f) + expected = f(before) + self.assertEqual(expected, nt._values) + # apply_ should swap values in-place without appending to autograd graph + self.assertIsNone(nt.grad) + self.assertIsNone(nt._values.grad_fn) + + @onlyOn(["cuda", "xpu"]) + @dtypes(torch.float64, torch.float32, torch.half) + @parametrize( + "contiguity", + ["noncontig_transposed", "noncontig_with_holes"], + name_fn=lambda c: c, + ) + def test_noncontiguous_to(self, device, dtype, contiguity): + # Dense tensors preserve non-contiguity through to() calls (i.e. strides are + # preserved). Test for the analogous behavior for NJTs: + # 1. non-contiguous transposed + # 2. non-contiguous with holes + if contiguity == "noncontig_transposed": + nt = random_nt_from_dims( + [3, None, 5, 2], + device=device, + dtype=dtype, + layout=torch.jagged, + ).transpose(-3, -2) + elif contiguity == "noncontig_with_holes": + nt = torch.nested.nested_tensor_from_jagged( + values=torch.randn(10, 3, device=device, dtype=dtype), + offsets=torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int64), + # these lengths specify holes + lengths=torch.tensor([1, 2, 3], device=device, dtype=torch.int64), ) + else: + raise ValueError("invalid contiguity specified for test_noncontiguous_to()") + + # test dtype conversion + dtype_conversions = { + torch.float32: torch.half, + torch.float64: torch.float32, + torch.half: torch.float32, + } + other_dtype = dtype_conversions[dtype] + nt2 = nt.to(dtype=other_dtype) + self.assertEqual(nt2.dtype, other_dtype) + self.assertEqual(nt.is_contiguous(), nt2.is_contiguous()) + self.assertEqual(nt._values.is_contiguous(), nt2._values.is_contiguous()) + self.assertEqual(nt.shape, nt2.shape) + # expect no change for offsets / lengths + self.assertEqual(nt._offsets, nt2._offsets) + self.assertEqual(nt._lengths, nt2._lengths) + + # test device conversion + other_device = torch.device("cpu") + nt3 = nt.to(device=other_device) + self.assertEqual(nt3.device, other_device) + self.assertEqual(nt.is_contiguous(), nt3.is_contiguous()) + self.assertEqual(nt._values.is_contiguous(), nt3._values.is_contiguous()) + self.assertEqual(nt.shape, nt3.shape) + # expect device change for offsets / lengths + self.assertEqual(nt3._offsets.device, other_device) + if nt._lengths is not None: + self.assertEqual(nt3._lengths.device, other_device) + + @dtypes(torch.float32) + def test_autograd_function_with_None_grad(self, device, dtype): + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + ctx.save_for_backward(inp) + out1 = inp + 1 + out2 = inp * 2 + return out1, out2 + + @staticmethod + def backward(ctx, grad_out1, grad_out2): + (inp,) = ctx.saved_tensors + return grad_out1 + grad_out2 + + f = MyFunction.apply + nt = random_nt_from_dims( + [5, None, 10], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # Only use one of the autograd.Function outputs downstream so that the grad + # for the other output is None. We're testing that the engine can allocate + # correctly-shaped (NJT) zeros for the grad of the other output in this case. + (out1, _) = f(nt) + out1.backward(torch.ones_like(out1)) + + @dtypes(torch.float64, torch.float32, torch.half) + def test_jagged_padded_dense_conversion_kernels(self, device, dtype): + values = torch.randn(10, 5, device=device, dtype=dtype) + offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64) + max_length = offsets.diff().max().item() + padding_value = 1.3 + + # convert jagged -> padded dense + padded = torch.ops.aten._jagged_to_padded_dense_forward( + values, [offsets], [max_length], padding_value + ) + + batch_size = offsets.shape[0] - 1 + expected_padded_shape = (batch_size, max_length, values.shape[-1]) + self.assertEqual(padded.shape, expected_padded_shape) - nt_grads = torch.autograd.grad(attn_nt.values().sum(), (q_nt, k_nt, v_nt)) - for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip( - nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols + # convert padded dense -> jagged + total_L = values.shape[0] + output_jagged = torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets], total_L + ) + + # should be equivalent to the original values + self.assertEqual(values, output_jagged) + + # success case: truncate to max length as needed + trunc_max_length = max_length - 1 + trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward( + values, [offsets], [trunc_max_length], padding_value + ) + self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded) + + # specific to CPU impls + if device == "cpu": + # error case: multiple offsets on cpu since CPU kernels don't support more now + with self.assertRaisesRegex( + RuntimeError, "only a single jagged dim is supported" ): - unbound_nt_grads = nt_grad.unbind() - self.assertEqual( - d1_grad, - unbound_nt_grads[0].unsqueeze(0), - atol=grad_atol, - rtol=grad_rtol, + torch.ops.aten._jagged_to_padded_dense_forward( + values, [offsets, offsets], [max_length, max_length], padding_value ) - self.assertEqual( - d2_grad, - unbound_nt_grads[1].unsqueeze(0), - atol=grad_atol, - rtol=grad_rtol, + + with self.assertRaisesRegex( + RuntimeError, "only a single jagged dim is supported" + ): + torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets, offsets], total_L ) - # Default - check_forward_backward() + # error case: > 1D offsets + offsets2d = offsets.unsqueeze(-1) + with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"): + torch.ops.aten._jagged_to_padded_dense_forward( + values, [offsets2d], [max_length], padding_value + ) - # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) - with torch.backends.xpu.sdp_kernel( - enable_flash=False, enable_mem_efficient=True, enable_math=True - ): - check_forward_backward() + with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"): + torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets2d], total_L + ) + + # error case: final offset != total_L + offsets_wrong = offsets.detach().clone() + offsets_wrong[-1] = total_L + 1 + with self.assertRaisesRegex( + RuntimeError, "final offset should match total_L value" + ): + torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets_wrong], total_L + ) + + # error case: 1D padded input + padded_wrong = padded.flatten().detach().clone() + with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"): + torch.ops.aten._padded_dense_to_jagged_forward( + padded_wrong, [offsets], total_L + ) + + # error case: batch item has length > max length + # max_length is 5 above; 7 here + offsets_wrong = torch.tensor( + [0, 1, 8, 9, 10], device=device, dtype=torch.int64 + ) + with self.assertRaisesRegex(RuntimeError, "found batch item of length"): + torch.ops.aten._padded_dense_to_jagged_forward( + padded, [offsets_wrong], total_L + ) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_preserves_metadata_cache(self, device, dtype): + # shape (B, *, D) + nt = random_nt_from_dims( + [4, None, 3, 16], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # expect min / max seqlen to be stored here + cache = dict(nt._metadata_cache) + + @torch.compile + def f(nt): + q = nt.transpose(-3, -2) + output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2) + return output + + output = f(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(output._metadata_cache, cache) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_with_dynamic_max_seq_len(self, device, dtype): + # shape (B, *, D) + # max seq len: 18 + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(18, 5), + ], + layout=torch.jagged, + ) + + # max seq len: 19 + nt2 = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(19, 5), + ], + layout=torch.jagged, + ) + + def f(nt): + # TODO: Replace with public API when we can use @properties + return torch.ones_like(nt) * nt._get_max_seqlen() + + for dynamic in [False, True, None]: + self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_with_dynamic_min_seq_len(self, device, dtype): + # shape (B, *, D) + # min seq len: 7 + nt = torch.nested.nested_tensor( + [ + torch.randn(7, 5), + torch.randn(8, 5), + torch.randn(9, 5), + ], + layout=torch.jagged, + ) + + # min seq len: 8 + nt2 = torch.nested.nested_tensor( + [ + torch.randn(8, 5), + torch.randn(9, 5), + torch.randn(10, 5), + ], + layout=torch.jagged, + ) + + def f(nt): + # TODO: Replace with public API when we can use @properties + return torch.ones_like(nt) * nt._get_min_seqlen() - # Test math fallback - with torch.backends.xpu.sdp_kernel( - enable_flash=False, enable_mem_efficient=False, enable_math=True - ): - # Math fallback doesn't work with bfloat16 on xpu because - # "group_gemm_dispatch" not implemented for 'BFloat16' - if not (str(device).startswith("xpu") and dtype == torch.bfloat16): - check_forward_backward() + for dynamic in [False, True, None]: + self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") @unittest.skipIf( - not PLATFORM_SUPPORTS_FUSED_ATTENTION, - "Platform doesn't support flash or mem-efficient attention", + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" ) - @skipIfTorchDynamo() @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") - def _test_sdpa_autocast(self, device): - def fn_nt(values32, values16, offsets): - nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16) - nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16) - nt32 = nt32.transpose(1, 2) - nt16 = nt16.transpose(1, 2) - return F.scaled_dot_product_attention(nt32, nt16, nt32) + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype): + # shape (B, *, D) + # max seq len: 18 + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(18, 5), + ], + layout=torch.jagged, + ) - def fn_dense(x32, x16): - x32 = x32.view(8, 16, 4, 16).transpose(1, 2) - x16 = x16.view(8, 16, 4, 16).transpose(1, 2) - return F.scaled_dot_product_attention(x32, x16, x32) + # max seq len: 19 + nt2 = torch.nested.nested_tensor( + [ + torch.randn(2, 5), + torch.randn(3, 5), + torch.randn(19, 5), + ], + layout=torch.jagged, + ) - values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32) - values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16) - offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32) + def f(nt): + nt2 = nt.sin() + 1 + # TODO: Replace with public API when we can use @properties + return torch.ones_like(nt2) * nt2._get_max_seqlen() - x32 = values32.clone() - x16 = values16.clone() + ref = f(nt) + output = torch.compile(f, fullgraph=True, dynamic=False)(nt) + self.assertEqual(ref, output) - with torch.autocast(device_type="xpu", dtype=torch.float16): - out_dense_eager = fn_dense(x32, x16) - out_dense_compiled = torch.compile(fn_dense)(x32, x16) - out_nt_eager = fn_nt(values32, values16, offsets) - out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets) + for dynamic in [False, True, None]: + self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic)) - self.assertEqual(out_dense_eager, out_dense_compiled) - self.assertEqual( - out_dense_eager.transpose(1, 2), - out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16), + def test_dropout_inference_mode(self, device): + seq_len = 32 + embed_dim = 128 + + nt = torch.nested.nested_tensor( + [ + torch.randn(11, seq_len, embed_dim, device=device), + torch.randn(11, seq_len, embed_dim, device=device), + ], + layout=torch.jagged, + device=device, ) - self.assertEqual( - out_dense_eager.transpose(1, 2), - out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16), + + with torch.inference_mode(): + torch.nn.functional.dropout(nt, p=0.05) + + @dtypes(torch.float32, torch.double, torch.half) + def test_unbind_backward(self, device, dtype): + nt = torch.nested.nested_tensor( + [ + torch.randn(2, 4, device=device), + torch.randn(5, 4, device=device), + torch.randn(3, 4, device=device), + ], + layout=torch.jagged, + requires_grad=True, ) - def get_values(): - return tuple( - x.detach().clone().requires_grad_(True) for x in (values32, values16) - ) + a, b, c = nt.unbind() + b.sum().backward() - v32_dense_eager, v16_dense_eager = get_values() - v32_dense_compile, v16_dense_compile = get_values() - v32_nt_eager, v16_nt_eager = get_values() - v32_nt_compile, v16_nt_compile = get_values() + @torch._dynamo.disable + def check(nt): + expected_grad = torch.zeros_like(nt) + expected_grad.unbind()[1].add_(1.0) + self.assertEqual(nt.grad, expected_grad) - with torch.autocast(device_type="xpu", dtype=torch.float16): - loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum() - loss_dense_compile = torch.compile(fn_dense)( - v32_dense_compile, v16_dense_compile - ).sum() - loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum() - loss_nt_compile = ( - torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets) - .values() - .sum() - ) + check(nt) - loss_dense_eager.backward() - loss_dense_compile.backward() - loss_nt_eager.backward() - loss_nt_compile.backward() + @dtypes(torch.float32, torch.double, torch.half, torch.bool) + @parametrize("nt_dim", [2, 3, 4]) + @parametrize("requires_grad", [False, True]) + def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad): + if dtype is torch.bool and requires_grad: + # grads not supported for bool + return - self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad) - self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad, atol=1e-4, rtol=1e-4) - self.assertEqual( - v32_dense_eager.grad, v32_nt_compile.grad, atol=1e-4, rtol=1e-4 - ) + if nt_dim == 2: + post_seq_len_shape = () + elif nt_dim == 3: + post_seq_len_shape = (10,) + elif nt_dim == 4: + post_seq_len_shape = (9, 10) - self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad) - self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad, atol=1e-5, rtol=5e-3) - self.assertEqual( - v16_dense_eager.grad, v16_nt_compile.grad, atol=1e-5, rtol=5e-3 + nt = torch.nested.nested_tensor( + [ + ( + torch.randint( + 2, (n, *post_seq_len_shape), device=device, dtype=dtype + ) + if dtype is torch.bool + else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + ) + for n in range(2, 9) + ], + layout=torch.jagged, + requires_grad=requires_grad, ) + PADDING_VAL = 4.2 + expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL) + for i, component in enumerate(nt.unbind()): + expected_padded[i, : component.shape[0]].copy_(component) + + padded = nt.to_padded_tensor(PADDING_VAL) + self.assertEqual(expected_padded, padded) + + # convert padded dense -> NJT + from torch.nested._internal.nested_tensor import nested_from_padded + + nt2 = nested_from_padded(padded, nt.offsets()) + self.assertEqual(nt, nt2) + + if requires_grad and dtype is not torch.bool: + # ensure gradients flow through conversions + nt2.backward(torch.ones_like(nt2)) + self.assertEqual(nt.grad, torch.ones_like(nt)) + # blows up due to test parametrization otherwise @torch._dynamo.utils.disable_cache_limit() @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm @dtypes(torch.float32, torch.double, torch.half) @parametrize("nt_dim", [2, 3, 4]) @parametrize("requires_grad", [False, True]) - def _test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad): + def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad): if dtype is torch.bool and requires_grad: # grads not supported for bool return @@ -744,9 +7993,13 @@ def _test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad): nt = torch.nested.nested_tensor( [ - torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype) - if dtype is torch.bool - else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + ( + torch.randint( + 2, (n, *post_seq_len_shape), device=device, dtype=dtype + ) + if dtype is torch.bool + else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype) + ) for n in range(2, 9) ], layout=torch.jagged, @@ -800,7 +8053,7 @@ def _g(nt): ) # NB: Fusion isn't supported on CPU. - self.assertEqual("xpu" in device, not fallback_op_calls_present) + self.assertEqual("cuda" in device, not fallback_op_calls_present) for i in range(len(generated_code)): # Examine buffer construction lines in the generated code to determine @@ -809,7 +8062,7 @@ def _g(nt): buffer_constructions = [ line.strip() for line in generated_code[i].split("\n") - if "empty_strided_xpu(" in line + if "empty_strided_cuda(" in line ] buffer_dims = [ @@ -818,40 +8071,1066 @@ def _g(nt): for t in buffer_constructions ] - if "xpu" in device: + if "cuda" in device: self.assertFalse(any(d == 3 for d in buffer_dims)) - TestNestedTensor.test_to = _test_to - TestNestedTensor.test_copy_ = _test_copy_ - TestNestedTensorDeviceType.test_device_checks = _test_device_checks - TestNestedTensorDeviceType.test_empty_like = _test_empty_like - TestNestedTensorSubclass.test_linear_backward_memory_usage = ( - _test_linear_backward_memory_usage + @dtypes(torch.float32) + @skipIfTorchDynamo("Test compiles internally") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @skipCUDAIfRocm + def test_compile_padded_dense_conversion_preserves_metadata_cache( + self, device, dtype + ): + # shape (B, *, D) + nt = random_nt_from_dims( + [4, None, 3, 16], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # expect min / max seqlen to be stored here + cache = dict(nt._metadata_cache) + + @torch.compile + def g(nt): + padded = nt.to_padded_tensor(0.3) + intermediate = padded.sin() + 1 + + from torch.nested._internal.nested_tensor import nested_from_padded + + return nested_from_padded( + intermediate, + nt.offsets(), + min_seqlen=nt._min_seqlen, + max_seqlen=nt._max_seqlen, + sum_S=nt.values().shape[0], + ) + + output = g(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(output._metadata_cache, cache) + + # See https://github.com/pytorch/pytorch/issues/128649 + @dtypes(torch.float32) + def test_composite_op_in_inference_mode(self, device, dtype): + # expect view + nt = random_nt_from_dims( + [4, None, 48], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + with torch.inference_mode(): + output = nt.reshape([4, -1, 3, 16]) + self.assertEqual(output.shape, (4, nt.shape[1], 3, 16)) + self.assertTrue(output._is_view()) + + # expect copy + nt = random_nt_from_dims( + [4, None, 3, 16], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ).transpose(-1, -2) + + with torch.inference_mode(): + output = nt.reshape([4, -1, 48]) + self.assertEqual(output.shape, (4, nt.shape[1], 48)) + self.assertFalse(output._is_view()) + + @dtypes(torch.float32) + def test_composite_op_with_custom_mode(self, device, dtype): + from torch.utils._python_dispatch import TorchDispatchMode + + # simple passthrough TorchDispatchMode + class CustomDispatchMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + return func(*args, **kwargs) + + nt = random_nt_from_dims( + [4, None, 2, 3], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + with CustomDispatchMode(): + res = nt.reshape(4, -1, 6) + + self.assertEqual(res.shape, (4, nt.shape[1], 6)) + + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + @dtypes(torch.float32) + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_broadcast_shapes_on_in_graph_constructed_njt(self, device, dtype): + # Tests that a guard isn't wrongly installed on a freshly-created nested int when + # broadcast_shapes() is used on NJT shapes. + # See https://github.com/pytorch/pytorch/issues/145874 for more context. + nt = torch.nested.nested_tensor( + [ + torch.randn(2), + torch.randn(3), + torch.randn(4), + ], + layout=torch.jagged, + device=device, + dtype=dtype, + ) + + values = nt._values.detach().clone() + offsets = nt._offsets.detach().clone() + + @torch.compile(fullgraph=True) + def f(values, offsets): + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + # NB: torch.where() utilizes broadcast_shapes() underneath + return torch.where(nt > 0.0, torch.ones_like(nt), torch.zeros_like(nt)) + + output = f(values, offsets) + self.assertTrue(output.is_nested) + self.assertEqual(nt.shape[:-1], output.shape[:-1]) + for nt_component, output_component in zip(nt.unbind(), output.unbind()): + self.assertEqual(nt_component.shape, output_component.shape) + + +# The following lists specify skips and xfails for particular SampleInputs. Note that +# these are attempted to be matched from top to bottom and only one at most will +# be matched, so order matters! The guiding general principle here should be one +# xfail / skip per bug if at all possible :) +FORWARD_SKIPS_AND_XFAILS = [ + # not implemented + XFailRule( + error_type=NotImplementedError, + op_match_fn=lambda device, op: op.full_name + in { + # unary + # needs log_sigmoid_forward, which returns a tuple + "nn.functional.logsigmoid", + "nn.functional.prelu", + # needs rrelu_with_noise + "nn.functional.rrelu", + # binary + "__rsub__", + "complex", + "floor_divide", + "polar", + "rsub", + # reduction + "count_nonzero", + "linalg.vector_norm", + "nansum", + "std", + "std.unbiased", + "var", + "var.unbiased", + "hash_tensor", + }, + name="not_implemented", + ), + # expected: torch.where() support has some limitations + # 1. condition must be an NJT + # 2. no dense tensors of higher dim than the NJT + XFailRule( + error_type=ValueError, + error_msg="expected condition to be a jagged layout NestedTensor", + op_match_fn=lambda device, op: op.full_name == "where", + sample_match_fn=lambda device, sample: not sample.kwargs["condition"].is_nested, + ), + XFailRule( + error_type=ValueError, + error_msg="broadcasting nested tensors with dense tensors of equal or higher dim", + op_match_fn=lambda device, op: op.full_name == "where", + sample_match_fn=lambda device, sample: ( + ( + not sample.input.is_nested + and sample.input.dim() >= sample.kwargs["condition"].dim() + ) + or ( + not sample.kwargs["other"].is_nested + and sample.kwargs["other"].dim() >= sample.kwargs["condition"].dim() + ) + ), + ), + # expected: masked ops don't support jagged layout + XFailRule( + error_type=ValueError, + error_msg="expects strided", + op_match_fn=lambda device, op: op.full_name + in { + "masked.amax", + "masked.amin", + "masked.argmax", + "masked.argmin", + "masked.logsumexp", + "masked.mean", + "masked.norm", + "masked.prod", + "masked.std", + "masked.sum", + "masked.var", + }, + name="no_masked_jagged_support", + ), + # Op doesn't support lengths being present + XFailRule( + error_type=ValueError, + error_msg="expected input to be a contiguous jagged layout NestedTensor", + op_match_fn=lambda device, op: (op.full_name == "nn.functional.linear"), + sample_match_fn=lambda device, sample: (sample.input._lengths is not None), + name="no_linear_noncontig_holes_support", + ), + # nanmean sometimes hits an unimplemented nansum() path and other times hits an + # unimplemented sum() path + XFailRule( + error_type=NotImplementedError, + op_match_fn=lambda device, op: (op.full_name == "nanmean"), + sample_match_fn=lambda device, sample: ( + not ( + "noncontig_holes" in sample.name + and "dim" in sample.kwargs + and ( + ( + isinstance(sample.kwargs["dim"], int) + and sample.kwargs["dim"] == sample.input._ragged_idx + ) + or ( + isinstance(sample.kwargs["dim"], (tuple, list)) + and sample.input._ragged_idx in sample.kwargs["dim"] + ) + ) + ) + ), + name="nansum_unimplemented", + ), + # expected: reducing across the ragged dimension is not supported for non-contiguous + # nested tensors with holes + XFailRule( + error_type=RuntimeError, + error_msg=( + "reducing across the ragged dimension is not supported for non-contiguous " + "nested tensors with holes" + ), + op_match_fn=lambda device, op: ( + # min.reduction_with_dim and max.reduction_with_dim aren't associated with + # ReductionOpInfo entries sadly even though they're reductions + isinstance(op, ReductionOpInfo) + or "reduction_with_dim" in op.full_name + ), + sample_match_fn=lambda device, sample: ( + "noncontig_holes" in sample.name + and "dim" in sample.kwargs + and ( + ( + isinstance(sample.kwargs["dim"], int) + and sample.kwargs["dim"] == sample.input._ragged_idx + ) + or ( + isinstance(sample.kwargs["dim"], (tuple, list)) + and sample.input._ragged_idx in sample.kwargs["dim"] + ) + ) + ), + name="ragged_dim_reduction_noncontig_holes", + ), + # expected: index_put() doesn't work on non-contiguous NJTs without ragged dimension indices + XFailRule( + error_type=RuntimeError, + error_msg="If ragged dimension is not part of indices, this only works on contiguous NJTs", + op_match_fn=lambda device, op: (op.full_name == "index_put"), + sample_match_fn=lambda device, sample: ( + not sample.input.is_contiguous() + and len(sample.kwargs["indices"]) - 1 < sample.input._ragged_idx + ), + name="index_put_noncontig_holes_no_ragged_dim_indices", + ), + # select() only supports dim=0 for non-contiguous with holes NJTs for now + XFailRule( + op_match_fn=lambda device, op: (op.full_name == "select"), + sample_match_fn=lambda device, sample: ( + sample.kwargs["dim"] != 0 and "noncontig_holes" in sample.name + ), + name="unsupported_select_on_non_batch_dim_with_noncontig_holes", + ), + # these don't work on non-contiguous NJTs yet + XFailRule( + error_type=ValueError, + error_msg="expected self to be a contiguous jagged layout NestedTensor", + op_match_fn=lambda device, op: ( + op.full_name + in { + "chunk", + "masked_select", + "narrow", + "split", + "split_with_sizes", + "squeeze", + } + ), + sample_match_fn=lambda device, sample: ( + sample.input._lengths is not None or sample.input._ragged_idx != 1 + ), + name="missing_noncontig_support", + ), + # these don't work on the ragged dim yet + XFailRule( + error_type=RuntimeError, + error_msg="not supported for NestedTensor on ragged dim", + op_match_fn=lambda device, op: ( + op.full_name + in { + "chunk", + "narrow", + "select", + "split", + } + ), + sample_match_fn=lambda device, sample: "ragged_dim" in sample.name, + name="ragged_dim_unsupported", + ), + XFailRule( + error_type=RuntimeError, + # error comes from usage of view() in the decomp + error_msg="does not support ragged_idx != 1 except when", + op_match_fn=lambda device, op: (op.full_name == "unflatten"), + sample_match_fn=lambda device, sample: "noncontig_transposed" in sample.name, + name="unflatten_ragged_dim_unsupported", + ), + # these don't work on the batch dim yet + XFailRule( + error_type=RuntimeError, + error_msg="not supported for NestedTensor on dim=0", + op_match_fn=lambda device, op: ( + op.full_name + in { + "narrow", + "split", + "split_with_sizes", + "unsqueeze", + } + ), + sample_match_fn=lambda device, sample: "batch_dim" in sample.name, + name="batch_dim_unsupported", + ), + XFailRule( + error_type=RuntimeError, + # error comes from usage of view() in the decomp + error_msg="cannot view shape", + op_match_fn=lambda device, op: (op.full_name == "unflatten"), + sample_match_fn=lambda device, sample: "batch_dim" in sample.name, + name="unflatten_batch_dim_unsupported", + ), + # expected: bmm / matmul sometimes use a to_padded_tensor() fallback which isn't + # supported for non-contig NJTs with holes + XFailRule( + error_type=RuntimeError, + error_msg="not supported for nested tensors with holes", + op_match_fn=lambda device, op: (op.full_name in {"bmm", "matmul"}), + sample_match_fn=lambda device, sample: ( + "noncontig_holes" in sample.name + # "other" is the name for the matmul arg and "mat2" is the name for the bmm arg + and sample.input.dim() + == sample.kwargs.get("other", sample.kwargs.get("mat2")).dim() + ), + name="mm_noncontig_holes", + ), + # some jiterator op failures due to unsupported jagged layout + XFailRule( + error_type=RuntimeError, + error_msg="unsupported tensor layout", + op_match_fn=lambda device, op: op.full_name + in { + "jiterator_binary", + "jiterator_binary_return_by_ref", + "jiterator_unary", + }, + name="no_jiterator_jagged_support", + ), + # Bug when broadcasting a binary op with non-contiguous with holes NJT + dense + # tensor with 1 in ragged dim. + XFailRule( + error_type=RuntimeError, + error_msg="cannot call binary pointwise function .* with inputs of shapes", + op_match_fn=lambda device, op: (isinstance(op, BinaryUfuncInfo)), + sample_match_fn=lambda device, sample: ( + "noncontig_holes" in sample.name + and "broadcasting 1 over ragged" in sample.name + ), + name="binary_noncontig_holes_broadcasting_1_over_ragged", + ), +] + +BACKWARD_SKIPS_AND_XFAILS = [ + # segfaults, so skip. It's trying to use the NST logic for NJT + SkipRule( + op_match_fn=lambda device, op: op.full_name == "split_with_sizes", + name="split_with_sizes_backward_segfault", + ), + *FORWARD_SKIPS_AND_XFAILS, + # Backwards is generally broken for non-contiguous NJTs with holes. Rather than + # determine the exceptions in detail, just skip for now. Fix is to ensure + # that summing over gradients during backwards after broadcasting takes into + # account holes / lengths. + SkipRule( + op_match_fn=lambda device, op: ( + isinstance(op, BinaryUfuncInfo) + or op.full_name in {"mean", "where", "unsqueeze"} + ), + sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name), + name="broken_noncontig_holes_backward", + ), + # mean(): need to examine backwards formula + XFailRule( + error_type=RuntimeError, + error_msg="SymIntArrayRef expected to contain only concrete integers", + op_match_fn=lambda device, op: (op.full_name in {"mean"}), + sample_match_fn=lambda device, sample: ( + "full reduction" not in sample.name + and "normal dim reduction" not in sample.name + ), + name="broken_mean_backward", + ), + # RuntimeError: expand(): cannot expand shape (3, 3, 1, j44) -> [3, 3, 7, j44] + # with noncontig transposed inputs to mean() + XFailRule( + error_type=RuntimeError, + error_msg="cannot expand shape", + op_match_fn=lambda device, op: (op.full_name == "mean"), + sample_match_fn=lambda device, sample: ( + "normal dim reduction" in sample.name + and "noncontig_transposed" in sample.name + ), + name="broken_mean_backward2", + ), + # unsqueeze() backward tries to call squeeze with noncontig transposed, + # but that's not supported + XFailRule( + error_type=ValueError, + error_msg="expected self to be a contiguous jagged layout NestedTensor", + op_match_fn=lambda device, op: (op.full_name == "unsqueeze"), + sample_match_fn=lambda device, sample: ( + "noncontig_transposed" in sample.name or "ragged_dim" in sample.name + ), + name="broken_unsqueeze_backward", + ), + # RuntimeError: view(): cannot view shape (3, j62, 1, 7, 3) as [3, j58, 7, 3] + # with unflatten() + XFailRule( + error_type=RuntimeError, + error_msg="cannot view shape", + op_match_fn=lambda device, op: (op.full_name in {"unflatten"}), + sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name), + name="broken_unflatten_backward", + ), + # sum() backward is not implemented for non-full reductions + XFailRule( + error_type=NotImplementedError, + error_msg="aten._nested_sum_backward.default", + op_match_fn=lambda device, op: (op.full_name == "sum"), + sample_match_fn=lambda device, sample: ("full reduction" not in sample.name), + name="broken_sum_backward", + ), + # squeeze(): invalid gradient shape; need to check formula + XFailRule( + error_type=RuntimeError, + error_msg="returned an invalid gradient at index 0", + op_match_fn=lambda device, op: (op.full_name == "squeeze"), + sample_match_fn=lambda device, sample: ( + sample.name == "5D_contig_with_seqlen_cache: normal_dim" + and sample.kwargs["dim"] == 3 + ), + name="broken_squeeze_backward", + ), + # sgn() / masked_select(): backwards formulas don't work at all + XFailRule( + error_type=RuntimeError, + error_msg="NestedTensor does not support directly calling torch.ops.aten.size", + op_match_fn=lambda device, op: (op.full_name in {"sgn", "masked_select"}), + name="broken_sgn_masked_select_backward", + ), + # select(): grad_output is an NJT for non-batch-dim operation + XFailRule( + error_type=ValueError, + error_msg="expected grad_output to be a tensor", + op_match_fn=lambda device, op: (op.full_name == "select"), + sample_match_fn=lambda device, sample: ("batch_dim" not in sample.name), + name="broken_select_backward", + ), + # prod(): completely broken in every way + XFailRule( + op_match_fn=lambda device, op: (op.full_name == "prod"), + name="broken_prod_backward", + ), + # pow() / float_power(): use where() underneath; broken for (NT, T) broadcasting cases + XFailRule( + error_type=ValueError, + error_msg="expected condition to be a jagged layout NestedTensor", + op_match_fn=lambda device, op: (op.full_name in {"pow", "float_power"}), + sample_match_fn=lambda device, sample: ("(NT, T)" in sample.name), + name="broken_pow_backward", + ), + # __rpow__() backward is also broken, but for the reverse (T, NT) broadcasting cases + XFailRule( + error_type=ValueError, + error_msg="expected condition to be a jagged layout NestedTensor", + op_match_fn=lambda device, op: (op.full_name == "__rpow__"), + sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name), + name="broken_rpow_backward", + ), + # linear(): some formula problem when bias is used; seems to be platform-specific + # (fails locally but not in CI) + SkipRule( + # result2.use_count() <= 1 INTERNAL ASSERT FAILED + op_match_fn=lambda device, op: (op.full_name == "nn.functional.linear"), + sample_match_fn=lambda device, sample: ("with bias" in sample.name), + name="broken_linear_backward", + ), + # narrow(): unimplemented backward + XFailRule( + error_type=RuntimeError, + error_msg="derivative for aten::narrow is not implemented", + op_match_fn=lambda device, op: (op.full_name == "narrow"), + name="broken_narrow_backward", + ), + # min / max: need factory function support for ragged dim reductions + # where the output is dense but sizes still contain a nested int + XFailRule( + error_type=RuntimeError, + error_msg="SymIntArrayRef expected to contain only concrete integers", + op_match_fn=lambda device, op: ( + op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"} + ), + sample_match_fn=lambda device, sample: ("ragged dim" in sample.name), + name="broken_min_max_reduction_with_dim_backward_on_ragged_dim", + ), + # copysign(): formula is broken for (T, NT) broadcasting + XFailRule( + error_type=RuntimeError, + error_msg="SymIntArrayRef expected to contain only concrete integers", + op_match_fn=lambda device, op: (op.full_name == "copysign"), + sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name), + name="broken_copysign_backward", + ), + # amin() / amax(): broken in a host of ways I don't think it's a good use of time + # to try to sift through + SkipRule( + op_match_fn=lambda device, op: (op.full_name in {"amin", "amax"}), + name="broken_amin_amax_backward", + ), + XFailRule( + error_type=RuntimeError, + error_msg="reducing across the ragged dimension is not supported for non-contiguous", + op_match_fn=lambda device, op: ( + isinstance(op, BinaryUfuncInfo) + # doesn't happen for these ops for some reason + and op.full_name + not in {"copysign", "max.binary", "maximum", "min.binary", "minimum"} + ), + sample_match_fn=lambda device, sample: ( + "(NT, T) broadcasting all 1s" in sample.name + and "noncontig_holes" in sample.name + ), + name="binary_noncontig_holes_ragged_dim_reduction", + ), + XFailRule( + error_type=RuntimeError, + error_msg="reducing across the ragged dimension is not supported for non-contiguous", + op_match_fn=lambda device, op: (op.full_name == "nn.functional.rms_norm"), + sample_match_fn=lambda device, sample: (sample.input._lengths is not None), + name="rms_norm_noncontig_holes_ragged_dim_reduction", + ), + # expected: autodiff on complex dtype is not supported + XFailRule( + error_type=RuntimeError, + error_msg=( + "_nested_view_from_jagged does not support automatic differentiation " + "for outputs with complex dtype" + ), + op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}), + name="no_complex_autodiff", + ), + # Bug: need to use the correct nested int in the return shape + XFailRule( + error_type=RuntimeError, + error_msg="Function CloneBackward0 returned an invalid gradient", + op_match_fn=lambda device, op: (op.full_name == "clone"), + sample_match_fn=lambda device, sample: ( + sample.kwargs.get("memory_format", None) == torch.contiguous_format + ), + name="clone_wrong_nested_int_for_gradient", + ), + # some min / max ops use masked_fill_ underneath sometimes, which isn't implemented + XFailRule( + error_type=NotImplementedError, + error_msg="aten.masked_fill_.Scalar", + op_match_fn=lambda device, op: ( + op.full_name + in {"max.binary", "min.binary", "minimum", "maximum", "copysign"} + ), + name="unimplemented_masked_fill", + ), +] + +COMPILE_FORWARD_SKIPS_AND_XFAILS = [ + *FORWARD_SKIPS_AND_XFAILS, + # Bug: cross-device conversions with to() result in new nested ints within compile only + XFailRule( + error_type=AssertionError, + error_msg="The values for attribute 'shape' do not match", + op_match_fn=lambda device, op: (op.full_name == "to"), + sample_match_fn=lambda device, sample: ("-> cpu" in sample.name), + name="cross_device_transfer_wrong_nested_int_in_compile", + ), + # clone() -> preserve format on an non-contiguous NJT with holes currently uses + # unbind(), leading to data-dependent expression. Should be fixed via torch._check() + XFailRule( + error_type=torch._dynamo.exc.Unsupported, + # Ne(u1, u0) (unhinted: Ne(u1, u0)). (Size-like symbols: u1, u0) + error_msg="Could not guard on data-dependent expression", + op_match_fn=lambda device, op: (op.full_name == "clone"), + sample_match_fn=lambda device, sample: ( + "noncontig_holes" in sample.name + and sample.kwargs.get("memory_format", None) == torch.contiguous_format + ), + name="clone_unbind_data_dependency", + ), + # chunk(): broken in several ways on the batch dim; revisit after similar + # data-dependency issues are handled for narrow() + SkipRule( + op_match_fn=lambda device, op: (op.full_name == "chunk"), + sample_match_fn=lambda device, sample: ("batch_dim" in sample.name), + name="broken_chunk_compile_backward_on_batch_dim", + ), + # select on batch dim currently uses unbind(), leading to data-dependent error in + # torch.compile that needs to be addressed via torch._check() + XFailRule( + error_type=torch._dynamo.exc.InternalTorchDynamoError, + error_msg="Pending unbacked symbols", + op_match_fn=lambda device, op: (op.full_name == "select"), + sample_match_fn=lambda device, sample: ("batch_dim" in sample.name), + name="broken_select_backward_unbacked", + ), +] + +COMPILE_BACKWARD_SKIPS_AND_XFAILS = [ + # non-contiguous with holes inputs + torch.compile doesn't work great today; need + # torch._check() statements. Skip these and handle them later. + SkipRule( + op_match_fn=lambda device, op: True, + sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name), + name="noncontig_holes_data_dependency", + ), + # mean(): weird bug + XFailRule( + error_type=torch._dynamo.exc.BackendCompilerFailed, + error_msg="'NestedIntNode' object has no attribute 'sub'", + op_match_fn=lambda device, op: (op.full_name == "mean"), + sample_match_fn=lambda device, sample: ( + "full reduction" not in sample.name + and "normal dim reduction" not in sample.name + ), + name="broken_mean_compile_backward", + ), + # min() / max(): weird bug + XFailRule( + error_type=AttributeError, + error_msg="'NestedIntNode' object has no attribute 'add'", + op_match_fn=lambda device, op: ( + op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"} + ), + sample_match_fn=lambda device, sample: ("ragged dim" in sample.name), + name="broken_min_max_compile_backward", + ), + # to() fails with data-dependent guards OR Unknown layout in record_stream_any_impl; + # need to fix with torch._check(), etc. + XFailRule( + op_match_fn=lambda device, op: (op.full_name == "to"), + sample_match_fn=lambda device, sample: ("-> cpu" in sample.name), + name="to_data_dependency", + ), + # copysign(): formula is broken for (T, NT) broadcasting + XFailRule( + error_type=AttributeError, + error_msg="'NestedIntNode' object has no attribute 'add'", + op_match_fn=lambda device, op: (op.full_name == "copysign"), + sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name), + name="broken_copysign_compile_backward", + ), + # in compile, these complex ops use view_as_real(), which isn't implemented + XFailRule( + error_type=NotImplementedError, + error_msg="aten.view_as_real.default", + op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}), + name="unimplemented_view_as_real", + ), + *COMPILE_FORWARD_SKIPS_AND_XFAILS, + *BACKWARD_SKIPS_AND_XFAILS, +] + +COMPARE_TENSOR_COMPONENT_EQUALITY = { + # masked_select is expected to output a different shape + "masked_select", +} + + +# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard +# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests: +# * All tests run with dtype=torch.float32 only +class TestNestedTensorOpInfo(NestedTensorTestCase): + # TODO: move this + def _gen_grad_outputs(self, out_val): + if isinstance(out_val, (list, tuple)): + need_grad_outs = tuple(o for o in out_val if o.grad_fn is not None) + grad_outputs = tuple( + torch.ones_like(o) for o in out_val if o.grad_fn is not None + ) + return need_grad_outs, grad_outputs + else: + return out_val, (torch.ones_like(out_val),) + + @ops( + [op for op in njt_op_db if op.supports_njt], + allowed_dtypes=(torch.float32,), + ) + @tf32_on_and_off(0.005) + @sample_skips_and_xfails(FORWARD_SKIPS_AND_XFAILS) + def test_forward(self, device, dtype, op): + for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs( + device=device, + dtype=dtype, + requires_grad=False, + use_subtests=True, + ): + with subtest_ctx(self), skip_xfail_ctx(self): + # compare to reference, but expect different nested int + out = op.op(sample.input, *sample.args, **sample.kwargs) + out_ref = op.ref(op, sample) + self.assertEqualIgnoringNestedInts(out, out_ref) + if op._extra_op_data.is_view: + tree_map_only( + NestedTensor, lambda x: self.assertTrue(x._is_view()), out + ) + + # TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands + # TODO: Add xfails for other inplace ops instead of hardcoding + if op.inplace_variant and "index_put" in op.full_name: + op.inplace_variant(sample.input, *sample.args, **sample.kwargs) + self.assertEqualIgnoringNestedInts(sample.input, out_ref) + + @ops( + [op for op in njt_op_db if op.supports_njt and op.supports_autograd], + allowed_dtypes=(torch.float32,), ) - TestNestedTensorSubclass.test_record_stream = _test_record_stream - TestNestedTensorSubclass.test_construction_from_list = _test_construction_from_list - TestNestedTensorSubclass.test_index_put_error = _test_index_put_error - TestNestedTensorSubclass.test_sdpa = _test_sdpa - TestNestedTensorSubclass.test_sdpa_autocast = _test_sdpa_autocast - TestNestedTensorSubclass.test_to_padded_tensor_compile = ( - _test_to_padded_tensor_compile + @tf32_on_and_off(0.005) + @sample_skips_and_xfails(BACKWARD_SKIPS_AND_XFAILS) + def test_backward(self, device, dtype, op): + for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs( + device=device, dtype=dtype, requires_grad=True, use_subtests=True + ): + with subtest_ctx(self), skip_xfail_ctx(self): + # compare to reference, but expect different nested int + out = op.op(sample.input, *sample.args, **sample.kwargs) + out_ref = op.ref(op, sample) + self.assertEqualIgnoringNestedInts(out, out_ref) + if op._extra_op_data.is_view: + tree_map_only( + NestedTensor, lambda x: self.assertTrue(x._is_view()), out + ) + + inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs)) + g_inps = [ + inp + for inp in inps + if isinstance(inp, torch.Tensor) and inp.requires_grad + ] + if len(g_inps) > 0: + need_grad_outs, grad_outputs = self._gen_grad_outputs(out) + grads = torch.autograd.grad( + need_grad_outs, inputs=g_inps, grad_outputs=grad_outputs + ) + + need_grad_outs, grad_outputs = self._gen_grad_outputs(out_ref) + grads_ref = torch.autograd.grad( + need_grad_outs, inputs=g_inps, grad_outputs=grad_outputs + ) + + self.assertEqualNoncontigAware(grads, grads_ref) + + @ops( + [op for op in njt_op_db if op.supports_njt], + allowed_dtypes=(torch.float32,), ) + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) + # needed to avoid "data dependent operator: aten._local_scalar_dense.default" + @torch._dynamo.config.patch(capture_scalar_outputs=True) + @sample_skips_and_xfails(COMPILE_FORWARD_SKIPS_AND_XFAILS) + def test_compile_forward(self, device, dtype, op): + for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs( + device=device, dtype=dtype, requires_grad=False, use_subtests=True + ): + with subtest_ctx(self), skip_xfail_ctx(self): + torch.compiler.reset() + op_fn = op.op -instantiate_parametrized_tests(TestNestedTensor) -instantiate_device_type_tests( - TestNestedTensorDeviceType, globals(), only_for="xpu", allow_xpu=True -) -instantiate_device_type_tests( - TestNestedTensorAutograd, globals(), only_for="xpu", allow_xpu=True -) -instantiate_device_type_tests( - TestNestedTensorSubclass, globals(), only_for="xpu", allow_xpu=True -) -instantiate_device_type_tests( - TestNestedTensorOpInfo, globals(), only_for="xpu", allow_xpu=True -) + def f(*args, **kwargs): + return op_fn(*args, **kwargs) + + compiled_f = torch.compile( + f, fullgraph=True, backend="aot_eager_decomp_partition" + ) + + out_ref = f(sample.input, *sample.args, **sample.kwargs) + out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs) + if op._extra_op_data.is_view: + tree_map_only( + NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref + ) + + if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY: + self.assertEqualIgnoringNestedInts(out_compile, out_ref) + else: + self.assertEqual(out_compile, out_ref) + + # TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands + # TODO: Add xfails for other inplace ops instead of hardcoding + if op.inplace_variant and "index_put" in op.full_name: + op_fn = op.inplace_variant + + def in_f(*args, **kwargs): + return op_fn(*args, **kwargs) + + compiled_in_f = torch.compile( + in_f, fullgraph=True, backend="aot_eager_decomp_partition" + ) + + compiled_in_f(sample.input, *sample.args, **sample.kwargs) + if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY: + self.assertEqualIgnoringNestedInts(sample.input, out_ref) + else: + self.assertEqual(sample.input, out_ref) + + @ops( + [op for op in njt_op_db if op.supports_njt and op.supports_autograd], + allowed_dtypes=(torch.float32,), + ) + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) + # needed to avoid "data dependent operator: aten._local_scalar_dense.default" + @torch._dynamo.config.patch(capture_scalar_outputs=True) + @sample_skips_and_xfails(COMPILE_BACKWARD_SKIPS_AND_XFAILS) + def test_compile_backward(self, device, dtype, op): + for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs( + device=device, dtype=dtype, requires_grad=True, use_subtests=True + ): + with subtest_ctx(self), skip_xfail_ctx(self): + torch.compiler.reset() + + op_fn = op.op + + def f(*args, **kwargs): + return op_fn(*args, **kwargs) + + compiled_f = torch.compile( + f, fullgraph=True, backend="aot_eager_decomp_partition" + ) + out_ref = f(sample.input, *sample.args, **sample.kwargs) + out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs) + if op._extra_op_data.is_view: + tree_map_only( + NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref + ) + + if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY: + self.assertEqualIgnoringNestedInts(out_compile, out_ref) + else: + self.assertEqual(out_compile, out_ref) + + inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs)) + g_inps = [ + inp + for inp in inps + if isinstance(inp, torch.Tensor) and inp.requires_grad + ] + if len(g_inps) > 0: + need_grad_outs, grad_outputs = self._gen_grad_outputs(out_compile) + grads_compile = torch.autograd.grad( + need_grad_outs, + inputs=g_inps, + grad_outputs=grad_outputs, + ) + + need_grad_outs, grad_outputs = self._gen_grad_outputs(out_ref) + grads_ref = torch.autograd.grad( + need_grad_outs, + inputs=g_inps, + grad_outputs=grad_outputs, + ) + + self.assertEqualNoncontigAware(grads_compile, grads_ref) + + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) + # needed to avoid "data dependent operator: aten._local_scalar_dense.default" + @torch._dynamo.config.patch(capture_scalar_outputs=True) + @skipIfTorchDynamo( + "Dynamo fails on pending unbacked symints at assertEqual(ref_y[0][0][0].item(), 2)" + ) + def test_nested_tensor_non_contiguous_mutation(self): + def fn(x, x0): + x[0, 0, 0] = 2 + return x + + def _inp(): + base = torch.zeros(32, 3) + v = base.t() + return torch.nested.nested_tensor_from_jagged( + v, + offsets=torch.tensor([0, 2, 3]), + ), torch.ones(2, 32) + + ref_x, ref_x0 = _inp() + ref_y = fn(ref_x, ref_x0) + + self.assertEqual(ref_y[0][0][0].item(), 2) + + y = torch.compile(fn, fullgraph=True, backend="aot_eager")(*_inp()) + self.assertEqual(y[0][0][0], 2) + + def test_nested_tensor_input_mutation_backward(self): + # See Note [AOTAutograd Tangent Subclassness for mutated inputs] + # NJT tangent is always subclass, See torch/csrc/autograd/python_function.cpp, use_zeros_like. + # This test checks that AOTD correctly guess NJT tangent as NJT. + def fn(x): + x.mul_(2) + return x + 1 + + def _inp(): + v = torch.zeros(32, 3, requires_grad=True) + return torch.nested.nested_tensor_from_jagged( + v, + offsets=torch.tensor([0, 2, 3]), + ).clone() + + ref_x = _inp() + ref_y = fn(ref_x) + ref_y.sum().backward() + + x = _inp() + y = torch.compile(fn, fullgraph=True, backend="aot_eager")(x) + y.sum().backward() + + +from torch.nested._internal.nested_int import NestedIntNode + + +class TestNestedInt(torch.testing._internal.common_utils.TestCase): + def test_comparisons(self): + a = torch.SymInt(NestedIntNode(1, 1)) + b = torch.SymInt(NestedIntNode(1, 1)) + c = torch.SymInt(NestedIntNode(2, 1)) + d = 3 + + self.assertTrue(a == a) + self.assertTrue(a == b) + self.assertFalse(a != a) + self.assertFalse(a != b) + self.assertFalse(a == c) + self.assertTrue(a != c) + + self.assertFalse(a == d) + self.assertTrue(a != d) + self.assertFalse(d == a) + self.assertTrue(d != a) + + # ge + self.assertTrue(a >= a) + self.assertTrue(a >= b) + self.assertTrue(b >= a) + with self.assertRaises(ValueError): + _ = a >= c + with self.assertRaises(ValueError): + _ = c >= a + with self.assertRaises(ValueError): + _ = c >= 3 + self.assertTrue(c >= 2) + self.assertTrue(c >= 1) + self.assertFalse(c <= 1) + + # lt + self.assertFalse(a < a) + self.assertFalse(a < b) + self.assertFalse(b < a) + with self.assertRaises(ValueError): + _ = a < c + with self.assertRaises(ValueError): + _ = c < a + with self.assertRaises(ValueError): + _ = 3 < a + with self.assertRaises(ValueError): + _ = 2 < a + self.assertTrue(a > 1) + + # le + self.assertTrue(a <= a) + self.assertTrue(b <= a) + self.assertTrue(a <= b) + with self.assertRaises(ValueError): + _ = a <= c + with self.assertRaises(ValueError): + _ = c <= a + with self.assertRaises(ValueError): + _ = 3 <= c + self.assertTrue(c >= 2) + self.assertTrue(c >= 1) + self.assertFalse(c <= 1) + + # gt + self.assertFalse(a > a) + self.assertFalse(b > a) + self.assertFalse(a > b) + with self.assertRaises(ValueError): + _ = a > c + with self.assertRaises(ValueError): + _ = c > a + with self.assertRaises(ValueError): + _ = a > 3 + with self.assertRaises(ValueError): + _ = a > 2 + self.assertTrue(a > 1) + + def test_with_factor(self): + a = torch.SymInt(NestedIntNode(1, 5)) + b = torch.SymInt(NestedIntNode(1, 10)) + # eq + self.assertFalse(a == b) + self.assertFalse(a >= b) + self.assertTrue(b >= a) + self.assertTrue(a <= b) + self.assertFalse(b <= a) + # ne + self.assertTrue(a != b) + # mul + self.assertTrue(a * 2 == b) + self.assertTrue(a * 3 >= b) + self.assertTrue(a * 2 == 2 * a) + + +instantiate_parametrized_tests(TestNestedTensor) +instantiate_device_type_tests(TestNestedTensorDeviceType, globals(), allow_xpu=True) +instantiate_device_type_tests(TestNestedTensorAutograd, globals(), allow_xpu=True) +instantiate_device_type_tests(TestNestedTensorSubclass, globals(), allow_xpu=True) +instantiate_device_type_tests(TestNestedTensorOpInfo, globals(), allow_xpu=True) if __name__ == "__main__": run_tests() From 699084ab980fe8d4ec26d649882734e415b5a300 Mon Sep 17 00:00:00 2001 From: "Deng, Daisy" Date: Thu, 4 Dec 2025 13:49:50 +0000 Subject: [PATCH 2/4] add only_for="xpu" --- test/xpu/test_nestedtensor_xpu.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/xpu/test_nestedtensor_xpu.py b/test/xpu/test_nestedtensor_xpu.py index 45b061c212..75284a2474 100644 --- a/test/xpu/test_nestedtensor_xpu.py +++ b/test/xpu/test_nestedtensor_xpu.py @@ -9127,10 +9127,18 @@ def test_with_factor(self): instantiate_parametrized_tests(TestNestedTensor) -instantiate_device_type_tests(TestNestedTensorDeviceType, globals(), allow_xpu=True) -instantiate_device_type_tests(TestNestedTensorAutograd, globals(), allow_xpu=True) -instantiate_device_type_tests(TestNestedTensorSubclass, globals(), allow_xpu=True) -instantiate_device_type_tests(TestNestedTensorOpInfo, globals(), allow_xpu=True) +instantiate_device_type_tests( + TestNestedTensorDeviceType, globals(), only_for="xpu", allow_xpu=True +) +instantiate_device_type_tests( + TestNestedTensorAutograd, globals(), only_for="xpu", allow_xpu=True +) +instantiate_device_type_tests( + TestNestedTensorSubclass, globals(), only_for="xpu", allow_xpu=True +) +instantiate_device_type_tests( + TestNestedTensorOpInfo, globals(), only_for="xpu", allow_xpu=True +) if __name__ == "__main__": run_tests() From 4ee38f3ecc7c89a5d7c2062b18ff97654b1bfc0c Mon Sep 17 00:00:00 2001 From: "Deng, Daisy" Date: Sun, 7 Dec 2025 14:01:07 +0000 Subject: [PATCH 3/4] add 4 test case to improve test coverage --- .../functorch/test_eager_transforms_xpu.py | 5434 +++++++++++++++++ test/xpu/test_cpp_api_parity_xpu.py | 92 + test/xpu/test_expanded_weights_xpu.py | 1171 ++++ test/xpu/test_matmul_cuda_xpu.py | 1574 ++++- 4 files changed, 7933 insertions(+), 338 deletions(-) create mode 100644 test/xpu/functorch/test_eager_transforms_xpu.py create mode 100644 test/xpu/test_cpp_api_parity_xpu.py create mode 100644 test/xpu/test_expanded_weights_xpu.py diff --git a/test/xpu/functorch/test_eager_transforms_xpu.py b/test/xpu/functorch/test_eager_transforms_xpu.py new file mode 100644 index 0000000000..8e98815846 --- /dev/null +++ b/test/xpu/functorch/test_eager_transforms_xpu.py @@ -0,0 +1,5434 @@ +# Owner(s): ["module: functorch"] +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import math +import os +import subprocess +import sys +import unittest +import warnings +from functools import partial, wraps + +# NB: numpy is a testing dependency! +import numpy as np + +sys.path.append("../../../../test/functorch") +import functorch +import torch +import torch.autograd.forward_ad as fwAD +import torch.nn as nn +import torch.nn.functional as F +from common_utils import expectedFailureIf +from functorch import ( + combine_state_for_ensemble, + grad, + grad_and_value, + hessian, + jacfwd, + jacrev, + jvp, + make_functional, + make_functional_with_buffers, + make_fx, + vjp, + vmap, +) +from functorch.experimental import functionalize, replace_all_batch_norm_modules_ +from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet +from torch._dynamo import allow_in_graph +from torch._functorch.eager_transforms import _slice_argnums +from torch._functorch.make_functional import ( + functional_init, + functional_init_with_buffers, +) +from torch._functorch.utils import enable_single_level_autograd_function +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.func import functional_call, linearize, stack_module_state +from torch.testing import make_tensor +from torch.testing._internal.common_cuda import ( + SM70OrLater, + TEST_CUDA, + tf32_on_and_off, + with_tf32_off, +) +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, + onlyCPU, + onlyOn, +) +from torch.testing._internal.common_dtype import get_all_fp_dtypes +from torch.testing._internal.common_utils import ( + freeze_rng_state, + instantiate_parametrized_tests, + IS_FBCODE, + IS_WINDOWS, + markDynamoStrictTest, + parametrize, + run_tests, + skipIfTorchDynamo, + subtest, + TEST_CUDA_MEM_LEAK_CHECK, + TEST_WITH_TORCHDYNAMO, + TestCase, + xfailIfTorchDynamo, +) +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +USE_TORCHVISION = False +try: + import torchvision # noqa: F401 + + USE_TORCHVISION = True +except ImportError: + warnings.warn( + "Couldn't import torchvision. Some of our tests use it, try " + "to install it with commands from pytorch.org, post-fixed with " + "`--no-deps` to avoid overwriting the pytorch installation", + UserWarning, + ) + +# TestCase for _slice_argnums, an important helper function + + +class VmapTearDownMixin: + def tearDown(self): + # Ensure that in the case of a test failure, the next test won't fail + # because of a previous call to _vmap_increment_nesting that wasn't undone + # i.e. test_vmap_free_tensor fails when PYTORCH_TEST_WITH_DYNAMO=1 + # and the call to increment nesting is not undone + if not TEST_WITH_TORCHDYNAMO: + return + + warn = False + while ci := torch._C._functorch.peek_interpreter_stack(): + if ci.key() == torch._C._functorch.TransformType.Vmap: + warn = True + torch._C._functorch._vmap_decrement_nesting() + else: + break + + if warn: + msg = ( + "Interpreter stack is not empty. Test should have called " + "'torch._C._functorch._vmap_decrement_nesting()'" + ) + warnings.warn(msg) + + +@markDynamoStrictTest +class TestSliceArgnums(TestCase): + def test_invalid_argnum_type(self): + x = torch.randn(3) + args = (x,) + with self.assertRaisesRegex(RuntimeError, "int or Tuple"): + _slice_argnums(args, 0.0) + with self.assertRaisesRegex(RuntimeError, "int or Tuple"): + _slice_argnums(args, [0]) + with self.assertRaisesRegex(RuntimeError, "must be int"): + _slice_argnums(args, (0.0,)) + + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + with self.assertRaisesRegex(RuntimeError, "must be int"): + _slice_argnums(args, ((0, 1), 2)) + + def test_out_of_bounds_argnum_values(self): + x = torch.randn(3) + args = (x,) + with self.assertRaisesRegex(RuntimeError, "positional inputs"): + _slice_argnums(args, 1) + with self.assertRaisesRegex(RuntimeError, "positional inputs"): + _slice_argnums(args, -2) + with self.assertRaisesRegex(RuntimeError, "positional inputs"): + _slice_argnums(args, (-2,)) + + def test_not_enough_argnums(self): + x = torch.randn(3) + args = (x,) + with self.assertRaisesRegex(RuntimeError, "must be non-empty"): + _slice_argnums(args, ()) + + def test_duplicate_argnums(self): + x = torch.randn(3) + args = (x, x) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + _slice_argnums(args, (0, 0)) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + _slice_argnums(args, (0, -2)) + + def test_flat_args_with_positive_int_argnum(self): + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + res = _slice_argnums(args, 0) + self.assertEqual(res, (0.1,)) + + res = _slice_argnums(args, 4) + self.assertEqual(res, (4.1,)) + + def test_flat_args_with_negative_int_argnum(self): + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + res = _slice_argnums(args, -1) + self.assertEqual(res, (4.1,)) + + res = _slice_argnums(args, -5) + self.assertEqual(res, (0.1,)) + + def test_flat_args_with_tuple_argnum(self): + args = (0.1, 1.1, 2.1, 3.1, 4.1) + + res = _slice_argnums(args, (0, 1, 2, 3, 4)) + self.assertEqual(res, args) + + res = _slice_argnums(args, (0, -3)) + self.assertEqual(res, (0.1, 2.1)) + + def test_pytree_args(self): + args = ((0.1, 1.1), 2.0, [3.1]) + + res = _slice_argnums(args, 0) + self.assertEqual(res, args[0:1]) + + res = _slice_argnums(args, (0,)) + self.assertEqual(res, args[0:1]) + + res = _slice_argnums(args, -1) + self.assertEqual(res, args[-1:]) + + res = _slice_argnums(args, (0, -2)) + self.assertEqual(res, args[0:2]) + + def test_argnums_reorders(self): + args = ((0.1, 1.1, 2.1), 3.1, 4.1) + + res = _slice_argnums(args, (1, 0)) + self.assertEqual(res, (args[1], args[0])) + + +def _get_weights_and_functional_call(net, mechanism): + if mechanism == "make_functional": + return make_functional(net) + else: + assert mechanism == "functional_call" + # this makes it so the function from make_functional and this call have the same signature + + def net_func(weights, data): + return functional_call(net, weights, (data,)) + + return net_func, dict(net.named_parameters()) + + +def _get_weights_and_functional_call_with_buffers(net, mechanism): + if mechanism == "make_functional": + return make_functional_with_buffers(net) + else: + assert mechanism == "functional_call" + + # this makes it so the function from make_functional and this call have the same signature + def net_func(weights, buffers, data): + return functional_call(net, (weights, buffers), (data,)) + + return net_func, dict(net.named_parameters()), dict(net.named_buffers()) + + +@markDynamoStrictTest +class TestGradTransform(TestCase): + def test_primitive(self, device): + x = torch.randn([], device=device) + result = grad(torch.sin)(x) + self.assertEqual(result, torch.cos(x)) + + def test_composite_simple(self, device): + x = torch.randn(2, 3, 4, device=device) + result = grad(lambda x: torch.flatten(x).sum())(x) + self.assertEqual(result, torch.ones_like(x)) + + def test_fn_with_kwargs(self, device): + def foo(x, y): + return (x * y).sum() + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + expected = grad(foo)(x, y) + result = grad(foo)(x, y=y) + self.assertEqual(result, expected) + + def test_composite_complicated(self, device): + x = torch.randn(3, device=device) + y = torch.randn(3, 5, device=device) + + def foo(x, y): + result = x @ y + return result.sum() + + result = grad(foo)(x, y) + + x.requires_grad_() + out = foo(x, y) + (expected,) = torch.autograd.grad(out, x) + + self.assertEqual(result, expected) + + def test_composite_two_ops(self, device): + N, C = 2, 5 + y = torch.randn(N, C, device=device) + targets = torch.randint(0, C, (N,), device=device) + + def foo(y, targets): + return F.cross_entropy(y, targets) + + result = grad(foo)(y, targets) + + y.requires_grad_() + (expected,) = torch.autograd.grad(foo(y, targets), y) + + self.assertEqual(result, expected) + + def _test_attributes(self, get_attr_lambda, device): + x = torch.randn(2, 3, 5, dtype=torch.double, device=device) + expected = get_attr_lambda(x) + + def foo(x): + self.assertEqual(get_attr_lambda(x), expected) + return x.sum() + + grad(foo)(x) + + def test_shape(self, device): + self._test_attributes(lambda x: x.shape, device) + + def test_dtype(self, device): + self._test_attributes(lambda x: x.dtype, device) + + def test_is_cuda(self, device): + self._test_attributes(lambda x: x.is_cuda, device) + + def test_numel(self, device): + self._test_attributes(lambda x: x.numel(), device) + + def test_layout_sparse(self, device): + indices = torch.tensor([[0, 1, 1], [2, 0, 2]], device=device) + values = torch.tensor([3.0, 4.0, 5.0], device=device) + sparse_x = torch.sparse_coo_tensor(indices, values, (2, 3), device=device) + + # Verify the input is sparse + self.assertEqual(sparse_x.layout, torch.sparse_coo) + + def foo(x): + # assert GradTrackingTensor still reports sparse layout + self.assertEqual(x.layout, torch.sparse_coo) + return x.coalesce()._values().sum() + + result = grad(foo)(sparse_x) + + # The gradient should also be sparse + self.assertEqual(result.layout, torch.sparse_coo) + + def test_inplace(self, device): + x = torch.randn([], device=device) + + def foo(x): + return x.clone().sin_() + + result = grad(foo)(x) + self.assertEqual(result, x.cos()) + + def test_inplace_on_view(self, device): + x = torch.randn(3, device=device) + + def foo(x): + y = x.clone() + y0 = y[0] + y0.sin_() + return y.sum() + + result = grad(foo)(x) + + x.requires_grad_() + out = foo(x) + (expected,) = torch.autograd.grad(out, x) + + self.assertEqual(result, expected) + + def test_inplace_on_view_base(self, device): + x = torch.randn(3, device=device) + + def foo(x): + y = x.clone() + y0 = y[0] + y.sin_() + return y0 + + result = grad(foo)(x) + + x.requires_grad_() + out = foo(x) + (expected,) = torch.autograd.grad(out, x) + + self.assertEqual(result, expected) + + def test_inplace_on_captures(self, device): + x = torch.tensor([1.0, 2.0, 3.0], device=device) + captured = torch.randn(3, device=device) + + def foo(x): + captured.copy_(x) + return (x * captured).sum() + + with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"): + grad(foo)(x) + + def test_nesting_simple(self, device): + x = torch.randn([], device=device) + result = grad(grad(torch.sin))(x) + self.assertEqual(result, -torch.sin(x)) + + @skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613") + def test_escaped_wrappers_are_marked_as_dead(self, device): + x = torch.randn([], device=device) + escaped = [] + + def foo(x): + y = x.sin() + escaped.append(y) + return y + + grad(foo)(x) + self.assertEqual(torch._C._functorch.dlevel(escaped[0]), -1) + + @skipIfTorchDynamo("Ref: https://github.com/pytorch/pytorch/issues/103613") + def test_escaped_wrappers_are_ignored(self, device): + x = torch.randn([], device=device) + escaped = [] + + def foo(x): + y = x.sin() + escaped.append(y) + return y + + grad(foo)(x) + + something = escaped[0].sum() + self.assertEqual(torch._C._functorch.dlevel(something), 0) + self.assertEqual(something, x.sin().sum()) + + def test_manual_seed_inside_grad(self, device): + x = torch.randn([], device=device) + + def f(x): + torch.manual_seed(0) + return x * torch.randn_like(x) + + with freeze_rng_state(): + result = grad(f)(x) + x.requires_grad_() + (expected,) = torch.autograd.grad(f(x), x) + self.assertEqual(result, expected) + + def test_vjp(self, device): + x = torch.randn([], device=device) + out, vjp_fn = vjp(torch.sin, x) + self.assertEqual(out, x.sin()) + + v = torch.randn([], device=device) + (result,) = vjp_fn(v) + self.assertEqual(result, v * x.cos()) + + def test_vjp_two_outputs(self, device): + def f(x): + return x, x + + result, vjp_fn = vjp(f, torch.tensor(1.0)) + vjp_fn(result) + + def test_conj_bit(self): + x = torch.tensor(1 + 1j) + + def foo(x): + assert not x.is_conj() + y = x.conj() + assert y.is_conj() + return y.abs() + + res = grad(foo)(x) + with torch.no_grad(): + self.assertEqual(res, torch.ones_like(res) * torch.sgn(x)) + + def test_composed_with_autograd(self, device): + x = torch.randn([], requires_grad=True, device=device) + + y = grad(torch.sin)(x) + (result,) = torch.autograd.grad(y, x) + self.assertEqual(result, -x.sin()) + + def test_grad_of_vjp_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) + + def foo(x, y): + out, vjp_fn = vjp(torch.sin, x) + return grad(lambda y: vjp_fn(y)[0])(y) + + result = foo(x, y) + expected = x.cos() + self.assertEqual(result, expected) + + def test_vjp_of_grad_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) + + def foo(x, y): + out, vjp_fn = vjp(grad(torch.sin), x) + return vjp_fn(y)[0] + + result = foo(x, y) + expected = -y * x.sin() + self.assertEqual(result, expected) + + def test_grad_of_vjp_of_grad_composition(self, device): + x = torch.randn([], device=device) + y = torch.randn([], device=device) + + def foo(x, y): + df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x) + return grad(lambda y: vjp_fn(y)[0])(y) + + result = foo(x, y) + expected = x.cos() + self.assertEqual(result, expected) + + def test_views(self, device): + x = torch.randn([], requires_grad=True, device=device) + y = torch.randn([], requires_grad=True, device=device) + + def silly_sin(x): + x = x.view([]) + x = x.sin() + return x + + def foo(x, y): + z1 = grad(silly_sin)(x) + z2 = torch.cos(y) + return z1 + z2 + + result = foo(x, y) + grads = torch.autograd.grad(result, [x, y]) + self.assertEqual(grads[0], -x.sin()) + self.assertEqual(grads[1], -y.sin()) + + def test_view_inplace_simple(self, device): + def foo(x): + x = x.clone() + x.view([]).sin_() + return x + + x = torch.randn([], requires_grad=True, device=device) + result = grad(foo)(x) + self.assertEqual(result, x.cos()) + + def test_invalid_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "but only"): + grad(torch.mul, argnums=-3)(x, y) + with self.assertRaisesRegex(RuntimeError, "but only"): + grad(torch.mul, argnums=2)(x, y) + with self.assertRaisesRegex(RuntimeError, "int or Tuple"): + grad(torch.mul, argnums=[0])(x, y) + with self.assertRaisesRegex(RuntimeError, "must be int"): + grad(torch.mul, argnums=("0",))(x, y) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + grad(torch.mul, argnums=(0, 0))(x, y) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + grad(torch.mul, argnums=(0, -2))(x, y) + + def test_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + gx = grad(torch.mul, argnums=0)(x, y) + self.assertEqual(gx, y) + + gy = grad(torch.mul, argnums=1)(x, y) + self.assertEqual(gy, x) + + (gx,) = grad(torch.mul, argnums=(0,))(x, y) + self.assertEqual(gx, y) + + gx, gy = grad(torch.mul, argnums=(0, 1))(x, y) + self.assertEqual(gx, y) + self.assertEqual(gy, x) + + def test_out_of_order_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + gy, gx = grad(torch.mul, argnums=(1, 0))(x, y) + self.assertEqual(gx, y) + self.assertEqual(gy, x) + + def test_negative_argnums(self, device): + x = torch.randn([]) + y = torch.randn([]) + gx = grad(torch.mul, argnums=-2)(x, y) + self.assertEqual(gx, y) + + gy = grad(torch.mul, argnums=-1)(x, y) + self.assertEqual(gy, x) + + (gx,) = grad(torch.mul, argnums=(-2,))(x, y) + self.assertEqual(gx, y) + + gx, gy = grad(torch.mul, argnums=(-2, -1))(x, y) + self.assertEqual(gx, y) + self.assertEqual(gy, x) + + def test_grad_pytree_inputs(self, device): + x = torch.randn([], device=device) + + def f(a, b): + x, y = a + return 1 * x + 2 * y + 3 * b["foo"] + + args = ((x, x), {"foo": x}) + + gx, gy = grad(f)(*args) + self.assertEqual(gx, torch.tensor(1.0, device=device)) + self.assertEqual(gy, torch.tensor(2.0, device=device)) + + ((gx, gy),) = grad(f, argnums=(0,))(*args) + self.assertEqual(gx, torch.tensor(1.0, device=device)) + self.assertEqual(gy, torch.tensor(2.0, device=device)) + + (gx, gy), gz = grad(f, argnums=(0, 1))(*args) + self.assertEqual(gx, torch.tensor(1.0, device=device)) + self.assertEqual(gy, torch.tensor(2.0, device=device)) + self.assertEqual(gz["foo"], torch.tensor(3.0, device=device)) + + def test_grad_aux_tensor(self, device): + x = torch.randn(3, device=device) + + with self.assertRaisesRegex( + RuntimeError, + r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple", + ): + grad(lambda t: [t, t], has_aux=True)(x) + + with self.assertRaisesRegex( + RuntimeError, + r"grad_and_value\(f\)\(\*args\): output of function f should be a tuple", + ): + grad(lambda t: (t, t + 2, t + 3), has_aux=True)(x) + + def f(t): + y = t.sin() + return y.sum(), t.cos() + + out, aux = grad(f, has_aux=True)(x) + self.assertEqual(aux, x.cos()) + self.assertEqual(out, x.cos()) + + def test_grad_aux_pytree(self, device): + def f(x): + y = x.sin() + return y.sum(), {"a": x.cos(), "b": [x.tan()]} + + x = torch.randn(3, device=device) + + out, aux = grad(f, has_aux=True)(x) + _, expected_aux = f(x) + self.assertEqual(aux, expected_aux) + self.assertEqual(out, x.cos()) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = grad(lambda x: (x.sum(), aux), has_aux=True)(x) + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = grad(lambda x: (x.sum(), [x, aux]), has_aux=True)(x) + + def test_zero_grad(self, device): + def f(x): + return (x["a"] ** 2.0).sum() + + inps = { + "a": torch.randn(10, device=device) + 3, + "b": torch.randn(10, device=device), + } + grads = grad(f)(inps) + self.assertNotEqual(grads["a"].sum(), 0.0) + self.assertEqual(grads["b"].sum(), 0.0) + + def test_unrelated_grad(self, device): + x = torch.tensor(1.0, device=device) + y = torch.tensor(2.0, device=device) + + def unrelated(x): + return y + + result = grad(unrelated)(x) + self.assertEqual(result, torch.zeros_like(x)) + + def test_unrelated_vjp(self, device): + x = torch.tensor(1.0, device=device) + y = torch.tensor(2.0, device=device) + v = torch.tensor(1.0, device=device) + + def unrelated(x): + return y + + out, vjp_fn = vjp(unrelated, x) + result = vjp_fn(v) + expected = (torch.zeros_like(x),) + self.assertEqual(result, expected) + + def test_unrelated_vjp_multiple_inputs_outputs(self, device): + w = torch.tensor(3.0, device=device) + x = torch.tensor(4.0, device=device) + y = torch.tensor(2.0, device=device) + v = torch.tensor(1.0, device=device) + + def unrelated(w, x): + return y, y, x + + out, vjp_fn = vjp(unrelated, w, x) + result = vjp_fn((v, v, v)) + expected = (torch.zeros_like(x), torch.ones_like(x)) + self.assertEqual(result, expected) + + # TODO: https://github.com/pytorch/functorch/issues/12 + @onlyCPU + def test_unrelated_hessian(self, device): + N = 5 + M = 3 + W = torch.randn(N, M, device=device) + + def f(x): + return W @ x + + x = torch.randn(M) + result = jacrev(jacrev(f))(x) + expected = torch.zeros(N, M, M, device=device) + self.assertEqual(result, expected) + + def test_vjp_pytree_input(self, device): + def f(x): + return x[0] * x[1][0] + + x = torch.randn([], device=device) + v = torch.randn([], device=device) + out, vjp_fn = vjp(f, (x, (x, x))) + self.assertEqual(out, x * x) + result = vjp_fn(v) + self.assertEqual(result, ((x * v, (x * v, 0.0)),)) + + def test_vjp_pytree_output(self, device): + def f(x): + return x, (x, x) + + x = torch.randn([], device=device) + v1 = torch.randn([], device=device) + v2 = torch.randn([], device=device) + v3 = torch.randn([], device=device) + _, vjp_fn = vjp(f, x) + (result,) = vjp_fn((v1, (v2, v3))) + self.assertEqual(result, v1 + v2 + v3) + + def test_vjp_outputs_can_any_pytree(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + for output in [None, ()]: + with self.assertRaisesRegex( + RuntimeError, + r"vjp\(f, \*primals\): Expected f to be a function that has non-empty output", + ): + _, vjp_fn = vjp(lambda _: output, x) + vjp_fn(t) + + for output in [1, True, 12.2, "abc"]: + with self.assertRaisesRegex( + RuntimeError, + r"vjp\(f, \*primals\): expected f\(\*primals\) to return only tensors", + ): + _, vjp_fn = vjp(lambda _: output, x) + vjp_fn(t) + + # Check list output + output, vjp_fn = vjp(lambda x: [x, x.sum()], x) + (vjp_out,) = vjp_fn([t, t.sum()]) + assert isinstance(output, list) and len(output) == 2 + assert isinstance(vjp_out, torch.Tensor) + + # Check dict output + output, vjp_fn = vjp(lambda x: {"x": x, "xsum": x.sum()}, x) + (vjp_out,) = vjp_fn({"x": t, "xsum": t.sum()}) + assert isinstance(output, dict) and len(output) == 2 and "xsum" in output + assert isinstance(vjp_out, torch.Tensor) + + def composite_output(x): + out = x.sum() + return [ + (out, {"a": x, "out": [x, out]}), + ] + + output, vjp_fn = vjp(composite_output, x) + (vjp_out,) = vjp_fn( + [ + (t.sum(), {"a": t, "out": [t, t.sum()]}), + ] + ) + assert isinstance(output, list) + assert isinstance(output[0], tuple) and isinstance(output[0][1], dict) + assert isinstance(vjp_out, torch.Tensor) + + def test_vjp_pytree_error(self, device): + def f(x): + return x, (x, x) + + x = torch.randn([], device=device) + v1 = torch.randn([], device=device) + v2 = torch.randn([], device=device) + v3 = torch.randn([], device=device) + _, vjp_fn = vjp(f, x) + with self.assertRaisesRegex(RuntimeError, "Expected pytree structure"): + (result,) = vjp_fn(((v1, (v2, v3)),)) + + def test_vjp_aux_tensor(self, device): + x = torch.randn(3, device=device) + + with self.assertRaisesRegex( + RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple" + ): + vjp(lambda t: [t, t], x, has_aux=True) + + with self.assertRaisesRegex( + RuntimeError, r"vjp\(f, \*primals\): output of function f should be a tuple" + ): + vjp(lambda t: (t, t + 2, t + 3), x, has_aux=True) + + def f(t): + y = t.sin() + return y, t.cos() + + out, vjp_fn, aux = vjp(f, x, has_aux=True) + self.assertEqual(aux, x.cos()) + self.assertEqual(out, x.sin()) + + v = torch.randn(3, device=device) + (grad_x,) = vjp_fn(v) + self.assertEqual(grad_x, v * x.cos()) + + def test_vjp_aux_pytree(self, device): + def f(x): + y = x.sin() + return y, {"a": x.cos(), "b": [x.tan()]} + + x = torch.randn(3, device=device) + + out, vjp_fn, aux = vjp(f, x, has_aux=True) + expected_out, expected_aux = f(x) + self.assertEqual(out, expected_out) + self.assertEqual(aux, expected_aux) + + v = torch.randn(3, device=device) + (grad_x,) = vjp_fn(v) + self.assertEqual(grad_x, v * x.cos()) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = vjp(lambda x: (x, aux), x, has_aux=True) + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = vjp(lambda x: (x, [x, aux]), x, has_aux=True) + + def test_functional_init(self, device): + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + B = 10 + weights, fn, _ = functional_init(MLPClassifier, (B,), device=device)(32, 2) + inputs = torch.randn(B, 7, 2, device=device) + vmap(fn)(weights, (inputs,)) + + def test_functional_init_with_buffers(self, device): + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.bn(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + B = 10 + weights, buffers, fn, _, _ = functional_init_with_buffers( + MLPClassifier, [B], device=device + )(32, 2) + inputs = torch.randn(B, 7, 2, device=device) + vmap(fn)(weights, buffers, (inputs,)) + + def test_advanced_indexing(self, device): + def f(value): + log_prob = torch.ones((), device=device) + val = torch.zeros(()) > 0 + log_prob[val] = 0 + return value + + result = grad(f)(torch.randn((), device=device)) + self.assertEqual(result, torch.ones_like(result)) + + def f2(value): + value = value.clone() + value[value > 0] = 0 + return value.sum() + + x = torch.randn(100, device=device) + result = grad(f2)(x) + self.assertEqual(result, (x <= 0).type_as(x)) + + def test_tensor_ctor_inside_grad(self, device): + def foo(x): + return x * torch.tensor(2.0, device=device) + + x = torch.tensor(3.14, device=device) + functorch.grad(foo)(x) + + @parametrize( + "op_list_data", + [ + subtest( + ( + [ + vmap, + ], + [(4, 2), (64, 3, 32, 32)], + ), + name="vmap", + ), + subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name="vmap_vmap"), + subtest( + ( + [ + grad, + ], + [(0,), [], (4, 2), (64, 3, 32, 32)], + ), + name="grad", + ), + subtest( + ( + [grad, grad], + [ + [], + ], + ), + name="grad_grad", + ), + subtest(([vmap, grad], [(4, 2)]), name="vmap_grad"), + ], + ) + def test_tensor_print(self, device, op_list_data): + op_list, shapes = op_list_data + + for dt in get_all_fp_dtypes(): + data = [torch.randn(s, dtype=dt, device=device) for s in shapes] + + for x in data: + buf = None + + def foo(t): + nonlocal buf + buf = repr(t) + return t.mean() + + fn = foo + bdim = 0 + for op in reversed(op_list): + if op is vmap: + fn = op(fn, in_dims=bdim) + bdim += 1 + else: + fn = op(fn) + + expected = f"{repr(x)}" + for level, op in enumerate(op_list): + if op is grad: + expected = ( + f"GradTrackingTensor(lvl={level + 1}, value={expected})" + ) + elif op is vmap: + bdim -= 1 + expected = f"BatchedTensor(lvl={level + 1}, bdim={bdim}, value={expected})" + + fn(x) + buf = buf.replace("\n", "").replace(" ", "") + expected = expected.replace("\n", "").replace(" ", "") + self.assertEqual(expected, buf) + + def test_print_captured_tensor_inside_transform(self, device): + x = torch.tensor([1.0, 2.0, 3.0], device=device) + out = None + + def f(y): + nonlocal out + out = repr(x) + return y + + vjp(f, torch.randn(4, device=device)) + self.assertEqual(out, repr(x)) + + def test_no_grad_outside(self, device): + x = torch.randn([], device=device, requires_grad=True) + with torch.no_grad(): + y = grad(torch.sin)(x) + self.assertEqual(y, x.cos()) + self.assertFalse(y.requires_grad) + + def test_no_grad_inside(self, device): + def f(x): + with torch.no_grad(): + shift = x**2 + return x**2 - shift + + x = torch.randn([], device=device) + y = grad(f)(x) + self.assertEqual(y, 2 * x) + y = grad(grad(f))(x) + self.assertEqual(y, 2) + + x = torch.randn([], device=device, requires_grad=True) + y = grad(f)(x) + (z,) = torch.autograd.grad(y, x) + self.assertEqual(z, 2) + + def test_no_grad_mixed(self, device): + def f(x): + with torch.no_grad(): + shift = x**2 + return x**2 - shift + + x = torch.randn([], device=device, requires_grad=True) + with torch.no_grad(): + y = grad(f)(x) + + self.assertEqual(y, 2 * x) + self.assertFalse(y.requires_grad) + + def test_no_grad_nested_simple(self, device): + def h(x): + with torch.no_grad(): + shift = grad(lambda x: 0.25 * x**4)(x) + return x**3 - shift + + x = torch.tensor(1.5, device=device, requires_grad=True) + y = grad(h)(x) + self.assertEqual(y, 3 * x**2) + + (z,) = torch.autograd.grad(y, x) + self.assertEqual(z, 6 * x) + + def test_no_grad_nested_complicated(self, device): + def f(x): + with torch.no_grad(): + shift = x**3 + return x**3 - shift + + def g(x): + r1 = grad(f)(x) + with torch.no_grad(): + shift = grad(f)(x) + return r1 - shift + + x = torch.randn([], requires_grad=True, device=device) + y = grad(g)(x) + # The only differential part of g is x ** 3 + self.assertEqual(y, 6 * x) + + (z,) = torch.autograd.grad(y, x) + self.assertEqual(z, 6) + + def test_no_grad_value(self, device): + def h(x): + with torch.no_grad(): + gvalue, value = grad_and_value(lambda x: x**3)(x) + return x**3 - value + + x = torch.tensor(1.6, device=device, requires_grad=True) + y = grad(h)(x) + self.assertEqual(y, 3 * x**2) + + (z,) = torch.autograd.grad(y, x) + self.assertEqual(z, 6 * x) + + def test_no_grad_outside_vjp(self, device): + def h(x): + return x**2 + + x = torch.tensor(2.0, requires_grad=True, device=device) + with torch.no_grad(): + out, vjp_fn = vjp(h, x) + (y,) = vjp_fn(torch.tensor(1.0, device=device)) + + self.assertEqual(y, 2 * x) + self.assertFalse(y.requires_grad) + self.assertFalse(out.requires_grad) + + def test_no_grad_outside_vjp_fn(self, device): + def h(x): + return x**2 + + x = torch.tensor(3.14, requires_grad=True, device=device) + out, vjp_fn = vjp(h, x) + with torch.no_grad(): + (y,) = vjp_fn(torch.tensor(1.0, device=device)) + + self.assertEqual(y, 2 * x) + self.assertFalse(y.requires_grad) + self.assertTrue(out.requires_grad) + + (z,) = torch.autograd.grad(out, x) + self.assertEqual(z, 2 * x) + + def test_no_grad_outside_vjp_only(self, device): + def h(x): + return x**2 + + x = torch.tensor(3.14, requires_grad=True, device=device) + with torch.no_grad(): + out, vjp_fn = vjp(h, x) + (y,) = vjp_fn(torch.tensor(1.0, device=device)) + + self.assertEqual(y, 2 * x) + self.assertFalse(out.requires_grad) + + # This one is a little weird... + self.assertTrue(y.requires_grad) + + (z,) = torch.autograd.grad(y, x) + self.assertEqual(z, 2) + + +@markDynamoStrictTest +class TestAutogradFunction(TestCase): + def test_set_materialize_grads(self, device): + class A(torch.autograd.Function): + @staticmethod + def forward(x, y): + return x, y + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.set_materialize_grads(False) + + @staticmethod + def backward(ctx, gx, gy): + self.assertIsNotNone(gx) + self.assertIsNone(gy) + return gx, gy + + def f(y, x): + x, y = A.apply(x, y) + return x**2 + + x = torch.tensor(2.0, device=device) + y = torch.tensor(3.0, device=device) + # grad differentiates w.r.t. arg 0 by default + grad(f)(y, x) + grad(grad(f))(y, x) + + @parametrize("inner_requires_grad", [True, False]) + @parametrize("save_for", ["jvp", "vjp"]) + @parametrize("save_tensors", ["input", "output", "neither"]) + @parametrize("mark_dirty", [True, False]) + def test_function_returns_input( + self, device, inner_requires_grad, save_for, save_tensors, mark_dirty + ): + class A(torch.autograd.Function): + @staticmethod + def forward(x): + return x + + @staticmethod + def setup_context(ctx, inputs, output): + if save_for == "jvp": + save_fn = ctx.save_for_forward + else: + save_fn = ctx.save_for_backward + + if mark_dirty: + ctx.mark_dirty(inputs[0]) + + if save_tensors == "input": + save_fn(inputs[0]) + elif save_tensors == "output": + save_fn(output) + elif save_tensors == "neither": + pass + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + @staticmethod + def jvp(ctx, x_t): + # NB: the logic to check ctx.save_for_forward happens + # before we reach this! + if mark_dirty: + ret = x_t.add_(0) + else: + ret = x_t.view_as(x_t) + return ret + + def fn(x): + return A.apply(x.clone()) + + err_msg = "A input that has been returned as-is" + + a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad) + a_t = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad) + if save_tensors in ("input", "output") and not mark_dirty: + with self.assertRaisesRegex(RuntimeError, err_msg): + grad(fn)(a) + with self.assertRaisesRegex(RuntimeError, err_msg): + jvp(fn, (a,), (a_t,)) + else: + grad(fn)(a) + jvp(fn, (a,), (a_t,)) + + a = torch.tensor(2.0, device=device, requires_grad=inner_requires_grad).clone() + a_t = torch.tensor( + 2.0, device=device, requires_grad=inner_requires_grad + ).clone() + + if save_tensors in ("input", "output") and not mark_dirty: + with self.assertRaisesRegex(RuntimeError, err_msg): + A.apply(a) + with self.assertRaisesRegex(RuntimeError, err_msg): + with fwAD.dual_level(): + A.apply(fwAD.make_dual(a, a_t)) + else: + b = A.apply(a) + if mark_dirty: + self.assertTrue(a is b) + if not ( + mark_dirty and save_for == "vjp" and save_tensors in ("input", "output") + ): + # TODO(soulitzer): https://github.com/pytorch/pytorch/issues/97827 + with fwAD.dual_level(): + a_dual = fwAD.make_dual(a, a_t) + b_dual = A.apply(a_dual) + if mark_dirty: + self.assertTrue(a_dual is b_dual) + + def test_needs_input_grads(self, device): + class A(torch.autograd.Function): + @staticmethod + def forward(x, y): + return x * y + + @staticmethod + def setup_context(ctx, inputs, output): + return + + @staticmethod + def backward(ctx, grad_output): + self.assertTrue(ctx.needs_input_grad[0]) + self.assertFalse(ctx.needs_input_grad[1]) + return None, None + + x = torch.tensor(2.0, device=device) + y = torch.tensor(3.0, device=device) + # grad differentiates w.r.t. arg 0 by default + grad(A.apply)(x, y) + grad(grad(A.apply))(x, y) + + def _get_NumpyCubeNotComposable(self): + class NumpyCubeNotComposable(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = input.cpu().numpy() + return torch.tensor(input_np**3, device=input.device), input_np + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.input_np = output[1] + ctx.device = inputs[0].device + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output, grad_saved): + result_np = 3 * (ctx.input_np**2) + return torch.tensor(result_np, device=ctx.device) + + return NumpyCubeNotComposable + + def test_once_differentiable_autograd_vjp(self, device): + NumpyCubeNotComposable = self._get_NumpyCubeNotComposable() + + def f(x): + y, _ = NumpyCubeNotComposable.apply(x) + return y + + # regular autograd x vjp + x = torch.randn([], requires_grad=True, device=device) + grad_y = torch.randn_like(x, requires_grad=True) + _, vjp_fn = vjp(f, x) + (gx,) = vjp_fn(grad_y) + + with self.assertRaisesRegex(RuntimeError, "marked with @once_differentiable"): + gx.backward() + + # TODO: support torch.autograd.function.once_differentiable + # (or, if impossible, figure out how to raise a nice error) + # https://github.com/pytorch/pytorch/issues/90224 + @unittest.expectedFailure + def test_once_differentiable_grad_vjp(self, device): + # grad x vjp + x = torch.randn([], device=device) + grad_y = torch.randn_like(x) + + def h(x, grad_y): + _, vjp_fn = vjp(f, x) # noqa: F821 + (gx,) = vjp_fn(grad_y) + return gx + + grad(h, argnums=(0, 1))(x, grad_y) + + def test_grad_fn_name(self, device): + names = [] + + class FooBar(torch.autograd.Function): + @staticmethod + def forward(x): + return x.clone() + + @staticmethod + def setup_context(ctx, inputs, output): + return + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + def f(x): + y = FooBar.apply(x) + names.append(type(y.grad_fn).__name__) + return y + + x = torch.tensor(1.0) + grad(f)(x) + self.assertEqual(names, ["FooBarGeneratedBackward"]) + + +@markDynamoStrictTest +class TestAutogradFunctionVmapAPI(TestCase): + def test_no_vmap_staticmethod_and_no_generate_vmap_rule(self, device): + class NumpyCube(torch.autograd.Function): + @staticmethod + def forward(input): + input_np = to_numpy(input) # noqa: F821 + dinput = torch.tensor(3 * input_np**2, device=input.device) + return torch.tensor(input_np**3, device=input.device), dinput + + @staticmethod + def setup_context(ctx, inputs, output): + ctx.save_for_backward(inputs, output[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + raise RuntimeError("foobar") + + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "does not have vmap support"): + vmap(NumpyCube.apply)(x) + + def test_has_vmap_staticmethod_and_has_generate_vmap_rule(self, device): + class NumpyCube(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(input): + input_np = to_numpy(input) # noqa: F821 + dinput = torch.tensor(3 * input_np**2, device=input.device) + return torch.tensor(input_np**3, device=input.device), dinput + + @staticmethod + def setup_context(ctx, outputs, input): + ctx.save_for_backward(input, outputs[1]) + + @staticmethod + def backward(ctx, grad_output, grad_saved): + raise RuntimeError("foobar") + + @staticmethod + def vmap(infos, in_dims, x): + raise RuntimeError("foobar") + + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "generate_vmap_rule=True and"): + vmap(NumpyCube.apply)(x) + + def test_info_object(self, device): + batch_size = 10 + + class Id(torch.autograd.Function): + @staticmethod + def forward(input): + pass + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def backward(ctx, grad_output, grad_saved): + pass + + @staticmethod + def vmap(info, in_dims, input): + self.assertEqual(info.batch_size, batch_size) + self.assertEqual(info.randomness, randomness) + return input, in_dims[0] + + x = torch.randn(batch_size, 3, device=device) + + for randomness in ("error", "different", "same"): + vmap(Id.apply, randomness=randomness)(x) + + def test_in_dims_single_input(self, device): + class Id(torch.autograd.Function): + @staticmethod + def forward(input): + pass + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def backward(ctx, grad_output, grad_saved): + pass + + @staticmethod + def vmap(info, in_dims, input): + self.assertEqual(in_dims, (1,)) + return input, in_dims[0] + + B = 10 + x = torch.randn(3, B, device=device) + vmap(Id.apply, in_dims=1)(x) + vmap(Id.apply, in_dims=(1,))(x) + + def test_in_dims_multiple_inputs(self, device): + class Id(torch.autograd.Function): + @staticmethod + def forward(x, y): + pass + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def backward(ctx, grad_output, grad_saved): + pass + + @staticmethod + def vmap(info, in_dims, x, y): + self.assertEqual(in_dims, (0, [0, 0])) + self.assertTrue(isinstance(in_dims, tuple)) + self.assertTrue(isinstance(in_dims[1], list)) + return (x, y), in_dims + + x = torch.randn(2, device=device) + vmap(Id.apply)(x, [x, x]) + + def test_skips_empty_layer(self, device): + class Id(torch.autograd.Function): + @staticmethod + def forward(input): + return input + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def backward(ctx, grad_output, grad_saved): + pass + + @staticmethod + def vmap(info, in_dims, input): + raise RuntimeError("expected to not be called") + + def f(x): + y = torch.tensor(1.0) + y = Id.apply(y) + return x * 1 + + x = torch.randn(2, 3) + vmap(f)(x) + + def test_none_returns(self, device): + class Zeros(torch.autograd.Function): + @staticmethod + def forward(input): + return torch.zeros(input.shape, device=input.device) + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, input): + assert in_dims == (0,) + return torch.zeros(input.shape[1:], device=input.device), None + + B = 2 + x = torch.randn(B, 3) + y = vmap(Zeros.apply)(x) + self.assertEqual(y, torch.zeros_like(x)) + + class TwoZeros(torch.autograd.Function): + @staticmethod + def forward(input): + r = torch.zeros(input.shape, device=input.device) + return r, r + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, input): + assert in_dims == (0,) + r = torch.zeros(input.shape[1:], device=input.device) + return (r, r), None + + B = 2 + x = torch.randn(B, 3) + result = vmap(TwoZeros.apply)(x) + + self.assertTrue(isinstance(result, tuple)) + y, z = result + self.assertEqual(y, torch.zeros_like(x)) + self.assertEqual(z, torch.zeros_like(x)) + + def test_should_have_two_returns(self, device): + class Zeros(torch.autograd.Function): + @staticmethod + def forward(input): + r = torch.zeros(input.shape, device=input.device) + return r + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, input): + r = torch.zeros(input.shape[1:], device=input.device) + return r + + B = 2 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "to have two returns"): + vmap(Zeros.apply)(x) + + class TwoZeros(torch.autograd.Function): + @staticmethod + def forward(input): + r = torch.zeros(input.shape, device=input.device) + return r, r + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, input): + r = torch.zeros(input.shape[1:], device=input.device) + return r, r, 0, 0 + + B = 2 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "to have two returns"): + vmap(Zeros.apply)(x) + + def test_incompatible_out_dims_error_msg(self, device): + class Zeros(torch.autograd.Function): + @staticmethod + def forward(input): + r = torch.zeros(input.shape, device=input.device) + return r + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, input): + r = torch.zeros(input.shape[1:], device=input.device) + return r, (None,) + + B = 2 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "returned an incompatible"): + vmap(Zeros.apply)(x) + + class Zeros(torch.autograd.Function): + @staticmethod + def forward(input): + r = torch.zeros(input.shape, device=input.device) + return [r] + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, input): + r = torch.zeros(input.shape[1:], device=input.device) + return [r], (None,) + + B = 2 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "returned an incompatible"): + vmap(Zeros.apply)(x) + + def test_kwarg_only_tensors(self, device): + with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): + + class MyClass(torch.autograd.Function): + @staticmethod + def forward(x, *, y): + return x + y + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def vmap(info, in_dims, x, *, y): + assert in_dims == (0,) + return x + y, 0 + + x = torch.randn(3) + y = torch.randn(3) + + vmap(MyClass.apply)(x, y=y) + + +@markDynamoStrictTest +class TestVmapOfGrad(TestCase): + def test_per_sample_grads_inplace_view(self, device): + def compute_loss(weight, x, t): + x = x.mm(weight) + y = x.squeeze_(0) + return (y - t).sum() + + weight = torch.randn(16, 2, device=device) + x = torch.randn(64, 1, 16, device=device) + t = torch.randn(64, 2, device=device) + result = vmap(partial(grad(compute_loss), weight))(x, t) + expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] + expected = torch.stack(expected) + # TODO: Check if the rtol is a problem + self.assertEqual(result, expected, atol=0, rtol=5e-4) + + def test_new_zeros_materializes_tensor(self, device): + N = 3 + C = 5 + + def foo(y, x): + result = x.new_zeros((C,)) + result.copy_(y) + return result.sum() + + x = torch.randn(N, device=device) + y = torch.randn(N, C, device=device) + result = vmap(grad(foo))(y, x) + self.assertEqual(result, torch.ones_like(y)) + + def test_new_empty_materializes_tensor(self, device): + N = 3 + C = 5 + + def foo(y, x): + result = x.new_empty((C,)) + result.copy_(y) + return result.sum() + + x = torch.randn(N, device=device) + y = torch.randn(N, C, device=device) + result = vmap(grad(foo))(y, x) + self.assertEqual(result, torch.ones_like(y)) + + def test_per_sample_grads_simple(self, device): + def compute_loss(weight, x, t): + y = x @ weight + return ((y - t) ** 2).sum() + + weight = torch.randn(16, 2, device=device) + x = torch.randn(64, 16, device=device) + t = torch.randn(64, 2, device=device) + result = vmap(partial(grad(compute_loss), weight))(x, t) + expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] + expected = torch.stack(expected) + # TODO: Check if the rtol is a problem + self.assertEqual(result, expected, atol=0, rtol=5e-4) + + def _compare_expected_and_result(self, expected, result, mechanism): + if mechanism == "make_functional": + expected = zip(*expected) + expected = tuple(torch.stack(shards) for shards in expected) + for r, e in zip(result, expected): + self.assertEqual(r, e, atol=0, rtol=1.5e-3) + else: + assert mechanism == "functional_call" + expected = { + k: tuple(d[k] for d in expected) for k, v in expected[0].items() + } + expected = {k: torch.stack(shards) for k, shards in expected.items()} + for key in result: + self.assertEqual(result[key], expected[key], atol=0, rtol=1.5e-3) + + @tf32_on_and_off(0.005) + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_per_sample_grads_embeddingnet(self, device, mechanism): + class SampleNet(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.emb = nn.Embedding(vocab_size, 16) + self.fc1 = nn.Linear(16, 16) + self.fc2 = nn.Linear(16, 2) + + def forward(self, x): + x = self.emb(x) + x = torch.transpose(x, -1, -2) + x = torch.mean(x, -1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + return x + + def name(self): + return "SampleNet" + + # Create our inputs... + vocab_size = 1000 + batch_shape = [64] + words_per_sentence = 5 + data = torch.randint( + 0, vocab_size, (*batch_shape, words_per_sentence), device=device + ) + targets = torch.randint(0, 1, (*batch_shape,), device=device) + + # Construct our module + net = SampleNet(vocab_size).to(device=device) + criterion = nn.CrossEntropyLoss() + + net_func, weights = _get_weights_and_functional_call(net, mechanism) + + def compute_loss(weights, data, target): + output = net_func(weights, data) + result = criterion(output, target) + return result + + expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)] + result = vmap(partial(grad(compute_loss), weights))(data, targets) + self._compare_expected_and_result(expected, result, mechanism) + + def test_log_softmax(self, device): + x = torch.randn(3, 5, device=device) + v = torch.randn(5, device=device) + + def foo(x, v): + _, vjp_fn = vjp(partial(torch.log_softmax, dim=-1), x) + return vjp_fn(v)[0] + + result = vmap(foo, (0, None))(x, v) + + v = v.expand_as(x) + x.requires_grad_() + output = torch.log_softmax(x, dim=-1) + output.backward(v) + self.assertEqual(result, x.grad) + + +jacrev_and_jacfwd = parametrize( + "jacapi", [subtest(jacrev, name="jacrev"), subtest(jacfwd, name="jacfwd")] +) + +FIXME_jacrev_only = parametrize("jacapi", [subtest(jacrev, name="jacrev")]) + + +@markDynamoStrictTest +class TestJac(VmapTearDownMixin, TestCase): + @jacrev_and_jacfwd + def test_simple(self, device, jacapi): + x = torch.randn(3, device=device) + y = jacapi(torch.sin)(x) + expected = torch.diagflat(x.cos()) + assert torch.allclose(y, expected) + + @jacrev_and_jacfwd + def test_simple_not_flat(self, device, jacapi): + x = torch.randn(2, 3, device=device) + y = jacapi(torch.sin)(x) + expected = torch.diagflat(x.view(-1).cos()) + expected = expected.view(2, 3, 2, 3) + assert torch.allclose(y, expected) + + @jacrev_and_jacfwd + def test_take(self, device, jacapi): + x = torch.rand(5) + + def func(x): + y = torch.ones(3, dtype=torch.long) + z = torch.take(x, y) + return z + + self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x)) + + @jacrev_and_jacfwd + def test_diff_numel(self, device, jacapi): + x = torch.randn(2, 4, device=device) + + # Tensor[2, 4] -> Tensor[3, 1] + def f(x): + return x[0, 1:].unsqueeze(-1) + + y = jacapi(f)(x) + self.assertEqual(y.shape, (3, 1, 2, 4)) + + expected = x.new_zeros(3, 1, 2, 4) + expected[0, 0, 0, 1] = 1 + expected[1, 0, 0, 2] = 1 + expected[2, 0, 0, 3] = 1 + self.assertEqual(y, expected) + + @jacrev_and_jacfwd + def test_vmap_on_jac_simple(self, device, jacapi): + x = torch.randn(2, 3, device=device) + y = vmap(jacapi(torch.sin))(x) + expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)]) + assert torch.allclose(y, expected) + + @jacrev_and_jacfwd + def test_nested_jac_simple(self, device, jacapi): + def foo(x): + return x.sin().sum() + + x = torch.randn(3, device=device) + y = jacapi(jacapi(foo))(x) + expected = torch.diagflat(-x.sin()) + assert torch.allclose(y, expected) + + @jacrev_and_jacfwd + def test_multiple_args(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=1)(x, y) + expected = torch.diagflat(x) + assert torch.allclose(z, expected) + + @jacrev_and_jacfwd + def test_multiple_outputs_multiple_argnums(self, device, jacapi): + def f(x, y): + return 2 * x + 3 * y, 4 * x + 5 * y + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(f, argnums=(0, 1))(x, y) + expected_out0_x = torch.diagflat(torch.full_like(x, 2)) + expected_out0_y = torch.diagflat(torch.full_like(y, 3)) + expected_out1_x = torch.diagflat(torch.full_like(x, 4)) + expected_out1_y = torch.diagflat(torch.full_like(y, 5)) + + self.assertEqual(len(z), 2) + self.assertTrue(isinstance(z, tuple)) + self.assertEqual(len(z[0]), 2) + self.assertTrue(isinstance(z[0], tuple)) + self.assertEqual(z[0][0], expected_out0_x) + self.assertEqual(z[0][1], expected_out0_y) + self.assertEqual(z[1][0], expected_out1_x) + self.assertEqual(z[1][1], expected_out1_y) + + @jacrev_and_jacfwd + def test_multiple_outputs_single_argnums(self, device, jacapi): + def f(x, y): + return 2 * x + 3 * y, 4 * x + 5 * y + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + expected_out0_x = torch.diagflat(torch.full_like(x, 2)) + expected_out1_x = torch.diagflat(torch.full_like(x, 4)) + + z = jacapi(f, argnums=0)(x, y) + self.assertEqual(len(z), 2) + self.assertTrue(isinstance(z, tuple)) + self.assertEqual(z, (expected_out0_x, expected_out1_x)) + + z = jacapi(f, argnums=(0,))(x, y) + self.assertEqual(len(z), 2) + self.assertTrue(isinstance(z, tuple)) + self.assertTrue(isinstance(z[0], tuple)) + self.assertEqual(z, ((expected_out0_x,), (expected_out1_x,))) + + @jacrev_and_jacfwd + def test_multiple_outputs_pytree(self, device, jacapi): + def f(x, y): + return {"left": 2 * x + 3 * y, "right": 4 * x + 5 * y} + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(f, argnums=(0, 1))(x, y) + expected_left_x = torch.diagflat(torch.full_like(x, 2)) + expected_left_y = torch.diagflat(torch.full_like(y, 3)) + expected_right_x = torch.diagflat(torch.full_like(x, 4)) + expected_right_y = torch.diagflat(torch.full_like(y, 5)) + expected = { + "left": (expected_left_x, expected_left_y), + "right": (expected_right_x, expected_right_y), + } + self.assertTrue(isinstance(z, dict)) + self.assertTrue(isinstance(z["left"], tuple)) + self.assertTrue(isinstance(z["right"], tuple)) + self.assertEqual(z, expected) + + @jacrev_and_jacfwd + def test_multiple_inputs_pytree(self, device, jacapi): + def f(a, b, c): + a0, a1 = a + return a0 + a1 * 2 + b * 3 + c * 4 + + x = torch.randn([], device=device) + args = ((x, x), x, x) + + result = jacapi(f, argnums=(0, 1, 2))(*args) + expected = ( + (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), + torch.tensor(3.0, device=device), + torch.tensor(4.0, device=device), + ) + self.assertEqual(result, expected) + + result = jacapi(f, argnums=(0,))(*args) + expected = ( + (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), + ) + self.assertEqual(result, expected) + + result = jacapi(f)(*args) + expected = (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_dimensionality(self, device, jacapi): + def f(x): + return x + + x = torch.randn([], device=device) + result = jacapi(f)(x) + self.assertEqual(result.dim(), 0) + self.assertEqual(result, torch.ones_like(x)) + + x = torch.randn([1], device=device) + result = jacapi(f)(x) + self.assertEqual(result.dim(), 2) + self.assertEqual(result, x.new_ones(1, 1)) + + @jacrev_and_jacfwd + def test_aux_tensor(self, device, jacapi): + def f(x): + y = x.clone() + return y, y.cos() + + x = torch.randn(3, device=device) + result, aux = jacapi(f, has_aux=True)(x) + + self.assertEqual(result, torch.eye(3, 3, device=device)) + self.assertEqual(aux, x.cos()) + + @jacrev_and_jacfwd + def test_aux_pytree(self, device, jacapi): + def f(x): + y = x.clone() + return y, {"a": y.cos(), "b": [y.tan()]} + + x = torch.randn(3, device=device) + + result, aux = jacapi(f, has_aux=True)(x) + self.assertEqual(result, torch.eye(3, 3, device=device)) + _, expected_aux = f(x) + self.assertEqual(aux, expected_aux) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = jacapi(lambda x: (x, aux), has_aux=True)(x) + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = jacapi(lambda x: (x, [x, aux]), has_aux=True)(x) + + @jacrev_and_jacfwd + def test_outputs_can_any_pytree(self, device, jacapi): + x = torch.randn(2, 3, device=device) + + for output in [None, ()]: + with self.assertRaisesRegex( + RuntimeError, + r"(vjp|jvp).+: Expected f to be a function that has non-empty output", + ): + jacapi(lambda _: output)(x) + + for output in [1, True, 12.2, "abc"]: + with self.assertRaisesRegex( + RuntimeError, + r"(vjp|jvp).+: expected f\(\*primals\) to return only tensors", + ): + jacapi(lambda _: output)(x) + + # Check list output + out = jacapi(lambda x: [x, x.sum()])(x) + assert isinstance(out, list) and len(out) == 2 + + # Check dict output + out = jacapi(lambda x: {"x": x, "xsum": x.sum()})(x) + assert isinstance(out, dict) and len(out) == 2 and "xsum" in out + + def composite_output(x): + out = x.sum() + return [ + (out, {"a": x, "out": [x, out]}), + ] + + out = jacapi(composite_output)(x) + assert isinstance(out, list) + assert isinstance(out[0], tuple) and isinstance(out[0][1], dict) + + @jacrev_and_jacfwd + def test_multiple_inputs_outputs_pytree(self, device, jacapi): + def f(a, b, c): + a0, a1 = a + return a0 + a1 * 2, {"foo": b * 3 + c * 4} + + x = torch.randn([], device=device) + zero = torch.zeros([], device=device) + args = ((x, x), x, x) + + result = jacapi(f)(*args) + expected = ( + (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), + {"foo": (zero, zero)}, + ) + self.assertEqual(result, expected) + + result = jacapi(f, argnums=(0,))(*args) + expected = ( + ((torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)),), + {"foo": ((zero, zero),)}, + ) + self.assertEqual(result, expected) + + result = jacapi(f, argnums=(0, 1))(*args) + expected = ( + ( + (torch.tensor(1.0, device=device), torch.tensor(2.0, device=device)), + zero, + ), + {"foo": ((zero, zero), torch.tensor(3.0, device=device))}, + ) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_multiple_inputs_outputs_pytree_multidim(self, device, jacapi): + def f(dct): + a = dct["a"] + b = dct["b"] + return {"c": a.sin(), "d": b.cos()} + + x = torch.randn(3, device=device) + args = ({"a": x, "b": x},) + + result = jacapi(f)(*args) + expected = { + "c": {"a": x.cos().diagflat(), "b": x.new_zeros(3, 3)}, + "d": {"a": x.new_zeros(3, 3), "b": -x.sin().diagflat()}, + } + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_unrelated_input(self, device, jacapi): + def f(x, y): + return x + + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + + result = jacapi(f, argnums=(0, 1))(x, y) + expected0 = torch.eye(6, 6, device=device).view(2, 3, 2, 3) + expected1 = y.new_zeros(2, 3, 2, 3) + expected = (expected0, expected1) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_unrelated_output(self, device, jacapi): + y = torch.randn(2, 3, device=device) + + def f(x): + return y + + x = torch.randn(2, 3, device=device) + + result = jacapi(f)(x) + expected = x.new_zeros(2, 3, 2, 3) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_empty_output(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + + def f(x, y): + return () + + with self.assertRaisesRegex(RuntimeError, "xpected"): + jacapi(f)(x, y) + + @jacrev_and_jacfwd + def test_argnums_tuple(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=(0, 1))(x, y) + expected0 = torch.diagflat(y) + expected1 = torch.diagflat(x) + assert len(z) == 2 + assert torch.allclose(z[0], expected0) + assert torch.allclose(z[1], expected1) + + @jacrev_and_jacfwd + def test_argnums_effect_on_return(self, device, jacapi): + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=(0,))(x, y) + expected0 = torch.diagflat(y) + assert isinstance(z, tuple) + assert len(z) == 1 + assert torch.allclose(z[0], expected0) + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(torch.multiply, argnums=0)(x, y) + expected0 = torch.diagflat(y) + assert isinstance(z, torch.Tensor) + assert torch.allclose(z, expected0) + + @jacrev_and_jacfwd + def test_argnums_defaults_to_zero(self, device, jacapi): + def f(x, y): + return x * 2 + y * 3 + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + z = jacapi(f)(x, y) + expected = torch.diagflat(torch.full_like(x, 2)) + self.assertEqual(z, expected) + + @jacrev_and_jacfwd + def test_empty_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "must be non-empty"): + jacapi(torch.sin, argnums=())(x) + + @jacrev_and_jacfwd + def test_out_of_bounds_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"): + jacapi(torch.sin, argnums=2)(x) + + @jacrev_and_jacfwd + def test_negative_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "only 1 positional inputs"): + jacapi(torch.sin, argnums=-2)(x) + + @jacrev_and_jacfwd + def test_repeated_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "must be unique"): + jacapi(torch.sin, argnums=(0, 0))(x) + + @jacrev_and_jacfwd + def test_float_argnums(self, device, jacapi): + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "must be int or Tuple"): + jacapi(torch.sin, argnums=0.0)(x) + with self.assertRaisesRegex(RuntimeError, "must be int"): + jacapi(torch.multiply, argnums=(1, 0.0))(x, x) + + def test_hessian_simple(self, device): + def f(x): + return x.sin() + + x = torch.randn(3, device=device) + hessian(f)(x) + + def _test_against_reference(self, f, inputs, jacapi): + def foo(inputs): + return f(*inputs) + + expected = torch.autograd.functional.jacobian(f, inputs) + result = jacapi(foo)(inputs) + self.assertEqual(result, expected) + + @jacrev_and_jacfwd + def test_against_reference_simple(self, device, jacapi): + def f(x): + return 3 * x**2 + + x = torch.randn(2, 3, 5, device=device) + self._test_against_reference(f, (x,), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_multi_input(self, device, jacapi): + def f(x, y): + return (x.cos() * x) @ y.sin() + + x = torch.randn(2, 3, device=device) + y = torch.randn(3, 5, device=device) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_multi_input_multi_output(self, device, jacapi): + def f(x, y): + return (x * x) @ y, x @ (x.sum(1) * y), y.sum() + + x = torch.randn(5, 3, device=device) + y = torch.randn(3, 5, device=device) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_unrelated_outputs(self, device, jacapi): + def f(x, y): + return x, y, x, y + + x = torch.randn(2, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_zero_dim(self, device, jacapi): + # zero-dim output + def f(x, y): + return x.sum(), y.sum(), x * y + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y), jacapi) + + # zero-dim input + def g(x): + return torch.stack([x, x, x]) + + x = torch.randn([], device=device) + self._test_against_reference(g, (x,), jacapi) + + # Mixed zero-dim input / zero-dim output + def h(x, y): + return y.sum(), x * y + + x = torch.randn([], device=device) + y = torch.randn(1, device=device) + self._test_against_reference(h, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_correctness_different_devices(self, device, jacapi): + def f(x, y): + return x * y, (x * y).to(device=device) + + x = torch.randn(3) + y = torch.randn(3) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_against_reference_default_arg(self, device, jacapi): + def f(x, y, z=3.0): + return x * y * z + + x = torch.randn(3, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y), jacapi) + + @jacrev_and_jacfwd + def test_inplace(self, device, jacapi): + def f(x, y): + y.copy_(x) + return y + + out = jacapi(f, argnums=0) # x is differentiable + x, y = torch.randn(2, device=device), torch.randn(2, device=device) + self.assertEqual(out(x, y), torch.eye(y.shape[0])) + + # testing tuple of argnums with the example that raised this issue originally + def g(x, y, z): + x[:2] = y + return torch.vstack([(x**2).sum(), (z**3).sum()]) + + out = jacapi(g, argnums=(1, 2)) + x, y, z = ( + torch.randn(3, device=device), + torch.randn(2, device=device), + torch.randn(2, device=device), + ) + + expected_out = ( + torch.zeros(2, 1, 2, device=device), + torch.zeros(2, 1, 2, device=device), + ) + expected_out[0][0][0] = 2 * y # top left corner + expected_out[1][1][0] = 3 * (z**2) # bottom right corner + + out_val = out(x, y, z) + self.assertEqual(out_val, expected_out) + + @parametrize("_preallocate_and_copy", (True, False)) + def test_chunk_jacrev(self, device, _preallocate_and_copy): + x = torch.randn(10, 2, device=device) + y = torch.randn(1, 2, device=device) + + def f(x, y): + return (x.sin(), x + y), (x + 2, x.sum()) + + for chunk_size in (1, 2, 3, 4, 7, 10, 1000): + expected = jacrev(f, argnums=(0, 1))(x, y) + actual = jacrev( + f, + argnums=(0, 1), + chunk_size=chunk_size, + _preallocate_and_copy=_preallocate_and_copy, + )(x, y) + self.assertEqual(actual, expected) + + err_msg = "jacrev: `chunk_size` should be greater than 0." + with self.assertRaisesRegex(ValueError, err_msg): + jacrev(f, argnums=(0,), chunk_size=0)(x, y) + + with self.assertRaisesRegex(ValueError, err_msg): + jacrev(f, argnums=(0,), chunk_size=-2)(x, y) + + @parametrize("_preallocate_and_copy", (True, False)) + def test_chunk_jacrev_composition(self, device, _preallocate_and_copy): + x = torch.randn(10, 2, device=device) + chunk_size = 3 + + def f(x): + return (x.sin(), x), (x + 2, x.sum()) + + expected = vmap(jacrev(jacrev(f)))(x) + actual = vmap( + jacrev( + jacrev( + f, + chunk_size=chunk_size, + _preallocate_and_copy=_preallocate_and_copy, + ), + chunk_size=chunk_size, + ) + )(x) + self.assertEqual(actual, expected) + + # https://github.com/pytorch/pytorch/issues/127036 + @xfailIfTorchDynamo + @parametrize("_preallocate_and_copy", (True, False)) + def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): + # With chunk_size=1, we shouldn't `vmap` and hence not be limited + # by it's constraints. + x = torch.randn(3, 3, device=device) + + # Function with Dynamic Op in Backward. + # This should cause jacrev/vmap(vjp) to fail. + class IdentityWithDynamicBackwardOp(torch.autograd.Function): + @staticmethod + def forward(input): + return input + + @staticmethod + def setup_context(ctx, inputs, output): + pass + + @staticmethod + def backward(ctx, grad_output): + # dynamic op in backward pass. + grad_output.nonzero() + return grad_output + + def f(x): + return IdentityWithDynamicBackwardOp.apply(x) + + # With `chunk_size=1`, we don't use vmap. So the following should work. + jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy) + actual = jacfn(x) + expected = torch.autograd.functional.jacobian(f, x, vectorize=False) + self.assertEqual(actual, expected) + + # Should fail with `chunk_size=2`. + msg = ( + r"vmap: We do not support batching operators that can output dynamic shape." + ) + with self.assertRaisesRegex(RuntimeError, msg): + jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x) + + def test_complex_error(self, device): + # Verify complex input raises error + # C -> C + def fn(x): + return x.conj() + + x = torch.randn(1, device=device, dtype=torch.cfloat) + + with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all inputs"): + jacrev(fn)(x) + + with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all inputs"): + jacfwd(fn)(x) + + # Verify complex output raises error + # R -> C + def fn(x): + return torch.conj(x * 0.5j) + + x = torch.randn(1, device=device, dtype=torch.float) + + with self.assertRaisesRegex(RuntimeError, "jacrev: Expected all outputs"): + jacrev(fn)(x) + + with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all outputs"): + jacfwd(fn)(x) + + @jacrev_and_jacfwd + def test_jac_with_non_tensor_args(self, device, jacapi): + def f(t, int_x): + return t + int_x + + t = torch.randn(3, 3, device=device) + + actual = jacapi(f)(t, 3) + expected = torch.autograd.functional.jacobian(partial(f, int_x=3), t) + self.assertEqual(actual, expected) + + +@markDynamoStrictTest +class TestHessian(TestCase): + def _test_against_reference(self, f, inputs): + def foo(inputs): + return f(*inputs) + + expected = torch.autograd.functional.hessian(f, inputs) + result = hessian(foo)(inputs) + self.assertEqual(result, expected) + + def test_hessian_vectorize_correctness_simple(self, device): + def f(x): + return (3 * x**2).sum() + + x = torch.randn(2, 3, 5, device=device) + self._test_against_reference(f, (x,)) + + def test_hessian_vectorize_correctness_multi_input(self, device): + def f(x, y, z): + return ((x.relu() * x) @ y.sin() @ z).sum() + + x = torch.randn(2, 3, device=device) + y = torch.randn(3, 5, device=device) + z = torch.randn(5, 5, device=device) + self._test_against_reference(f, (x, y, z)) + + def test_hessian_vectorize_correctness_unrelated_outputs(self, device): + # output unrelated to one input + def f(x, y): + return (x**2).sum() + + x = torch.randn(2, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y)) + + # output unrelated to all inputs + def f(x, y): + return torch.ones([]) + + x = torch.randn(2, device=device) + y = torch.randn(3, device=device) + self._test_against_reference(f, (x, y)) + + def test_jacfwd_different_levels(self, device): + # Test case from: + # https://github.com/pytorch/functorch/issues/597 + b = 8 + n = 100 + d = 2 + x1 = torch.randn(b, n, d, device=device) + x2 = x1 + A = 0.1 * torch.randn(b, d, d, device=device) + + def loss(A, x1, x2): + x2_hat = (A @ (x1.T)).T + res = x2 - x2_hat + res_sqr = res**2 + return res_sqr.sum() + + hess1 = vmap(jacrev(jacrev(loss)))(A, x1, x2) + hess2 = vmap(hessian(loss))(A, x1, x2) + self.assertEqual(hess2, hess1) + + +@markDynamoStrictTest +class TestJvp(TestCase): + def test_inplace_on_captures(self, device): + x = torch.tensor([1.0, 2.0, 3.0], device=device) + captured = torch.randn(3, device=device) + + def foo(x): + captured.copy_(x) + return (x * captured).sum() + + with self.assertRaisesRegex(RuntimeError, "mutate a captured Tensor"): + grad(foo)(x) + + def test_simple(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + result = jvp(torch.sin, (x,), (t,)) + expected = (x.sin(), x.cos() * t) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_multiple_inputs(self, device): + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + ty = torch.randn(2, 3, device=device) + + def f(x, y): + return x * y + + result = jvp(f, (x, y), (tx, ty)) + expected = (x * y, y * tx + x * ty) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_pytree_inputs(self, device): + def f(x, y, z): + a, b = x + return a + 2 * b + 3 * y + 4 * z + + one = torch.tensor(1.0, device=device) + primal_outs, tangent_outs = jvp( + f, ((one, one), one, one), ((one, one), one, one) + ) + self.assertEqual(primal_outs, one * 10) + self.assertEqual(tangent_outs, one * 10) + + def test_pytree_inputs_error_cases(self, device): + def f(x): + return x + + one = torch.tensor(1.0, device=device) + + with self.assertRaisesRegex(RuntimeError, "Expected primals to be a tuple"): + jvp(f, one, one) + with self.assertRaisesRegex(RuntimeError, "same python structure"): + jvp(f, ((one, one), one), (one, one)) + with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): + jvp(f, ((one, one), 1), ((one, one), one)) + with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): + jvp(f, ((one, one), 1), ((1, one), one)) + with self.assertRaisesRegex(RuntimeError, "at least one Tensor"): + jvp(f, ((),), ((),)) + + def test_unrelated_input(self, device): + def f(x, y): + return x + + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + ty = torch.randn(2, 3, device=device) + + result = jvp(f, (x, y), (tx, ty)) + expected = (x, tx) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_unrelated_output(self, device): + y = torch.randn(2, 3, device=device) + + def f(x): + return y + + x = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + + result = jvp(f, (x,), (tx,)) + expected = (y, torch.zeros_like(y)) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_strict_mode(self, device): + y = torch.randn(2, 3, device=device) + + def f(x): + return x, y + + x = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + + with self.assertRaisesRegex(RuntimeError, "strict"): + jvp(f, (x,), (tx,), strict=True) + + def test_multiple_outputs(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + def f(x): + return torch.sin(x), torch.cos(x) + + result = jvp(f, (x,), (t,)) + expected = (f(x), (x.cos() * t, -x.sin() * t)) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_multiple_inputs_outputs(self, device): + x = torch.randn(2, 3, device=device) + y = torch.randn(2, 3, device=device) + tx = torch.randn(2, 3, device=device) + ty = torch.randn(2, 3, device=device) + + def f(x, y): + return 2 * x + 3 * y, 4 * x + 5 * y + + result = jvp(f, (x, y), (tx, ty)) + expected = (f(x, y), f(tx, ty)) + self.assertTrue(isinstance(result, tuple)) + self.assertEqual(result, expected) + + def test_jvp_new_tensor(self): + def f(x): + y = x.new_tensor(0.5) + return x + y + + x = torch.rand(10, 10) + tangents = torch.zeros_like(x) + actual = jvp(f, (x,), (tangents,)) + expected = (f(x), torch.zeros_like(x)) + self.assertEqual(actual, expected) + + def test_primals_tangents_length_mismatch(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + msg = "same python structure" + with self.assertRaisesRegex(RuntimeError, msg): + jvp(torch.sin, (x,), (t, t)) + with self.assertRaisesRegex(RuntimeError, msg): + jvp(torch.sin, (x, x), (t, t, t)) + + def test_nonempty_primals_and_tangents(self, device): + with self.assertRaisesRegex(RuntimeError, "at least one Tensor"): + jvp(torch.sin, (), ()) + + def test_inputs_are_tuples_of_tensors(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + with self.assertRaisesRegex(RuntimeError, "be a tuple"): + jvp(torch.sin, x, (t,)) + with self.assertRaisesRegex(RuntimeError, "same python structure"): + jvp(torch.sin, (x,), t) + with self.assertRaisesRegex(RuntimeError, "same python structure"): + jvp(torch.sin, (x,), [t]) + with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): + jvp(torch.sin, (1.0,), (t,)) + with self.assertRaisesRegex(RuntimeError, "only contain Tensors"): + jvp(torch.sin, (x,), (1.0,)) + + def test_outputs_can_any_pytree(self, device): + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + for output in [None, ()]: + with self.assertRaisesRegex( + RuntimeError, + r"jvp\(f, primals, tangents\): Expected f to be a function that has non-empty output", + ): + jvp(lambda _: output, (x,), (t,)) + + for output in [1, True, 12.2, "abc"]: + with self.assertRaisesRegex( + RuntimeError, + r"jvp\(f, primals, tangents\): expected f\(\*primals\) to return only tensors", + ): + jvp(lambda _: output, (x,), (t,)) + + # Check list output + out = jvp(lambda x: [x, x.sum()], (x,), (t,)) + for i in range(2): + assert isinstance(out[i], list) and len(out[i]) == 2 + + # Check dict output + out = jvp(lambda x: {"x": x, "xsum": x.sum()}, (x,), (t,)) + for i in range(2): + assert isinstance(out[i], dict) and len(out[i]) == 2 and "xsum" in out[i] + + def composite_output(x): + out = x.sum() + return [ + (out, {"a": x, "out": [x, out]}), + ] + + out = jvp(composite_output, (x,), (t,)) + for i in range(2): + assert isinstance(out[i], list) + assert isinstance(out[i][0], tuple) and isinstance(out[i][0][1], dict) + + def test_aux_tensor(self, device): + x = torch.randn(3, device=device) + t = torch.randn(3, device=device) + + with self.assertRaisesRegex( + RuntimeError, + r"jvp\(f, primals, tangents\): output of function f should be a tuple", + ): + jvp(lambda t: [t, t], (x,), (t,), has_aux=True) + + with self.assertRaisesRegex( + RuntimeError, + r"jvp\(f, primals, tangents\): output of function f should be a tuple", + ): + jvp(lambda t: (t, t + 2, t + 3), (x,), (t,), has_aux=True) + + def f(z): + y = z.sin() + return y, z.cos() + + out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True) + self.assertEqual(aux, x.cos()) + self.assertEqual(out, x.sin()) + self.assertEqual(jvp_out, t * x.cos()) + + def test_aux_pytree(self, device): + def f(x): + y = x.sin() + return y, {"a": x.cos(), "b": [x.tan()]} + + x = torch.randn(3, device=device) + t = torch.randn(3, device=device) + + out, jvp_out, aux = jvp(f, (x,), (t,), has_aux=True) + expected_out, expected_aux = f(x) + self.assertEqual(out, expected_out) + self.assertEqual(aux, expected_aux) + self.assertEqual(jvp_out, t * x.cos()) + + for aux in [1, 1.0, "abc"]: + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = jvp(lambda x: (x, aux), (x,), (t,), has_aux=True) + with self.assertRaisesRegex( + RuntimeError, r"Expected tensors, got unsupported type" + ): + _ = jvp(lambda x: (x, [x, aux]), (x,), (t,), has_aux=True) + + def test_autograd_function_disables_fwd_grad(self, device): + # Sanity check. We don't really assume this anywhere so + # it's fine if this breaks one day. + class MySquare(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + enabled = fwAD._is_fwd_grad_enabled() + self.assertFalse(enabled) + return x * x + + @staticmethod + def backward(ctx, gx): + return gx + + x = torch.randn(3, requires_grad=True) + MySquare.apply(x) + + def test_disable_fwd_grad_outside(self, device): + x = torch.randn([], device=device) + t = torch.ones_like(x) + with fwAD._set_fwd_grad_enabled(False): + _, y = jvp(torch.sin, (x,), (t,)) + self.assertEqual(y, x.cos()) + + def test_disable_fwd_grad_inside(self, device): + def f(x): + with fwAD._set_fwd_grad_enabled(False): + shift = x**2 + return x**2 - shift + + x = torch.randn([], device=device) + t = torch.ones_like(x) + _, y = jvp(f, (x,), (t,)) + self.assertEqual(y, 2 * x) + _, y = jvp(lambda x: jvp(f, (x,), (t,))[1], (x,), (t,)) + self.assertEqual(y, 2) + + def test_disable_fwd_grad_mixed(self, device): + def f(x): + with fwAD._set_fwd_grad_enabled(False): + shift = x**2 + return x**2 - shift + + x = torch.randn([], device=device) + t = torch.ones_like(x) + with fwAD._set_fwd_grad_enabled(True): + _, y = jvp(f, (x,), (t,)) + + self.assertEqual(y, 2 * x) + + def test_jvp_inside_autograd_function(self, device): + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + t = torch.ones_like(x) + _, neg_sin_x = jvp(torch.cos, (x,), (t,)) + ctx.save_for_backward(x) + return -neg_sin_x + + @staticmethod + def backward(ctx, gx): + (x,) = ctx.saved_tensors + t = torch.ones_like(x) + _, cos_x = jvp(torch.sin, (x,), (t,)) + return gx * cos_x + + x = torch.randn([], device=device, requires_grad=True) + y = MySin.apply(x) + self.assertEqual(y, x.sin()) + + (gx,) = torch.autograd.grad(y, x) + self.assertEqual(gx, x.cos()) + + def test_zerotensor_vmapjvp_interaction(self, device): + dummy = torch.ones(4, 1) + x = torch.randn(4, 2) + x_tangent = torch.randn(2) + + def push_jvp(dummy, x): + result = jvp(torch.cov, (x,), (x_tangent,)) + return result + + # Should not error + vmap(vmap(push_jvp, (0, None)))(dummy, x) + + +@markDynamoStrictTest +class TestLinearize(TestCase): + @dtypes(torch.float) + def test_linearize_basic(self, device, dtype): + x_p = make_tensor((3, 1), device=device, dtype=dtype) + x_t = make_tensor((3, 1), device=device, dtype=dtype) + + def fn(x): + return x.cos() + + actual_output, jvp_fn = linearize(fn, x_p) + actual_jvp = jvp_fn(x_t) + expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,)) + self.assertEqual(actual_output, expected_output) + self.assertEqual(actual_jvp, expected_jvp) + + @dtypes(torch.float) + @unittest.skipIf( + TEST_CUDA_MEM_LEAK_CHECK, + "Leaking memory, see https://github.com/pytorch/pytorch/pull/150059 for example", + ) + def test_linearize_return(self, device, dtype): + x_p = make_tensor((3, 1), device=device, dtype=dtype) + x_t = make_tensor((3, 1), device=device, dtype=dtype) + + def fn(x): + return (x.cos(), x.sum()) + + actual_output, jvp_fn = linearize(fn, x_p) + actual_jvp = jvp_fn(x_t) + expected_output, expected_jvp = jvp(fn, (x_p,), (x_t,)) + self.assertEqual(actual_output, expected_output) + self.assertEqual(actual_jvp, expected_jvp) + + @dtypes(torch.float) + @unittest.skipIf( + TEST_CUDA_MEM_LEAK_CHECK, + "Leaking memory, see https://github.com/pytorch/pytorch/pull/150059 for example", + ) + def test_linearize_composition_vmap(self, device, dtype): + x_p = make_tensor((3, 1), device=device, dtype=dtype) + x_t = make_tensor((3, 3, 1), device=device, dtype=dtype) + + def fn(x): + return (x.cos(), x.sum()) + + _, jvp_fn = linearize(fn, x_p) + actual_batched_jvp = vmap(jvp_fn)(x_t) + + def jvp_fn(x_t): + return jvp(fn, (x_p,), (x_t,))[1] + + expected_batched_jvp = vmap(jvp_fn)(x_t) + + self.assertEqual(actual_batched_jvp, expected_batched_jvp) + + @dtypes(torch.float) + @unittest.skipIf( + TEST_CUDA_MEM_LEAK_CHECK, + "Leaking memory, see https://github.com/pytorch/pytorch/pull/150059 for example", + ) + def test_linearize_composition_grad(self, device, dtype): + x_p = make_tensor((3,), device=device, dtype=dtype) + x_t = make_tensor((3,), device=device, dtype=dtype) + + def fn(x): + z = torch.ones(3, device=device, dtype=dtype) + return grad(lambda x: z @ x)(x) + + _, jvp_fn = linearize(fn, x_p) + actual_batched_jvp = jvp_fn(x_t) + + def jvp_fn(x_t): + return jvp(fn, (x_p,), (x_t,))[1] + + expected_batched_jvp = jvp_fn(x_t) + + self.assertEqual(actual_batched_jvp, expected_batched_jvp) + + @dtypes(torch.float) + @unittest.skipIf( + TEST_CUDA_MEM_LEAK_CHECK, + "Leaking memory, see https://github.com/pytorch/pytorch/pull/150059 for example", + ) + def test_linearize_nested_input_nested_output(self, device, dtype): + x_p = make_tensor((3, 1), device=device, dtype=dtype) + x_t = make_tensor((3, 1), device=device, dtype=dtype) + y_p = make_tensor((3, 1), device=device, dtype=dtype) + y_t = make_tensor((3, 1), device=device, dtype=dtype) + z_p = make_tensor((3, 1), device=device, dtype=dtype) + z_t = make_tensor((3, 1), device=device, dtype=dtype) + + def fn(arg): + x = arg["x"] + y = arg["yz"][0] + z = arg["yz"][1] + + return {"a": x.sum(), "b": {"c": y + z, "d": (x * z, y.exp())}} + + inp_p = {"x": x_p, "yz": (y_p, z_p)} + inp_t = {"x": x_t, "yz": (y_t, z_t)} + actual_output, jvp_fn = linearize(fn, inp_p) + actual_jvp = jvp_fn(inp_t) + + expected_output, expected_jvp = jvp(fn, (inp_p,), (inp_t,)) + + self.assertEqual(actual_output, expected_output) + self.assertEqual(actual_jvp, expected_jvp) + + @onlyOn(["cuda", "xpu"]) + def test_linearize_errors(self): + dtype = torch.float + device = torch.device("cpu") + x_p = make_tensor((3, 1), device=device, dtype=dtype) + x_t = make_tensor((3, 1), device=device, dtype=dtype) + + def fn(x): + return x.sin() + + _, jvp_fn = linearize(fn, x_p) + + with self.assertRaisesRegex( + RuntimeError, "to have the same argspec as the primals" + ): + jvp_fn((x_t, x_t)) + + with self.assertRaisesRegex( + RuntimeError, "in flattened pytree doesn't match the shape" + ): + jvp_fn(x_t.unsqueeze(0)) + + with self.assertRaisesRegex( + RuntimeError, "in flattened pytree doesn't match the dtype" + ): + jvp_fn(x_t.to(torch.double)) + + with self.assertRaisesRegex( + RuntimeError, "in flattened pytree doesn't match the device" + ): + jvp_fn(x_t.to(torch.device(device))) + + +# The tests here follow the cases in [Forward Grad View/inplace] +# https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/autograd_meta.cpp#L18-L43 +@markDynamoStrictTest +class TestVmapJvpInplaceView(TestCase): + # Case 1 in [Forward Grad View/inplace] + def test_all_dual_no_view(self, device): + B = 2 + + def push_jvp(f): + def inner(x, xt, y, yt): + return jvp(f, (x, y), (xt, yt)) + + return inner + + def f(x, y): + x.copy_(y) + return x + + x = torch.randn(3, B, device=device) + xt = torch.randn(3, B, device=device) + y = torch.randn(3, B, device=device) + yt = torch.randn(3, B, device=device) + out, out_tangent = vmap(push_jvp(f), in_dims=1)(x, xt, y, yt) + self.assertEqual(out, x.movedim(1, 0)) + self.assertEqual(out_tangent, yt.movedim(1, 0)) + + x = torch.randn(3, B, device=device) + xt = torch.randn(3, B, device=device) + y = torch.randn(3, 3, device=device)[:, 1] + yt = torch.randn(6, device=device)[::2] + out, out_tangent = vmap(push_jvp(f), in_dims=(1, 1, None, None))(x, xt, y, yt) + self.assertEqual(out, x.movedim(1, 0)) + self.assertEqual(out_tangent, yt.expand(B, 3)) + + # Case 2 in [Forward Grad View/inplace] + def test_all_dual_base_view_inplace(self, device): + B = 2 + + def push_jvp(f): + def inner(x, xt, y, yt): + return jvp(f, (x, y), (xt, yt)) + + return inner + + # with view, propagate from view to base + def f(x, y): + view = x[:, ::2] + view.copy_(y) + return view, x + + orig_x = torch.randn(2, 6, B, device=device) + orig_xt = torch.randn(2, 6, B, device=device) + x = orig_x.clone() + xt = orig_xt.clone() + y = torch.randn(2, B, 3, device=device) + yt = torch.randn(2, B, 3, device=device) + out, out_tangent = vmap(push_jvp(f), in_dims=(2, 2, 1, 1))(x, xt, y, yt) + + expected_out = vmap(f, in_dims=(2, 1))(orig_x.clone(), y) + self.assertEqual(out[0], expected_out[0]) + self.assertEqual(out[1], expected_out[1]) + + self.assertEqual(out_tangent[0], yt.movedim(1, 0)) + + expected_x_tangent = orig_xt.movedim(-1, 0).clone() + expected_x_tangent[:, :, ::2].copy_(yt.movedim(1, 0)) + self.assertEqual(out_tangent[1], expected_x_tangent) + + expected = orig_x.movedim(2, 0).clone() + expected[:, :, ::2] = y.movedim(1, 0) + self.assertEqual(x.movedim(2, 0), expected) + + # Case 3 in [Forward Grad View/inplace] + def test_all_dual_base_inplace(self, device): + B = 2 + + def push_jvp(f): + def inner(x, xt, y, yt): + return jvp(f, (x, y), (xt, yt)) + + return inner + + # Case 3: with view, propagate from base to view + def f(x, y): + view = x[0, ::2] + x.copy_(y) + return x, view + + x = torch.randn(2, B, 6, device=device) + xt = torch.randn(2, 6, B, device=device) + y = torch.randn(2, B, 6, device=device) + yt = torch.randn(2, B, 6, device=device) + out, out_tangent = vmap(push_jvp(f), in_dims=(1, 2, 1, 1))(x.clone(), xt, y, yt) + + expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y) + self.assertEqual(out[0], expected_out[0]) + self.assertEqual(out[1], expected_out[1]) + + self.assertEqual(out_tangent[0], yt.movedim(1, 0)) + self.assertEqual(out_tangent[1], yt.movedim(1, 0)[:, 0, ::2]) + + # Case 4 in [Forward Grad View/inplace] + def test_right_dual_view_prop(self, device): + B = 2 + + # Changes on the view must propagate to its base. Also: + # - x is a regular Tensor + # - y is a dual tensor + def f(x, y): + x = x.clone() + view = x[0] + view.copy_(y) + return view, x + + def push_jvp(x, y, yt): + return jvp(partial(f, x), (y,), (yt,)) + + x = torch.randn(2, B, 6, device=device) + y = torch.randn(6, B, device=device) + yt = torch.randn(6, B, device=device) + outs, tangents = vmap(push_jvp, in_dims=(1, 1, 1))(x, y, yt) + + expected_out = vmap(f, in_dims=(1, 1))(x.clone(), y) + self.assertEqual(outs[0], expected_out[0]) + self.assertEqual(outs[1], expected_out[1]) + + self.assertEqual(tangents[0], yt.movedim(1, 0)) + + expected_tangent_1 = torch.zeros_like(x).movedim(1, 0) + expected_tangent_1[:, 0].copy_(yt.movedim(1, 0)) + self.assertEqual(tangents[1], expected_tangent_1) + + # Case 5 in [Forward Grad View/inplace] + def test_right_dual_base_prop(self, device): + B = 2 + + # Changes on the base must propagate on all its views. Also: + # - x is a regular Tensor + # - y is a dual tensor + def f(x, y): + x = x.clone() + view = x[0] + x.copy_(y) + return view, x + + def push_jvp(x, y, yt): + return jvp(partial(f, x), (y,), (yt,)) + + x = torch.randn(2, B, 6) + y = torch.randn(2, 6, B) + yt = torch.randn(2, 6, B) + outs, tangents = vmap(push_jvp, in_dims=(1, 2, 2))(x, y, yt) + + expected_out = vmap(f, in_dims=(1, 2))(x, y) + self.assertEqual(outs[0], expected_out[0]) + self.assertEqual(outs[1], expected_out[1]) + + self.assertEqual(tangents[0], yt.movedim(2, 0)[:, 0]) + self.assertEqual(tangents[1], yt.movedim(2, 0)) + + +# Use for testing miscellaneous helper functions +@markDynamoStrictTest +class TestHelpers(TestCase): + def test_CtxWithSavedTensors_error_if_name_collision(self, device): + x = torch.randn([], device=device, requires_grad=True) + y = torch.randn([], device=device, requires_grad=True) + + class A(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx._pt_inner_ctx = 1 + ctx.save_for_backward(x) + return x + + @staticmethod + def backward(ctx, gy): + wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( # noqa: F841 + ctx, (y,) + ) + return gy + + class B(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx._pt_new_saved_tensors = 1 + ctx.save_for_backward(x) + return x + + @staticmethod + def backward(ctx, gy): + wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( # noqa: F841 + ctx, (y,) + ) + return gy + + out = A.apply(x) + with self.assertRaisesRegex(RuntimeError, "name collision"): + out.backward() + out = B.apply(x) + with self.assertRaisesRegex(RuntimeError, "name collision"): + out.backward() + + def test_CtxWithSavedTensors_nesting(self, device): + CtxWithSavedTensors = torch._functorch.autograd_function.CtxWithSavedTensors + x = torch.randn([], device=device, requires_grad=True) + y = torch.randn([], device=device) + z = torch.randn([], device=device) + + class A(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x + + @staticmethod + def backward(ctx, gy): + ctx_y = CtxWithSavedTensors(ctx, (y,)) + # Can't use self.assertEqual because that relies on TLS + # that is not available in multithread autograd + assert len(ctx_y.saved_tensors) == 1 + assert torch.allclose(ctx_y.saved_tensors[0], y) + + wrapped = CtxWithSavedTensors(ctx_y, (z,)) + + assert len(wrapped.saved_tensors) == 1 + assert torch.allclose(wrapped.saved_tensors[0], z) + + assert len(ctx_y.saved_tensors) == 1 + assert torch.allclose(ctx_y.saved_tensors[0], y) + + return gy * wrapped.saved_tensors[0] + + out = A.apply(x) + out.backward() + self.assertEqual(x.grad, z) + + def test_CtxWithSavedTensors_overrides_saved_tensors(self, device): + x = torch.randn([], device=device, requires_grad=True) + + class A(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x + + @staticmethod + def backward(ctx, gy): + # The override can be literally anything + override = (1, 2, 3) + wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( + ctx, override + ) + assert wrapped.saved_tensors == override + return gy + + out = A.apply(x) + out.backward() + + def test_CtxWithSavedTensors_passthrough(self, device): + x = torch.randn([], device=device, requires_grad=True) + y = torch.randn([], device=device) + + class A(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return x * y + + @staticmethod + def backward(ctx, gz): + # The override can be literally anything + override = (1, 2, 3) + wrapped = torch._functorch.autograd_function.CtxWithSavedTensors( + ctx, override + ) + + assert wrapped.needs_input_grad[0] == ctx.needs_input_grad[0] + assert wrapped.needs_input_grad[1] == ctx.needs_input_grad[1] + wrapped.foo = "bar" + assert wrapped.foo == "bar" + assert ctx.foo == "bar" + return gz, gz + + out = A.apply(x, y) + out.backward() + + def test_debug_unwrap(self): + stuff = [] + + def f(x): + stuff.append(torch.func.debug_unwrap(x)) + return x.sin() + + x = torch.randn(2, 3) + _ = vmap(vmap(f))(x) + self.assertEqual(stuff[0], x) + self.assertTrue(stuff[0] is x) + + def test_reductify_leaf(self, device): + reductify_leaf = torch._functorch.autograd_function.reductify_leaf + B = 2 + + # grad_input None case + output = reductify_leaf(None, None, 0, B) + self.assertIsNone(output) + output = reductify_leaf(None, None, None, B) + self.assertIsNone(output) + + # grad_input has bdim, input does not have bdim + grad_input = torch.randn([B, 3, 4], device=device) + output = reductify_leaf(grad_input, 0, None, B) + self.assertEqual(output, grad_input.sum(0)) + + grad_input = torch.randn([3, B, 4], device=device) + output = reductify_leaf(grad_input, 1, None, B, (3,)) + self.assertEqual(output, grad_input.sum(1)) + + # grad_input does not have bdim, input has bdim + # This can happen if the user returns a fresh Tensor from the backward pass + # that is unrelated to the input + grad_input = torch.randn([3, 4], device=device) + output = reductify_leaf(grad_input, None, 1, B) + self.assertEqual(output, grad_input.view(3, 1, 4).expand(3, B, 4)) + + grad_input = torch.randn([3, 4], device=device) + output = reductify_leaf(grad_input, None, 1, B, (4,)) + self.assertEqual(output, grad_input.view(3, 4, 1).expand(3, 4, B).sum(0)) + + # grad_input has bdim, input has bdim + grad_input = torch.randn([B, 3, 4], device=device) + output = reductify_leaf(grad_input, 0, 1, B) + self.assertEqual(output, grad_input.movedim(0, 1)) + + grad_input = torch.randn([3, 4, 5, B], device=device) + output = reductify_leaf(grad_input, 3, 0, B, (5,)) + self.assertEqual(output, grad_input.movedim(-1, 2).sum(0).sum(0)) + + +@markDynamoStrictTest +class TestComposability(TestCase): + def test_deprecation_vmap(self, device): + # functorch version of the API is deprecated + with self.assertWarnsRegex(FutureWarning, "Please use `torch.vmap`"): + vmap(torch.sin) + + # the non-functorch version is not deprecated + with warnings.catch_warnings(): + warnings.simplefilter("error") + torch.vmap(torch.sin) + + # Some of these pass, some of these don't + @parametrize( + "transform", + ["grad", "jacrev", "jacfwd", "grad_and_value", "hessian", "functionalize"], + ) + def test_deprecation_transforms(self, device, transform): + api = getattr(functorch, transform) + new_api = getattr(torch.func, transform) + + # functorch version of the API is deprecated + with self.assertWarnsRegex( + FutureWarning, f"Please use `torch.func.{transform}`" + ): + api(torch.sin) + + # the non-functorch version is not deprecated + with warnings.catch_warnings(): + warnings.simplefilter("error") + new_api(torch.sin) + + def test_grad_grad(self, device): + x = torch.randn([], device=device) + y = grad(grad(torch.sin))(x) + self.assertEqual(y, -x.sin()) + + def test_grad_vmap(self, device): + def foo(x): + y = vmap(torch.sin)(x) + return y.sum() + + x = torch.randn(3, device=device) + y = grad(foo)(x) + self.assertEqual(y, x.cos()) + + def test_grad_vjp(self, device): + x = torch.randn(3, device=device) + + def foo(x): + _, vjp_fn = vjp(torch.sin, x) + return vjp_fn(x)[0].sum() + + y = grad(foo)(x) + expected = grad(lambda x: (x * x.cos()).sum())(x) + self.assertEqual(y, expected) + + def test_vmap_grad(self, device): + x = torch.randn(3, device=device) + y = vmap(grad(torch.sin))(x) + self.assertEqual(y, x.cos()) + + def test_vmap_vmap(self, device): + x = torch.randn(2, 3, device=device) + y = vmap(vmap(torch.sin))(x) + self.assertEqual(y, x.sin()) + + def test_vmap_vjp(self, device): + x = torch.randn(3, device=device) + _, vjp_fn = vjp(torch.sin, x) + + def foo(x): + _, vjp_fn = vjp(torch.sin, x) + return vjp_fn(x) + + y = vmap(foo)(x) + self.assertEqual(y, vjp_fn(x)) + + # TODO: there's a very interesting error message when the following + # is on CPU + xs = torch.randn(5, 3, device=device) + expected = torch.stack([vjp_fn(x)[0] for x in xs]) + result = vmap(lambda x: vjp_fn(x)[0])(xs) + self.assertEqual(result, expected) + + def test_vjp_grad(self, device): + x = torch.randn([], device=device) + y, vjp_fn = vjp(grad(torch.sin), x) + self.assertEqual(y, x.cos()) + + v = torch.randn([]) + self.assertEqual(vjp_fn(v)[0], -x.sin() * v) + + def test_vjp_vmap(self, device): + x = torch.randn(3, device=device) + y, vjp_fn = vjp(vmap(torch.sin), x) + self.assertEqual(y, x.sin()) + + v = torch.randn(3, device=device) + self.assertEqual(vjp_fn(v)[0], x.cos() * v) + + def test_vjp_vjp(self, device): + x = torch.randn(3, device=device) + y, vjp_fn = vjp(torch.sin, x) + self.assertEqual(y, x.sin()) + + y, vjp_fn = vjp(lambda x: vjp_fn(x)[0], x) + self.assertEqual(y, x * x.cos()) + + y = vjp_fn(x)[0] + # Honestly IDK what the result here is... but at least it runs + + def test_make_fx_vmap(self, device): + def f(x): + return torch.sin(x) + + inp = torch.randn(5, 3) + f = vmap(f) + fx_f = make_fx(f)(inp) + new_inp = torch.randn(5, 3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_make_fx_jacrev(self, device): + def f(x): + return x.sin().sum() + + inp = torch.randn(3) + f = jacrev(jacrev(f)) + fx_f = make_fx(f)(inp) + new_inp = torch.randn(3) + self.assertEqual(fx_f(new_inp), f(new_inp)) + + def test_make_fx_vjp(self, device): + def f(x): + return torch.sin(x).sum() + + primals = torch.randn(3) + _, vjp_fn = vjp(f, primals) + cotangent = torch.randn(()) + fx_f = make_fx(vjp_fn)(cotangent, True, True) + new_cotangent = torch.randn(()) + self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) + + # FIXME: test fails in Windows + @unittest.skipIf(IS_WINDOWS, "fails in Windows; needs investigation") + @unittest.skipIf(IS_FBCODE, "can't subprocess in fbcode") + # it is redundant to run this test twice on a machine that has GPUs + @onlyCPU + def test_no_warning_on_import_functorch(self, device): + out = subprocess.check_output( + [sys.executable, "-W", "always", "-c", "import functorch"], + stderr=subprocess.STDOUT, + cwd=os.path.dirname(os.path.realpath(__file__)), + ).decode("utf-8") + self.assertEqual(out, "") + + def test_requires_grad_inside_transform(self, device): + def f(x): + x.requires_grad_() + return x.sin().sum() + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + vmap(f)(x) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + grad(f)(x) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + vmap(grad(f))(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"): + grad(grad(f))(x) + + def test_retain_grad_inside_transform(self, device): + def f(x): + y = x.sin() + y.retain_grad() + return y.sum() + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"): + grad(f)(x) + + def test_autograd_functional_jacrev_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_vjp_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_jvp_inside_transform(self, device): + def f(x): + t = torch.ones_like(x) + y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,)) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + vmap(f)(x) + + x = torch.randn([]) + with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"): + grad(f)(x) + + def test_autograd_functional_jacfwd_inside_transform(self, device): + def f(x): + y = torch.autograd.functional.jacobian( + lambda x: x.sin().sum(), x, strategy="forward-mode", vectorize=True + ) + return y + + B = 5 + x = torch.randn(B, 3) + with self.assertRaisesRegex( + RuntimeError, "Batching rule not implemented for aten::_make_dual" + ): + vmap(f)(x) + + @parametrize( + "transform", + [ + "vmap", + "grad", + "jacrev", + "jacfwd", + "grad_and_value", + "hessian", + "functionalize", + ], + ) + def test_autograd_function_no_setup_context(self, device, transform): + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.sin() + + @staticmethod + def backward(ctx, gy): + (x,) = ctx.saved_tensors + return gy * x.cos() + + x = torch.randn(3, device=device) + transform = getattr(functorch, transform) + with self.assertRaisesRegex(RuntimeError, "must override the setup_context"): + transform(MySin.apply)(x) + + # Some of these pass, some of these don't + @parametrize( + "transform", + [ + "grad", + "jacrev", + "grad_and_value", + "hessian", + ], + ) + def test_transforms_dont_support_saved_tensor_hooks(self, device, transform): + def f(x): + return torch.sin(x).sum() + + def g(x): + with torch.autograd.graph.save_on_cpu(): + return f(x) + + x = torch.randn(3, device=device) + + if transform == "functionalize": + transform = functorch.experimental.functionalize + else: + transform = getattr(functorch, transform) + with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): + with torch.autograd.graph.save_on_cpu(): + transform(f)(x) + + with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): + transform(g)(x) + + def test_vjp_doesnt_support_saved_tensor_hooks(self, device): + def f(x): + return torch.sin(x).sum() + + def g(x): + with torch.autograd.graph.save_on_cpu(): + return f(x) + + x = torch.randn(3, device=device) + with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): + with torch.autograd.graph.save_on_cpu(): + vjp(f, x) + + with self.assertRaisesRegex(RuntimeError, "saved tensor hooks"): + vjp(g, x) + + def test_jvp_supports_saved_tensor_hooks(self, device): + def f(x): + return torch.sin(x).sum() + + def g(x): + with torch.autograd.graph.save_on_cpu(): + return f(x) + + x = torch.randn(3, device=device) + t = torch.randn(3, device=device) + + # smoke tests + with torch.autograd.graph.save_on_cpu(): + jvp(f, (x,), (t,)) + + # smoke tests + jvp(g, (x,), (t,)) + + def test_can_use_functionalize_when_key_is_excluded(self, device): + def f(x): + y = x.clone() + y.sin_() + return y + + x = torch.randn([], device=device) + expected = f(x) + + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): + gm = make_fx(functorch.functionalize(f))(x) + self.assertTrue("sin_" not in gm.code) + self.assertEqual(gm(x), expected) + + local_exclude_set = torch._C._dispatch_tls_local_exclude_set() + self.assertTrue(local_exclude_set.has(DispatchKey.Functionalize)) + + def test_can_use_vmap_when_key_is_excluded(self, device): + def f(x): + return x.sum(0) + + x = torch.randn(3, device=device) + expected = vmap(f)(x) + + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched)): + result = vmap(f)(x) + self.assertEqual(result, expected) + local_exclude_set = torch._C._dispatch_tls_local_exclude_set() + self.assertTrue(local_exclude_set.has(DispatchKey.FuncTorchBatched)) + + def test_can_use_grad_when_key_is_excluded(self, device): + def f(x): + return x.sin() + + x = torch.randn([], device=device) + expected = grad(f)(x) + + with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd)): + result = grad(f)(x) + self.assertEqual(result, expected) + local_exclude_set = torch._C._dispatch_tls_local_exclude_set() + self.assertTrue(local_exclude_set.has(DispatchKey.Autograd)) + + +@markDynamoStrictTest +class TestMakeFunctional(TestCase): + @parametrize("disable_autograd_tracking", [True, False]) + def test_disable_autograd_tracking(self, disable_autograd_tracking): + class Foo(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 3) + + def forward(self, x): + x = self.linear(x) + return x + + mod = Foo() + _, params = make_functional( + mod, disable_autograd_tracking=disable_autograd_tracking + ) + self.assertEqual(len(params), 2) + for param in params: + self.assertEqual(param.requires_grad, not disable_autograd_tracking) + + def test_parameter_tying(self): + class Foo(nn.Module): + def __init__(self) -> None: + super().__init__() + self.bias = nn.Parameter(torch.randn(3)) + self.linear = nn.Linear(3, 3) + self.linear.bias = self.bias + self.linear_tied = self.linear + + def forward(self, x): + x = self.linear(x) + x = self.linear_tied(x) + x = x + self.bias + return x + + torch.manual_seed(1) + mod = Foo() + func, _ = make_functional(mod) + + torch.manual_seed(0) + mod = Foo() + _, params = make_functional(mod) + self.assertEqual(len(params), 2) + + x = torch.randn(2, 3) + result = func(params, x) + expected = mod(x) + self.assertEqual(result, expected) + + def test_buffer_tying(self): + class Foo(nn.Module): + def __init__(self) -> None: + super().__init__() + self.bias = nn.Parameter(torch.randn(3)) + self.linear = nn.Linear(3, 3) + self.buffer = nn.Buffer(torch.randn(3)) + self.buffer_tied = self.buffer + + def forward(self, x): + x = self.linear(x) + x = x + self.bias + x = x + self.buffer + x = x + self.buffer_tied + return x + + torch.manual_seed(1) + mod = Foo() + func, _, _ = make_functional_with_buffers(mod) + + torch.manual_seed(0) + mod = Foo() + _, params, buffers = make_functional_with_buffers(mod) + self.assertEqual(len(params), 3) + self.assertEqual(len(buffers), 1) + + x = torch.randn(2, 3) + result = func(params, buffers, x) + expected = mod(x) + self.assertEqual(result, expected) + + @parametrize("disable_autograd_tracking", [True, False]) + def test_with_buffers_disable_autograd_tracking(self, disable_autograd_tracking): + class Foo(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 3) + self.buffer = nn.Buffer(torch.randn(3)) + + def forward(self, x): + x = self.linear(x) + x = x + self.buffer + return x + + mod = Foo() + _, params, buffers = make_functional_with_buffers( + mod, disable_autograd_tracking=disable_autograd_tracking + ) + self.assertEqual(len(params), 2) + self.assertEqual(len(buffers), 1) + for param in params: + self.assertEqual(param.requires_grad, not disable_autograd_tracking) + + @parametrize("detach_params", [True, False]) + def test_using_detach_functional_call(self, detach_params): + class Foo(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 3) + self.buffer = nn.Buffer(torch.randn(3)) + + def forward(self, x): + x = self.linear(x) + x = x + self.buffer + return x + + def params_dict(mod): + named_params = mod.named_parameters() + return ( + {k: v.detach() for k, v in named_params} + if detach_params + else dict(named_params) + ) + + mod = Foo() + x = torch.randn(3, 3) + d = (params_dict(mod), dict(mod.named_buffers())) + out = functional_call(mod, d, x) + self.assertEqual(out.grad_fn is None, detach_params) + + def test_parameter_tying_grad(self): + class Foo(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 3) + self.weight = self.linear.weight + self.bias = self.linear.bias + + def forward(self, x): + x = self.linear(x) + x = F.linear(x, self.weight, self.bias) + return x + + x = torch.randn(2, 3) + torch.manual_seed(0) + mod = Foo() + loss = mod(x).sum() + expected = torch.autograd.grad(loss, mod.parameters()) + + mod = Foo() + fmod, _, _ = make_functional_with_buffers(mod) + torch.manual_seed(0) + mod = Foo() + _, params, buffers = make_functional_with_buffers(mod) + + def compute_loss(params, buffers, x): + return fmod(params, buffers, x).sum() + + result = grad(compute_loss)(params, buffers, x) + + self.assertEqual(result, expected) + + def test_parameter_tying_ensemble(self): + class Foo(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 3) + self.weight = self.linear.weight + self.bias = self.linear.bias + self.buffer = nn.Buffer(torch.randn(3)) + self.buffer_tied = self.buffer + + def forward(self, x): + x = self.linear(x) + x = F.linear(x, self.weight, self.bias) + x = x + self.buffer + x = x + self.buffer_tied + return x + + num_models = 2 + xs = torch.randn(num_models, 64, 3) + models = [Foo() for _ in range(num_models)] + fmodel, _, _ = combine_state_for_ensemble(models) + + torch.manual_seed(0) + models = [Foo() for _ in range(num_models)] + _, params, buffers = combine_state_for_ensemble(models) + result = vmap(fmodel)(params, buffers, xs) + + torch.manual_seed(0) + models = [Foo() for _ in range(num_models)] + expected = torch.stack([model(x) for model, x in zip(models, xs)]) + + self.assertEqual(result, expected) + + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_correctness_mnist(self, mechanism): + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + x = torch.randn(64, 1, 32, 32) + torch.manual_seed(301) + fnet, _ = _get_weights_and_functional_call(Net(), mechanism) + + torch.manual_seed(0) + _, params = _get_weights_and_functional_call(Net(), mechanism) + result = fnet(params, x) + + torch.manual_seed(0) + net = Net() + expected = net(x) + + self.assertEqual(result, expected) + + def test_combine_state_for_ensemble_error(self): + in_features = 2 + out_features = 2 + + models = [] + with self.assertRaisesRegex(RuntimeError, "Expected at least one model"): + _ = combine_state_for_ensemble(models) + + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + models[1].eval() + with self.assertRaisesRegex(RuntimeError, "same training/eval mode"): + _ = combine_state_for_ensemble(models) + + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + models[1] = torch.nn.Conv2d(3, 3, (3, 3)) + with self.assertRaisesRegex(RuntimeError, "models to be of the same class"): + _ = combine_state_for_ensemble(models) + + def test_combine_state_for_ensemble_smoke(self): + in_features = 2 + out_features = 2 + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + _ = combine_state_for_ensemble(models) + + def test_stack_module_state_smoke(self): + in_features = 2 + out_features = 2 + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + _ = stack_module_state(models) + + def test_stack_module_state_leaf(self): + in_features = 2 + out_features = 2 + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + params, buffers = stack_module_state(models) + for param in params.values(): + self.assertTrue(param.requires_grad) + self.assertTrue(param.is_leaf) + + def test_stack_module_state_mismatch_error(self): + in_features = 2 + out_features = 2 + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + models[0].weight.requires_grad_(False) + with self.assertRaisesRegex(RuntimeError, "same .requires_grad"): + params, buffers = stack_module_state(models) + + def test_stack_module_state_error(self): + in_features = 2 + out_features = 2 + + models = [] + with self.assertRaisesRegex( + RuntimeError, "stack_module_state:.* Expected at least one model" + ): + _ = stack_module_state(models) + + num_models = 3 + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + models[1].eval() + with self.assertRaisesRegex( + RuntimeError, "stack_module_state:.* same training/eval mode." + ): + _ = stack_module_state(models) + + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] + models[1] = torch.nn.Conv2d(3, 3, (3, 3)) + with self.assertRaisesRegex( + RuntimeError, "stack_module_state:.* models to be of the same class" + ): + _ = stack_module_state(models) + + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_make_functional_state_correctly_returned_after_forward(self, mechanism): + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(3, 3) + + def forward(self, x): + x = self.linear(x) + return x + + def get_module_info(mod): + if mechanism == "make_functional": + return make_functional(mod) + else: + assert mechanism == "functional_call" + return mod, dict(mod.named_parameters()) + + mod = Net() + func_mod, params = get_module_info(mod) + + # state in func.names_map + mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod + old_state_linear_weight = mod.linear.weight + old_state_linear_bias = mod.linear.bias + + self.assertIsNotNone(old_state_linear_weight) + self.assertIsNotNone(old_state_linear_bias) + + x = torch.randn(4, 3) + if mechanism == "make_functional": + func_mod(params, x) + else: + assert mechanism == "functional_call" + functional_call(func_mod, params, x) + + mod = func_mod.stateless_model if mechanism == "make_functional" else func_mod + new_state_linear_weight = mod.linear.weight + new_state_linear_bias = mod.linear.bias + + self.assertIsNotNone(new_state_linear_weight) + self.assertIsNotNone(new_state_linear_bias) + + self.assertEqual(old_state_linear_weight, new_state_linear_weight) + self.assertEqual(old_state_linear_bias, new_state_linear_bias) + + +@markDynamoStrictTest +class TestExamplesCorrectness(TestCase): + def _update_params(self, params, grads, alpha, mechanism): + if mechanism == "make_functional": + return [(params[i] - alpha * grads[i]) for i in range(len(params))] + else: + assert mechanism == "functional_call" + return {k: params[k] - alpha * grads[k] for k in params} + + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_maml_regression(self, device, mechanism): + class ThreeLayerNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = nn.Linear(1, 40) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(40, 40) + self.relu2 = nn.ReLU() + self.fc3 = nn.Linear(40, 1) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = self.fc2(x) + x = self.relu2(x) + x = self.fc3(x) + return x + + # TODO: should replace with F.mse_loss + def mse_loss(x, y): + return torch.mean((x - y) ** 2) + + net, params = _get_weights_and_functional_call( + ThreeLayerNet().to(device), mechanism + ) + K = 20 + num_tasks = 4 + alpha = 0.1 + + def sample_tasks(outer_batch_size, inner_batch_size): + # Select amplitude and phase for the task + As = [] + phases = [] + for _ in range(outer_batch_size): + As.append(np.random.uniform(low=0.1, high=0.5)) + phases.append(np.random.uniform(low=0.0, high=np.pi)) + + def get_batch(): + xs, ys = [], [] + for A, phase in zip(As, phases): + x = np.random.uniform( + low=-5.0, high=5.0, size=(inner_batch_size, 1) + ) + y = A * np.sin(x + phase) + xs.append(x) + ys.append(y) + return torch.tensor(xs, dtype=torch.float, device=device), torch.tensor( + ys, dtype=torch.float, device=device + ) + + x1, y1 = get_batch() + x2, y2 = get_batch() + return x1, y1, x2, y2 + + def get_loss_for_task(use_transform, x1, y1, x2, y2): + def inner_loss(params, x1, y1): + f = net(params, x1) + loss = mse_loss(f, y1) + return loss + + if use_transform: + grads = grad(inner_loss)(params, x1, y1) + else: + loss = inner_loss(params, x1, y1) + grad_params, spec = tree_flatten(params) + grads = torch.autograd.grad(loss, grad_params, create_graph=True) + grads = tree_unflatten(grads, spec) + + new_params = self._update_params(params, grads, alpha, mechanism) + + v_f = net(new_params, x2) + return mse_loss(v_f, y2) + + task = sample_tasks(num_tasks, K) + list_params = ( + params if mechanism == "make_functional" else list(params.values()) + ) + + # Compute with vmap+grad + inner_losses = vmap(partial(get_loss_for_task, True))( + task[0], task[1], task[2], task[3] + ) + loss2 = sum(inner_losses) / len(inner_losses) + result_grads = torch.autograd.grad(loss2, list_params) + + # Compute without vmap+grad + inner_losses = [ + get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i]) + for i in range(num_tasks) + ] + loss2 = sum(inner_losses) / len(inner_losses) + expected_grads = torch.autograd.grad(loss2, list_params) + + self.assertEqual(result_grads, expected_grads) + + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_maml_omniglot(self, device, mechanism): + # TODO: there appears to be precision issues for float32 + dtype = torch.double + + # TODO: We don't support inplace relu? + inplace_relu = False + n_way = 5 + n_inner_iter = 2 + num_tasks = 2 + + # real example uses batch norm but it's numerically unstable in the first + # iteration, when near 0, and won't produce same gradients. Uses group norm instead + net = ( + nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.GroupNorm(64, 64, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.GroupNorm(64, 64, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.GroupNorm(64, 64, affine=True), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, n_way), + ) + .to(device) + .to(dtype) + ) + + fnet, params, buffers = _get_weights_and_functional_call_with_buffers( + net, mechanism + ) + net = (params, buffers, fnet) + + def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry): + params, buffers, fnet = net + querysz = x_qry.size(0) + + def compute_loss(new_params, buffers, x, y): + logits = fnet(new_params, buffers, x) + loss = F.cross_entropy(logits, y) + return loss + + new_params = params + for _ in range(n_inner_iter): + if use_transform: + grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) + else: + res = compute_loss(new_params, buffers, x_spt, y_spt) + grad_params, spec = tree_flatten(new_params) + grads = torch.autograd.grad(res, grad_params, create_graph=True) + grads = tree_unflatten(grads, spec) + + new_params = self._update_params(new_params, grads, 1e-1, mechanism) + + qry_logits = fnet(new_params, buffers, x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry) + qry_acc = (qry_logits.argmax(dim=1) == y_qry).sum() / querysz + + return qry_loss, qry_acc + + # Get some sample inputs... + x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device) + y_spt = torch.randint(0, 5, (num_tasks, 25), device=device) + x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype, device=device) + y_qry = torch.randint(0, 5, (num_tasks, 75), device=device) + + # compute with vmap + grad + compute_loss = partial(loss_for_task, net, n_inner_iter, True) + qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry) + list_params = ( + params if mechanism == "make_functional" else list(params.values()) + ) + result_grads = torch.autograd.grad(qry_losses.sum(), list_params) + + # compute without vmap + grad + compute_loss = partial(loss_for_task, net, n_inner_iter, False) + losses = [ + compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0] + for i in range(num_tasks) + ] + expected_grads = torch.autograd.grad(sum(losses), list_params) + + self.assertEqual(result_grads, expected_grads) + + @parametrize("mechanism", ["make_functional", "functional_call"]) + @parametrize("originally_track_running_stats", [True, False]) + def test_update_batch_norm(self, device, originally_track_running_stats, mechanism): + dtype = torch.double + inplace_relu = False + classes = 5 + num_batches = 2 + net = ( + nn.Sequential( + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d( + 64, affine=True, track_running_stats=originally_track_running_stats + ), + nn.ReLU(inplace=inplace_relu), + nn.Flatten(), + nn.Linear(43264, classes), + ) + .to(device) + .to(dtype) + ) + + replace_all_batch_norm_modules_(net) + transformed_net = net + fnet, params, buffers = _get_weights_and_functional_call_with_buffers( + transformed_net, mechanism + ) + criterion = nn.CrossEntropyLoss() + + def compute_loss(x, y, params, buffers): + return criterion(fnet(params, buffers, x), y) + + # Get some sample inputs... + x = torch.randn(num_batches, 1, 64, 28, 28, device=device, dtype=dtype) + y = torch.randint(0, classes, (num_batches, 1), device=device) + + # compute some per sample grads with vmap + grad + result_grads = vmap(grad(compute_loss, argnums=2), in_dims=(0, 0, None, None))( + x, y, params, buffers + ) + + # compute some per sample grads without vmap + grad + fnet, params, buffers = _get_weights_and_functional_call_with_buffers( + transformed_net, mechanism + ) + flat_params, spec = tree_flatten(params) + expected_grads = [ + torch.autograd.grad(compute_loss(x[i], y[i], params, buffers), flat_params) + for i in range(num_batches) + ] + expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] + expected_grads = tree_unflatten(expected_grads, spec) + + self.assertEqual(result_grads, expected_grads) + + @parametrize("jac", ["jacfwd", "jacrev"]) + def test_lennard_jones_batched_jac(self, device, jac): + sigma = 0.5 + epsilon = 4.0 + + jac = getattr(functorch, jac) + + def lennard_jones(r): + return epsilon * ((sigma / r) ** 12 - (sigma / r) ** 6) + + def lennard_jones_force(r): + """Get magnitude of LJ force""" + return -epsilon * ( + (-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7) + ) + + r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device) + drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device)) + norms = torch.norm(drs, dim=1).reshape(-1, 1) + training_energies = torch.stack(list(map(lennard_jones, norms))).reshape(-1, 1) + training_forces = torch.stack( + [force * dr for force, dr in zip(map(lennard_jones_force, norms), drs)] + ) + + model = nn.Sequential( + nn.Linear(1, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 1), + ).to(device) + + def make_prediction(model, drs, use_functorch): + norms = torch.norm(drs, dim=1).reshape(-1, 1) + energies = model(norms) + + if use_functorch: + network_derivs = vmap(jac(model))(norms).squeeze(-1) + forces = -network_derivs * drs / norms + else: + forces = [] + for r, dr in zip(norms, drs): + network_deriv = torch.autograd.functional.jacobian( + model, r, create_graph=True + ) + force = -network_deriv * dr / r + forces.append(force) + forces = torch.cat(forces) + return energies, forces + + def loss_fn(energies, forces, predicted_energies, predicted_forces): + return ( + F.mse_loss(energies, predicted_energies) + + 0.01 * F.mse_loss(forces, predicted_forces) / 3 + ) + + energies, forces = make_prediction(model, drs, use_functorch=True) + loss = loss_fn(training_energies, training_forces, energies, forces) + result = torch.autograd.grad(loss, model.parameters()) + + energies, forces = make_prediction(model, drs, use_functorch=False) + loss = loss_fn(training_energies, training_forces, energies, forces) + expected = torch.autograd.grad(loss, model.parameters()) + + self.assertEqual(result, expected) + + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_ensemble_regression(self, device, mechanism): + def make_spirals(n_samples, noise_std=0.0, rotations=1.0): + ts = torch.linspace(0, 1, n_samples) + rs = ts**0.5 + thetas = rs * rotations * 2 * math.pi + signs = torch.randint(0, 2, (n_samples,)) * 2 - 1 + labels = (signs > 0).to(torch.long) + + xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std + ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std + points = torch.stack([xs, ys], dim=1) + return points.to(device), labels.to(device) + + points, labels = make_spirals(100, noise_std=0.05) + + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + loss_fn = nn.NLLLoss() + + func_model, weights = _get_weights_and_functional_call( + MLPClassifier().to(device), mechanism + ) + + def train_step_fn(use_transform, weights, batch, targets, lr=0.2): + def compute_loss(weights, batch, targets): + output = func_model(weights, batch) + loss = loss_fn(output, targets) + return loss + + if use_transform: + grad_weights, loss = grad_and_value(compute_loss)( + weights, batch, targets + ) + else: + loss = compute_loss(weights, batch, targets) + flat_weights, spec = tree_flatten(weights) + flat_grad_weights = torch.autograd.grad(loss, flat_weights) + grad_weights = tree_unflatten(flat_grad_weights, spec) + + new_weights = self._update_params(weights, grad_weights, lr, mechanism) + return (loss, new_weights) + + def unpack(train_result): + return train_result[0], train_result[1] + + def init_fn(num_models): + models = tuple(MLPClassifier().to(device) for _ in range(num_models)) + if mechanism == "make_functional": + return combine_state_for_ensemble(models)[1] + else: + return stack_module_state(models)[0] + + def slice_weights(batched_weights, index): + return tree_map( + lambda weight: weight[index].detach().requires_grad_(), batched_weights + ) + + batched_weights = init_fn(num_models=2) + parallel_train_step_fn = vmap( + partial(train_step_fn, True), in_dims=(0, None, None) + ) + + result_loss, result_weights = unpack( + parallel_train_step_fn(batched_weights, points, labels) + ) + + loss0, weights0 = unpack( + train_step_fn(False, slice_weights(batched_weights, 0), points, labels) + ) + loss1, weights1 = unpack( + train_step_fn(False, slice_weights(batched_weights, 1), points, labels) + ) + expected_loss = torch.stack([loss0, loss1]) + + weights0, spec0 = tree_flatten(weights0) + weights1, spec1 = tree_flatten(weights1) + assert spec0 == spec1 + expected_weights = tuple( + torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1) + ) + expected_weights = tree_unflatten(expected_weights, spec0) + + self.assertEqual(result_loss, expected_loss) + self.assertEqual(result_weights, expected_weights) + + @parametrize( + "dropout_layer", + [ + subtest(nn.Dropout, "Dropout"), + subtest(nn.AlphaDropout, "AlphaDropout"), + subtest(nn.FeatureAlphaDropout, "FeatureAlphaDropout"), + ], + ) + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_find_learning_rate_ensembling(self, device, dropout_layer, mechanism): + # This example mimics what a user might do when trying to find the optimal learning rate. They would + # want to run a bunch of models with the same behavior (including the same dropout!) and have them + # each run with different learning rates. Specifically, this is an example of using same randomness with vmap + points, labels = ( + torch.randn(100, 2, 2, 2, 2, device=device), + torch.randint(0, 2, (100,), device=device), + ) + + class MLPClassifier(nn.Module): + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.dropout = dropout_layer() + self.fc1 = nn.Linear(16, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.dropout(x) + x = torch.flatten(x, start_dim=1) + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + loss_fn = nn.NLLLoss() + + func_model, weights = _get_weights_and_functional_call( + MLPClassifier().to(device), mechanism + ) + + def train_step_fn(weights, batch, targets, lr): + def compute_loss(weights, batch, targets): + output = func_model(weights, batch) + loss = loss_fn(output, targets) + return loss + + grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) + new_weights = self._update_params(weights, grad_weights, lr, mechanism) + if mechanism != "make_functional": + new_weights = list(new_weights.values()) + # NB: return looks weird because torch.vmap must return Tensors + return (loss, *new_weights) + + def unpack(train_result): + return train_result[0], train_result[1:] + + def init_fn(num_models): + og_model = MLPClassifier().to(device) + models = tuple( + copy.deepcopy(og_model) for _ in range(num_models) + ) # have same initialization + if mechanism == "make_functional": + return combine_state_for_ensemble(models)[1] + else: + return stack_module_state(models)[0] + + batched_weights = init_fn(num_models=2) + parallel_train_step_fn = vmap( + train_step_fn, in_dims=(0, None, None, 0), randomness="same" + ) + + lrs = torch.tensor([0.2, 0.4], device=device) + result_loss, result_weights = unpack( + parallel_train_step_fn(batched_weights, points, labels, lrs) + ) + + self.assertEqual(result_loss[0], result_loss[1]) + self.assertNotEqual( + tuple(weight[0] for weight in result_weights), + tuple(weight[1] for weight in result_weights), + ) + + @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 + @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") + @parametrize("mechanism", ["make_functional", "functional_call"]) + def test_resnet18_per_sample_grads(self, device, mechanism): + from torchvision import models + + model = models.__dict__["resnet18"]( + pretrained=False, norm_layer=(lambda c: nn.GroupNorm(min(32, c), c)) + ).to(device) + criterion = nn.CrossEntropyLoss( + reduction="sum" + ) # avoid cross batch reductions for for loop comparison + + func_model, weights = _get_weights_and_functional_call(model, mechanism) + + def compute_loss(weights, image, target): + image = image.unsqueeze(0) + target = target.unsqueeze(0) + output = func_model(weights, image) + loss = criterion(output, target) + return loss + + batch_size = 3 + images = torch.randn(batch_size, 3, 32, 32, device=device) + targets = torch.randint(0, 10, (batch_size,), device=device) + + result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))( + weights, images, targets + ) + + flat_weights, spec = tree_flatten(weights) + expected_grads = [ + torch.autograd.grad( + compute_loss(weights, images[i], targets[i]), flat_weights + ) + for i in range(batch_size) + ] + expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] + expected_grads = tree_unflatten(expected_grads, spec) + + self.assertEqual(result_grads, expected_grads, atol=1e-3, rtol=1.0) + + +def normalize_devices(fx_g): + for node in fx_g.graph.nodes: + args = list(node.args) + for idx, arg in enumerate(args): + if isinstance(arg, torch.device): + args[idx] = "cpu" + node.args = tuple(args) + new_kwargs = {} + for k, v in node.kwargs.items(): + if isinstance(v, torch.device): + v = "cpu" + new_kwargs[k] = v + node.kwargs = new_kwargs + fx_g.recompile() + return fx_g + + +@markDynamoStrictTest +class TestFunctionalize(TestCase): + def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False): + inpt1 = inpt.clone() + inpt2 = inpt.clone() + inpt3 = inpt.clone() + + expected_outputs = f(inpt1) + if skip_vmap: + actual_outputs = functionalize(f)(inpt2) + else: + actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze() + # Right now the flavor of functionalize that also removes view ops + # isn't being used with vmap + # That's because {view}_copy ops don't have batching rules yet + # (although we should probably fix that) + actual_outputs_view_copy = functionalize(f, remove="mutations_and_views")(inpt3) + # Check that outputs are the same + self.assertEqual(actual_outputs, expected_outputs) + self.assertEqual(actual_outputs_view_copy, expected_outputs) + + # Inputs might have been mutated by f: check that they were mutated properly + self.assertEqual(inpt1, inpt2) + self.assertEqual(inpt1, inpt3) + + def test_simple_view(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x.view(4, 2) + y.add_(tmp) + return x + + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + + def test_multioutput_view(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y1, y2 = x.split(2) + y1_view = y1.diagonal() + y1_view.add_(tmp) + return x + + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + + def test_inplace_view(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(4, device=device) + y = x + x + y2 = y.transpose(1, 0) + z = y2[0] + z.add_(tmp) + return y + + self._check_functionalize_correctness( + f, torch.zeros(4, 2, device=device), skip_vmap=True + ) + + # See https://github.com/pytorch/functorch/issues/780 + def test_linear(self, device): + def f(x, y, z) -> torch.Tensor: + return torch._C._nn.linear(x, y, z) + + x = torch.randn(14, 1, 384, device=device) + y = torch.randn(96, 384, device=device) + z = torch.randn(96, device=device) + + out_expected = f(x, y, z) + out_actual = functionalize(f)(x, y, z) + self.assertEqual(out_expected, out_actual) + + def test_multioutput_inplace_slice_view(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, 2, device=device) + y = x.view(8) + z0 = y.reshape(2, 4) + z1 = z0.transpose(1, 0) + z1.unsqueeze_(0) + z1.squeeze_() + z2, z3 = z1.split(2) + z2.add_(tmp) + return x + + # See Note [Fix vmap slice_scatter] + self._check_functionalize_correctness( + f, torch.zeros(4, 2, device=device), skip_vmap=True + ) + + # Ensure functionalize works with List[Optional[Tensor]] arguments. + # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085 + def test_functionalize_opt_tensor_list(self, device): + def f(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return x[indices] + + inpta = torch.ones(4, device=device) + inptb = torch.arange(2, device=device) + out1 = f(inpta, inptb) + out2 = functionalize(f)(inpta, inptb) + self.assertEqual(out1, out2) + out = make_fx(functionalize(f))(inpta, inptb) + self.assertExpectedInline( + (out.code), + """\ + + + +def forward(self, x_1, indices_1) -> torch.Tensor: + index = torch.ops.aten.index.Tensor(x_1, [indices_1]); x_1 = indices_1 = None + return index + """, + ) + + # Ensure grad(functionalize(f)) works + def test_functionalize_grad(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x + x + z = y.view(4, 2) + y.add_(tmp) + return z.sum() + + inpt1 = torch.ones(4, 2, device=device) + inpt2 = torch.ones(4, 2, device=device) + out1 = grad(f)(inpt1) + out2 = grad(functionalize(f))(inpt2) + self.assertEqual(out1, out2) + self.assertEqual(inpt1, inpt2) + + @unittest.skipIf(IS_FBCODE, "fails in fbcode") + def test_vmap_functionalize_jvp(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + y = x + x + z = y.view(-1) + y.add_(1) + return z + + def jvp_wrapper(x, t): + return jvp( + f, + (x,), + (t,), + ) + + x = torch.randn(2, 3, device=device) + t = torch.randn(2, 3, device=device) + + out1 = vmap(jvp_wrapper)(x, t) + out2 = vmap(functionalize(jvp_wrapper))(x, t) + self.assertEqual(out1, out2) + + # TODO: move this test into test_fake_tensor.py + # once functionalize() can be used in core tests. + def test_functionalize_fake_tensors(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + y = x.detach() + return y + y + + with FakeTensorMode(): + x = torch.ones(2, device=device, requires_grad=True) + functionalize(f)(x) + self.assertEqual(x.size(), (2,)) + + def test_functionalize_fx_simple(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x.view(4, 2) + y.add_(tmp) + return x + + # There's a copy_ in the graph, because the input (x) was mutated. + # To preserve semantics, functionalize() needs to propagate the mutation. + fn = make_fx(functionalize(f, remove="mutations_and_views")) + out = fn(torch.zeros(4, 2, device=device)) + out = normalize_devices(out) + self.assertExpectedInline( + (out.code), + """\ + + + +def forward(self, x_1) -> torch.Tensor: + ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False) + view_copy = torch.ops.aten.view_copy.default(x_1, [4, 2]) + add = torch.ops.aten.add.Tensor(view_copy, ones); view_copy = ones = None + view_copy_1 = torch.ops.aten.view_copy.default(add, [4, 2]); add = None + view_copy_2 = torch.ops.aten.view_copy.default(view_copy_1, [4, 2]); view_copy_2 = None + copy_ = torch.ops.aten.copy_.default(x_1, view_copy_1); x_1 = copy_ = None + return view_copy_1 + """, + ) + + def test_functionalize_fx_transpose_simple(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + return x.transpose(1, 0) + + fn = make_fx(functionalize(f, remove="mutations_and_views")) + out = fn(torch.zeros(4, 2, device=device)) + out = normalize_devices(out) + self.assertExpectedInline( + out.code, + """\ + + + +def forward(self, x_1) -> torch.Tensor: + transpose_copy = torch.ops.aten.transpose_copy.int(x_1, 1, 0); x_1 = None + return transpose_copy + """, + ) + + def test_functionalize_fx_out_op(self, device): + def f(inpt: torch.Tensor) -> torch.Tensor: + out = torch.empty((), dtype=torch.float32) + torch.add(inpt, inpt, out=out) + out_view = out.view(4) + out_view.add_(1) + return out + + fn = make_fx(functionalize(f, remove="mutations_and_views")) + out = fn(torch.arange(4, device=device, dtype=torch.float32)) + out = normalize_devices(out) + self.assertExpectedInline( + out.code, + """\ + + + +def forward(self, inpt_1) -> torch.Tensor: + empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = 'cpu', pin_memory = False); empty = None + add = torch.ops.aten.add.Tensor(inpt_1, inpt_1); inpt_1 = None + view_copy = torch.ops.aten.view_copy.default(add, [4]); view_copy = None + view_copy_1 = torch.ops.aten.view_copy.default(add, [4]); add = None + add_1 = torch.ops.aten.add.Tensor(view_copy_1, 1); view_copy_1 = None + view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]); add_1 = None + view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]); view_copy_3 = None + return view_copy_2 + """, + ) + + def test_functionalize_fx_multi_out_op(self, device): + def f(inpt: torch.Tensor) -> torch.Tensor: + mins = torch.empty(4, dtype=torch.float32) + maxs = torch.empty(2, 2, dtype=torch.float32) + maxs_view = maxs.view(4) + inpt_view = inpt.view(2, 4) + torch.aminmax(inpt_view, dim=0, out=(mins, maxs_view)) + return (maxs, mins) + + fn = make_fx(functionalize(f, remove="mutations_and_views")) + out = fn(torch.arange(8, device=device, dtype=torch.float32)) + out = normalize_devices(out) + self.assertExpectedInline( + out.code, + """\ + + + +def forward(self, inpt_1) -> torch.Tensor: + empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = 'cpu', pin_memory = False); empty = None + empty_1 = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = 'cpu', pin_memory = False) + view_copy = torch.ops.aten.view_copy.default(empty_1, [4]); empty_1 = view_copy = None + view_copy_1 = torch.ops.aten.view_copy.default(inpt_1, [2, 4]); inpt_1 = None + aminmax = torch.ops.aten.aminmax.default(view_copy_1, dim = 0); view_copy_1 = None + getitem = aminmax[0] + getitem_1 = aminmax[1]; aminmax = None + view_copy_2 = torch.ops.aten.view_copy.default(getitem_1, [2, 2]); getitem_1 = None + view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [4]); view_copy_3 = None + return (view_copy_2, getitem) + """, + ) + + def test_functionalize_fx_reapply_views_simple(self, device): + def f(x: torch.Tensor) -> torch.Tensor: + tmp = torch.ones(2, device=device) + y = x.view(4, 2) + y.add_(tmp) + return x + + out = make_fx(functionalize(f))(torch.zeros(4, 2, device=device)) + out = normalize_devices(out) + self.assertExpectedInline( + out.code, + """\ + + + +def forward(self, x_1) -> torch.Tensor: + ones = torch.ops.aten.ones.default([2], device = 'cpu', pin_memory = False) + view = torch.ops.aten.view.default(x_1, [4, 2]) + add = torch.ops.aten.add.Tensor(view, ones); view = ones = None + view_1 = torch.ops.aten.view.default(add, [4, 2]); add = None + view_2 = torch.ops.aten.view.default(view_1, [4, 2]); view_2 = None + copy_ = torch.ops.aten.copy_.default(x_1, view_1); x_1 = copy_ = None + return view_1 + """, + ) + + def test_functionalize_nonfunctional_output(self, device): + global_out = torch.ones(2, device=device) + + def f() -> torch.Tensor: + return global_out + + out = make_fx(functionalize(f))() + out = normalize_devices(out) + self.assertExpectedInline( + out.code, + """\ + + + +def forward(self) -> torch.Tensor: + _tensor_constant0 = self._tensor_constant0 + return _tensor_constant0 + """, + ) + + def test_functionalize_optional_tensorlist1(self, device): + def f(a, b) -> torch.Tensor: + # at::index has OptionalTensorList arguments, + # test that here + return a[b] + + a = torch.arange(4).reshape(2, 2) + b = torch.ones(2, dtype=torch.long) + out = make_fx(functionalize(f))(a, b) + out = normalize_devices(out) + self.assertExpectedInline( + out.code, + """\ + + + +def forward(self, a_1, b_1) -> torch.Tensor: + index = torch.ops.aten.index.Tensor(a_1, [b_1]); a_1 = b_1 = None + return index + """, + ) + + @unittest.skipIf(IS_FBCODE, "fails in fbcode") + def test_functionalize_optional_tensorlist2(self, device): + def f(a, b) -> torch.Tensor: + # See https://github.com/pytorch/pytorch/pull/77846 + return torch.ops.aten.index(a, b) + + a = torch.arange(4).reshape(2, 2) + b = torch.ones(2, dtype=torch.long) + out = make_fx(functionalize(f))(a, b) + self.assertExpectedInline( + out.code, + """\ + + + +def forward(self, a_1, b_1) -> torch.Tensor: + unbind = torch.ops.aten.unbind.int(b_1); b_1 = None + getitem = unbind[0] + getitem_1 = unbind[1]; unbind = None + index = torch.ops.aten.index.Tensor(a_1, [getitem, getitem_1]); a_1 = getitem = getitem_1 = None + return index + """, + ) + + def test_resize_program_inputs(self, device): + def f(x): + x.resize_(10) + x.fill_(2) + + fn = make_fx(functionalize(f)) + out = fn(torch.zeros(0, device=device)) + out = normalize_devices(out) + self.assertExpectedInline( + (out.code), + """\ + + + +def forward(self, x_1): + resize = torch.ops.aten.resize.default(x_1, [10]) + fill = torch.ops.aten.fill.Scalar(resize, 2); resize = None + resize_ = torch.ops.aten.resize_.default(x_1, [10]); x_1 = None + copy_ = torch.ops.aten.copy_.default(resize_, fill); resize_ = fill = copy_ = None + return None + """, + ) + + +def construct_sum_pyop(): + class MySum(HigherOrderOperator): + def __init__(self): + super().__init__("mysum") + + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) + + mysum = MySum() + + @mysum.py_impl(torch._C._functorch.TransformType.Vmap) + def mysum_batch_rule(interpreter, x, dim): + if not torch._C._functorch.is_batchedtensor(x): + with interpreter.lower(): + x = x.view_as(x) # unnecessary, just here to test the dispatch + return mysum(x, dim) + + bdim = torch._C._functorch.maybe_get_bdim(x) + value = torch._C._functorch.get_unwrapped(x) + + with interpreter.lower(): + value = value.movedim(bdim, 0) + result = mysum(value, dim + 1) + + return torch._C._functorch._add_batch_dim(result, 0, interpreter.level()) + + @mysum.py_impl(torch._C._functorch.TransformType.Grad) + def mysum_grad_rule(interpreter, x, dim): + level = interpreter.level() + + class MySum(torch.autograd.function._SingleLevelFunction): + @staticmethod + def forward(ctx, x, dim): + ctx.x_shape = x.shape + ctx.dim = dim + x = torch._C._functorch._unwrap_for_grad(x, level) + with torch.enable_grad(), interpreter.lower(): + x = x.view_as(x) # unnecessary, just here to test the dispatch + y = mysum(x, dim) + + y = torch._C._functorch._wrap_for_grad(y, level) + return y + + @staticmethod + def backward(ctx, gy): + return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None + + with enable_single_level_autograd_function(): + return MySum.apply(x, dim) + + @mysum.py_impl(torch._C.DispatchKey.AutogradCPU) + def mysum_autograd_cpu(x, dim): + return torch.sum(x, dim) + + @mysum.py_impl(torch._C.DispatchKey.AutogradCUDA) + def mysum_autograd_cuda(x, dim): + return torch.sum(x, dim) + + return mysum + + +sum_pyop = construct_sum_pyop() + + +@markDynamoStrictTest +class TestHigherOrderOperatorInteraction(TestCase): + def test_basic_sum(self, device): + x = torch.randn(2, 3, 4, device=device) + result = sum_pyop(x, 1) + self.assertEqual(result, torch.sum(x, 1)) + + def test_vmap_sum(self, device): + x = torch.randn(2, 3, 4, device=device) + result = vmap(sum_pyop, (0, None))(x, 0) + self.assertEqual(result, torch.sum(x, 1)) + + result = vmap(vmap(sum_pyop, (0, None)), (0, None))(x, 0) + self.assertEqual(result, torch.sum(x, 2)) + + def test_grad_sum(self, device): + x = torch.randn(3, device=device) + gx = grad(sum_pyop)(x, 0) + self.assertEqual(gx, torch.ones_like(x)) + + def test_grad_grad_sum(self, device): + x = torch.randn(3, requires_grad=True, device=device) + + def f(x): + # higher order grad. Requires a non-linearity + return sum_pyop(x.sin(), 0) + + def grad_f_sum(x): + return grad(f)(x).sum() + + ggx = grad(grad_f_sum)(x) + self.assertEqual(ggx, -x.sin()) + + def test_vmap_grad_sum(self, device): + x = torch.randn(2, 3, device=device) + gx = vmap(grad(sum_pyop), (0, None))(x, 0) + self.assertEqual(gx, torch.ones_like(x)) + + def test_no_grad_outside_grad(self, device): + x = torch.randn(3, device=device, requires_grad=True) + with torch.no_grad(): + y = grad(sum_pyop)(x, 0) + self.assertEqual(y, torch.ones_like(x)) + self.assertFalse(y.requires_grad) + + def test_no_grad_inside_grad(self, device): + def f(x): + with torch.no_grad(): + shift = sum_pyop(x**2, 0) + return sum_pyop(x**2, 0) - shift + + x = torch.randn(3, device=device) + y = grad(f)(x) + self.assertEqual(y, 2 * x) + y = grad(lambda x: grad(f)(x).sum())(x) + self.assertEqual(y, torch.full_like(x, 2)) + + x = torch.randn(3, device=device, requires_grad=True) + y = grad(f)(x) + (z,) = torch.autograd.grad(y.sum(), x) + self.assertEqual(z, torch.full_like(x, 2)) + + def test_grad_name_wrapping(self, device): + def my_fn(x): + return x.sum() + + grad_fn = grad(my_fn) + self.assertEqual(grad_fn.__name__, "my_fn") + + def test_functional_call_multiple_dicts(self): + mod = nn.Linear(1, 1) + x = torch.randn((1, 1)) + params = ({"weight": torch.zeros(1, 1)}, {"bias": torch.ones(1)}) + functional_call(mod, params, x) + + +def traceable(f): + f = allow_in_graph(f) + + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + return wrapper + + +@markDynamoStrictTest +class TestCompileTransforms(TestCase): + # torch.compile is not supported on Windows CUDA. + # Triton only supports GPU with SM70 or later. + @expectedFailureIf((IS_WINDOWS and TEST_CUDA) or (TEST_CUDA and not SM70OrLater)) + @unittest.skipIf( + TEST_CUDA_MEM_LEAK_CHECK, + "Leaking memory, see https://github.com/pytorch/pytorch/pull/150059 for example", + ) + def test_compile_vmap_hessian(self, device): + # The model and inputs are a smaller version + # of code at benchmark repo: + # https://github.com/pytorch/benchmark/blob/main/userbenchmark/functorch/vmap_hessian_fc.py + D = 2 + B = 4 + + x = torch.randn(B, D, device=device) + + model = nn.Sequential(nn.Linear(D, D), nn.ReLU()).to(device) + + params_and_buffers = ( + dict(model.named_parameters()), + dict(model.named_buffers()), + ) + + def predict(params_and_buffers, x): + out = torch.func.functional_call(model, params_and_buffers, x) + return out, out + + fn = vmap( + jacfwd(jacrev(predict, argnums=1, has_aux=True), argnums=1, has_aux=True), + in_dims=(None, 0), + ) + + expected = fn(params_and_buffers, x) + + opt_fn = torch.compile(traceable(fn)) + actual = opt_fn(params_and_buffers, x) + self.assertEqual(actual, expected) + + # torch.compile is not supported on Windows + @torch._dynamo.config.patch(suppress_errors=False) + def test_grad_deprecated_api(self, device): + x = torch.randn((), device=device) + y = torch.randn((), device=device) + + def wrapper_fn(x, y): + return functorch.grad(torch.mul)(x, y) + + actual = wrapper_fn(x, y) + expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y) + torch.compile(wrapper_fn, backend="eager", fullgraph=True) + self.assertEqual(actual, expected) + + def wrapper_fn(x, y): + return functorch.grad(torch.mul, argnums=(0, 1))(x, y) + + actual = wrapper_fn(x, y) + expected = torch.compile(wrapper_fn, backend="eager", fullgraph=True)(x, y) + self.assertEqual(actual, expected) + + +class TestGradTrackingTensorToList(TestCase): + """Tests for tolist() method with GradTrackingTensor (functorch tensors).""" + + def test_tolist_with_grad(self): + """Test to see if tolist works inside grad transformation.""" + + def f(x): + # inside grad, x is a GradTrackingTensor + result = x.tolist() + # tolist should return a python list and not fail + self.assertIsInstance(result, list) + self.assertEqual(result, [1.0, 2.0, 3.0]) + return (x**2).sum() + + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + grad_f = torch.func.grad(f) + result = grad_f(x) + self.assertIsInstance(result, torch.Tensor) + # gradients should still be computed correctly + self.assertEqual(result, [2.0, 4.0, 6.0]) + + def test_tolist_nested_grad(self): + """Test `tolist` with nested grad transformations.""" + + def f(x): + def g(y): + # y is gradTrackingTensor(lvl=1) + inner_list = y.tolist() + self.assertIsInstance(inner_list, list) + return (y**2).sum() + + # x is a gradTrackingTensor(lvl=0) + outer_list = x.tolist() + self.assertIsInstance(outer_list, list) + grad_g = torch.func.grad(g) + return grad_g(x).sum() + + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + grad_f = torch.func.grad(f) + result = grad_f(x) + # should compute second derivate + self.assertIsInstance(result, torch.Tensor) + # grad_f should return the derivate of g(y) which is (2*x).sum + self.assertEqual( + result, + [ + 2.0, + 2.0, + 2.0, + ], + ) + + def test_tolist_multidimensional_grad(self): + """Test tolist with multi-dimensional tensors in grad.""" + + def f(x): + result = x.tolist() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + self.assertEqual(result, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + return x.sum() + + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True) + grad_f = torch.func.grad(f) + result = grad_f(x) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual( + result, + [ + [ + 1.0, + 1.0, + 1.0, + ], + [1.0, 1.0, 1.0], + ], + ) + + def test_tolist_conj_neg_grad(self): + """Test tolist method with conjugate/negative tensors in grad context.""" + + def f(x): + # test with the conjugate view + x_conj = x.conj() + result_conj = x_conj.tolist() + self.assertIsInstance(result_conj, list) + return (x * x.conj()).real.sum() + + x = torch.tensor([1.0 + 2.0j, 3.0 + 4.0j], requires_grad=True) + grad_f = torch.func.grad(f) + result = grad_f(x) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result, [2.0 + 4.0j, 6.0 + 8.0j]) + + +only_for = ("cpu", "cuda", "xpu") +instantiate_device_type_tests( + TestGradTransform, globals(), only_for=only_for, allow_xpu=True +) +instantiate_device_type_tests( + TestVmapOfGrad, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestJac, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestJvp, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestLinearize, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestVmapJvpInplaceView, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestHessian, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestComposability, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestExamplesCorrectness, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestHigherOrderOperatorInteraction, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestFunctionalize, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestAutogradFunction, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestAutogradFunctionVmapAPI, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestHelpers, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_parametrized_tests( + TestMakeFunctional, +) +instantiate_device_type_tests( + TestCompileTransforms, + globals(), + only_for=only_for, + allow_xpu=True, +) +instantiate_device_type_tests( + TestGradTrackingTensorToList, globals(), only_for=only_for, allow_xpu=True +) + +if __name__ == "__main__": + run_tests() diff --git a/test/xpu/test_cpp_api_parity_xpu.py b/test/xpu/test_cpp_api_parity_xpu.py new file mode 100644 index 0000000000..da3fdf3060 --- /dev/null +++ b/test/xpu/test_cpp_api_parity_xpu.py @@ -0,0 +1,92 @@ +# Owner(s): ["module: cpp"] + + +import os +import sys + +test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../test") +sys.path.extend([test_dir]) +import torch +import torch.testing._internal.common_nn as common_nn +import torch.testing._internal.common_utils as common +from cpp_api_parity import ( + functional_impl_check, + module_impl_check, + sample_functional, + sample_module, +) +from cpp_api_parity.parity_table_parser import parse_parity_tracker_table +from cpp_api_parity.utils import is_torch_nn_functional_test + +# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose) +PRINT_CPP_SOURCE = False + +devices = ["cpu", "cuda", "xpu"] + +PARITY_TABLE_PATH = os.path.join( + os.path.dirname(__file__) + "/../../../../test", + "cpp_api_parity", + "parity-tracker.md", +) + +parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH) + + +@torch.testing._internal.common_utils.markDynamoStrictTest +class TestCppApiParity(common.TestCase): + module_test_params_map = {} + functional_test_params_map = {} + + +expected_test_params_dicts = [] + +for test_params_dicts, test_instance_class in [ + (sample_module.module_tests, common_nn.NewModuleTest), + (sample_functional.functional_tests, common_nn.NewModuleTest), + (common_nn.module_tests, common_nn.NewModuleTest), + (common_nn.get_new_module_tests(), common_nn.NewModuleTest), + (common_nn.criterion_tests, common_nn.CriterionTest), +]: + for test_params_dict in test_params_dicts: + if test_params_dict.get("test_cpp_api_parity", True): + if is_torch_nn_functional_test(test_params_dict): + functional_impl_check.write_test_to_test_class( + TestCppApiParity, + test_params_dict, + test_instance_class, + parity_table, + devices, + ) + else: + module_impl_check.write_test_to_test_class( + TestCppApiParity, + test_params_dict, + test_instance_class, + parity_table, + devices, + ) + expected_test_params_dicts.append(test_params_dict) + +# Assert that all NN module/functional test dicts appear in the parity test +assert len( + [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name] +) == len(expected_test_params_dicts) * len(devices) + +# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`. +# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) +print([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) +assert len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 6 +# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices) +assert ( + len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name]) + == 6 +) + +module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE) +functional_impl_check.build_cpp_tests( + TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE +) + +if __name__ == "__main__": + common.TestCase._default_dtype_check_enabled = True + common.run_tests() diff --git a/test/xpu/test_expanded_weights_xpu.py b/test/xpu/test_expanded_weights_xpu.py new file mode 100644 index 0000000000..1c25de4e54 --- /dev/null +++ b/test/xpu/test_expanded_weights_xpu.py @@ -0,0 +1,1171 @@ +# Owner(s): ["module: nn"] +import unittest +from dataclasses import dataclass +from functools import partial +from itertools import chain, product + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from torch.nn.utils._expanded_weights import ExpandedWeight +from torch.nn.utils._expanded_weights.expanded_weights_utils import ( + forward_helper, + set_grad_sample_if_exists, + standard_kwargs, + sum_over_all_but_batch_and_last_n, + unpack_expanded_weight_or_tensor, +) +from torch.nn.utils._per_sample_grad import call_for_per_sample_grads +from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + OpDTypes, + ops, +) +from torch.testing._internal.common_methods_invocations import op_db, SampleInput +from torch.testing._internal.common_modules import module_db, modules +from torch.testing._internal.common_nn import ( + get_new_module_tests, + module_tests, + TestBase, +) +from torch.testing._internal.common_utils import ( + freeze_rng_state, + make_tensor, + parametrize, + run_tests, + skipIfTorchDynamo, + TEST_XPU, + TestCase, +) +from torch.utils._pytree import tree_map_only + + +class TestContext: + pass + + +class TestExpandedWeightHelperFunction(TestCase): + def test_forward_helper(self, device): + input = torch.randn(3, 4, device=device) + weight = torch.randn(5, 4, device=device) + bias = torch.randn(5, device=device) + for weight_batched, bias_batched in product([True, False], [True, False]): + maybe_batched_weight = weight + maybe_batched_bias = bias + if weight_batched: + maybe_batched_weight = ExpandedWeight( + weight.clone().requires_grad_(), 3, loss_reduction="sum" + ) + if bias_batched: + maybe_batched_bias = ExpandedWeight( + bias.clone().requires_grad_(), 3, loss_reduction="sum" + ) + args = (input, maybe_batched_weight, maybe_batched_bias) + expanded_args, expanded_kwargs = standard_kwargs(("bias",), args) + res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) + expected = nn.functional.linear(input, weight, bias) + self.assertEqual(res, expected) + + self.assertEqual(len(expanded_args), 2) + assert expanded_args[0] is args[0] # avoids property checks in assertEquals + assert expanded_args[1] is args[1] # avoids property checks in assertEquals + self.assertEqual(len(expanded_kwargs), 1) + assert ( + expanded_kwargs["bias"] is args[2] + ) # avoids property checks in assertEquals + + def test_forward_helper_failure_args(self, device): + weight = torch.randn(5, 4, device=device) + bias = torch.randn(5, device=device) + with self.assertRaisesRegex( + RuntimeError, r"do not support inputs that are also ExpandedWeights." + ): + input = ExpandedWeight( + torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum" + ) + expanded_args, expanded_kwargs = standard_kwargs( + ("bias",), (input, weight, bias) + ) + forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) + with self.assertRaisesRegex( + RuntimeError, r"requires a Tensor as the first input" + ): + expanded_args, expanded_kwargs = standard_kwargs( + ("bias",), (3, weight, bias) + ) + forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) + with self.assertRaisesRegex( + RuntimeError, r"requires a batch dimension but got an input of size 0" + ): + expanded_args, expanded_kwargs = standard_kwargs( + ("bias",), (torch.tensor(3), weight, bias) + ) + forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) + with self.assertRaisesRegex( + RuntimeError, r"0 is not a valid batch size for Expanded Weights" + ): + expanded_args, expanded_kwargs = standard_kwargs( + ("bias",), (torch.randn(0, 1, 2), weight, bias) + ) + forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) + input = torch.randn(3, 4) + for weight_batched, bias_batched in product([True, False], [True, False]): + if not weight_batched and not bias_batched: + continue + maybe_batched_weight = weight + maybe_batched_bias = bias + if weight_batched: + maybe_batched_weight = ExpandedWeight( + weight.clone().requires_grad_(), 4, loss_reduction="sum" + ) + if bias_batched: + maybe_batched_bias = ExpandedWeight( + bias.clone().requires_grad_(), 4, loss_reduction="sum" + ) + with self.assertRaisesRegex( + RuntimeError, + r"Expected ExpandedWeights to have batch size matching input", + ): + expanded_args, expanded_kwargs = standard_kwargs( + ("bias",), (input, maybe_batched_weight, maybe_batched_bias) + ) + forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) + + def test_set_grad_sample_if_exists(self, device): + def test_fn(a): + return grad_sample + + orig_weight = torch.randn(4, device=device, requires_grad=True) + expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum") + grad_sample = torch.randn(3) + set_grad_sample_if_exists(expanded_weight, test_fn) + self.assertTrue(hasattr(orig_weight, "grad_sample")) + self.assertEqual(orig_weight.grad_sample, grad_sample) + + basic_tensor = torch.randn(4, device=device) + set_grad_sample_if_exists(basic_tensor, test_fn) + self.assertFalse(hasattr(basic_tensor, "grad_sample")) + + non_tensor = 3 + set_grad_sample_if_exists(non_tensor, test_fn) + self.assertFalse(hasattr(non_tensor, "grad_sample")) + + def test_set_grad_sample_if_exists_failure(self, device): + def test_fn(a): + return True + + grad_tensor = torch.randn(4, requires_grad=True, device=device) + with self.assertRaisesRegex( + RuntimeError, + r"does not support a mixture of ExpandedWeight parameters and normal Parameters", + ): + set_grad_sample_if_exists(grad_tensor, test_fn) + + def test_unpack_expanded_weight_or_tensor(self, device): + input = torch.randn(3, requires_grad=True, device=device) + self.assertEqual( + input, + unpack_expanded_weight_or_tensor( + ExpandedWeight(input, 3, loss_reduction="sum") + ), + ) + + input.requires_grad_(False) + self.assertEqual(input, unpack_expanded_weight_or_tensor(input)) + self.assertTrue(unpack_expanded_weight_or_tensor(4) is None) + + def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device): + input = torch.randn(3, requires_grad=True, device=device) + self.assertTrue( + unpack_expanded_weight_or_tensor( + ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input + ) + ) + + input.requires_grad_(False) + self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input)) + self.assertTrue( + unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None + ) + + def test_unpack_expanded_weight_or_tensor_failure(self, device): + input = torch.randn(3, requires_grad=True, device=device) + with self.assertRaisesRegex( + RuntimeError, + r"does not support a mixture of ExpandedWeight parameters and normal Parameters", + ): + unpack_expanded_weight_or_tensor(input) + + with self.assertRaisesRegex( + RuntimeError, + r"does not support a mixture of ExpandedWeight parameters and normal Parameters", + ): + unpack_expanded_weight_or_tensor(input, lambda x: x is input) + + def test_sum_over_all_but_batch_and_last_n(self, device): + input = torch.randn(1, 2, 3, 4, 5, device=device) + res = sum_over_all_but_batch_and_last_n(input, 2) + expected = input.sum((1, 2)) + self.assertEqual(res, expected) + + res = sum_over_all_but_batch_and_last_n(input, 0) + expected = input.sum((1, 2, 3, 4)) + self.assertEqual(res, expected) + + res = sum_over_all_but_batch_and_last_n(input, 4) + self.assertEqual(res, input) + + +class TestExpandedWeightFunctional(TestCase): + def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction): + input = sample_input.input + args = sample_input.args + kwargs = sample_input.kwargs + batch_size = input.shape[0] if len(input.shape) > 1 else 1 + + # get per sample grads with ExpandedWeights objects + loss_reduction = "sum" if reduction == torch.sum else "mean" + (ew_input, ew_args, ew_kwargs) = make_expanded_weight( + sample_input, batch_size, loss_reduction + ) + diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) + diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] + diff_input_list = [ + i.orig_weight if isinstance(i, ExpandedWeight) else i + for i in diff_input_list + ] + if not diff_input_list: + return + result = run_op(op, ew_input, *ew_args, **ew_kwargs) + reduction( + result + ).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ + expanded_weight_grad = tuple( + i.grad_sample if hasattr(i, "grad_sample") else i.grad + for i in diff_input_list + ) + + # get per sample grads with for loop + func = partial(run_op, op) + + per_sample_grad = for_loop_per_sample_grad( + batch_size, reduction, input, func, *args, **kwargs + ) + + # check equality + self.assertEqual(len(per_sample_grad), len(expanded_weight_grad)) + if loss_reduction == "mean": + # don't check equality of `input.grad`s since these vanilla tensors won't be scaled + expanded_weight_grad = expanded_weight_grad[1:] + per_sample_grad = per_sample_grad[1:] + for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad): + self.assertEqual(result_grad, expected_grad) + + @ops( + filter(lambda op: op.supports_expanded_weight, op_db), + dtypes=OpDTypes.supported, + allowed_dtypes=(torch.double,), + ) + def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) + for sample_input in supported_inputs(op, sample_inputs): + if ( + op.name == "nn.functional.embedding" + ): # embedding flips its argument order for autograd tests + sample_input = SampleInput( + sample_input.args[0], + args=(sample_input.input,), + kwargs=sample_input.kwargs, + ) + + self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum) + + @ops( + filter(lambda op: op.supports_expanded_weight, op_db), + dtypes=OpDTypes.supported, + allowed_dtypes=(torch.double,), + ) + def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) + for sample_input in supported_inputs(op, sample_inputs): + if ( + op.name == "nn.functional.embedding" + ): # embedding flips its argument order for autograd tests + sample_input = SampleInput( + sample_input.args[0], + args=(sample_input.input,), + kwargs=sample_input.kwargs, + ) + + self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) + + @ops( + filter(lambda op: op.supports_expanded_weight, op_db), + dtypes=OpDTypes.supported, + allowed_dtypes=(torch.double,), + ) + def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) + for sample_input in supported_inputs(op, sample_inputs): + if ( + op.name == "nn.functional.embedding" + ): # embedding flips its argument order for autograd tests + sample_input = SampleInput( + sample_input.args[0], + args=(sample_input.input,), + kwargs=sample_input.kwargs, + ) + sample_input.input.requires_grad_(False) + + self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) + + @skipIfTorchDynamo("Checking error message doesn't work with dynamo") + @ops( + filter(lambda op: op.supports_expanded_weight, op_db), + dtypes=OpDTypes.supported, + allowed_dtypes=(torch.double,), + ) + def test_unsupported_expand_weights(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) + unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False) + for sample_input in unsupported_inputs: + with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"): + if ( + op.name == "nn.functional.embedding" + ): # embedding flips its argument order for autograd tests + sample_input = SampleInput( + sample_input.args[0], + args=(sample_input.input,), + kwargs=sample_input.kwargs, + ) + input = sample_input.input + + batch_size = input.shape[0] if len(input.shape) > 1 else 1 + + # get per sample grads with ExpandedWeights objects + (ew_input, ew_args, ew_kwargs) = make_expanded_weight( + sample_input, batch_size + ) + result = run_op(op, ew_input, *ew_args, **ew_kwargs) + diff_input_list = ( + (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) + ) + diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] + diff_input_list = [ + i.orig_weight if isinstance(i, ExpandedWeight) else i + for i in diff_input_list + ] + result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ + + @ops( + filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported + ) + def test_expanded_weight_forward(self, device, dtype, op): + sample_inputs = op.sample_inputs(device, dtype) + for sample_input in supported_inputs(op, sample_inputs): + if ( + op.name == "nn.functional.embedding" + ): # embedding flips its argument order for autograd tests + sample_input = SampleInput( + sample_input.args[0].clone(), + args=(sample_input.input.clone(),), + kwargs=sample_input.kwargs, + ) + if ( + "cuda" in device + and "max_norm" in sample_input.kwargs + and "padding_idx" in sample_input.kwargs + ): + self.skipTest( + "embedding is non-determinstic in this case, see issue #74679" + ) + batch_size = ( + sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1 + ) + for loss_reduction in ["sum", "mean"]: + (ew_input, ew_args, ew_kwargs) = make_expanded_weight( + sample_input, batch_size, loss_reduction + ) + expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs) + normal_result = run_op( + op, sample_input.input, *sample_input.args, **sample_input.kwargs + ) + self.assertEqual(expanded_weight_result, normal_result) + + def test_expanded_weight_error(self, device): + batch_size = 3 + sample_input = make_tensor( + (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True + ) + sample_weight = make_tensor( + (4), dtype=torch.float32, device=device, requires_grad=True + ) + with self.assertRaisesRegex( + RuntimeError, r"Expanded Weights encountered but cannot handle function" + ): + torch.add( + sample_input, + ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"), + ) + + def _test_embedding_model(self, model, num_embedding, device): + batch_size = 32 + input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device) + return self._test_model( + partial(model, num_embedding=num_embedding), batch_size, input, device + ) + + def _test_conv_model( + self, + model, + input_size, + num_dim, + device, + loss_reduction="sum", + atol=1e-4, + rtol=5e-5, + ): + batch_size = 32 + input_ending = [input_size] * num_dim + input = torch.randn([batch_size, 3] + input_ending, device=device) + return self._test_model( + partial(model, num_dim=num_dim), + batch_size, + input, + device, + loss_reduction, + atol, + rtol, + ) + + def _test_model( + self, + model, + batch_size, + input, + device, + loss_reduction="sum", + atol=1e-4, + rtol=5e-5, + ): + model = model(10).to(device) + targets = torch.randint(0, 10, (batch_size,), device=device) + criterion = CrossEntropyLoss(reduction=loss_reduction) + result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input) + loss = criterion(result, targets) + loss.backward() + result = [] + for weight in model.parameters(): + result.append(weight.grad_sample) + del weight.grad_sample + + expected = [] + for i in range(batch_size): + loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) + expected.append( + torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss)) + ) + + expected = [torch.stack(grad) for grad in zip(*expected)] + for res, exp in zip(result, expected): + self.assertEqual(res, exp, atol=atol, rtol=rtol) + + def _compute_tolerances(self, device): + is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability( + 0 + ) == (8, 6) + return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5) + + @tf32_off() + def test_cnn_model_sum(self, device): + def convnet(num_classes, num_dim): + return nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(128, num_classes, bias=True), + ) + + atol, rtol = self._compute_tolerances(device) + return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol) + + @tf32_off() + def test_cnn_model_mean(self, device): + def convnet(num_classes, num_dim): + return nn.Sequential( + nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AvgPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(128, num_classes, bias=True), + ) + + atol, rtol = self._compute_tolerances(device) + return self._test_conv_model( + convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol + ) + + @parametrize("num_dim", [1, 2, 3]) + @tf32_off() + def test_instance_norm_model(self, num_dim, device): + def instance_norm_model(num_classes, num_dim): + conv_layer = ( + nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d + ) + norm_layer = ( + nn.InstanceNorm1d + if num_dim == 1 + else nn.InstanceNorm2d + if num_dim == 2 + else nn.InstanceNorm3d + ) + return nn.Sequential( + conv_layer(3, 32, kernel_size=3, stride=1, padding=1), + norm_layer(32, affine=True), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(32 * (7**num_dim), num_classes, bias=True), + ) + + atol, rtol = self._compute_tolerances(device) + return self._test_conv_model( + instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol + ) + + @parametrize("num_dim", [1, 2, 3]) + @tf32_off() + def test_group_norm_model(self, num_dim, device): + def group_norm_model(num_classes, num_dim): + conv_layer = ( + nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d + ) + return nn.Sequential( + conv_layer(3, 32, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, 32, affine=True), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(32 * (7**num_dim), num_classes, bias=True), + ) + + atol, rtol = self._compute_tolerances(device) + return self._test_conv_model( + group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol + ) + + @parametrize("num_dim", [1, 2, 3]) + @tf32_off() + def test_layer_norm_model(self, num_dim, device): + def layer_norm_model(num_classes, num_dim): + conv_layer = ( + nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d + ) + normalized_shape = [7] * num_dim + return nn.Sequential( + conv_layer(3, 32, kernel_size=3, stride=1, padding=1), + nn.LayerNorm(normalized_shape, elementwise_affine=True), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(32 * (7**num_dim), num_classes, bias=True), + ) + + atol, rtol = self._compute_tolerances(device) + return self._test_conv_model( + layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol + ) + + def test_embedding_model(self, device): + def embedding_model(num_classes, num_embedding): + return nn.Sequential( + nn.Embedding(num_embedding, 15), + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear(375, num_classes, bias=True), + ) + + return self._test_embedding_model(embedding_model, 16, device) + + def test_group_norm_error(self, device): + # group norm has to call native_group_norm. This checks that it hits the same errors + # that normal group norm would + + N = 3 + C = 5 + inp = torch.randn(N, C) + with self.assertRaisesRegex( + RuntimeError, r"Expected number of channels in input to be divisible" + ): + F.group_norm(inp, 2) # 5 is not divisible by 2 + + +class TestExpandedWeightModule(TestCase): + def _do_test( + self, + module, + input, + args=None, + kwargs=None, + batch_first=True, + atol=None, + rtol=None, + ): + args = args or () + kwargs = kwargs or {} + + batch_dim = 0 if batch_first else 1 + batch_size = input.shape[batch_dim] + diff_input = input.dtype == torch.float or input.dtype == torch.double + if diff_input: + input.requires_grad_() + + with freeze_rng_state(): + # get per sample grads with ExpandedWeights context manager + actual_res = call_for_per_sample_grads( + module, + batch_size=batch_size, + loss_reduction="sum", + batch_first=batch_first, + )(input, *args, **kwargs).sum() + actual_res.backward() + actual_grads = [] + for param in module.parameters(): + actual_grads.append(param.grad_sample) + del param.grad_sample + if diff_input: + actual_grads.append(input.grad.clone()) + input.grad = torch.zeros_like(input.grad) + + # get per sample grads with a for loop + expected_res = torch.tensor( + 0.0, device=input.device, dtype=actual_res.dtype + ) + expected_grads = [] + for i in range(batch_size): + input_slice = input.narrow(batch_dim, i, 1) + input_slice = input_slice.squeeze(batch_dim) + + # h's batch dim is always the first dim. Must be contiguous for CUDA + sliced_args = tree_map_only( + torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args + ) + diff_params = module.parameters() + if diff_input: + diff_params = chain(diff_params, (input_slice,)) + res = module( + input_slice.unsqueeze(batch_dim).contiguous(), + *sliced_args, + **kwargs, + ).sum() + out_grads = torch.autograd.grad( + res, diff_params, torch.ones_like(res), allow_unused=True + ) + expected_grads.append(out_grads) + expected_res += res + expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] + if not batch_first: + expected_grads[-1] = expected_grads[-1].transpose(0, 1) + self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol) + [ + self.assertEqual(actual, expected, atol=atol, rtol=rtol) + for (actual, expected) in zip(actual_grads, expected_grads) + ] + + def _do_test_multi_input(self, module, input): + class TestModule(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, input): + return self.module(input) + self.module(input) + + batch_size = input.shape[0] + diff_input = input.dtype == torch.float or input.dtype == torch.double + if diff_input: + input.requires_grad_() + with freeze_rng_state(): + # get per sample grads with ExpandedWeights context manager, calling .backward() twice + test_module = TestModule(module) + actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")( + input + ).sum() + actual_res.backward() + actual_grads = [] + for param in module.parameters(): + actual_grads.append(param.grad_sample) + del param.grad_sample + if diff_input: + actual_grads.append(input.grad.clone()) + input.grad = torch.zeros_like(input.grad) + + # get per sample grads with a for loop, running over the input twice + expected_grads = [] + for i in range(batch_size): + input_slice = input[i] + diff_params = module.parameters() + if diff_input: + diff_params = chain(diff_params, (input_slice,)) + res = module(input_slice.unsqueeze(0)).sum() + out_grads = torch.autograd.grad( + res, diff_params, torch.ones_like(res), allow_unused=True + ) + expected_grads.append(out_grads) + expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads)) + expected_grads = tuple( + expected_grad + for expected_grad in expected_grads + if expected_grad is not None + ) + assert [ + self.assertEqual(actual, 2 * expected) + for (actual, expected) in zip(actual_grads, expected_grads) + ] + + def _do_test_rnn_packed_sequence( + self, module, input, args=None, kwargs=None, atol=None, rtol=None + ): + args = args if args is not None else () + kwargs = kwargs if kwargs is not None else {} + + batch_size = max(tuple(input.batch_sizes)).item() + + with freeze_rng_state(): + # get per sample grads with ExpandedWeights context manager + actual_res = call_for_per_sample_grads( + module, batch_size=batch_size, loss_reduction="sum" + )(input, *args, **kwargs).data.sum() + actual_res.backward() + actual_grads = [] + for param in module.parameters(): + self.assertEqual(param.grad_sample.shape[0], batch_size) + actual_grads.append(param.grad_sample) + del param.grad_sample + + input.data.grad = torch.zeros_like(input.data) + + # compute the per sample grads with a for loop + expected_res = torch.zeros_like(actual_res) + expected_grads = [] + padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence( + input, batch_first=True + ) + for i in range(len(seq_sizes)): + input_slice = padded_input[i].narrow(0, 0, seq_sizes[i]) + diff_params = module.parameters() + batch_dim = 0 if module.m.batch_first else 1 + res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum() + expected_res += res + out_grads = torch.autograd.grad( + res, diff_params, torch.ones_like(res), allow_unused=True + ) + expected_grads.append(out_grads) + + expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] + self.assertEqual(actual_res, expected_res, atol=atol, rtol=rtol) + [ + self.assertEqual(actual, expected, atol=atol, rtol=rtol) + for (actual, expected) in zip(actual_grads, expected_grads) + ] + + @modules( + filter( + lambda m_info: m_info.module_cls + in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), + module_db, + ) + ) + @tf32_off() + def test_module(self, device, dtype, module_info, training): + class RNNWrapper(torch.nn.Module): + def __init__(self, m_cons, args, kwargs): + super().__init__() + self.m = m_cons(*args, **kwargs) + + def forward(self, *inps): + ret = self.m(*inps) + assert isinstance(ret, tuple) + return ret[0] + + def batch_hidden(h): + new_h_shape = [1] * (len(h.shape) + 1) + new_h_shape[1] = 2 + return h.unsqueeze(1).repeat(new_h_shape) + + module_cls = module_info.module_cls + atol, rtol = (1e-3, 1e-4) if dtype == torch.float32 else (None, None) + module_inputs = module_info.module_inputs_func( + module_info, + device=device, + dtype=dtype, + requires_grad=True, + training=training, + with_packed_sequence=True, + ) + for module_input in module_inputs: + if module_input.forward_input is None: + continue + args, kwargs = ( + module_input.constructor_input.args, + module_input.constructor_input.kwargs, + ) + m = RNNWrapper(module_cls, args, kwargs) + batch_first = m.m.batch_first + m.to(device).to(dtype) + + args, kwargs = ( + module_input.forward_input.args, + module_input.forward_input.kwargs, + ) + + # if the RNN tests use unbatched inputs--batch the inputs + input = args[0] + if isinstance(input, torch.Tensor) and input.dim() == 2: + input = input.detach() + new_input_shape = [1] * (len(input.shape) + 1) + if batch_first: + new_input_shape[0] = 2 + input = input.repeat(new_input_shape) + else: + new_input_shape[1] = 2 + input = input.unsqueeze(1).repeat(new_input_shape) + + h = args[1] if len(args) > 1 else None + if h is not None: + h = ( + batch_hidden(h) + if isinstance(h, torch.Tensor) + else tuple(batch_hidden(hx) for hx in h) + ) + args = list(args) + args[1] = h + + if isinstance(input, torch.nn.utils.rnn.PackedSequence): + self._do_test_rnn_packed_sequence( + m, input, args[1:], kwargs, atol=atol, rtol=rtol + ) + else: + self._do_test( + m, + input, + args[1:], + kwargs, + batch_first=batch_first, + atol=atol, + rtol=rtol, + ) + + def test_per_sample_api_failing(self): + module = nn.Linear(10, 10) + input = torch.randn(64, 10) + with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): + call_for_per_sample_grads("fail")(input) + with self.assertRaisesRegex( + RuntimeError, r"Batch size passed must be None or an integer" + ): + call_for_per_sample_grads(module, batch_size=6.4)(input) + with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): + call_for_per_sample_grads(module, batch_size=-64)(input) + with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): + loss = call_for_per_sample_grads(module)(input).sum() + loss.backward() # populate grad_sample fields + call_for_per_sample_grads(module)(input) + + module = nn.Linear(10, 10) # reset to not have grad_sample fields + with self.assertRaisesRegex( + RuntimeError, r"Expected loss_reduction argument to be sum or mean" + ): + call_for_per_sample_grads(module, loss_reduction="")(input) + + def test_per_sample_api_compute_batch_size(self): + class CustomModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(5, 5) + + def forward(self, input1, input2): + return self.linear(input1) + self.linear(input2) + + module = CustomModule() + input1 = torch.randn(4, 5) + input2 = torch.randn(5, 5) + + with self.assertRaisesRegex( + RuntimeError, + "found at least one input with batch size 4 and one with batch size 5", + ): + call_for_per_sample_grads(module)(input1, input2) + + input2 = torch.randn(4, 5) + call_for_per_sample_grads(module)(input1, input2) + + module = CustomModule() + call_for_per_sample_grads(module)(input1, input2=input2) + + module = CustomModule() + call_for_per_sample_grads(module)(input1=input1, input2=input2) + + def test_per_sample_api_compute_batch_size_not_pytreeable(self): + @dataclass + class NonPytreeableTuple: + elem1: torch.Tensor + elem2: torch.Tensor + + class CustomModule(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(5, 5) + + def forward(self, input1, input2): + return self.linear(input1.elem1) + self.linear(input1.elem2) + + input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5)) + model = CustomModule() + with self.assertRaisesRegex( + RuntimeError, + "ExpandedWeights cannot compute the batch size from the inputs", + ): + call_for_per_sample_grads(model)(input, "") + + # would prefer for it to error because input is not pytree-able but that's hard to detect + with self.assertRaisesRegex( + RuntimeError, "Expected ExpandedWeights to have batch size matching input" + ): + call_for_per_sample_grads(model)(input, torch.randn(5)) + + model = CustomModule() # TODO: functional call bug, sam will fix + call_for_per_sample_grads(model)(input, torch.randn(4, 5)) + model = CustomModule() + call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5)) + + +class ContextManagerTests(TestBase): + def __init__(self, *args, **kwargs): + self.test_cpu = kwargs.get("test_cpu", True) + self.test_cuda = kwargs.get("test_cuda", True) + self.test_xpu = kwargs.get("test_xpu", False) + super().__init__(*args, **kwargs) + + @property + def constructor_args(self): + return self._get_arg("constructor_args", False) + + def test_context_manager(self, test_case, device): + kwargs = {"device": device, "dtype": torch.double} + module = self.constructor(*self.constructor_args).to(**kwargs) + if "Embedding" in self.get_name(): + kwargs["dtype"] = torch.long + input = self._get_input().to(**kwargs) + if len(input.shape) == 0 or input.shape[0] == 0: + raise unittest.SkipTest( + "Can't get per sample gradients when no batch dim or batch dim is 0" + ) + if self.constructor == torch.nn.Linear and len(input.shape) == 1: + raise unittest.SkipTest( + "Can't get per sample gradients for input of rank 1" + ) + test_case._do_test(module, input) + + def test_context_manager_multiple_inputs(self, test_case, device): + module = self.constructor(*self.constructor_args).to(device) + input = self._get_input() + if len(input.shape) == 0 or input.shape[0] == 0: + raise unittest.SkipTest( + "Can't get per sample gradients when no batch dim or batch dim is 0" + ) + if self.constructor == torch.nn.Linear and len(input.shape) == 1: + raise unittest.SkipTest( + "Can't get per sample gradients for input of rank 1" + ) + test_case._do_test_multi_input(module, input) + + +def filter_supported_tests(t): + supported_modules = [ + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "Embedding", + "LayerNorm", + "GroupNorm", + "InstanceNorm", + ] + if "module_name" in t and t["module_name"] in supported_modules: + return True + + +# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests +# These currently use the legacy nn tests +supported_tests = [ + t for t in module_tests + get_new_module_tests() if filter_supported_tests(t) +] +for test_param in supported_tests: + if "constructor" not in test_param: + name = test_param.pop("module_name") + test_param["constructor"] = getattr(nn, name) + decorator = test_param.pop("decorator", lambda test: test) + test = ContextManagerTests(**test_param) + test_name = test.get_name() + if hasattr(TestExpandedWeightModule, test_name): + raise RuntimeError("Found two tests with the same name: " + test_name) + test_name_multi_input = test.get_name() + "_multiple_inputs" + if hasattr(TestExpandedWeightModule, test_name_multi_input): + raise RuntimeError("Found two tests with the same name: " + test_name) + if test.test_cpu: + setattr( + TestExpandedWeightModule, + test_name, + decorator(lambda self, test=test: test.test_context_manager(self, "cpu")), + ) + setattr( + TestExpandedWeightModule, + test_name_multi_input, + decorator( + lambda self, test=test: test.test_context_manager_multiple_inputs( + self, "cpu" + ) + ), + ) + if TEST_CUDA and test.test_cuda: + # since this checks derivatives, only use double for precision + setattr( + TestExpandedWeightModule, + test_name + "_cuda_double", + decorator(lambda self, test=test: test.test_context_manager(self, "cuda")), + ) + if TEST_XPU and test.test_xpu: + # since this checks derivatives, only use double for precision + setattr( + TestExpandedWeightModule, + test_name + "_xpu_double", + decorator(lambda self, test=test: test.test_context_manager(self, "xpu")), + ) + +# ------------- HELPER FUNCTIONS ----------------- + + +def run_op(op, input, *args, **kwargs): + r""" + OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative + of the weight, not the input, which can't be differentiable since its dtype is int. Calls op, + using the special ordering that Embedding's OpInfo expects for that case. + """ + if op.name == "nn.functional.embedding": + return op(args[0], input, **kwargs) + else: + return op(input, *args, **kwargs) + + +def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"): + def expanded_weight_or_clone(arg): + if is_diff_tensor(arg): + return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction) + return clone_if_tensor(arg) + + ew_input = clone_if_tensor(sample_input.input) + ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args) + ew_kwargs = { + name: expanded_weight_or_clone(arg) + for (name, arg) in sample_input.kwargs.items() + } + return ew_input, ew_args, ew_kwargs + + +def supported_inputs(op, sample_inputs, supported_inputs=True): + r""" + ExpandedWeights currently does not support some use cases when there's no batch dimension or + operations that would cause inter-batch operations. Removes all of the cases it cannot deal with + """ + + def filter_fn(input): + convolutions = [ + "nn.functional.conv1d", + "nn.functional.conv2d", + "nn.functional.conv3d", + ] + batched_input_size = dict(zip(convolutions, [3, 4, 5])) + if op.name == "nn.functional.linear": + is_supported_input = ( + input.input.dim() > 1 + ) # input of rank 1 means no batch dim + elif op.name == "nn.functional.layer_norm": + normalized_shape = input.args[0] + is_supported_input = ( + input.input.shape != normalized_shape + ) # would cause inter-batch operations + elif op.name in convolutions: + # currently can't deal with padding computation on Python level + is_supported_input = input.input.dim() == batched_input_size[op.name] + elif op.name == "nn.functional.embedding": + idx = input.args[0] + is_supported_input = len(idx.shape) > 1 # there's no batch size + else: + is_supported_input = True + is_supported_input = ( + is_supported_input and input.input.shape[0] > 0 + ) # 0 is not a valid batch size + return is_supported_input if supported_inputs else not is_supported_input + + return [input for input in sample_inputs if filter_fn(input)] + + +def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs): + # get per sample grads by getting derivative for each input in a for loop + per_sample_grad = [] + for i in range(batch_size): + per_sample_input = input[i] + result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs)) + diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values()) + diff_input_list = [ + i + for i in diff_input_list + if isinstance(i, torch.Tensor) and i.requires_grad + ] + per_sample_grad.append( + torch.autograd.grad( + result, diff_input_list, torch.ones_like(result), allow_unused=True + ) + ) + if len(per_sample_grad) == batch_size: + per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad)) + return per_sample_grad + + +def is_diff_tensor(t): + return isinstance(t, ExpandedWeight) or ( + isinstance(t, torch.Tensor) and t.requires_grad + ) + + +def clone_if_tensor(t): + if isinstance(t, torch.Tensor): + res = torch.clone(t).detach() + res.requires_grad_(t.requires_grad) + return res + else: + return t + + +instantiate_device_type_tests( + TestExpandedWeightHelperFunction, globals(), allow_xpu=True +) +instantiate_device_type_tests(TestExpandedWeightFunctional, globals(), allow_xpu=True) +instantiate_device_type_tests(TestExpandedWeightModule, globals(), allow_xpu=True) +if __name__ == "__main__": + run_tests() diff --git a/test/xpu/test_matmul_cuda_xpu.py b/test/xpu/test_matmul_cuda_xpu.py index 0a44d69207..747d0fe37d 100644 --- a/test/xpu/test_matmul_cuda_xpu.py +++ b/test/xpu/test_matmul_cuda_xpu.py @@ -1,397 +1,1295 @@ -# Owner(s): ["module: intel"] +# Owner(s): ["module: linear algebra"] -import re +import contextlib +import time import unittest +from collections.abc import Callable from functools import partial +from itertools import product import torch +import torch.nn.functional as F +from torch._inductor.test_case import TestCase as InductorTestCase +from torch.quantization._quantized_conversions import ( + pack_int4_to_int8, + quantized_weight_reorder_for_mixed_dtypes_linear_cutlass, +) from torch.testing import make_tensor +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + PLATFORM_SUPPORTS_BF16, + PLATFORM_SUPPORTS_GREEN_CONTEXT, + SM100OrLater, + SM53OrLater, + SM80OrLater, + SM90OrLater, +) from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, + onlyCUDA, + onlyOn, tol as xtol, toleranceOverride, ) from torch.testing._internal.common_utils import ( + decorateIf, + getRocmVersion, + IS_JETSON, IS_WINDOWS, + isRocmArchAnyOf, + MI200_ARCH, + NAVI_ARCH, parametrize, run_tests, + runOnRocmArch, + serialTest, + skipIfRocm, + TEST_CUDA, + TEST_WITH_ROCM, + TEST_XPU, TestCase, ) +from torch.testing._internal.inductor_utils import IS_BIG_GPU -try: - from xpu_test_utils import XPUPatchForImport -except Exception as e: - from .xpu_test_utils import XPUPatchForImport - +_IS_SM8X = False +if TEST_CUDA: + _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 -def get_device_capability(device=None): - return (9, 0) +# Protects against includes accidentally setting the default dtype +assert torch.get_default_dtype() is torch.float32 +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" -torch.cuda.get_device_capability = get_device_capability -torch.testing._internal.common_cuda.SM90OrLater = True -with XPUPatchForImport(False): - from test_matmul_cuda import ( - e4m3_type, - e5m2_type, - f8_msg, - mm_float8, - mm_float8_emulated, - tensor_to_scale, - TestFP8MatmulCuda, - TestMatmulCuda, - TestMixedDtypesLinearCuda, - to_fp8_saturated, +def xfailIfSM100OrLaterNonRTXAndCondition(condition_fn): + """ + Conditionally xfail tests on SM100+ datacenter SKUs based on a condition function. + The condition function receives the test parameters dict and returns True to xfail. + """ + computeCapabilityCheck = ( + SM100OrLater and torch.cuda.get_device_capability()[0] != 12 + ) + return decorateIf( + unittest.expectedFailure, + lambda params: computeCapabilityCheck and condition_fn(params), ) -def cublas_addmm( - self, - size: int, - dtype: torch.dtype, - reduced_precision: bool = False, - fp16_accumulate: bool = False, -): - # - # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between - # results from the CUDA invocation of torch.addmm and the CPU invocation - # (which does not use CUDA backend). - # - # Get dims - n, m, p = (size + 1, size, size + 2) - # Disable reduced precision reductions in BFloat16 to bypass some kernels - # which fail the threshold check - orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction - orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction - orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( - reduced_precision +@contextlib.contextmanager +def blas_library_context(backend): + prev_backend = torch.backends.cuda.preferred_blas_library() + torch.backends.cuda.preferred_blas_library(backend) + try: + yield + finally: + torch.backends.cuda.preferred_blas_library(prev_backend) + + +class TestMatmulCuda(InductorTestCase): + def setUp(self): + super().setUp() + torch.backends.cuda.matmul.allow_tf32 = False + + def tearDown(self): + torch.backends.cuda.matmul.allow_tf32 = True + super().tearDown() + + def cublas_addmm( + self, + size: int, + dtype: torch.dtype, + reduced_precision: bool = False, + fp16_accumulate: bool = False, + bias_shape_modifier: Callable | None = None, + ): + # + # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between + # results from the CUDA invocation of torch.addmm and the CPU invocation + # (which does not use CUDA backend). + # + # Get dims + m, k, n = (size + 1, size, size + 2) + # Disable reduced precision reductions in BFloat16 to bypass some kernels + # which fail the threshold check + orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + reduced_precision + ) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + reduced_precision + ) + torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate + # Make random tensors on CPU (seed set on common_utils.py import) + # (Not using numpy because it does not support bfloat16) + make_arg = partial(make_tensor, dtype=dtype, device="cpu") + + bias_shape_modifier = ( + (lambda shape: shape) + if bias_shape_modifier is None + else bias_shape_modifier + ) + m_input = make_arg(bias_shape_modifier((m, n))) + m_1 = make_arg((m, k)) + m_2 = make_arg((k, n)) + m_beta = make_arg(1) + # scale to abate overflows in fp16 accum + if fp16_accumulate: + m_1 = m_1 / 100 + m_2 = m_2 / 100 + # *(B)FLOAT16 Special Handling* + # Backend does not tensorize float16 on CPU, + # and bloat16 may present accuracy issues, + # so convert to float32 for these cases + # (but keep same for other types, e.g. float32 and int*) + if dtype == torch.float16 or dtype == torch.bfloat16: + m_beta = m_beta.to(dtype=torch.float32) + m_input = m_input.to(dtype=torch.float32) + m_1 = m_1.to(dtype=torch.float32) + m_2 = m_2.to(dtype=torch.float32) + # Get CPU result + res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) + # *(B)FLOAT16 Special Handling*`` + # Convert back to (b)float16 + if dtype == torch.float16 or dtype == torch.bfloat16: + m_beta = m_beta.to(dtype=dtype) + m_input = m_input.to(dtype=dtype) + m_1 = m_1.to(dtype=dtype) + m_2 = m_2.to(dtype=dtype) + res_cpu = res_cpu.to(dtype=dtype) + # Move arg tensors to CUDA + m_beta = m_beta.to("cuda") + m_input = m_input.to("cuda") + m_1 = m_1.to("cuda") + m_2 = m_2.to("cuda") + # Get CUDA result + res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) + # Move to CPU for comparison + res_cuda = res_cuda.to("cpu") + # Compare + self.assertEqual(res_cpu, res_cuda) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 + torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate + + @onlyCUDA + # imported 'tol' as 'xtol' to avoid aliasing in code above + @toleranceOverride( + { + torch.float16: xtol(atol=1e-1, rtol=1e-1), + torch.bfloat16: xtol(atol=1e-1, rtol=1e-1), + torch.float32: xtol(atol=1e-1, rtol=1e-1), + } ) - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( - reduced_precision + @dtypes(torch.float16, torch.bfloat16, torch.float32) + @parametrize("size", [100, 1000, 10000]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend): + with blas_library_context(backend): + if ( + TEST_WITH_ROCM + and backend == "cublas" + and isRocmArchAnyOf(NAVI_ARCH) + and getRocmVersion() < (6, 4) + and dtype == torch.float16 + and size >= 10000 + ): + self.skipTest( + f"failed on Navi for ROCm6.3 due to hipblas backend, dtype={dtype} and size={size}" + ) + self.cublas_addmm(size, dtype, False) + + @onlyCUDA + @xfailIfSM100OrLaterNonRTXAndCondition( + lambda params: params.get("dtype") == torch.bfloat16 + and params.get("size") == 10000 ) - torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate - # Make random tensors on CPU (seed set on common_utils.py import) - # (Not using numpy because it does not support bfloat16) - make_arg = partial(make_tensor, dtype=dtype, device="cpu") - m_beta = make_arg(1) - m_input = make_arg((n, p)) - m_1 = make_arg((n, m)) - m_2 = make_arg((m, p)) - # scale to abate overflows in fp16 accum - if fp16_accumulate: - m_1 = m_1 / 100 - m_2 = m_2 / 100 - # *(B)FLOAT16 Special Handling* - # Backend does not tensorize float16 on CPU, - # and bloat16 may present accuracy issues, - # so convert to float32 for these cases - # (but keep same for other types, e.g. float32 and int*) - if dtype == torch.float16 or dtype == torch.bfloat16: - m_beta = m_beta.to(dtype=torch.float32) - m_input = m_input.to(dtype=torch.float32) - m_1 = m_1.to(dtype=torch.float32) - m_2 = m_2.to(dtype=torch.float32) - # Get CPU result - res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) - # *(B)FLOAT16 Special Handling*`` - # Convert back to (b)float16 - if dtype == torch.float16 or dtype == torch.bfloat16: - m_beta = m_beta.to(dtype=dtype) - m_input = m_input.to(dtype=dtype) - m_1 = m_1.to(dtype=dtype) - m_2 = m_2.to(dtype=dtype) - res_cpu = res_cpu.to(dtype=dtype) - # Move arg tensors to CUDA - m_beta = m_beta.to("xpu") - m_input = m_input.to("xpu") - m_1 = m_1.to("xpu") - m_2 = m_2.to("xpu") - # Get CUDA result - res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) - # Move to CPU for comparison - res_cuda = res_cuda.to("cpu") - # Compare - self.assertEqual(res_cpu, res_cuda) - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 - torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate - - -@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) -@dtypes(torch.float16) -def cublas_addmm_alignment(self, dtype): - device = "xpu" - # perturb X, A, or B alignment - for idx in range(0, 3): - for offset in range(1, 3): - offsets = [0, 0, 0] - offsets[idx] = offset - x_offset, a_offset, b_offset = offsets - A = torch.rand( - (5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device - ) - A = A[a_offset:].reshape(5120, 2560) - X = torch.rand( - (26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device + # imported 'tol' as 'xtol' to avoid aliasing in code above + @toleranceOverride( + { + torch.float16: xtol(atol=7e-1, rtol=2e-1), + torch.bfloat16: xtol(atol=1e1, rtol=2e-1), + } + ) + @dtypes(torch.float16, torch.bfloat16) + @parametrize("size", [100, 1000, 10000]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm_reduced_precision( + self, size: int, dtype: torch.dtype, backend + ): + with blas_library_context(backend): + self.cublas_addmm(size, dtype, True) + + @onlyCUDA + # imported 'tol' as 'xtol' to avoid aliasing in code above + @toleranceOverride( + { + torch.float16: xtol(atol=1e-3, rtol=1e-4), + torch.bfloat16: xtol(atol=1e-3, rtol=1e-4), + torch.float32: xtol(atol=1e-3, rtol=1e-4), + } + ) + @dtypes(torch.bfloat16, torch.float16, torch.float32) + @parametrize("size", [128]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend): + with blas_library_context(backend): + # 2D bias + self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: shape) + # 1D bias which is row-broadcast to 2D + self.cublas_addmm( + size, dtype, bias_shape_modifier=lambda shape: (1, shape[-1]) ) - X = X[x_offset:].reshape(26, 1, 2560) - B = torch.rand( - (5120 + b_offset), requires_grad=True, dtype=dtype, device=device + # 1D bias which row-broadcasts + self.cublas_addmm( + size, dtype, bias_shape_modifier=lambda shape: (shape[-1],) ) - B = B[b_offset:].reshape(5120) - out = torch.nn.functional.linear(X, A, B) - self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) - - -TestMatmulCuda.cublas_addmm = cublas_addmm -TestMatmulCuda.test_cublas_addmm_alignment = cublas_addmm_alignment - - -@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) -def _test_scaled_mm_vs_emulated(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - compare_type = torch.float32 - - x = torch.randn(16, 16, device="xpu", dtype=base_dtype) - y = torch.randn(32, 16, device="xpu", dtype=base_dtype).t() - - x_scale = tensor_to_scale(x, input_dtype).float() - y_scale = tensor_to_scale(y, input_dtype).float() - - x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) - y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) - - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype - ) - - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype) - - if output_dtype != base_dtype: - out_scaled_mm = out_scaled_mm.to(compare_type) - out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) - out_emulated = out_emulated.to(compare_type) - out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 3e-3, 3e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - -TestFP8MatmulCuda.test_scaled_mm_vs_emulated = _test_scaled_mm_vs_emulated - - -@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) -def _test_scaled_mm_change_stride(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - compare_type = torch.float32 - - x = torch.empty_strided((16, 16), (16, 1), device="xpu", dtype=base_dtype) - y = torch.empty_strided((16, 32), (1, 64), device="xpu", dtype=base_dtype) - - x.normal_() - y.normal_() - - x_scale = tensor_to_scale(x, input_dtype).float() - y_scale = tensor_to_scale(y, input_dtype).float() - - x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) - y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) + @onlyCUDA + @dtypes(torch.float16) + # m == 4 chooses OUTPUT_TYPE reduction on H200 + # m == 8 chooses OUTPUT_TYPE reduction on A100 + @parametrize("small_size", [4, 8]) + @parametrize("size", [32768]) + @parametrize("backend", ["cublaslt", "cublas"]) + def test_cublas_addmm_no_reduced_precision( + self, small_size: int, size: int, dtype: torch.dtype, backend + ): + with blas_library_context(backend): + torch.backends.cuda.preferred_blas_library(backend) + orig_precision = ( + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + ) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device="cuda") + m2 = torch.ones((size, small_size), dtype=dtype, device="cuda") + m2[size // 2 :, :] = -1.0 + b = torch.zeros((small_size,), dtype=dtype, device="cuda") + out = torch.addmm(b, m1, m2, beta=1.0) + self.assertEqual(out.sum().item(), 0.0) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + orig_precision + ) - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8, a_scale=x_scale, b_scale=y_scale, output_dtype=output_dtype + @onlyCUDA + # imported 'tol' as 'xtol' to avoid aliasing in code above + @toleranceOverride( + { + torch.float16: xtol(atol=7e-1, rtol=2e-1), + torch.bfloat16: xtol(atol=1e1, rtol=2e-1), + } ) - - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated(x_fp8, x_scale, y_fp8, y_scale, output_dtype) - - if output_dtype != base_dtype: - out_scaled_mm = out_scaled_mm.to(compare_type) - out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) - - out_emulated = out_emulated.to(compare_type) - out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 3e-3, 3e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - -TestFP8MatmulCuda.test_scaled_mm_change_stride = _test_scaled_mm_change_stride - - -@unittest.skipIf(IS_WINDOWS, f8_msg) -def _test_float8_error_messages(self, device) -> None: - M, K, N = (1024, 512, 2048) - fill_value = 0.5 - x = torch.full((M, K), fill_value, device=device) - y = torch.full((N, K), fill_value, device=device) - - x_fp8 = x.to(e4m3_type) - y_fp8 = y.to(e4m3_type).t() - - with self.assertRaisesRegex( - RuntimeError, - re.escape( - "For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" - ), + @dtypes(torch.float16, torch.bfloat16) + @parametrize("size", [100, 1000, 10000]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm_reduced_precision_fp16_accumulate( + self, size: int, dtype: torch.dtype, backend ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((1, 1), device="xpu"), - scale_b=torch.ones((1, 2), device="xpu"), - out_dtype=torch.bfloat16, + with blas_library_context(backend): + self.cublas_addmm(size, dtype, False, True) + + @onlyOn(["cuda", "xpu"]) + def test_cublas_and_lt_reduced_precision_fp16_accumulate(self): + orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation + torch.backends.cuda.matmul.allow_fp16_accumulation = True + x = torch.rand(32, 512, 512, device=device_type, dtype=torch.half) + w = torch.rand(512, 512, device=device_type, dtype=torch.half) + b = torch.rand(512, device=device_type, dtype=torch.half) + out = torch.nn.functional.linear(x, w, b) + out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu()) + self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3) + + a = torch.rand(16, 128, 128, device=device_type, dtype=torch.half) + b = torch.rand(16, 128, 128, device=device_type, dtype=torch.half) + c = torch.rand(16, 128, 128, device=device_type, dtype=torch.half) + out = torch.baddbmm(a, b, c) + out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu()) + self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3) + if device_type == "cuda": + torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate + + @onlyOn(["cuda", "xpu"]) + @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) + @dtypes(torch.float16) + def test_cublas_addmm_alignment(self, dtype): + device = device_type + # perturb X, A, or B alignment + for idx in range(3): + for offset in range(1, 3): + offsets = [0, 0, 0] + offsets[idx] = offset + x_offset, a_offset, b_offset = offsets + A = torch.rand( + (5120 * 2560 + a_offset), + requires_grad=True, + dtype=dtype, + device=device, + ) + A = A[a_offset:].reshape(5120, 2560) + X = torch.rand( + (26 * 2560 + x_offset), + requires_grad=True, + dtype=dtype, + device=device, + ) + X = X[x_offset:].reshape(26, 1, 2560) + B = torch.rand( + (5120 + b_offset), requires_grad=True, dtype=dtype, device=device + ) + B = B[b_offset:].reshape(5120) + out = torch.nn.functional.linear(X, A, B) + self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) + + @onlyOn(["cuda", "xpu"]) + @unittest.skipIf(IS_JETSON, "Too large for Jetson") + @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)}) + @dtypes( + *( + [torch.float32, torch.float16] + [torch.bfloat16] + if TEST_WITH_ROCM or SM53OrLater or TEST_XPU + else [] + ) + ) + @parametrize( + "batch_size, N, M, P", + [ + (2, 100, 100, 100), + (2, 1000, 1000, 1000), + (1, 10000, 1000, 10000), + (1, 10000, 10000, 10000), + ], + name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", + ) + def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): + cpu_dtype = dtype + if dtype == torch.float16 or dtype == torch.bfloat16: + cpu_dtype = torch.float32 + + M1 = torch.rand((N, M), device=device, dtype=dtype) + M2 = torch.rand((M, P), device=device, dtype=dtype) + A = torch.rand((N, P), device=device, dtype=dtype) + + def _convert_to_cpu(t): + return t.to(device="cpu", dtype=cpu_dtype) + + M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A]) + + # linear + out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype) + out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu() + self.assertEqual(out1_cpu, out1_gpu) + # test multiply the identity matrix + if N == M and M == P: + M2_eye = torch.eye(N, device=device, dtype=dtype) + out1_eye_gpu = torch.nn.functional.linear( + M1, M2_eye.t(), torch.zeros_like(A) + ) + if runOnRocmArch(MI200_ARCH) and dtype == torch.float16: + self.assertEqual( + M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu(), atol=1e-4, rtol=0.001 + ) + else: + self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu()) + + # baddbmm + def _expand_to_batch(t: torch.Tensor): + return t.expand((batch_size,) + t.size()) + + alpha, beta = 1.0, 1.0 + M1, M2, A, M1_cpu, M2_cpu, A_cpu = map( + _expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu] ) - with self.assertRaisesRegex( - RuntimeError, - re.escape( - " For RowWise scaling, scale_a should be (1024, 1) and scale_b " - "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" - ), - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M, 1), device="xpu"), - scale_b=torch.ones((1, N + 1), device="xpu"), - out_dtype=torch.bfloat16, + out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to( + dtype=dtype ) - with self.assertRaisesRegex( - RuntimeError, - re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M), device="xpu"), - scale_b=torch.ones((N, N), device="xpu"), - out_dtype=torch.bfloat16, + out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu() + self.assertEqual(out2_cpu, out2_gpu) + # test multiply the identity matrix + if N == M and M == P: + M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N) + out2_eye_gpu = torch.baddbmm( + torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha + ) + if runOnRocmArch(MI200_ARCH) and dtype == torch.float16: + self.assertEqual( + M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu(), atol=1e-4, rtol=0.001 + ) + else: + self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu()) + + # cross comparison + self.assertEqual(out1_gpu, out2_gpu[0]) + + @onlyOn(["cuda", "xpu"]) + @skipIfRocm + @parametrize("shape", [2**i for i in range(5, 14)]) + @dtypes(torch.float, torch.half, torch.bfloat16) + def test_cublas_deterministic(self, device, shape, dtype): + inp = torch.randn(shape, shape, device=device, dtype=dtype) + first = torch.matmul(inp, inp) + for _ in range(10): + self.assertEqual(first, torch.matmul(inp, inp), atol=0.0, rtol=0.0) + + def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist): + for a, b, gO, agrad, bgrad, out in zip( + alist, blist, gOlist, agradlist, bgradlist, outlist + ): + a = a.clone().detach().requires_grad_() + b = b.clone().detach().requires_grad_() + out_ref = torch.mm(a, b.t()) + out_ref.backward(gO) + self.assertEqual(out, out_ref) + if agrad is not None: + self.assertEqual(agrad, a.grad) + self.assertEqual(bgrad, b.grad) + + @onlyCUDA + @skipIfRocm + @dtypes(torch.half, torch.bfloat16) + @unittest.skipIf( + not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell" + ) + @serialTest() + def test_cublas_batch_invariance_blackwell(self, device, dtype): + orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + False, + False, ) - - with self.assertRaisesRegex( - RuntimeError, - re.escape("Both scale_a and scale_b must be contiguous for RowWise scaling."), - ): - torch._scaled_mm( - x_fp8, - y_fp8, - scale_a=torch.ones((M, 1), device="xpu"), - scale_b=torch.ones((1, N * 2), device="xpu")[:, ::2], - out_dtype=torch.bfloat16, + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + False, + False, ) - - # Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message. - with self.assertRaisesRegex( - RuntimeError, - r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.", - ): - torch._scaled_mm( - x_fp8, - y_fp8.to(e5m2_type), - scale_a=torch.ones((M, 1), device="xpu"), - scale_b=torch.ones((1, N), device="xpu"), - out_dtype=torch.bfloat16, + with blas_library_context("cublaslt"): + N = 2048 + K = 6144 + M_max = 32 + x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16) + w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t() + full = x @ w + xx = x[:1] + out = xx @ w + self.assertEqual(full[:1], out, atol=0.0, rtol=0.0) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 + + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn( + m, k * n_groups + k * int(strided), device=device, dtype=dtype + )[:, : k * n_groups] + else: + a = torch.randn( + k * n_groups + k * int(strided), m, device=device, dtype=dtype + ).t()[:, : k * n_groups] + + if b_row_major: + b = torch.randn( + n, k * n_groups + k * int(strided), device=device, dtype=dtype + )[:, : k * n_groups] + else: + b = torch.randn( + k * n_groups + k * int(strided), n, device=device, dtype=dtype + ).t()[:, : k * n_groups] + + a.requires_grad_(True) + b.requires_grad_(True) + offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) + + f = F.grouped_mm + out = f(a, b.t(), offs=offs, out_dtype=dtype) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + alist, blist, agradlist, bgradlist = [], [], [], [] + start = 0 + for i in range(n_groups): + alist.append(a[:, start : offs_cpu[i]]) + blist.append(b[:, start : offs_cpu[i]]) + agradlist.append(a.grad[:, start : offs_cpu[i]]) + bgradlist.append(b.grad[:, start : offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) + + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[ + :, :k + ] + else: + a = torch.randn( + k, (m + 2 * s_int) * n_groups, device=device, dtype=dtype + ).t()[: m * n_groups, :] + + if b_row_major: + b = torch.randn( + n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype + )[:: (1 + s_int), :, :k] + else: + b = torch.randn( + n_groups * (1 + s_int), k * (1 + s_int), n, device=device, dtype=dtype + ).transpose(-2, -1)[:: (1 + s_int), :, :k] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.t() + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + for check_zero_size in (False, True): + if check_zero_size and n_groups <= 1: + continue + + a.grad = None + b.grad = None + offs = torch.arange( + m, n_groups * m + 1, m, device=device, dtype=torch.int32 + ) + if check_zero_size: + offs[0] = offs[1] + + f = F.grouped_mm + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) + gO = torch.rand_like(out) + if not check_zero_size: + out.backward(gO) + offs_cpu = offs.cpu() + alist, agradlist, gOlist, outlist = [], [], [], [] + bgradlist = [None] * n_groups if check_zero_size else b.grad + start = 0 + for i in range(n_groups): + alist.append(a[start : offs_cpu[i]]) + agradlist.append( + None if check_zero_size else a.grad[start : offs_cpu[i]] + ) + outlist.append(out[start : offs_cpu[i]]) + gOlist.append(gO[start : offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist) + + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn( + n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype + )[:: (1 + s_int), :, :k] + else: + a = torch.randn( + n_groups * (1 + s_int), k * (1 + s_int), m, device=device, dtype=dtype + ).transpose(-2, -1)[:: (1 + s_int), :, :k] + if b_row_major: + b = torch.randn( + n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype + )[:: (1 + s_int), :, :k] + else: + b = torch.randn( + n_groups * (1 + s_int), k * (1 + s_int), n, device=device, dtype=dtype + ).transpose(-2, -1)[:: (1 + s_int), :, :k] + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + + f = F.grouped_mm + out = f(a, b.transpose(-2, -1), out_dtype=dtype) + gO = torch.rand_like(out) + out.backward(gO) + self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) + + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @dtypes(torch.bfloat16, torch.float32, torch.float16) + def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): + device = "cuda" + s_int = int(strided) + m, n, k, n_groups = 16, 32, 64, 4 + if a_row_major: + a = torch.randn( + n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype + )[:: (1 + s_int), :, :k] + else: + a = torch.randn( + n_groups * (1 + s_int), k * (1 + s_int), m, device=device, dtype=dtype + ).transpose(-2, -1)[:: (1 + s_int), :, :k] + if b_row_major: + b = torch.randn(n * n_groups, k * (1 + s_int), device=device, dtype=dtype)[ + :, :k + ] + else: + b = torch.randn( + k, n * (n_groups + s_int), device=device, dtype=dtype + ).transpose(-2, -1)[: n * n_groups, :] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + for check_zero_size in (False, True): + if check_zero_size and n_groups <= 1: + continue + + offs = torch.arange( + n, n_groups * n + 1, n, device=device, dtype=torch.int32 + ) + if check_zero_size: + offs[0] = offs[1] + + f = F.grouped_mm + out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) + gO = torch.rand_like(out) + if not check_zero_size: + out.backward(gO) + offs_cpu = offs.cpu() + blist, outlist, bgradlist, gOlist = [], [], [], [] + agradlist = [None] * n_groups if check_zero_size else a.grad + start = 0 + for i in range(n_groups): + blist.append(b[start : offs_cpu[i]]) + bgradlist.append(b.grad[start : offs_cpu[i]]) + outlist.append(out[:, start : offs_cpu[i]]) + gOlist.append(gO[:, start : offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + # TODO(future PR): enable compile for torch.nn.functional.grouped_mm fallback path + @unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90") + @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + @parametrize("max_autotune", [False, True]) + def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune): + device = "cuda" + dtype_AB = torch.bfloat16 + dtype_offset = torch.int32 + + align = 16 // dtype_AB.itemsize + + f_ref = F.grouped_mm + + options = {} + if max_autotune: + options.update( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + } + ) + f = torch.compile( + f_ref, + options=options, ) + if op == "2d/2d": + m, n = 3, 7 + m_align = (m + align - 1) // align * align + n_align = (n + align - 1) // align * align + if not a_row_major and not b_row_major: + offs = torch.tensor([0, 1, 6, 6, 7], device=device, dtype=dtype_offset) + else: + offs = torch.tensor( + [0, 8, 16, 16, 27], device=device, dtype=dtype_offset + ) + ngroups = offs.shape[0] + k = offs[-1] + k_align = (k + align - 1) // align * align + + if a_row_major: + A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] + if b_row_major: + B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] + elif op == "2d/3d": + n, k = ( + 7, + 259, + ) # k is larger here, to validate iterating over k tiles on an op + n_align = (n + align - 1) // align * align + k_align = (k + align - 1) // align * align + if a_row_major: + offs = torch.tensor([0, 1, 3, 3, 5], device=device, dtype=dtype_offset) + else: + offs = torch.tensor( + [0, 8, 16, 16, 19], device=device, dtype=dtype_offset + ) + ngroups = offs.shape[0] + m = offs[-1] + m_align = (m + align - 1) // align * align + + if a_row_major: + A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] + if b_row_major: + B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[ + :, :, :k + ] + else: + B = torch.randn( + ngroups, k, n_align, device=device, dtype=dtype_AB + ).transpose(-2, -1)[:, :n, :] + elif op == "3d/2d": + m, k = 3, 13 + m_align = (m + align - 1) // align * align + k_align = (k + align - 1) // align * align + offs = torch.tensor([0, 8, 16, 16, 19], device=device, dtype=dtype_offset) + ngroups = offs.shape[0] + n = offs[-1] + n_align = (n + align - 1) // align * align + + if a_row_major: + A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[ + :, :, :k + ] + else: + A = torch.randn( + ngroups, k, m_align, device=device, dtype=dtype_AB + ).transpose(-2, -1)[:, :m, :] + if b_row_major: + B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] + else: + B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] + elif op == "3d/3d": + offs = None + ngroups = 5 + m, n, k = 3, 7, 13 + m_align = (m + align - 1) // align * align + n_align = (n + align - 1) // align * align + k_align = (k + align - 1) // align * align + if a_row_major: + A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[ + :, :, :k + ] + else: + A = torch.randn( + ngroups, k, m_align, device=device, dtype=dtype_AB + ).transpose(-2, -1)[:, :m, :] + if b_row_major: + B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[ + :, :, :k + ] + else: + B = torch.randn( + ngroups, k, n_align, device=device, dtype=dtype_AB + ).transpose(-2, -1)[:, :n, :] + else: + raise AssertionError(f"Invalid op: {op}") + + C_ref = f_ref(A, B.transpose(-2, -1), offs=offs) + if not IS_BIG_GPU and max_autotune: + with self.assertRaisesRegex( + torch._inductor.exc.InductorError, "NoValidChoicesError" + ): + C = f(A, B.transpose(-2, -1), offs=offs) + else: + C = f(A, B.transpose(-2, -1), offs=offs) + self.assertEqual(C, C_ref) + + @onlyCUDA + @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @parametrize("M", [1, 32, 64]) + @parametrize("N", [1, 32, 64]) + @parametrize("K", [1, 32, 64]) + @parametrize("batch_size", [None, 1, 16]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): + if torch.version.hip: + msg = "accuracy regression in hipblas and hipblaslt in ROCm 7.0 for certain shapes" + if input_dtype == torch.bfloat16 and N == 1 and K == 32 and batch_size: + raise unittest.SkipTest(msg) + if input_dtype == torch.bfloat16 and N == 1 and K == 64 and batch_size: + raise unittest.SkipTest(msg) + if ( + input_dtype == torch.float16 + and M == 32 + and N == 1 + and K == 64 + and batch_size == 1 + ): + raise unittest.SkipTest(msg) + if ( + input_dtype == torch.float16 + and M == 64 + and N == 1 + and K == 64 + and batch_size == 1 + ): + raise unittest.SkipTest(msg) + + device = "cuda" + dtype = input_dtype + with blas_library_context(backend): + + def create_inputs(B=None): + if B is None: + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + else: + a = torch.randn(B, M, K, device=device, dtype=dtype) + b = torch.randn(B, K, N, device=device, dtype=dtype) + return a, b + + a, b = create_inputs(batch_size) + + a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32) + + output_dtypes = [torch.float32] + + if input_dtype != torch.float32: + output_dtypes.append(input_dtype) + + for output_dtype in output_dtypes: + # Catch edge case of incompat with bfloat16 and major version < 8 + if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: + if output_dtype == torch.bfloat16: + continue + + if batch_size: + with self.assertRaises(RuntimeError): + torch.bmm(a, b, out_dtype=output_dtype) + else: + with self.assertRaises(RuntimeError): + torch.mm(a, b, out_dtype=output_dtype) + else: + if batch_size: + out = torch.bmm(a, b, out_dtype=output_dtype) + baseline = ( + torch.bmm(a_fp32, b_fp32) + if output_dtype == torch.float32 + else torch.bmm(a, b) + ) + else: + out = torch.mm(a, b, out_dtype=output_dtype) + baseline = ( + torch.mm(a_fp32, b_fp32) + if output_dtype == torch.float32 + else torch.mm(a, b) + ) + + self.assertEqual(out.dtype, output_dtype) + + torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) + + @onlyCUDA + @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @parametrize("M", [1, 32, 64]) + @parametrize("N", [1, 64]) + @parametrize("K", [1, 32, 64]) + @parametrize("batch_size", [None, 1]) + @parametrize("broadcast_self", [False, True]) + @parametrize("high_precision_self", [False, True]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_addmm_baddmm_dtype_overload( + self, + input_dtype, + M, + N, + K, + batch_size, + broadcast_self, + high_precision_self, + backend, + ): + if torch.version.hip: + msg = "accuracy regression in hipblas and hipblaslt in ROCm 7.0 for certain shapes" + if input_dtype == torch.bfloat16 and N == 1 and K == 32 and batch_size: + raise unittest.SkipTest(msg) + if input_dtype == torch.bfloat16 and N == 1 and K == 64 and batch_size: + raise unittest.SkipTest(msg) + if ( + input_dtype == torch.float16 + and M == 32 + and N == 1 + and K == 64 + and batch_size == 1 + ): + raise unittest.SkipTest(msg) + if ( + input_dtype == torch.float16 + and M == 64 + and N == 1 + and K == 64 + and batch_size == 1 + ): + raise unittest.SkipTest(msg) + + device = "cuda" + dtype = input_dtype + with blas_library_context(backend): + + def create_inputs(B, broadcast_self): + if B is None: + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + c_shape = (M, N) if not broadcast_self else (N) + c = torch.randn(c_shape, device=device, dtype=dtype) + else: + a = torch.randn(B, M, K, device=device, dtype=dtype) + b = torch.randn(B, K, N, device=device, dtype=dtype) + c_shape = (B, M, N) if not broadcast_self else (N) + c = torch.randn(c_shape, device=device, dtype=dtype) + + return a, b, c + + a, b, c = create_inputs(batch_size, broadcast_self) + + a_fp32, b_fp32, c_fp32 = ( + a.to(torch.float32), + b.to(torch.float32), + c.to(torch.float32), + ) -TestFP8MatmulCuda.test_float8_error_messages = _test_float8_error_messages - - -@unittest.skipIf(IS_WINDOWS, f8_msg) -@parametrize("base_dtype", [torch.bfloat16]) -def _test_scaled_mm_vs_emulated_row_wise(self, base_dtype): - torch.manual_seed(42) - input_dtype = e4m3_type - output_dtype = base_dtype - - x = torch.randn(16, 16, device="xpu", dtype=base_dtype) - y = torch.randn(32, 16, device="xpu", dtype=base_dtype).t() - - x_scales = tensor_to_scale(x, input_dtype, dim=1).float() - y_scales = tensor_to_scale(y, input_dtype, dim=0).float() - - x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type) - y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type) - - # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype + output_dtypes = [torch.float32] + + if input_dtype != torch.float32: + output_dtypes.append(input_dtype) + + for output_dtype in output_dtypes: + # Catch edge case of incompat with bfloat16 and major version < 8 + if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: + if output_dtype == torch.bfloat16: + continue + + if batch_size: + with self.assertRaises(RuntimeError): + torch.baddbmm(c, a, b, out_dtype=output_dtype) + else: + with self.assertRaises(RuntimeError): + torch.addmm(c, a, b, out_dtype=output_dtype) + else: + if c.dtype != output_dtype and high_precision_self: + c = c.to(output_dtype) + if batch_size: + out = torch.baddbmm(c, a, b, out_dtype=output_dtype) + if output_dtype == torch.float32: + baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) + else: + baseline = torch.baddbmm(c, a, b) + # test out variant + out_ten = torch.full_like(out, float("nan")) + torch.baddbmm(c, a, b, out_dtype=output_dtype, out=out_ten) + else: + out = torch.addmm(c, a, b, out_dtype=output_dtype) + if output_dtype == torch.float32: + baseline = torch.addmm(c_fp32, a_fp32, b_fp32) + else: + baseline = torch.addmm(c, a, b) + # test out variant + out_ten = torch.full_like(out, float("nan")) + torch.addmm(c, a, b, out_dtype=output_dtype, out=out_ten) + + self.assertEqual(out.dtype, output_dtype) + self.assertEqual(out_ten.dtype, output_dtype) + torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(out_ten, out, atol=0, rtol=0) + + @onlyCUDA + @parametrize("batch_size", [1, 32]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend): + M, N, K = 32, 32, 32 + device = "cuda" + dtype = torch.float16 + with blas_library_context(backend): + torch.backends.cuda.preferred_blas_library(backend) + + orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation + torch.backends.cuda.matmul.allow_fp16_accumulation = True + + def create_inputs(): + a = torch.randn(M, K, device=device, dtype=dtype) + b = torch.randn(K, N, device=device, dtype=dtype) + c = torch.randn(M, N, device=device, dtype=dtype) + return a, b, c + + def expand(tensor): + return tensor.unsqueeze(0).expand(batch_size, *tensor.shape) + + a, b, c = create_inputs() + + with self.assertRaises(Exception): + torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32) + + with self.assertRaises(Exception): + torch.addmm(c, a, b, out_dtype=torch.float32) + + with self.assertRaises(Exception): + torch.bmm( + expand( + a, + ), + expand(b), + out_dtype=torch.float32, + ) + + with self.assertRaises(Exception): + torch.mm(a, b, out_dtype=torch.float32) + + torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum + + @onlyCUDA + @parametrize( + "ops", + [ + ("mm", torch.mm), + ("bmm", torch.bmm), + ("addmm", torch.addmm), + ("baddbmm", torch.baddbmm), + ], ) + def test_input_dimension_checking_out_dtype(self, ops): + op_name, op = ops + B = 2 + M, N, K = 32, 32, 32 - # Calculate emulated F8 mm - out_emulated = mm_float8_emulated(x_fp8, x_scales, y_fp8, y_scales, output_dtype) - - if base_dtype in {torch.bfloat16, torch.float16}: - atol, rtol = 7e-2, 7e-2 - else: - atol, rtol = 2e-3, 2e-3 - - torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - - -TestFP8MatmulCuda.test_scaled_mm_vs_emulated_row_wise = ( - _test_scaled_mm_vs_emulated_row_wise -) + def is_addmm(): + return "add" in op_name + def is_batched(): + return "bmm" in op_name -def _cublas_and_lt_reduced_precision_fp16_accumulate(self): - orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation - torch.backends.cuda.matmul.allow_fp16_accumulation = True - x = torch.rand(32, 512, 512, device="xpu", dtype=torch.half) - w = torch.rand(512, 512, device="xpu", dtype=torch.half) - b = torch.rand(512, device="xpu", dtype=torch.half) - out = torch.nn.functional.linear(x, w, b) - out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu()) - self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3) + if is_batched(): + a = torch.randn(B, M, K, device=device_type, dtype=torch.bfloat16) + mismatch_k_b = torch.randn( + B, K + 1, N, device=device_type, dtype=torch.bfloat16 + ) + c = torch.randn(B, M, N, device=device_type, dtype=torch.bfloat16) + extra_dim_b = a.clone().unsqueeze(0) - a = torch.rand(16, 128, 128, device="xpu", dtype=torch.half) - b = torch.rand(16, 128, 128, device="xpu", dtype=torch.half) - c = torch.rand(16, 128, 128, device="xpu", dtype=torch.half) - out = torch.baddbmm(a, b, c) - out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu()) - self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3) - torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate + mismatch_k_err = ( + "Expected size for first two dimensions of batch2 tensor to be" + ) + extra_dim_err = "batch2 must be a 3D tensor" + else: + a = torch.randn(M, K, device=device_type, dtype=torch.bfloat16) + mismatch_k_b = torch.randn( + K + 1, N, device=device_type, dtype=torch.bfloat16 + ) + c = torch.randn(M, N, device=device_type, dtype=torch.bfloat16) + extra_dim_b = a.clone().unsqueeze(0) + + mismatch_k_err = "mat1 and mat2 shapes cannot be multiplied" + extra_dim_err = "mat2 must be a matrix, got 3-D tensor" + + # Test mismatch K + with self.assertRaisesRegex(RuntimeError, mismatch_k_err): + if is_addmm(): + op(c, a, mismatch_k_b, out_dtype=torch.float32) + else: + op(a, mismatch_k_b, out_dtype=torch.float32) + + # Test extra dimension + with self.assertRaisesRegex(RuntimeError, extra_dim_err): + if is_addmm(): + op(c, a, extra_dim_b, out_dtype=torch.float32) + else: + op(c, extra_dim_b, out_dtype=torch.float32) + + if is_batched(): + with self.assertRaisesRegex( + RuntimeError, + "Expected size for first two dimensions of batch2 tensor to be", + ): + # Test mismatch B for bmm/baddbmm + mismatch_batch_dim_b = torch.randn( + B + 1, K, N, device=device_type, dtype=torch.bfloat16 + ) + if is_addmm(): + op(c, a, mismatch_batch_dim_b, out_dtype=torch.float32) + else: + op(a, mismatch_batch_dim_b, out_dtype=torch.float32) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported" + ) + @serialTest() + def test_greencontext_carveout(self): + a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16) + ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) + ctx.set_context() + torch.matmul(a, a) + torch.cuda.synchronize() + t0 = time.perf_counter() + partial_res = torch.matmul(a, a) + torch.cuda.synchronize() + t1 = time.perf_counter() + ctx.pop_context() + torch.matmul(a, a) + torch.cuda.synchronize() + t2 = time.perf_counter() + full_res = torch.matmul(a, a) + torch.cuda.synchronize() + t3 = time.perf_counter() + self.assertEqual(partial_res, full_res) + self.assertGreater(t1 - t0, t3 - t2) + + +@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") +@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") +@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") +class TestMixedDtypesLinearCuda(TestCase): + @dtypes(torch.float16, torch.bfloat16) + def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"): + version = _get_torch_cuda_version() + if version < (11, 8): + self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+") + + def run_test( + batch_shape, + m, + n, + k, + add_bias, + activation, + dtype, + dtypeq, + device, + rtol, + atol, + ): + if not add_bias and activation != "none": + return + + val_lo, val_hi = -1, 1 + valq_lo, valq_hi = -2, 2 + input = make_tensor( + *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device + ) + weight = make_tensor( + n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device + ) + scale = make_tensor( + (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device + ) + bias = ( + make_tensor( + (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device + ) + if add_bias + else None + ) + input_ref = input.reshape(-1, input.shape[-1]) -TestMatmulCuda.test_cublas_and_lt_reduced_precision_fp16_accumulate = ( - _cublas_and_lt_reduced_precision_fp16_accumulate -) + # First, test plain multiplication. + weight_ref = weight.T.to(input.dtype) * scale.view(1, n) + weightq = ( + pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T + ) + output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n) + output = torch.ops.aten._mixed_dtypes_linear( + input, + quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( + weightq, dtypeq, transpose=False + ), + scale, + ) + torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) + + # Second, test the linear operator itself. + weight_ref = weight.to(input.dtype) * scale.view(n, 1) + weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight + bias_ref = bias.view(1, n) if add_bias else None + output_ref = torch.nn.functional.linear( + input_ref, weight_ref, bias=bias_ref + ).reshape(*input.shape[:-1], n) + if activation == "relu": + relu = torch.nn.ReLU() + output_ref = relu(output_ref) + elif activation == "silu": + silu = torch.nn.SiLU() + output_ref = silu(output_ref) + output = torch.ops.aten._mixed_dtypes_linear( + input, + quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( + weightq, dtypeq, transpose=True + ), + scale, + bias=bias, + activation=activation, + ) + torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) + + dtypeqs = [torch.int8, torch.quint4x2] + batch_shapes = [[], [2], [2, 1]] + shapes = [ + [8, 64, 64], + [8, 64, 128], + [8, 128, 64], + [8, 128, 128], + [8, 128, 192], + [8, 128, 256], + [8, 256, 128], + [8, 256, 384], + [8, 384, 256], + ] + activations = [None, "relu", "silu"] + rtol, atol = 1e-3, 1e-3 + if dtype == torch.bfloat16: + rtol, atol = 1e-2, 1e-3 + for dtypeq, batch_shape, (m, n, k), add_bias, activation in product( + dtypeqs, batch_shapes, shapes, (False, True), activations + ): + run_test( + batch_shape, + m, + n, + k, + add_bias, + activation, + dtype, + dtypeq, + device, + rtol, + atol, + ) -TestMixedDtypesLinearCuda._default_dtype_check_enabled = True -TestFP8MatmulCuda._default_dtype_check_enabled = True -TestMatmulCuda._default_dtype_check_enabled = True instantiate_device_type_tests( - TestMixedDtypesLinearCuda, globals(), only_for=("xpu"), allow_xpu=True + TestMatmulCuda, globals(), allow_xpu=True, except_for="cpu" ) - instantiate_device_type_tests( - TestFP8MatmulCuda, globals(), only_for=("xpu"), allow_xpu=True + TestMixedDtypesLinearCuda, globals(), allow_xpu=True, except_for="cpu" ) -instantiate_device_type_tests( - TestMatmulCuda, globals(), only_for=("xpu"), allow_xpu=True -) if __name__ == "__main__": TestCase._default_dtype_check_enabled = True run_tests() From 27b05fe17123fbb70ae6702a5e82e2e36c64e707 Mon Sep 17 00:00:00 2001 From: "Deng, Daisy" Date: Mon, 8 Dec 2025 06:16:10 +0000 Subject: [PATCH 4/4] add test_fake_tensor_xpu.py updated skiplist --- test/xpu/skip_list_common.py | 5 + test/xpu/test_expanded_weights_xpu.py | 10 +- test/xpu/test_fake_tensor_xpu.py | 2660 +++++++++++++++++++++++++ 3 files changed, 2672 insertions(+), 3 deletions(-) create mode 100644 test/xpu/test_fake_tensor_xpu.py diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 620f3516f9..8dae2d6791 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -805,4 +805,9 @@ "test_sparse_xpu.py": None, "test_sparse_csr_xpu.py": None, "test_nestedtensor_xpu.py": None, + "functorch/test_eager_transforms_xpu.py": None, + "test_cpp_api_parity_xpu.py": None, + "test_expanded_weights_xpu.py": None, + "test_fake_tensor_xpu.py": None, + "test_matmul_cuda_xpu.py": None, } diff --git a/test/xpu/test_expanded_weights_xpu.py b/test/xpu/test_expanded_weights_xpu.py index 1c25de4e54..1229fd3cc2 100644 --- a/test/xpu/test_expanded_weights_xpu.py +++ b/test/xpu/test_expanded_weights_xpu.py @@ -1163,9 +1163,13 @@ def clone_if_tensor(t): instantiate_device_type_tests( - TestExpandedWeightHelperFunction, globals(), allow_xpu=True + TestExpandedWeightHelperFunction, globals(), only_for=("xpu"), allow_xpu=True +) +instantiate_device_type_tests( + TestExpandedWeightFunctional, globals(), only_for=("xpu"), allow_xpu=True +) +instantiate_device_type_tests( + TestExpandedWeightModule, globals(), only_for=("xpu"), allow_xpu=True ) -instantiate_device_type_tests(TestExpandedWeightFunctional, globals(), allow_xpu=True) -instantiate_device_type_tests(TestExpandedWeightModule, globals(), allow_xpu=True) if __name__ == "__main__": run_tests() diff --git a/test/xpu/test_fake_tensor_xpu.py b/test/xpu/test_fake_tensor_xpu.py new file mode 100644 index 0000000000..a7896fcc9e --- /dev/null +++ b/test/xpu/test_fake_tensor_xpu.py @@ -0,0 +1,2660 @@ +# Owner(s): ["module: meta tensors"] +# ruff: noqa: F841 + + +import contextlib +import copy +import dataclasses +import gc +import inspect +import io +import itertools +import pickle +import unittest +import weakref +from unittest.mock import patch + +import numpy as np +import torch +import torch._dynamo +import torch._functorch.config +import torch._prims as prims +import torch.testing._internal.optests as optests +import torch.utils._pytree as pytree +from torch import distributed as dist +from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.testing import make_test_cls_with_patches, rand_strided +from torch._guards import tracing, TracingContext +from torch._higher_order_ops.scan import scan +from torch._subclasses.fake_tensor import ( + _CacheKeyState, + DynamicOutputShapeException, + extract_tensor_metadata, + FakeTensor, + FakeTensorConverter, + FakeTensorMode, + MetadataMismatchError, + unset_fake_temporarily, + UnsupportedOperatorException, +) +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import ( + DimDynamic, + free_symbols, + ShapeEnv, + ShapeEnvSettings, + StatelessSymbolicContext, + statically_known_true, +) +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.testing import FileCheck +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + OpDTypes, + ops, +) +from torch.testing._internal.common_dtype import all_types_complex_float8_and +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + skipIfCrossRef, + skipIfRocm, + skipIfTorchDynamo, + skipIfWindows, + TemporaryFileName, + TEST_WITH_TORCHDYNAMO, + TEST_XPU, + TestCase, + xfailIfTorchDynamo, +) +from torch.testing._internal.custom_op_db import custom_op_db +from torch.testing._internal.inductor_utils import GPU_TYPE +from torch.testing._internal.jit_utils import RUN_CUDA +from torch.testing._internal.two_tensor import TwoTensor +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import TorchDispatchMode + +aten = torch.ops.aten + +torch._dynamo.config.fake_tensor_cache_enabled = True +torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True + + +def expectedFailurePropagateRealTensors(fn): + fn._expected_failure_propagate_real_tensors = True + return fn + + +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + + +class FakeTensorTest(TestCase): + def checkType(self, t, device_str, size): + self.assertTrue(isinstance(t, FakeTensor)) + self.assertEqual(t.device.type, device_str) + self.assertEqual(list(t.size()), size) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_cuda_initialized(self): + # doesn't error + with FakeTensorMode(): + p = torch.randn(4, 2, requires_grad=True, device=device_type) + x = torch.randn(8, 4, device=device_type) + y = torch.mm(x, p).square().sum() + y.backward() + + def test_basic(self): + x = torch.empty(2, 2, device="cpu") + y = torch.empty(4, 2, 2, device="cpu") + with FakeTensorMode() as mode: + x = mode.from_tensor(x) + y = mode.from_tensor(y) + z = x + y + self.assertEqual(z.shape, (4, 2, 2)) + self.assertEqual(z.device, torch.device("cpu")) + self.assertTrue(isinstance(z, FakeTensor)) + + def test_custom_op_fallback(self): + from torch.library import impl, Library + + try: + test_lib = Library("my_test_op", "DEF") # noqa: TOR901 + test_lib.define("foo(Tensor self) -> Tensor") + + @impl(test_lib, "foo", "CPU") + def foo_impl(self): + return self.cos() + + x = torch.empty(2, 2, device="cpu") + with self.assertRaisesRegex( + UnsupportedOperatorException, "my_test_op.foo.default" + ): + with FakeTensorMode(allow_fallback_kernels=True) as mode: + x = mode.from_tensor(x) + torch.ops.my_test_op.foo(x) + + finally: + test_lib._destroy() + + def test_parameter_instantiation(self): + with FakeTensorMode(): + x = torch.rand([4]) + y = torch.nn.parameter.Parameter(x) + self.assertTrue(isinstance(y, torch.nn.Parameter)) + + @unittest.skipIf(not dist.is_available(), "requires distributed") + def test_fsdp_flat_param(self): + from torch.distributed.fsdp._flat_param import FlatParameter + + with FakeTensorMode() as m: + data = torch.randn(2, 2) + param = FlatParameter(data, requires_grad=True) + self.assertIsInstance(param, FlatParameter) + self.assertIsInstance(param, torch.nn.Parameter) + self.assertIsInstance(param, FakeTensor) + + def test_non_parameter_grad(self): + mode = FakeTensorMode() + t = torch.rand([4], requires_grad=True) + fake_t = mode.from_tensor(t) + self.assertEqual(fake_t.requires_grad, t.requires_grad) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + @parametrize( + "dtype", + all_types_complex_float8_and(), + ) + def test_index_cuda_with_cpu(self, dtype): + with FakeTensorMode(): + x = torch.ones([2048], device=device_type, dtype=dtype) + out = x[torch.zeros([36], dtype=torch.int64)] + self.checkType(out, device_type, [36]) + self.assertEqual(out.dtype, dtype) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_shape_take_not_device(self): + with FakeTensorMode(): + x = torch.empty(1, device="cpu") + y = torch.empty(8, 8, device=device_type) + out = x.resize_as_(y) + self.assertEqual(out.shape, (8, 8)) + self.assertEqual(out.device.type, "cpu") + self.assertTrue(isinstance(out, FakeTensor)) + + def test_repr(self): + with FakeTensorMode(): + x = torch.empty(2, 2, device="cpu") + self.assertEqual(repr(x), "FakeTensor(..., size=(2, 2))") + x = torch.empty(2, 2, device="meta") + self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))") + + def test_convert_fake_to_real(self): + x = torch.ones([20]) + with FakeTensorMode(allow_non_fake_inputs=True) as m: + _ = x + 1 + + out = torch._subclasses.fake_utils.try_convert_fake_to_real([x[0:10]]) + + self.assertEqual(torch.ones([10]), out[0]) + + def test_conv_nhwc(self): + x = torch.randn([1, 1024, 16, 16]).to(memory_format=torch.channels_last) + w = torch.randn([256, 1024, 4, 4]).to(memory_format=torch.channels_last) + b = torch.randn([256]) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, w, b): + return torch.ops.aten.convolution( + x, w, b, [1, 1], [0, 0], [1, 1], False, [0, 0], 1 + ) + + model = Model() + with FakeTensorMode(allow_non_fake_inputs=True) as mode: + fake_out = model.forward(x, w, b) + eager_out = model.forward(x, w, b) + self.assertEqual(fake_out.stride(), eager_out.stride()) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_zero_dim(self): + with FakeTensorMode() as mode: + x = torch.tensor(0.0) + y = torch.rand([4, 4], device=device_type) + out = x + y + self.assertEqual(out.shape, (4, 4)) + self.assertEqual(out.device, y.device) + self.assertTrue(isinstance(out, FakeTensor)) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_op_with_zero_dim_bypassed(self): + if torch._functorch.config.fake_tensor_propagate_real_tensors: + self.skipTest("Propagate real tensor not supported") + shape_env = ShapeEnv() + mode = FakeTensorMode(shape_env=shape_env) + x = torch.tensor(1.0, device=device_type) + y = torch.tensor(2.0) + fake_x = mode.from_tensor(x) + fake_y = mode.from_tensor(y) + + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation for.*" + ) as exc: + torch.nextafter(fake_x, fake_y) + + def test_nan_to_num(self): + with FakeTensorMode(): + for dtype in [torch.float16, torch.float32]: + x = torch.rand([4], dtype=dtype) + y = torch.nan_to_num(x, nan=None) + z = torch.nan_to_num(x, 0.0) + self.assertEqual(dtype, y.dtype) + self.assertEqual(dtype, z.dtype) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_throw(self): + x = torch.tensor(0.0) # TODO: tensor() errors + with FakeTensorMode() as mode: + x_conv = mode.from_tensor(x) + y = torch.rand([4, 4], device=device_type) + z = torch.rand([4, 4], device="cpu") + self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z)) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_type_as(self): + with FakeTensorMode(): + x = torch.rand([16, 1], device="cpu") + y = torch.rand([4, 4], device=device_type) + out = x.type_as(y) + self.assertEqual(out.device.type, device_type) + self.assertTrue(isinstance(out, FakeTensor)) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_setitem(self): + for device in ["cpu", device_type]: + with FakeTensorMode(): + x = torch.rand([16, 1], device=device) + x[..., 0] = 0 + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_device_inplace_copy(self): + with FakeTensorMode(): + x = torch.rand([8, 8], device="cpu") + y = torch.rand([8, 8], device=device_type) + assert x.copy_(y).device.type == "cpu" + assert y.copy_(x).device.type == device_type + + def test_fake_device(self): + t = torch.ones(3) + t = t.view(1, 3) + + fake_mode1 = FakeTensorMode(allow_non_fake_inputs=True) + fake_t = fake_mode1.from_tensor(t) + fake_t.fake_device = torch.device(device_type) + + fake_mode2 = FakeTensorMode(allow_non_fake_inputs=True) + new_fake_t = fake_mode2.from_tensor(fake_t) + + self.assertEqual(new_fake_t.device, fake_t.device) + + def test_fake_dispatch_keys(self): + with FakeTensorMode(): + x = torch.rand([4]) + f = ( + FileCheck() + .check("CPU") + .check("ADInplaceOrView") + .check("AutogradCPU") + .check("AutocastCPU") + ) + f.run(torch._C._dispatch_key_set(x)) + + with torch.inference_mode(): + x = torch.rand([4]) + y = x + x + FileCheck().check("CPU").check("AutocastCPU").run( + torch._C._dispatch_key_set(y) + ) + FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run( + torch._C._dispatch_key_set(y) + ) + + def test_batch_tensor(self): + x = torch.rand((3, 4, 5)) + b = _add_batch_dim(x, 0, 0) + mode = FakeTensorMode() + fake_b = mode.from_tensor(b) + prims.utils.compare_tensor_meta(b, fake_b, check_strides=True) + + b1 = _add_batch_dim(x, 1, 1) + b2 = _add_batch_dim(b1, 0, 2) + fake_b2 = mode.from_tensor(b2) + prims.utils.compare_tensor_meta(b2, fake_b2, check_strides=True) + self.assertTrue(is_batchedtensor(fake_b2)) + fake_b1 = get_unwrapped(fake_b2) + self.assertTrue(is_batchedtensor(fake_b1)) + fake_tensor = get_unwrapped(fake_b1) + self.assertIsInstance(fake_tensor, FakeTensor) + + def test_constructor(self): + with FakeTensorMode(): + x = torch.rand([4, 4], device="cpu") + + self.assertTrue(isinstance(x, FakeTensor)) + self.assertTrue(x.device.type == "cpu") + + def test_mode(self): + with FakeTensorMode(): + y = torch.rand([4], device="cpu") + out = y + y + + self.assertTrue(isinstance(out, FakeTensor)) + + def test_full(self): + # Test torch.full returns tensor with correct dtype + with torch._subclasses.CrossRefFakeMode(): + y = torch.full((4, 4), 1) + + def check_function_with_fake(self, fn): + out = fn() + with torch._subclasses.FakeTensorMode(): + out_fake = fn() + + for a, b in zip(pytree.tree_leaves(out), pytree.tree_leaves(out_fake)): + if not isinstance(a, torch.Tensor): + self.assertTrue(not isinstance(b, torch.Tensor)) + continue + + prims.utils.compare_tensor_meta(a, b, check_strides=True) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_non_kwarg_device(self): + with FakeTensorMode(): + x = torch.rand([16, 1], device="cpu") + y = x.to(torch.device("cpu")) + self.assertIs(x, y) + z = x.to(torch.device(device_type)) + self.assertEqual(z.device.type, device_type) + + def test_non_overlapping_stride_zero(self): + def foo(): + x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3)) + return x.half() + + self.check_function_with_fake(foo) + + def test_fake_mode_error(self): + x = torch.rand([4, 4]) + + with self.assertRaisesRegex(Exception, "Please convert all Tensors"): + with FakeTensorMode(): + y = x[0] + + def test_no_tag_func(self): + import functools + + from torch.nn.attention.flex_attention import _identity, flex_attention + + def create_attention(score_mod, block_mask, enable_gqa=False): + return functools.partial( + flex_attention, + score_mod=score_mod, + block_mask=block_mask, + enable_gqa=enable_gqa, + ) + + input_shape = (4, 16, 128, 64) + q = torch.randn( + input_shape, + dtype=torch.bfloat16, + device="cpu", + requires_grad=False, + ) + k = torch.randn( + input_shape, + dtype=torch.bfloat16, + device="cpu", + requires_grad=False, + ) + v = torch.randn( + input_shape, + dtype=torch.bfloat16, + device="cpu", + requires_grad=False, + ) + sdpa_partial = create_attention(_identity, None) + with FakeTensorMode(allow_non_fake_inputs=True): + sdpa_partial(q, k, v, return_lse=False) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + def test_fake_grad_copy(self): + x = torch.rand([4, 4], requires_grad=True) + x.grad = torch.rand([4, 4]) + mode = FakeTensorMode() + fake_x = mode.from_tensor(x) + prims.utils.compare_tensor_meta(fake_x, x) + prims.utils.compare_tensor_meta(fake_x.grad, x.grad) + + self.assertTrue(isinstance(fake_x.grad, FakeTensor)) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_index_put_error(self): + mode = FakeTensorMode() + for context in [contextlib.nullcontext, lambda: mode]: + with context(): + y = torch.randn(2, 2, 3) + x = torch.randn(2, 2, 3).to(device_type) + with self.assertRaises(RuntimeError): + x[[1, 1]] = y + + with self.assertRaises(RuntimeError): + torch.ops.aten.index_put( + x, torch.tensor([1, 1], device=device_type), y + ) + + # no error + torch.ops.aten.index_put( + x, torch.tensor([1, 1], device=device_type), torch.tensor(5.0) + ) + torch.ops.aten.index_put_( + x, torch.tensor([1, 1], device=device_type), torch.tensor(5.0) + ) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_like_constructor(self): + with FakeTensorMode(): + x = torch.rand([4, 4]) + y = torch.ones_like(x) + self.assertTrue(isinstance(y, FakeTensor)) + self.assertEqual(y.device.type, "cpu") + z = torch.ones_like(x, device=device_type) + self.assertTrue(isinstance(z, FakeTensor)) + self.assertEqual(z.device.type, device_type) + + def test_binary_op_type_promotion(self): + with FakeTensorMode(): + x = torch.empty([2, 2], dtype=torch.float) + y = torch.empty([2, 2], dtype=torch.int64) + out = x / y + self.assertEqual(out.dtype, torch.float) + self.assertEqual(out.device.type, "cpu") + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + def test_from_numpy(self): + with FakeTensorMode(): + x = torch.tensor(np.zeros([4, 4])) + self.checkType(x, "cpu", [4, 4]) + + def test_randperm(self): + x = torch.randperm(10) + y = torch.randperm(5, device="cpu") + with FakeTensorMode(): + x1 = torch.randperm(10) + prims.utils.compare_tensor_meta(x, x1) + y1 = torch.randperm(5, device="cpu") + prims.utils.compare_tensor_meta(y, y1) + + def test_print_in_fake_mode(self): + x = torch.zeros(2) + # does not fail + with FakeTensorMode(): + out = str(x) + assert "FakeTensor" not in out + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_upsample_bilinear_small_channels(self): + out = [] + mode = FakeTensorMode() + for context in [contextlib.nullcontext, lambda: mode]: + with context(): + arg0_1 = torch.empty_strided( + (3, 427, 640), (1, 1920, 3), dtype=torch.float32, device=device_type + ) + unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0) + out.append( + torch.ops.aten.upsample_bilinear2d.default( + unsqueeze, [800, 1199], False + ) + ) + + self.assertTrue(out[1].is_contiguous()) + self.checkMetaProps(out[0], out[1]) + + def test_split_return_self(self): + def fn(x): + return torch.functional.split(x, 0)[0] + + # meta should not return self + with FakeTensorMode(), enable_python_dispatcher(): + out_fake = fn(torch.empty((0,))) + + out_eager = fn(torch.empty((0,))) + self.checkMetaProps(out_fake, out_eager) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_cpu_fallback(self): + with FakeTensorMode(allow_fallback_kernels=False): + filters = torch.randn(8, 4, 3, 3).to(device_type) + inputs = torch.randn(1, 4, 5, 5).to(device_type) + out = torch.nn.functional.conv2d(inputs, filters, padding=1) + self.assertEqual(out.device.type, device_type) + self.assertEqual(list(out.size()), [1, 8, 5, 5]) + + with FakeTensorMode(allow_fallback_kernels=True): + # intentionally bad inputs + filters = torch.randn(8, 20, 3, 3).to(device_type) + inputs = torch.randn(1, 7, 10, 5).to(device_type) + with self.assertRaises(RuntimeError): + torch.nn.functional.conv2d(inputs, filters, padding=1) + + with FakeTensorMode(allow_fallback_kernels=True): + filters = torch.randn(8, 4, 3, 3).to(device_type) + inputs = torch.randn(1, 4, 5, 5).to(device_type) + + out = torch.nn.functional.conv2d(inputs, filters, padding=1) + self.assertEqual(out.device.type, device_type) + self.assertEqual(list(out.size()), [1, 8, 5, 5]) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_out_multi_device(self): + with FakeTensorMode(): + x = torch.rand([4]) + y = torch.rand([4], device=device_type) + + with self.assertRaisesRegex(Exception, "found.+two.+devices"): + torch.sin(x, out=y) + + with self.assertRaisesRegex(Exception, "found.+two.+devices"): + x.add_(y) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_normalize_device(self): + with FakeTensorMode(): + x = torch.empty(1, device=device_type) + y = torch.empty( + 1, device=f"{device_type}:{torch.accelerator.current_device_idx()}" + ) + out = x + y + self.checkType(out, device_type, [1]) + + def test_recursive_invocation(self): + mode = FakeTensorMode() + with mode: + x = torch.tensor(2) + mode.in_kernel_invocation = True + y = x + x + self.assertTrue(mode.in_kernel_invocation) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + @skipIfRocm + @parametrize( + "allow_fallback_kernels", + [False, True], + lambda a: "with_fallback" if a else "without_fallback", + ) + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_cudnn_rnn(self, allow_fallback_kernels): + def fn( + a0, + b0, + b1, + b2, + b3, + b4, + b5, + b6, + b7, + b8, + b9, + b10, + b11, + b12, + b13, + b14, + b15, + a3, + a4, + a5, + ): + a1 = [ + b0, + b1, + b2, + b3, + b4, + b5, + b6, + b7, + b8, + b9, + b10, + b11, + b12, + b13, + b14, + b15, + ] + return torch.ops.aten._cudnn_rnn( + a0, + a1, + 4, + a3, + a4, + a5, + 2, + 2048, + 0, + 2, + False, + 0.0, + False, + True, + [], + None, + ) + + mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels) + for i, context in enumerate([contextlib.nullcontext, lambda: mode]): + with context(): + inps1 = [ + torch.randn([92, 8, 2048]).to(device_type), + torch.randn([8192, 2048]).to(device_type), + torch.randn([8192, 2048]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([8192, 2048]).to(device_type), + torch.randn([8192, 2048]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([8192, 4096]).to(device_type), + torch.randn([8192, 2048]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([8192, 4096]).to(device_type), + torch.randn([8192, 2048]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([8192]).to(device_type), + torch.randn([167837696]).to(device_type), + torch.randn([4, 8, 2048]).to(device_type), + torch.randn([4, 8, 2048]).to(device_type), + ] + inps2 = inps1 + inps2[len(inps2) - 1] = None # argument `cx` can be None + + for inps in [inps1, inps2]: + out = fn(*inps) + self.assertIs(out[4], inps[-3]) + for ten in out: + if i == 1: + self.assertTrue(isinstance(ten, FakeTensor)) + self.assertEqual(ten.device.type, device_type) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_cuda_lstm(self): + # Ensure CUDA (non-cuDNN) impl succeeds with fake tensors. + with torch.backends.cudnn.flags(enabled=False): + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) + with fake_tensor_mode: + N = 5 + L = 4 + H_in = 2 + hidden_size = 3 + proj_size = 2 + num_layers = 2 + bidir = False + D = 2 if bidir else 1 + H_out = proj_size if proj_size > 0 else hidden_size + + lstm = torch.nn.LSTM( + input_size=H_in, + hidden_size=hidden_size, + num_layers=num_layers, + proj_size=proj_size, + batch_first=False, + bias=True, + bidirectional=bidir, + device=device_type, + ) + + h_0 = torch.randn((num_layers * D, N, H_out), device=device_type) + c_0 = torch.randn((num_layers * D, N, hidden_size), device=device_type) + inp = torch.randn((L, N, H_in), device=device_type) + (output, (h_n, c_n)) = lstm(inp, (h_0, c_0)) + output.sum().backward() + + self.assertEqual(output.shape, (L, N, D * H_out)) + self.assertEqual(h_n.shape, (D * num_layers, N, H_out)) + self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size)) + + def test_data_dependent_operator(self): + with FakeTensorMode(allow_fallback_kernels=False): + x = torch.rand([10, 10]) + + self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x)) + + def test_parameter_view(self): + x = torch.nn.Parameter(torch.randn(4)) + x_view = x.view(4) + mode = FakeTensorMode() + fake_x_view = mode.from_tensor(x_view) + fake_x = mode.from_tensor(x) + self.assertFalse(isinstance(fake_x_view, torch.nn.Parameter)) + self.assertTrue(isinstance(fake_x, torch.nn.Parameter)) + + def test_tolist(self): + shape_env = ShapeEnv() + with FakeTensorMode(allow_fallback_kernels=False, shape_env=shape_env): + x = torch.rand([10]) + x.tolist() + + # Propagate real tensors doesn't work with fake-on-fake + @expectedFailurePropagateRealTensors + def test_same_shape_env_preserved(self): + shape_env = ShapeEnv() + mode1 = FakeTensorMode(shape_env=shape_env) + t1 = mode1.from_tensor( + torch.randn(10), + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[DimDynamic.DYNAMIC], constraint_sizes=[None] + ), + ) + mode2 = FakeTensorMode(shape_env=shape_env) + t2 = mode2.from_tensor(t1) + # t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here + self.assertIsNot(t2, t1) + self.assertIs(t1.fake_mode, mode1) + self.assertIs(t2.fake_mode, mode2) + self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env) + self.assertEqual(str(t2.size(0)), str(t1.size(0))) + + # TODO: Support NJT. There's also some funny business with dynamic shapes + # which would need to be dealt with as well + @expectedFailurePropagateRealTensors + def test_jagged_fake_to_fake_preserved(self): + from torch.nested._internal.nested_tensor import jagged_from_list + + S0, S1, S2 = 3, 4, 5 + D = 4 + a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64) + b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64) + c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64) + offsets = None + jt, _ = jagged_from_list([a, b, c], offsets) + shape_env = ShapeEnv() + mode1 = FakeTensorMode(shape_env=shape_env) + t1 = mode1.from_tensor(jt) + mode2 = FakeTensorMode(shape_env=shape_env) + t2 = mode2.from_tensor(t1) + # It's not obvious that the invocation above makes it dynamic but it + # does! + self.assertTrue(free_symbols(t1.size())) + self.assertIsNot(t2, t1) + self.assertIs(t1.offsets().fake_mode, mode1) + self.assertIs(t2.offsets().fake_mode, mode2) + self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env) + self.assertEqual(str(t2.size(1)), str(t1.size(1))) + + def checkMetaProps(self, t1, t2): + prims.utils.compare_tensor_meta(t1, t2, check_strides=True) + + @skipIfCrossRef + def test_deepcopy(self): + with FakeTensorMode() as mode: + pass + mod = torch.nn.BatchNorm2d(10) + with torch._subclasses.fake_tensor.FakeCopyMode(mode): + mod_copied = copy.deepcopy(mod) + + def check_copy(mod, mod_copied): + for name, param in itertools.chain( + mod.named_parameters(), mod.named_buffers() + ): + param_copied = getattr(mod_copied, name) + self.checkMetaProps(param, param_copied) + self.assertTrue(isinstance(param_copied, FakeTensor)) + self.assertEqual( + isinstance(param, torch.nn.Parameter), + isinstance(param_copied, torch.nn.Parameter), + ) + self.assertEqual(param.requires_grad, param_copied.requires_grad) + + check_copy(mod, mod_copied) + + class ModuleNew(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.a = torch.rand([10, 2]) + self.b = self.a + self.c = self.a[0] + + mod = ModuleNew() + with torch._subclasses.fake_tensor.FakeCopyMode(mode): + mod_copied = copy.deepcopy(mod) + + self.assertIs(mod_copied.a, mod_copied.b) + self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_new(self): + with FakeTensorMode(): + a = torch.rand([16, 1]) + self.checkType(a.new(10, 10), "cpu", [10, 10]) + self.checkType(a.new([1, 2, 3, 4]), "cpu", [4]) + b = torch.rand([4, 4], device=device_type) + self.checkType(b.new(device=device_type), device_type, [0]) + self.checkType(a.new(torch.rand([1])), "cpu", [1]) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + def test_scalar_inputs(self): + with FakeTensorMode(): + self.checkType(torch.div(3, 2), "cpu", []) + ten = torch.zeros(2, dtype=torch.int32) * 2.0 + self.assertEqual(ten.dtype, torch.float) + self.checkType(ten, "cpu", [2]) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + def test_allow_meta(self): + def run_meta(): + with FakeTensorMode(): + x = torch.rand([4], device="meta") + return x + x + + self.checkType(run_meta(), "meta", [4]) + + with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False): + self.assertRaises(Exception, run_meta) + + def test_embedding_bag_meta(self): + def f(): + # This behavior was originally unintentional but we see people + # relying on it + embedding = torch.nn.EmbeddingBag(10, 3, mode="sum", device="meta") + input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long) + offsets = torch.tensor([0, 4], dtype=torch.long) + return embedding(input, offsets) + + real_out = f() + with FakeTensorMode(): + fake_out = f() + + for r, f in zip(real_out, fake_out): + self.assertEqual(r.size(), f.size()) + self.assertEqual(r.device, f.device) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + def test_mixed_real_and_fake_inputs(self): + class _TestPattern(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + self.bn = torch.nn.BatchNorm2d(1) + + def forward(self, input): + running_std = torch.sqrt(self.bn.running_var + self.bn.eps) + scale_factor = self.bn.weight / running_std + weight_shape = [1] * len(self.conv.weight.shape) + weight_shape[0] = -1 + bias_shape = [1] * len(self.conv.weight.shape) + bias_shape[1] = -1 + scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape) + zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype) + conv = self.conv._conv_forward(input, scaled_weight, zero_bias) + conv_orig = conv / scale_factor.reshape(bias_shape) + conv_orig = conv_orig + self.conv.bias.reshape(bias_shape) + conv = self.bn(conv_orig) + return conv + + example_inputs = (torch.randn(1, 1, 3, 3),) + mod = _TestPattern() + with FakeTensorMode(allow_non_fake_inputs=True): + out = mod(torch.randn(1, 1, 3, 3)) + self.checkType(out, "cpu", (1, 1, 3, 3)) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_aten_copy_multi_device(self): + with FakeTensorMode(): + x1 = torch.rand(4, device="cpu") + x2 = torch.rand(4, device=device_type) + copy1 = torch.ops.aten.copy.default(x1, x2) + copy2 = torch.ops.aten.copy.default(x2, x1) + out = torch.empty(4, device="cpu") + torch.ops.aten.copy.out(x1, x2, out=out) + self.checkType(copy1, "cpu", (4,)) + self.checkType(copy2, device_type, (4,)) + self.checkType(out, "cpu", (4,)) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_aten_index_multi_device(self): + with FakeTensorMode(): + x1 = torch.rand(4, 4, device="cpu") + x2 = torch.rand(4, 4, device=device_type) + i1 = torch.tensor([0, 1], device=device_type) + i2 = torch.tensor([0, 1], device="cpu") + # NB: This one does not work: cuda indices not allowed on cpu + # tensor + # r1 = torch.ops.aten.index(x1, i1) + r2 = torch.ops.aten.index(x2, i2) + + y1 = torch.rand(4, device="cpu") + y2 = torch.rand(4, device=device_type) + j1 = torch.tensor([2], device=device_type) + j2 = torch.tensor([2], device="cpu") + r3 = torch.ops.aten.index_put.default(x1, j1, y1) + r4 = torch.ops.aten.index_put.default(x2, j2, y2) + # self.checkType(r1, "cpu", ()) + self.checkType(r2, device_type, ()) + self.checkType(r3, "cpu", (4, 4)) + self.checkType(r4, device_type, (4, 4)) + + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" + ) + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_aten_slice_scatter_multi_device(self): + with FakeTensorMode(): + x1 = torch.rand(4, 4, device="cpu") + y1 = torch.rand(2, 4, device=device_type) + x2 = torch.rand(4, 4, device=device_type) + y2 = torch.rand(2, 4, device="cpu") + out = torch.empty(4, 4, device="cpu") + r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2) + r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2) + r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2) + self.checkType(r1, "cpu", (4, 4)) + self.checkType(r2, device_type, (4, 4)) + self.checkType(r3, "cpu", (4, 4)) + self.checkType(out, "cpu", (4, 4)) + + def test__adaptive_avg_pool2d_backward(self): + with FakeTensorMode(): + grad_out = torch.rand(2, 3, 4, 4) + inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last) + grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp) + self.assertTrue( + torch._prims_common.suggest_memory_format(grad_in) + == torch.channels_last + ) + + def test_export_numpy(self): + class MyNumpyModel(torch.nn.Module): + def forward(self, input): + input = input.numpy() + return input + np.random.randn(*input.shape) + + with FakeTensorMode(): + ep = torch.export.export( + MyNumpyModel(), args=(torch.randn(1000),), strict=True + ) + self.assertTrue(isinstance(ep, torch.export.ExportedProgram)) + + def test_unsqueeze_copy(self): + shape_env = ShapeEnv() + t1 = torch.ones(2, 2, 768) + with FakeTensorMode(shape_env=shape_env) as fake_mode: + t = fake_mode.from_tensor( + t1, + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[ + DimDynamic.DYNAMIC, + DimDynamic.STATIC, + DimDynamic.STATIC, + ], + ), + ) + + self.assertEqual(t.shape[0], torch.ops.aten.unsqueeze_copy(t, 1).shape[0]) + + def test_alias_call(self): + fwAD = torch.autograd.forward_ad + + def f(x): + return 4312491 * x + + with torch._subclasses.fake_tensor.FakeTensorMode(): + with fwAD.dual_level(): + x = torch.randn(3, device="cpu") + y = torch.ones_like(x) + dual = fwAD.make_dual(x, y) + r = f(dual) + + self.assertIsInstance(r, FakeTensor) + self.assertEqual(r.size(), [3]) + + @parametrize("reverse", [False, True]) + def test_scan(self, reverse): + def add(x, y): + return x + y, x + y + + with torch._subclasses.fake_tensor.FakeTensorMode(): + x = torch.randn((3, 5, 7), device="cpu") + init = torch.randn((3, 7), device="cpu") + r = scan(add, init, x, dim=1, reverse=reverse) + + self.assertIsInstance(r[0], FakeTensor) + self.assertIsInstance(r[1], FakeTensor) + + def test_fast_div_int_to_float(self): + mode = FakeTensorMode() + with mode: + x = torch.empty(2, 2, device="cpu", dtype=torch.int32) + y = torch.empty(2, 2, device="cpu", dtype=torch.int32) + from torch._subclasses.fake_impls import get_fast_op_impls + + fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor] + z = fast_div(mode, x, y) + self.assertEqual(z.dtype, torch.float32) + + def test_fast_div(self): + mode = FakeTensorMode() + with mode: + x = torch.empty(2, 2, device="cpu", dtype=torch.int32) + from torch._subclasses.fake_impls import get_fast_op_impls + + fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor] + y = fast_div(mode, x, 2) + self.assertEqual(y.dtype, torch.float32) + + def test_nanmean_out(self): + # Regression test to ensure we don't error out. + with torch._subclasses.fake_tensor.FakeTensorMode() as mode: + x = torch.randn(10) + out = torch.empty(()) + torch.nanmean(x, out=out) + + self.assertEqual(out.dtype, x.dtype) + + def test_unbind_copy_out(self): + # Regression test to ensure we don't error out. + with torch._subclasses.fake_tensor.FakeTensorMode() as mode: + eye = torch.eye(3) + out = (torch.zeros(3), torch.zeros(3), torch.zeros(3)) + torch.unbind_copy(eye, out=out) + + self.assertEqual(out[0].dtype, eye.dtype) + self.assertEqual(out[1].dtype, eye.dtype) + self.assertEqual(out[2].dtype, eye.dtype) + + +instantiate_parametrized_tests(FakeTensorTest) + + +def make_propagate_real_tensors_cls(cls): + cls = make_test_cls_with_patches( + cls, + "PropagateRealTensors", + "_propagate_real_tensors", + (torch._functorch.config, "fake_tensor_propagate_real_tensors", True), + xfail_prop="_expected_failure_propagate_real_tensors", + decorator=skipIfTorchDynamo("propagate_real_tensors affects Dynamo"), + ) + cls.__file__ = __file__ + cls.__module__ = __name__ + globals()[cls.__name__] = cls + + +make_propagate_real_tensors_cls(FakeTensorTest) + + +class FakeTensorConstHandling(TestCase): + def assertConst(self, *args): + for arg in args: + self.assertTrue(arg.constant is not None) + + def assertNotConst(self, *args): + for arg in args: + self.assertTrue(arg.constant is None) + + def test_simple(self): + with FakeTensorMode(): + x = torch.tensor(4.0) + self.assertEqual(x.item(), 4.0) + + def test_inplace_add(self): + with FakeTensorMode(): + x = torch.tensor(4.0) + y = x.add_(1) + self.assertEqual(x.item(), 5.0) + self.assertEqual(y.item(), 5.0) + self.assertConst(x, y) + + def test_shared_storages(self): + with FakeTensorMode(): + x = torch.tensor([4.0]) + y = x[:] + + self.assertEqual(x.storage()._cdata, y.storage()._cdata) + self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata) + + def test_constant_invalidation(self): + with FakeTensorMode(): + x = torch.tensor([1.0]) + self.assertConst(x) + y = torch.rand([1]) + x.add_(y) + self.assertNotConst(x) + + def test_inplace_view_invalidation(self): + with FakeTensorMode(): + x = torch.tensor([1]) + self.assertConst(x) + x.resize_([2]) + self.assertEqual(x.size(0), 2) + self.assertNotConst(x) + + def test_fake_tensor_in_intlist_repro(self): + def fn(tensors): + max_size = torch.tensor([800, 1216], dtype=torch.int64) + batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size) + return tensors[0].new_full(batch_shape, 0.0) + + with self.assertRaises( + torch._subclasses.fake_tensor.DataDependentOutputException + ): + with torch._subclasses.fake_tensor.FakeTensorMode(): + a = torch.randn(3, 800, 1199) + b = torch.randn(3, 800, 800) + inputs = [a, b] + ref = fn(inputs) + + def test_fake_tensor_batch_norm_cpu(self): + with torch._subclasses.CrossRefFakeMode(): + m = torch.nn.Sequential( + torch.nn.BatchNorm2d(10), + torch.nn.ReLU(), + ) + m.eval() + out = m(torch.randn([2, 10, 8, 8])) + + def test_shared_storage_invalidation(self): + with FakeTensorMode(): + x = torch.tensor([1.0]) + y = x[:] + self.assertConst(x, y) + y.add_(torch.rand([1])) + self.assertNotConst(x, y) + + def test_aliased_const_write(self): + with FakeTensorMode(): + x = torch.tensor([1]) + y = x.expand([4]) + self.assertNotConst(y) + y[0] = 1 + self.assertNotConst(x) + + def test_constant_propagate_through_functions(self): + with FakeTensorMode(): + y = torch.div(4, 4, rounding_mode="trunc") + self.assertConst(y) + + +make_propagate_real_tensors_cls(FakeTensorConstHandling) + + +def contains_type(type: torch.Type, maybe_contained_type: torch.Type): + return maybe_contained_type.isSubtypeOf(type) or any( + contains_type(e, maybe_contained_type) for e in type.containedTypes() + ) + + +class FakeTensorOpInfoTest(TestCase): + @ops(custom_op_db, dtypes=OpDTypes.any_one) + def test_fake(self, device, dtype, op): + sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) + for sample_input in sample_inputs_itr: + args = (sample_input.input,) + sample_input.args + kwargs = sample_input.kwargs + optests.fake_check(op, args, kwargs) + + +make_propagate_real_tensors_cls(FakeTensorOpInfoTest) +instantiate_device_type_tests( + FakeTensorOpInfoTest, globals(), only_for=("cpu", device_type) +) +instantiate_device_type_tests( + PropagateRealTensorsFakeTensorOpInfoTest, # noqa: F821 + globals(), + only_for=("cpu",), +) + + +class FakeTensorConverterTest(TestCase): + def test_memoized_conversion_to_meta(self): + x = torch.rand(2, 2, 2) + mode = FakeTensorMode() + self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x)) + + def test_memoized_conversion_from_meta(self): + x = torch.rand(2, 2).to(device="meta") + mode = FakeTensorMode() + converter = mode.fake_tensor_converter + self.assertTrue( + converter.from_meta_and_device(mode, x, "cpu") + is converter.from_meta_and_device(mode, x, "cpu") + ) + + def test_separate_tensor_storages_view(self): + x = torch.rand(2, 2, 2) + y = x[0] + mode = FakeTensorMode() + converter = mode.fake_tensor_converter + x_conv = converter.from_real_tensor(mode, x) + y_conv = converter.from_real_tensor(mode, y) + self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv)) + + @xfailIfTorchDynamo + def test_separate_tensor_storages_non_view(self): + x = torch.rand(2, 2, 2) + y = torch.rand(4, 2) + y.set_(x.storage()) + mode = FakeTensorMode() + converter = mode.fake_tensor_converter + x_conv = converter.from_real_tensor(mode, x) + y_conv = converter.from_real_tensor(mode, y) + stor_id = torch._C._storage_id(x_conv) + self.assertEqual(stor_id, torch._C._storage_id(y_conv)) + del x + del x_conv + self.assertEqual(len(converter.tensor_memo), 1) + self.assertEqual(len(converter.meta_converter.storage_memo), 1) + del y + del y_conv + self.assertEqual(len(converter.tensor_memo), 0) + self.assertEqual(len(converter.meta_converter.storage_memo), 0) + + def test_dead_weak_ref(self): + x = torch.rand(2, 2, 2) + y = x[0] + mode = FakeTensorMode() + converter = FakeTensorConverter() + x_conv = converter.from_real_tensor(mode, x) + x_conv_storage = x_conv.untyped_storage() + del x_conv + self.assertFalse(x in converter.tensor_memo) + y_conv = converter.from_real_tensor(mode, y) + self.assertIs(x_conv_storage, y_conv.untyped_storage()) + + @xfailIfTorchDynamo + def test_dead_key(self): + x = torch.rand(2, 2, 2) + mode = FakeTensorMode() + converter = FakeTensorConverter() + x_conv = converter.from_real_tensor(mode, x) + self.assertEqual(len(converter.tensor_memo), 1) + x_conv2 = converter.from_real_tensor(mode, x) + assert x_conv2 is x_conv + del x + del x_conv + del x_conv2 + self.assertEqual(len(converter.tensor_memo), 0) + + def test_no_active_mode(self): + with FakeTensorMode() as mode: + x = torch.empty(2, 2, device="cpu") + y = torch.empty(2, 2, device="cpu") + + out = x + y + self.assertEqual(mode, out.fake_mode) + self.assertTrue(isinstance(out, FakeTensor)) + self.assertEqual(out.device.type, "cpu") + + def test_multiple_modes(self): + t = torch.rand([4]) + t2 = torch.rand([4]) + with FakeTensorMode() as m: + with FakeTensorMode() as m2: + t_fake = m.from_tensor(t) + t2_fake = m2.from_tensor(t2) + + with self.assertRaisesRegex(Exception, "Mixing fake modes"): + t_fake + t2_fake + + def test_separate_mode_error(self): + with FakeTensorMode(): + x = torch.empty(2, 2, device="cpu") + with FakeTensorMode(): + y = torch.empty(2, 2, device="cpu") + self.assertRaises(Exception, lambda: x, y) + + @xfailIfTorchDynamo + def test_no_ref_cycle(self): + x = torch.rand([4]) + mode = FakeTensorMode() + y = mode.from_tensor(x) + self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1) + mode_weak = weakref.ref(mode) + y_weak = weakref.ref(mode) + del mode + del y + assert mode_weak() is None + assert y_weak() is None + + +make_propagate_real_tensors_cls(FakeTensorConverterTest) + + +class FakeTensorOperatorInvariants(TestCase): + def get_aten_op(self, schema): + namespace, name = schema.name.split("::") + overload = schema.overload_name if schema.overload_name else "default" + assert namespace == "aten" + return getattr(getattr(torch.ops.aten, name), overload) + + def get_all_aten_schemas(self): + for schema in torch._C._jit_get_all_schemas(): + namespace = schema.name.split("::")[0] + if namespace != "aten": + continue + yield schema + + def test_non_kwarg_only_device(self): + for schema in self.get_all_aten_schemas(): + ten_type = torch._C.TensorType.get() + if not any( + contains_type(arg.type, ten_type) + for arg in itertools.chain(schema.arguments, schema.returns) + ): + continue + + opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get()) + has_non_kwarg_device = any( + not arg.kwarg_only and arg.type.isSubtypeOf(opt_device) + for arg in schema.arguments + ) + if has_non_kwarg_device: + self.assertTrue( + self.get_aten_op(schema) + in torch._subclasses.fake_tensor._device_not_kwarg_ops + ) + + def test_tensor_constructors_all_have_kwarg_device(self): + for schema in self.get_all_aten_schemas(): + op = self.get_aten_op(schema) + if not torch._subclasses.fake_tensor._is_tensor_constructor(op): + continue + + opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get()) + has_kwarg_device = any( + arg.kwarg_only and arg.type.isSubtypeOf(opt_device) + for arg in schema.arguments + ) + + self.assertTrue( + has_kwarg_device or op == torch.ops.aten._list_to_tensor.default + ) + + @unittest.expectedFailure + def test_sparse_new(self): + with FakeTensorMode(): + indices = torch.randn(1, 1, dtype=torch.int64) + values = torch.randn(1) + extra = (2,) + sparse = torch.randn(1).to_sparse() + # This used to segfault, now it does not, but it still raises an + # error + sparse2 = sparse.new(indices, values, extra) + + def test_tensor_new(self): + with FakeTensorMode(): + x = torch.Tensor([1, 2, 3]) + self.assertIsInstance(x, FakeTensor) + + def test_like_ops(self): + for schema in self.get_all_aten_schemas(): + if "_like" == schema.name[-5:]: + op = self.get_aten_op(schema) + self.assertIn( + op, torch._subclasses.fake_tensor._like_tensor_constructors + ) + + def test_str_storage(self): + x = torch.zeros(3) + with FakeTensorMode() as m: + y = m.from_tensor(x) + self.assertExpectedInline( + str(x.storage()), + """\ + 0.0 + 0.0 + 0.0 +[torch.storage.TypedStorage(dtype=torch.float32, device=cpu) of size 3]""", + ) + self.assertExpectedInline( + str(y.storage()), + """\ +... +[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""", + ) + + self.assertExpectedInline( + str(y.storage()), + """\ +... +[torch.storage.TypedStorage(dtype=torch.float32, device=meta) of size 3]""", + ) + + # at::_embedding_bag has no op info, + # and returns extra tensors that at::embedding bag throws away + def test_embedding_bag_private(self): + args = [ + torch.ones(6, 1), + torch.ones(6, dtype=torch.int64), + torch.arange(2, dtype=torch.int64), + False, + 2, # mode = max + ] + + ref_out = torch.ops.aten._embedding_bag(*args) + with FakeTensorMode() as m: + meta_args = [ + m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args + ] + meta_out = torch.ops.aten._embedding_bag(*meta_args) + + self.assertEqual(len(ref_out), len(meta_out)) + for ref_o, meta_o in zip(ref_out, meta_out): + self.assertEqual(ref_o.size(), meta_o.size()) + + def test_cross_entropy_loss(self): + inp = torch.randn(3, 5) + target = torch.randint(5, (3,), dtype=torch.long) + weight = torch.rand(5) + fn = torch.nn.functional.cross_entropy + for w in (weight, None): + args = (inp, target, w) + ref = fn(*args) + with FakeTensorMode() as m: + meta_args = [ + m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args + ] + meta_out = torch.nn.functional.cross_entropy( + *meta_args, label_smoothing=0.5 + ) + + self.assertEqual(ref.size(), meta_out.size()) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Does not support SDPA or pre-SM80 hardware", + ) + def test_flash_attention(self): + class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, arg1, arg2, arg3): + torch.ops.aten._scaled_dot_product_flash_attention( + arg1, arg2, arg3, scale=0.17677669529663687 + ) + + args_new = [ + [ + ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, device_type), + ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, device_type), + ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, device_type), + ], + [ + ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, device_type), + ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, device_type), + ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, device_type), + ], + ] + for args_list in args_new: + args = [ + rand_strided(bsz, num_heads, seq_len, head_dim) + for (bsz, num_heads, seq_len, head_dim) in args_list + ] + try: + with torch._subclasses.CrossRefFakeMode(): + Repro()(*args) + except MetadataMismatchError as e: + # We expect the cross ref to succeed for the first output to fail + # for the rng state, see Note [Seed and Offset] + self.assertTrue("output[0]" not in str(e)) + if self.__class__.__name__.startswith("PropagateRealTensors"): + self.assertTrue( + "Real tensor propagation found a metadata mismatch" in str(e) + ) + else: + self.assertTrue( + "found mismatched tensor metadata for output" in str(e) + ) + + # IMPORTANT!!! Always run even if CUDA is not available + def test_fake_gpu_no_init(self): + # Skip this test, we will try to run CUDA operations to real prop so + # it clearly will not work on CPU runner + if torch._functorch.config.fake_tensor_propagate_real_tensors: + self.skipTest("Propagate real tensor not supported") + + with FakeTensorMode(allow_non_fake_inputs=True): + self.assertEqual(torch.empty(10, device=GPU_TYPE).device.type, GPU_TYPE) + self.assertEqual(torch.ones(10, device=GPU_TYPE).device.type, GPU_TYPE) + self.assertEqual(torch.zeros(10, device=GPU_TYPE).device.type, GPU_TYPE) + self.assertEqual(torch.rand(10, device=GPU_TYPE).device.type, GPU_TYPE) + self.assertEqual(torch.tensor(3.14, device=GPU_TYPE).device.type, GPU_TYPE) + self.assertEqual( + torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE).device.type, GPU_TYPE + ) + + @unittest.skipIf(not torch.backends.cuda.is_built(), "requires CUDA build") + def test_move_module_under_fake(self): + if torch._functorch.config.fake_tensor_propagate_real_tensors: + self.skipTest("Propagate real tensor not supported") + + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.buffer = torch.nn.Buffer(torch.rand(2, 2)) + self.param = torch.nn.Parameter(torch.rand(2, 2)) + + def forward(self, x): + return self.linear(x) + self.buffer + self.param + + m = Module() + input = torch.rand(2, 2) + gpu_device = torch.device(GPU_TYPE, 0) + + with FakeTensorMode(allow_non_fake_inputs=True): + m.to(device=gpu_device) + arg = input.to(device=gpu_device) + out = m(arg) + + for p in m.parameters(): + self.assertTrue(isinstance(p, FakeTensor)) + self.assertEqual(p.device, gpu_device) + for b in m.buffers(): + self.assertTrue(isinstance(b, FakeTensor)) + self.assertEqual(b.device, gpu_device) + + self.assertTrue(isinstance(out, FakeTensor)) + self.assertEqual(out.device, gpu_device) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_move_meta_tensor(self): + if torch._functorch.config.fake_tensor_propagate_real_tensors: + self.skipTest("Propagate real tensor not supported") + + meta_tensor = torch.ones(2, device="meta") + with FakeTensorMode(allow_non_fake_inputs=True): + self.assertEqual(meta_tensor.to(device="cpu").device.type, "cpu") + self.assertEqual(meta_tensor.to(device=GPU_TYPE).device.type, GPU_TYPE) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_conv_c1_backward(self): + class Repro(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, arg1, arg2, arg3): + torch.ops.aten.convolution_backward.default( + arg1, + arg2, + arg3, + [1], + [1, 1], + [1, 1], + [1, 1], + False, + [0, 0], + 1, + [True, True, False], + ) + + args_new = [ + ((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, device_type), + ((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, device_type), + ((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, device_type), + ] + args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new] + + with torch._subclasses.CrossRefFakeMode(): + Repro()(*args) + + def test_no_dispatch_with_like_function(self): + class CountingMode(TorchDispatchMode): + def __init__(self) -> None: + self.count = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + self.count += 1 + return func(*args, **kwargs) + + with FakeTensorMode(): + x = torch.randn(2) + with CountingMode() as mode: + with no_dispatch(): + torch.zeros_like(x) + + self.assertEqual(mode.count, 0) + + # PropagateRealTensors installs weakrefs + @expectedFailurePropagateRealTensors + @unittest.skipIf(not RUN_CUDA, "requires cuda") + def test_module_to(self): + def _check_device(sd, device_type): + for v in sd.values(): + self.assertEqual(v.device.type, device_type) + + with FakeTensorMode(): + m = torch.nn.Linear(2, 2) + _check_device(m.state_dict(), "cpu") + m.to("cuda") + _check_device(m.state_dict(), "cuda") + + +make_propagate_real_tensors_cls(FakeTensorOperatorInvariants) + + +class FakeTensorPropTest(TestCase): + def test_fake_tensor_prop_on_nn_module(self): + class ToyNnModuleWithParameters(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer1 = torch.nn.Linear(4, 3) + self.layer2 = torch.nn.Linear(3, 2) + + def forward(self, value): + value = self.layer1(value) + value = torch.relu(value) + value = self.layer2(value) + return value + + model = ToyNnModuleWithParameters() + value = torch.randn(5, 4) + # Convert nn.Module to GraphModule so that FakeTensorProp runs. + graph_model = torch.fx.symbolic_trace(model, (value,)) + # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode + # + # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule + # with parameters and buffers. + with FakeTensorMode() as fake_tensor_mode: + + def to_fake_tensor(x): + if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor): + return fake_tensor_mode.from_tensor(x) + return x + + fake_parameters_and_buffers = { + k: to_fake_tensor(v) + for k, v in itertools.chain( + graph_model.named_parameters(), graph_model.named_buffers() + ) + } + with torch.nn.utils.stateless._reparametrize_module( + graph_model, fake_parameters_and_buffers + ): + # This case uses the **same** fake tensor mode to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The result should be correct. + result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value) + self.assertTrue(isinstance(result, FakeTensor)) + self.assertEqual(result.shape, (5, 2)) + # This case uses the **different** fake tensor modes to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The following code should fail. + failed = False + try: + FakeTensorProp(graph_model).propagate(value) + except AssertionError: + # AssertionError: tensor's device must be `meta`, got cpu instead + failed = True + self.assertTrue(failed) + + def test_fake_tensor_prop_on_nn_module_with_optional_args(self): + class OptionalArgumentInBetween(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer1 = torch.nn.Linear(4, 3) + self.layer2 = torch.nn.Linear(3, 2) + + def forward(self, value, another_value=None, another_optional_value=None): + # Mimic huggingface's `forward` methods which have several optional arguments. + # For example, GPT accepts forward(self, input_ids, None, attention_mask, ...). + # To apply FakeTensorProp, its from_real_tensor(...) needs to accept None. + if another_value is None: + another_value = torch.rand_like(value) + if another_optional_value is None: + another_optional_value = torch.rand_like(value) + value = value + another_value + another_optional_value + return value * value + + fake_mode = FakeTensorMode( + allow_non_fake_inputs=True, allow_fallback_kernels=False + ) + with fake_mode: + model = OptionalArgumentInBetween() + value = torch.randn(5, 4) + another_optional_value = torch.randn(5, 4) + graph_model = torch.fx.symbolic_trace( + model, (value, None, another_optional_value) + ) + FakeTensorProp(graph_model, fake_mode).propagate( + value, None, another_optional_value + ) + + def test_unbacked_shape_realloc(self): + def f(x): + return x.nonzero() + + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + with fake_mode: + value = torch.randn(5) + gm = make_fx(f)(value) + nonzero_nodes = [ + n for n in gm.graph.nodes if n.target is torch.ops.aten.nonzero.default + ] + self.assertEqual(len(nonzero_nodes), 1) + self.assertIsInstance(nonzero_nodes[0].meta["val"].shape[0], torch.SymInt) + u0 = nonzero_nodes[0].meta["val"].shape[0] + FakeTensorProp(gm, fake_mode).propagate(value) + u1 = nonzero_nodes[0].meta["val"].shape[0] + # Test that this test is actually doing something in that the + # FakeTensorProp actually triggered a reallocation. If this assert is + # failing, it could be because we started memoizing the nnz count for + # nonzero, which is nice in some sense (no reallocation) but not + # helpful for this test, which is checking what we do when we have + # to reallocate. If so, you need to make this example more + # complicated (e.g., maybe have a nontrivial computation on the input + # before feeding it into nonzero, or have some sort of randomness) + self.assertIsNot(u0, u1) + self.assertTrue(statically_known_true(u0 == u1)) + + def test_nonzero_stride(self): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + with fake_mode: + value = torch.ones(5) + fake_r = value.nonzero() + + r = torch.ones(5).nonzero() + + self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous()) + + def test_nan_to_num(self): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + with fake_mode: + x = torch.randn(5, 10).t() + y = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + + self.assertEqual(x.size(), y.size()) + self.assertEqual(x.stride(), y.stride()) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_torch_load_with_fake_mode(self): + model = torch.nn.Linear(5, 10) + sd = model.state_dict() + sd["tt"] = TwoTensor(torch.randn(2), torch.randn(2)) + + def _read_tensor_and_check(key, sd_loaded, all_bytes, device): + dtype = torch.float32 + t = sd_loaded[key] + self.assertEqual(t.device.type, device) + if isinstance(t, TwoTensor): + untyped_storage_a, untyped_storage_b = ( + t.a.untyped_storage(), + t.b.untyped_storage(), + ) + offset_a, offset_b = ( + untyped_storage_a._checkpoint_offset, + untyped_storage_b._checkpoint_offset, + ) + nbytes_a, nbytes_b = ( + untyped_storage_a.nbytes() // 4, + untyped_storage_b.nbytes() // 4, + ) + result_a = torch.frombuffer( + all_bytes, dtype=dtype, count=nbytes_a, offset=offset_a + ).resize_(t.a.size()) + result_b = torch.frombuffer( + all_bytes, dtype=dtype, count=nbytes_b, offset=offset_b + ).resize_(t.b.size()) + self.assertEqual(TwoTensor(result_a, result_b), sd[key]) + else: + untyped_storage = t.untyped_storage() + offset = untyped_storage._checkpoint_offset + nbytes = untyped_storage.nbytes() // 4 + result = torch.frombuffer( + all_bytes, dtype=dtype, count=nbytes, offset=offset + ).resize_(t.size()) + self.assertEqual(result, sd[key]) + + with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]): + # Create state_dict to be loaded later + torch.save(sd, f) + with open(f, "rb") as g: + all_bytes = g.read() + + fake_mode = FakeTensorMode() + with fake_mode: + sd_loaded = torch.load(f) + for k in sd: + _read_tensor_and_check(k, sd_loaded, all_bytes, "cpu") + with fake_mode: + sd_loaded = torch.load(f, map_location=device_type) + for k in sd: + _read_tensor_and_check(k, sd_loaded, all_bytes, device_type) + + for k in sd: + sd[k] = sd[k].to(device_type) + + with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]): + torch.save(sd, f) + with open(f, "rb") as g: + all_bytes = g.read() + + fake_mode = FakeTensorMode() + with fake_mode: + sd_loaded = torch.load(f) + for k in sd: + _read_tensor_and_check(k, sd_loaded, all_bytes, device_type) + with fake_mode: + sd_loaded = torch.load(f, map_location="cpu") + for k in sd: + _read_tensor_and_check(k, sd_loaded, all_bytes, "cpu") + + +make_propagate_real_tensors_cls(FakeTensorPropTest) + + +class FakeTensorSerialization(TestCase): + def test_serialization(self): + x = torch.tensor([0], device="cpu") + with FakeTensorMode(): + y = pickle.loads(pickle.dumps(x)) + self.assertEqual(type(y), FakeTensor) + self.assertEqual(y.device.type, "meta") + + with unset_fake_temporarily(): + y = pickle.loads(pickle.dumps(x)) + self.assertEqual(x.device, y.device) + + def test_serialization_with_tracing(self): + x = torch.tensor([0], device="cpu") + with tracing(TracingContext(FakeTensorMode())): + y = pickle.loads(pickle.dumps(x)) + self.assertEqual(x.device, y.device) + + +class FakeTensorDispatchCache(TestCase): + def test_shape_env_settings(self): + """ + Validation that any boolean settings in ShapeEnv are present in the + ShapeEnvSettings. We hope to ensure that any new settings that might + affect FakeTensor dispatch are included in the cache key calculation. + If this test fails, consider updating ShapeEnvSettings or change this + test to omit checking for the new field. + """ + init_sig = inspect.signature(ShapeEnv._init) + args = [ + name + for name, param in init_sig.parameters.items() + if type(param.default) is bool + ] + + settings = [f.name for f in dataclasses.fields(ShapeEnvSettings)] + for arg in args: + self.assertTrue(arg in settings) + + def _test_cache_key(self, fm, x, y, z): + """ + Helper for all test_cache_key_* tests below. Assert that the + cache keys for inputs x and y are the same, but z is different. + """ + func = aten.add.Tensor + state = _CacheKeyState() + key_x = fm._cache_key(state, func, [x], {}) + key_y = fm._cache_key(state, func, [y], {}) + key_z = fm._cache_key(state, func, [z], {}) + + self.assertEqual(key_x, key_y) + self.assertNotEqual(key_x, key_z) + + def test_cache_key_dtype(self): + with FakeTensorMode() as fm: + x = torch.randn(4, 3, dtype=torch.float16) + y = torch.randn(4, 3, dtype=torch.float16) + z = x.to(dtype=torch.float32) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_shape(self): + with FakeTensorMode() as fm: + x = torch.randn(4, 3) + y = torch.randn(4, 3) + z = torch.randn(4, 2) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_stride(self): + with FakeTensorMode() as fm: + x = torch.randn(4, 2) + y = torch.randn(4, 2) + z = x.as_strided((4, 2), (1, 2)) + self._test_cache_key(fm, x, y, z) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_cache_key_device(self): + with FakeTensorMode() as fm: + x = torch.randn(4, 3) + y = torch.randn(4, 3) + z = x.to(device=device_type) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_memory_format(self): + with FakeTensorMode() as fm: + x = torch.randn(1, 2, 3, 4) + y = torch.randn(1, 2, 3, 4) + z = x.to(memory_format=torch.channels_last) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_storage_offset(self): + with FakeTensorMode() as fm: + x = torch.randn(3)[1:] + y = torch.randn(3)[1:] + z = torch.randn(2) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_requires_grad(self): + with FakeTensorMode() as fm: + x = torch.randn(4, 3) + y = torch.randn(4, 3) + z = torch.randn(4, 3, requires_grad=True) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_is_conj(self): + with FakeTensorMode() as fm: + x = torch.randn(4, 3, dtype=torch.complex64) + y = torch.randn(4, 3, dtype=torch.complex64) + z = torch.randn(4, 3, dtype=torch.complex64) + torch._C._set_conj(z, not z.is_conj()) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_is_neg(self): + with FakeTensorMode() as fm: + x = torch.randn(4, 3, dtype=torch.complex64) + y = torch.randn(4, 3, dtype=torch.complex64) + z = torch.randn(4, 3, dtype=torch.complex64) + torch._C._set_neg(z, not z.is_neg()) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_is_inference(self): + with torch.inference_mode(True): + t = torch.randn(4, 3) + with FakeTensorMode() as fm: + x = torch.randn(4, 3) + y = torch.randn(4, 3) + z = fm.from_tensor(t) + self._test_cache_key(fm, x, y, z) + + def test_cache_key_constants(self): + with FakeTensorMode() as fm: + # Python hashes 1.0 to the same value as 1. Make sure the + # cache key calculation differentiates them. + self._test_cache_key(fm, 1.0, 1.0, 1) + self._test_cache_key(fm, 0.0, 0.0, 0) + + def test_empty_list(self): + with FakeTensorMode() as fm: + func = aten.any.dims + state = _CacheKeyState() + x = torch.ones((2, 3)) + key_x = fm._cache_key(state, func, [x, []], {}) + key_y = fm._cache_key(state, func, [x], {}) + + self.assertNotEqual(key_x, key_y) + + def assertHitsMisses(self, hits, misses): + """ + Helper to assert on the number of recorded hits and misses. + """ + info = FakeTensorMode.cache_info() + self.assertEqual(info.hits, hits) + self.assertEqual(info.misses, misses) + + def assertBypasses(self, reason, count): + """ + Helper to assert on the number of recorded bypasses. + """ + info = FakeTensorMode.cache_info() + if count > 0: + self.assertIn(reason, info.bypasses) + self.assertEqual(info.bypasses[reason], count) + else: + self.assertNotIn(reason, info.bypasses) + + def test_cache_hit(self): + """ + Test that cache hit/miss counters are updated correctly. + """ + with FakeTensorMode(): + x = torch.randn(4, 3) + y = torch.randn(4, 3) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + res1 = x + y + self.assertHitsMisses(0, 1) + res2 = x + y + self.assertHitsMisses(1, 1) + + self.assertEqual( + extract_tensor_metadata(res1), + extract_tensor_metadata(res2), + ) + + def test_cache_bypass(self): + """ + Test that cache bypass counters are updated correctly. + """ + with FakeTensorMode(): + x = torch.randn(1, 2) + + FakeTensorMode.cache_clear() + self.assertBypasses("inplace view", 0) + + x.unsqueeze_(0) + self.assertBypasses("inplace view", 1) + + def test_cache_default_dtype(self): + """ + Test that the default dtype is respected when serving cached results. + """ + with FakeTensorMode(): + x = torch.tensor([1, 2], dtype=torch.int32) + torch.set_default_dtype(torch.float32) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + y = x + 1.0 + self.assertEqual(y.dtype, torch.float32) + self.assertHitsMisses(0, 1) + + torch.set_default_dtype(torch.float16) + y = x + 1.0 + self.assertEqual(y.dtype, torch.float16) + self.assertHitsMisses(0, 2) + + torch.set_default_dtype(torch.float32) + y = x + 1.0 + self.assertEqual(y.dtype, torch.float32) + self.assertHitsMisses(1, 2) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_cache_default_device(self): + """ + Test that the default device is respected when serving cached results. + """ + with FakeTensorMode(): + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + torch.set_default_device("cpu") + x = torch.tensor([1, 2]) + y = x + 1.0 + self.assertEqual(y.device.type, "cpu") + self.assertHitsMisses(0, 1) + + torch.set_default_device(device_type) + x = torch.tensor([1, 2]) + y = x + 1.0 + self.assertEqual(y.device.type, device_type) + self.assertHitsMisses(0, 2) + + torch.set_default_device("cpu") + x = torch.tensor([1, 2]) + y = x + 1.0 + self.assertEqual(y.device.type, "cpu") + self.assertHitsMisses(1, 2) + + def test_cache_inplace_op(self): + """ + Test that inplace ops served from the cache correctly reference the + input parameter. + """ + with FakeTensorMode(): + x = torch.randn(1, 2) + y = torch.randn(1, 2) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + z = x.add_(y) + self.assertHitsMisses(0, 1) + self.assertEqual(id(x), id(z)) + + w = x.add_(y) + self.assertHitsMisses(1, 1) + self.assertEqual(id(x), id(w)) + + def test_cache_view_op(self): + """ + Test that view ops are handled correctly when served from the cache. + """ + with FakeTensorMode(): + x1 = torch.ones(2, requires_grad=True).clone() + x2 = torch.ones(2, requires_grad=True).clone() + y2 = x2.view(-1) + + # Test operating on a non-view tensor, then the same operation + # on a view tensor. Assert that the view property is set correctly. + z1 = x1.mul_(2) + self.assertFalse(z1._is_view()) + + z2 = y2.mul_(2) + self.assertTrue(z2._is_view()) + + # Now the other way around: first operate on a view tensor, then + # the same operation on a non-view tensor. + z2 = y2.mul_(2) + self.assertTrue(z2._is_view()) + + z1 = x1.mul_(2) + self.assertFalse(z1._is_view()) + + def test_cache_dispatch_key_set(self): + """ + Test that operations that change the dispatch key set bypass caching. + """ + with FakeTensorMode(): + FakeTensorMode.cache_clear() + self.assertBypasses("dispatch_key_set mismatch", 0) + + x = torch._efficientzerotensor(3) + self.assertTrue(x._is_zerotensor()) + self.assertBypasses("dispatch_key_set mismatch", 1) + + y = torch._efficientzerotensor(3) + self.assertTrue(y._is_zerotensor()) + self.assertBypasses("dispatch_key_set mismatch", 2) + + def test_fft_hfft2_issue145522(self): + with FakeTensorMode(): + s0 = 5 + s1 = 6 + s2 = 7 + s3 = 3 + s4 = 10 + s5 = 2 + x = torch.randn(s0, s1, s2) + out = torch.randn(s0, s3, s4) + kwargs = { + "s": (s3, s4), + "dim": (1, s5), + "norm": "ortho", + } + r = torch._C._fft.fft_hfft2(x, **kwargs, out=out) + self.assertEqual(r.shape, out.shape) + + def test_inference_mode(self): + """ + Test that caching handles inference mode correctly. + """ + with FakeTensorMode(): + x = torch.randn(4, 3) + y = torch.randn(4, 3) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + # Expect a miss when the inference mode is different + res1 = x + y + with torch.inference_mode(): + res2 = x + y + + self.assertHitsMisses(0, 2) + self.assertFalse(res1.is_inference()) + self.assertTrue(res2.is_inference()) + + # Second tries should see hits + res3 = x + y + + self.assertHitsMisses(1, 2) + self.assertFalse(res3.is_inference()) + self.assertEqual( + extract_tensor_metadata(res1), + extract_tensor_metadata(res3), + ) + + with torch.inference_mode(): + res4 = x + y + + self.assertHitsMisses(2, 2) + self.assertTrue(res4.is_inference()) + self.assertEqual( + extract_tensor_metadata(res2), + extract_tensor_metadata(res4), + ) + + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_wrapper_tensor_subclass_different_device(self): + class DifferentDeviceTensor(torch.Tensor): + @staticmethod + def __new__(cls, a): + kwargs = {} + kwargs["strides"] = a.stride() + kwargs["storage_offset"] = a.storage_offset() + kwargs["device"] = torch.device("cpu") + kwargs["layout"] = a.layout + kwargs["requires_grad"] = a.requires_grad + kwargs["dtype"] = a.dtype + out = torch.Tensor._make_wrapper_subclass(cls, a.size(), **kwargs) + return out + + def __init__(self, a): + self.inner_tensor = a + + def __repr__(self): + return f"DifferentDeviceTensor({repr(self.inner_tensor)})" + + def __tensor_flatten__(self): + return ["inner_tensor"], None + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + assert meta is None + return DifferentDeviceTensor(inner_tensors["inner_tensor"]) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + args = pytree.tree_map_only( + DifferentDeviceTensor, lambda x: x.inner_tensor, args + ) + kwargs = pytree.tree_map_only( + DifferentDeviceTensor, lambda x: x.inner_tensor, kwargs + ) + # Returns unwrapped tensor + return func(*args, **kwargs) + + a = torch.ones(2, 2, 768, device=device_type) + wrapped_a = DifferentDeviceTensor(a) + + # Outer Tensor is on cpu, inner is on cuda + self.assertTrue(wrapped_a.is_cpu) + self.assertFalse(wrapped_a.inner_tensor.is_cpu) + + with FakeTensorMode() as fake_mode: + fake_wrapped_a = fake_mode.from_tensor(wrapped_a) + + self.assertTrue(fake_wrapped_a.is_cpu) + assert isinstance(fake_wrapped_a, DifferentDeviceTensor) + self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu) + + def test__upsample_bilinear2d_aa_backward_dynamic_shapes(self): + def f(x): + return torch.nn.functional.interpolate( + x, + size=[256, 256], + mode="bilinear", + align_corners=False, + antialias=True, + ) + + shape_env = ShapeEnv() + fake_m = FakeTensorMode(shape_env=shape_env) + x = fake_m.from_tensor( + torch.randn(1, 3, 2005, 1920, requires_grad=True), + symbolic_context=StatelessSymbolicContext( + dynamic_sizes=[ + DimDynamic.STATIC, + DimDynamic.STATIC, + DimDynamic.DYNAMIC, + DimDynamic.DYNAMIC, + ], + constraint_sizes=[None, None, None, None], + ), + ) + with fake_m, enable_python_dispatcher(): + out = f(x) + out.sum().backward() + self.assertEqual(x.shape, x.grad.shape) + + def test_from_buffer(self): + with FakeTensorMode(): + obj = [1, 2] + f = io.BytesIO() + pickle.Pickler(f).dump(obj) + storage = torch.UntypedStorage.from_buffer(f.getvalue(), dtype=torch.uint8) + + t = torch.ByteTensor(storage) + self.assertTrue(isinstance(t, FakeTensor)) + self.assertEqual(t.device, torch.device("cpu")) + + def test_meta_tensor_to_fake_cpu(self): + x = torch.randn(4, 4, device="meta") + with FakeTensorMode(allow_non_fake_inputs=True): + x_cpu = x.to(device="cpu") + self.assertTrue(isinstance(x_cpu, FakeTensor)) + self.assertEqual(x_cpu.device, torch.device("cpu")) + + def test_cache_tuple_outputs(self): + """ + Test to check that ops with tuple outputs work. + """ + with FakeTensorMode(): + x = torch.randn(6, 4) + y = torch.randn(6, 4) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + ref = torch.split(x, 2) + self.assertHitsMisses(0, 1) + + res = torch.split(y, 2) + self.assertHitsMisses(1, 1) + self.assertEqual(len(ref), len(res)) + for a, b in zip(ref, res): + self.assertEqual( + extract_tensor_metadata(a), + extract_tensor_metadata(b), + ) + + def test_cache_aten_index(self): + with FakeTensorMode(): + x = torch.randn(4, 4, 4) + idx_tensor1 = torch.tensor([0, 2, 3]) + idx_tensor2 = torch.tensor([0, 1, 2]) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + ref = torch.ops.aten.index(x, [None, idx_tensor1, idx_tensor2]) + self.assertHitsMisses(0, 3) + + res = torch.ops.aten.index(x, [None, idx_tensor1, idx_tensor2]) + self.assertHitsMisses(1, 3) + self.assertEqual(extract_tensor_metadata(ref), extract_tensor_metadata(res)) + + with FakeTensorMode(): + x = torch.randn(4, 4, 4) + idx_tensor1 = torch.tensor([True, True, False, True]) + self.assertRaises( + DynamicOutputShapeException, + lambda: torch.ops.aten.index(x, [None, idx_tensor1]), + ) + + idx_tensor1 = torch.tensor([1, -2, 3, -4], dtype=torch.int8) + self.assertRaises( + DynamicOutputShapeException, + lambda: torch.ops.aten.index(x, [None, idx_tensor1]), + ) + + @skipIfWindows( + msg="weird bug - cache may not be cleared after https://github.com/pytorch/pytorch/pull/154283" + ) + @skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching") + def test_invoke_subgraph(self): + """ + Tests invoke subgraph + """ + invoke_subgraph = torch._higher_order_ops.invoke_subgraph + + def run(): + def fn(x, y): + return (x + y * 2,) + + # Ensure there is no caching for non-Fx graph module inputs + with FakeTensorMode(): + x = torch.randn(6, 4) + y = torch.randn(6, 4) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + ref = invoke_subgraph(fn, "subgraph", x, y) + self.assertHitsMisses(0, 2) + self.assertBypasses("function argument", 1) + + res = invoke_subgraph(fn, "subgraph", x, y) + # The hits are from the ops inside fn + self.assertHitsMisses(2, 2) + self.assertBypasses("function argument", 2) + + res = invoke_subgraph(fn, "subgraph", x, y) + # The hits are from the ops inside fn + self.assertHitsMisses(4, 2) + self.assertBypasses("function argument", 3) + + # Get the mod as if its going through torch.compile + backend = torch._dynamo.testing.AotEagerAndRecordGraphs() + x = torch.randn(6, 4) + y = torch.randn(6, 4) + torch.compile(fn, backend=backend, fullgraph=True)(x, y) + self.assertEqual(len(backend.fw_graphs), 1) + mod = backend.fw_graphs[0] + + # Ensure that we see hits every time + with FakeTensorMode(): + x = torch.randn(6, 4) + y = torch.randn(6, 4) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + ref = invoke_subgraph(mod, "subgraph", x, y) + self.assertHitsMisses(0, 3) + + res = invoke_subgraph(mod, "subgraph", x, y) + # The hits are from re-running the subgraph + self.assertHitsMisses(1, 3) + + res = invoke_subgraph(mod, "subgraph", x, y) + # The hits are from re-running the subgraph + self.assertHitsMisses(2, 3) + + self.assertEqual(len(ref), len(res)) + self.assertEqual(len(ref), len(res)) + for a, b in zip(ref, res): + self.assertEqual( + extract_tensor_metadata(a), + extract_tensor_metadata(b), + ) + self.assertTrue(count_invoke_subgraph_keys() > 0) + + def count_invoke_subgraph_keys(): + invoke_subgraph_keys = 0 + for cache_key in FakeTensorMode.cache: + if isinstance(cache_key.key[0], torch._ops.HigherOrderOperator): + invoke_subgraph_keys += 1 + return invoke_subgraph_keys + + # Check that the graph gc clears the cache + run() + torch.compiler.reset() + gc.collect() + self.assertTrue(count_invoke_subgraph_keys() == 0) + + @skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching") + def test_invoke_subgraph_cacheable_inplace(self): + invoke_subgraph = torch._higher_order_ops.invoke_subgraph + + def fn(x, y): + # aten ops are used so that eager backend graph is suitable for fake + # tensor testing + cos = torch.ops.aten.cos.default(x) + # inplace-view - this should cause the whole invoke_subgraph to not + # being able to cache + t = torch.ops.aten.t_.default(cos) + mul = torch.ops.aten.mul.Tensor(t, y) + return (mul,) + + # Get the mod as if its going through torch.compile + backend = torch._dynamo.testing.AotEagerAndRecordGraphs() + x = torch.randn(4, 4) + y = torch.randn(4, 4) + torch.compile(fn, backend=backend, fullgraph=True)(x, y) + self.assertEqual(len(backend.graphs), 1) + mod = backend.graphs[0] + + # Ensure that invoke_subgraph result is still cached + with FakeTensorMode(): + x = torch.randn(4, 4) + y = torch.randn(4, 4) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + ref = invoke_subgraph(mod, "subgraph", x, y) + self.assertHitsMisses(0, 3) + + res = invoke_subgraph(mod, "subgraph", x, y) + # The hits are from the ops inside fn and not the subgraph + self.assertHitsMisses(1, 3) + + res = invoke_subgraph(mod, "subgraph", x, y) + # The hits are from the ops inside fn and not the subgraph + self.assertHitsMisses(2, 3) + + self.assertEqual(len(ref), len(res)) + self.assertEqual(len(ref), len(res)) + for a, b in zip(ref, res): + self.assertEqual( + extract_tensor_metadata(a), + extract_tensor_metadata(b), + ) + + @skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching") + def test_unbacked_output(self): + # The point of this test is to have an op which has no symbols as input + # but a symbol as an output and make sure that we skip caching it. + class LengthsGather(torch.nn.Module): + def forward( + self, + input: torch.Tensor, + lengths: torch.Tensor, + indices: torch.Tensor, + offsets: torch.Tensor, + ) -> torch.Tensor: + bias = torch.gather(offsets, 0, indices) + lengths_selected = torch.gather(lengths, 0, indices) + index = torch.repeat_interleave(bias, lengths_selected, dim=0) + return index + + input = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + lengths = torch.tensor([0, 2, 3, 1, 4]) + indices = torch.tensor([2, 3, 4, 6, 7, 8, 9]) + offsets = torch.cumsum(lengths, 0) + ep = torch.export.export( + LengthsGather(), (input, lengths, indices, offsets), strict=False + ) + + FakeTensorMode.cache_clear() + ep.run_decompositions({}) + self.assertBypasses("unrepresented symbol in output", 2) + + +class FakeTensorPreferDeviceType(TestCase): + @unittest.skipIf(not RUN_CUDA and not TEST_XPU, "requires cuda or xpu") + def test_fake_tensor_prefer_device_type(self): + """ + Test that fake_tensor_prefer_device_type configuration works correctly + for device mismatch scenarios. + """ + + # Create a custom operation that would normally cause device mismatch + def mixed_device_op(a, b): + # This simulates an operation where 'a' is on MTIA/CUDA but 'b' is created on CPU + cpu_tensor = torch.arange(a.shape[0], device="cpu") + return a + cpu_tensor.unsqueeze(-1) + + with FakeTensorMode(): + # Test default behavior (should raise error on device mismatch) + cuda_tensor = torch.randn(3, 4, device=device_type) + + # Without the config, this should raise a device mismatch error + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation" + ): + mixed_device_op(cuda_tensor, None) + + # Test with prefer_device_type set to device_type + with torch._functorch.config.patch(fake_tensor_prefer_device_type=device_type): + with FakeTensorMode(): + cuda_tensor = torch.randn(3, 4, device=device_type) + + # This should now work and prefer the CUDA device + result = mixed_device_op(cuda_tensor, None) + + # The result should be on CUDA device (preferred device type) + self.assertEqual(result.device.type, device_type) + self.assertEqual(result.shape, (3, 4)) + self.assertTrue(isinstance(result, FakeTensor)) + + # Test that the configuration doesn't affect normal operations + with torch._functorch.config.patch(fake_tensor_prefer_device_type=device_type): + with FakeTensorMode(): + # Normal same-device operations should work as before + x = torch.randn(2, 3, device=device_type) + y = torch.randn(2, 3, device=device_type) + result = x + y + self.assertEqual(result.device.type, device_type) + + # CPU operations should still work + x_cpu = torch.randn(2, 3, device="cpu") + y_cpu = torch.randn(2, 3, device="cpu") + result_cpu = x_cpu + y_cpu + self.assertEqual(result_cpu.device.type, "cpu") + + # Test that the configuration is properly scoped + with FakeTensorMode(): + cuda_tensor = torch.randn(3, 4, device=device_type) + + # After exiting the config context, should raise error again + with self.assertRaisesRegex( + RuntimeError, "Unhandled FakeTensor Device Propagation" + ): + mixed_device_op(cuda_tensor, None) + + def test_fake_tensor_prefer_device_type_cpu_only(self): + """ + Test that fake_tensor_prefer_device_type works correctly when only CPU tensors are involved. + """ + with torch._functorch.config.patch(fake_tensor_prefer_device_type=device_type): + with FakeTensorMode(): + # When all tensors are CPU, the result should still be CPU + x = torch.randn(2, 3, device="cpu") + y = torch.randn(2, 3, device="cpu") + result = x + y + self.assertEqual(result.device.type, "cpu") + self.assertTrue(isinstance(result, FakeTensor)) + + +if __name__ == "__main__": + run_tests()