diff --git a/.github/workflows/_linux_ut.yml b/.github/workflows/_linux_ut.yml index dc10647124..1156d04396 100644 --- a/.github/workflows/_linux_ut.yml +++ b/.github/workflows/_linux_ut.yml @@ -164,7 +164,7 @@ jobs: # get distributed known issues gh --repo intel/torch-xpu-ops issue view $UT_SKIP_ISSUE --json body -q .body |sed -E '/^(#|$)/d' > Known_issue.log.tmp # get skipped known issues - count=$(gh api --paginate "repos/${{ github.repository }}/issues?labels=skipped" --jq 'length') + count=$(gh api "search/issues?q=repo:$GITHUB_REPOSITORY+label:skipped+state:open" --jq '.total_count') if [ "$count" -gt 0 ]; then echo -e "$count issues with skipped label found" gh api --paginate "repos/${{ github.repository }}/issues?labels=skipped" \ diff --git a/src/ATen/native/transformers/Attention.cpp b/src/ATen/native/transformers/Attention.cpp index f72dcb2d8b..08b45fda83 100644 --- a/src/ATen/native/transformers/Attention.cpp +++ b/src/ATen/native/transformers/Attention.cpp @@ -13,11 +13,14 @@ #include #include #include +#include +#include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else +#include #include #include #include @@ -304,5 +307,59 @@ std::tuple native_multi_head_attention_xpu( return std::make_tuple(std::move(proj), std::move(qkt)); } +/** + * Get the mask for dropout. only used for testing, not much + * attention is paid to performance + */ +at::Tensor& _fill_mem_eff_dropout_mask_( + Tensor& self, + double dropout_p, + const int64_t seed, + const int64_t offset) { + auto mask = std::get<1>(xpu::dropout_kernel(self, dropout_p, true)); + self.copy_(mask); + return self; +} + +/** + * Fall back implementation of efficient attention + */ +std::tuple +_scaled_dot_product_efficient_attention_xpu( + const Tensor& query, + const Tensor& key, + const Tensor& value, + const std::optional& attn_bias, + bool compute_log_sumexp, + double dropout_p, + bool is_causal, + std::optional scale) { + // Used for tracking usage statistics + C10_LOG_API_USAGE_ONCE("torch.sdpa.mem_efficient_attention"); + constexpr int64_t MAX_BATCH_SIZE = (1LL << 16) - 1; + int64_t batch_size = query.size(0); + + if (batch_size > MAX_BATCH_SIZE) { + TORCH_CHECK( + dropout_p == 0.0, + "Efficient attention cannot produce valid seed and offset outputs when " + "the batch size exceeds (", + MAX_BATCH_SIZE, + ")."); + } + auto res = at::_scaled_dot_product_attention_math( + query, + key, + value, + attn_bias, + dropout_p, + is_causal, + std::nullopt, /*dropout_mask*/ + scale, + true); + return std::make_tuple( + std::get<0>(res), std::get<1>(res), Tensor(), Tensor()); +} + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 90d314353b..2d4ecd303c 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -1161,8 +1161,24 @@ inline void gpu_reduce_kernel( using traits = function_traits; using arg_t = typename traits::template arg<0>::type; + + // at::Half/at::ComplexHalf overflows easily as it's range is very small. + // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_half_or_chalf = + (std::is_same_v && + std::is_same_v) || + (std::is_same_v, scalar_t> && + std::is_same_v, out_scalar_t>); + // at::BFloat16 has lower precision and can lead to rounding errors. + // So when scalar_t and out_scalar_t are at::BFloat16, we + // set can_accumulate_in_output to False. + static constexpr bool is_inp_out_type_bfloat16 = + (std::is_same_v && + std::is_same_v); static constexpr bool can_accumulate_in_output = - std::is_convertible::value; + std::is_convertible_v && + !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16); bool can_use_32bit_indexing = iter.can_use_32bit_indexing(); std::unique_ptr owned_buf_ptr; diff --git a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp index c86203bf4a..fe534488ff 100644 --- a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp @@ -58,15 +58,6 @@ struct SumFunctor { } }; -template <> -struct SumFunctor> { - using scalar_t = c10::complex; - using acc_t = at::opmath_type; - inline acc_t operator()(acc_t a, acc_t b) const { - return a + b; - } -}; - template < typename scalar_t, typename acc_t = scalar_t, @@ -78,6 +69,16 @@ struct sum_functor { } }; +template <> +struct sum_functor> { + void operator()(TensorIterator& iter) { + using scalar_t = c10::complex; + using acc_t = at::opmath_type; + gpu_reduce_kernel( + iter, func_wrapper(SumFunctor())); + } +}; + void sum_kernel(TensorIterator& iter) { auto general_dispatcher = [](TensorIterator& iter) { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( diff --git a/test/xpu/run_test_with_only.py b/test/xpu/run_test_with_only.py deleted file mode 100644 index af9230c02a..0000000000 --- a/test/xpu/run_test_with_only.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2020-2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 - -import os -import sys - -# Cases in the file is too slow to run all suites on CPU. So add white list. - - -def launch_test(test_case, skip_list=None, exe_list=None): - os.environ["PYTORCH_TEST_WITH_SLOW"] = "1" - if skip_list is not None: - skip_options = ' -k "not ' + skip_list[0] - for skip_case in skip_list[1:]: - skip_option = " and not " + skip_case - skip_options += skip_option - skip_options += '"' - test_command = ( - "pytest --junit-xml=./op_ut_with_only.xml " + test_case + skip_options - ) - return os.system(test_command) - elif exe_list is not None: - exe_options = ' -k "' + exe_list[0] - for exe_case in exe_list[1:]: - exe_option = " or " + exe_case - exe_options += exe_option - exe_options += '"' - test_command = ( - "pytest --junit-xml=./op_ut_with_only.xml " + test_case + exe_options - ) - return os.system(test_command) - else: - test_command = "pytest --junit-xml=./op_ut_with_only.xml " + test_case - return os.system(test_command) - - -res = 0 - -# test_decomp -# full skip_list is in Issue #470 -execute_list = ( - "test_comprehensive_nn_functional_cross_entropy_xpu", - "test_comprehensive_nn_functional_nll_loss_xpu_bfloat16", - "test_comprehensive_nn_functional_nll_loss_xpu_float32", - "test_comprehensive_nn_functional_nll_loss_xpu_float64", - "bincount", -) -skip_list = ( - "test_comprehensive_baddbmm_xpu_float64", - "test_comprehensive_logspace_tensor_overload_xpu_int16", - "test_comprehensive_logspace_tensor_overload_xpu_int32", - "test_comprehensive_logspace_tensor_overload_xpu_int64", - "test_comprehensive_logspace_xpu_int16", - "test_comprehensive_logspace_xpu_int32", - "test_comprehensive_logspace_xpu_int64", - "test_comprehensive_nn_functional_conv_transpose2d_xpu_bfloat16", - "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex128", - "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex32", - "test_comprehensive_nn_functional_conv_transpose2d_xpu_complex64", - "test_comprehensive_nn_functional_conv_transpose2d_xpu_float16", - "test_comprehensive_nn_functional_conv_transpose2d_xpu_float32", - "test_comprehensive_nn_functional_conv_transpose2d_xpu_float64", - "test_comprehensive_nn_functional_conv_transpose3d_xpu_bfloat16", - "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex128", - "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex32", - "test_comprehensive_nn_functional_conv_transpose3d_xpu_complex64", - "test_comprehensive_nn_functional_conv_transpose3d_xpu_float16", - "test_comprehensive_nn_functional_conv_transpose3d_xpu_float32", - "test_comprehensive_nn_functional_conv_transpose3d_xpu_float64", - "test_comprehensive_nn_functional_instance_norm_xpu_float64", - "test_comprehensive_nn_functional_nll_loss_xpu_float16", - "test_comprehensive_nn_functional_pad_reflect_xpu_bfloat16", - "test_comprehensive_torch_ops_aten__flash_attention_forward_xpu_float16", - "test_comprehensive_vdot_xpu_complex128", - "test_comprehensive_vdot_xpu_complex64", - "test_quick_addmm_xpu_float64", - "test_quick_baddbmm_xpu_float64", - "test_quick_core_backward_baddbmm_xpu_float64", - "test_quick_core_backward_mv_xpu_float64", - "test_quick_logspace_tensor_overload_xpu_int16", - "test_quick_logspace_tensor_overload_xpu_int32", - "test_quick_logspace_tensor_overload_xpu_int64", - "test_quick_logspace_xpu_int16", - "test_quick_logspace_xpu_int32", - "test_quick_logspace_xpu_int64", - "test_quick_vdot_xpu_complex128", - "test_quick_vdot_xpu_complex64", - "test_exponential_non_inf_xpu", - "test_aten_core_operators", - "test_has_decomposition", - "test_comprehensive_diff_xpu_complex128", - "test_comprehensive_ormqr_xpu_complex128", - "test_quick_var_mean_xpu_float64", - "test_comprehensive_diff_xpu_complex64", - "test_comprehensive_ormqr_xpu_complex64", - "test_quick_mean_xpu_complex128", - "test_comprehensive_grid_sampler_2d_xpu_bfloat16", -) -# res += launch_test("test_decomp_xpu.py", exe_list=execute_list) -res += launch_test("test_decomp.py", skip_list=skip_list) - -if os.name == "nt": - sys.exit(res) -else: - exit_code = os.WEXITSTATUS(res) - sys.exit(exit_code) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 7b939e18c3..d8b864af1e 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -21,13 +21,6 @@ default="selected", help="Test cases scope", ) -# Add skip-cases parameter to import window skip dictionary -parser.add_argument( - "--skip-cases", - action="store_true", - default=False, - help="Use window skip dictionary for test cases", -) args = parser.parse_args() @@ -38,18 +31,24 @@ def should_skip_entire_file(skip_list): return any(item.endswith(".py::") for item in skip_list) -# Import window skip dictionary if skip-cases is True -if args.skip_cases: +platform = sys.platform +print(f"Running test on the platform: {platform}") +# Import window skip dictionary if Platform is Windows +if platform.startswith("win"): try: # Import the window skip dictionary module - from window_skip_dict import skip_dict as window_skip_dict + from windows_skip_dict import skip_dict as window_skip_dict # Merge the window skip dictionary with the default one using intelligent strategy merged_skip_dict = {} # First, copy all keys from default skip_dict for key in skip_dict: - merged_skip_dict[key] = skip_dict[key].copy() if skip_dict[key] else [] + merged_skip_dict[key] = ( + list(skip_dict[key]) + if isinstance(skip_dict[key], tuple) + else (skip_dict[key].copy() if skip_dict[key] else []) + ) # Then merge with window_skip_dict using intelligent strategy for key in window_skip_dict: @@ -61,9 +60,12 @@ def should_skip_entire_file(skip_list): # Intelligent merge strategy: if should_skip_entire_file(window_skip_list): # If Windows wants to skip entire file, use ONLY Windows skip list - merged_skip_dict[key] = window_skip_list + window_skip_entire_file_list = [ + item.replace("::", "") for item in window_skip_list + ] + merged_skip_dict[key] = window_skip_entire_file_list print( - f"Windows entire file skip detected for {key}, using: {window_skip_list}" + f"Windows entire file skip detected for {key}, using: {window_skip_entire_file_list}" ) else: # Otherwise, merge both lists and remove duplicates @@ -112,6 +114,10 @@ def should_skip_entire_file(skip_list): skip_list = None # For "selected" case, use the skip_list as is + # If skip_list is empty, set it to None + if skip_list is not None and len(skip_list) == 0: + skip_list = None + print(f"Running test case: {key}") if skip_list: print(f"Skip list: {skip_list}") diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 5736bd47d9..52bec312d0 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -745,7 +745,7 @@ # CUDA specific case "test_cufft_plan_cache_xpu_float64", ), - "test_decomp.py": ( + "test_decomp_xpu.py": ( # AssertionError: Tensor-likes are not close! ; Exception: Tensor-likes are not close! "test_comprehensive_baddbmm_xpu_float64", "test_comprehensive_logspace_tensor_overload_xpu_int16", diff --git a/test/xpu/test_decomp.py b/test/xpu/test_decomp.py deleted file mode 100644 index 1694e33d16..0000000000 --- a/test/xpu/test_decomp.py +++ /dev/null @@ -1,1371 +0,0 @@ -# Copyright 2020-2025 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 - -# Owner(s): ["module: decompositions"] - -import functools -import itertools -import re -import unittest -from collections import defaultdict -from functools import partial - -import torch._inductor.decomposition -import torch.autograd -from torch import Tensor -from torch._decomp import core_aten_decompositions, decomposition_table -from torch._dispatch.python import enable_python_dispatcher -from torch._export.utils import _is_cia_op -from torch._ops import DispatchKey -from torch.testing import make_tensor -from torch.testing._internal.common_cuda import SM70OrLater, tf32_off -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - onlyCPU, - onlyNativeDeviceTypes, - ops, -) -from torch.testing._internal.common_methods_invocations import ( - op_db, - skip, - skipOps, - xfail, -) -from torch.testing._internal.common_modules import module_db, modules -from torch.testing._internal.common_utils import ( - is_iterable_of_tensors, - run_tests, - skipIfCrossRef, - skipIfTorchDynamo, - suppress_warnings, - TEST_WITH_ASAN, - TEST_WITH_SLOW, - TestCase, - unMarkDynamoStrictTest, -) -from torch.utils import _pytree as pytree -from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten - -device_type = ( - acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" -) - -aten = torch.ops.aten - - -# TODO: this isn't going to work with non-aten namespaces -def overload_to_aten_name(op): - return op._schema.name.split("::")[1] - - -# All operators that can have decomp tests -decomposition_names = { - overload_to_aten_name(k) - for k in decomposition_table - if isinstance(k, torch._ops.OpOverload) -} -core_decomposition_names = { - overload_to_aten_name(k) - for k in core_aten_decompositions() - if isinstance(k, torch._ops.OpOverload) and not _is_cia_op(k) -} -_decomp_test_ops = [ - op - for op in op_db - if op.aten_name in decomposition_names - or op.aten_backward_name in decomposition_names -] -_decomp_test_ops_core_autograd = [ - op - for op in op_db - if op.aten_name in core_decomposition_names and op.supports_autograd -] -_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name] - - -def diff_arg(arg, requires_grad=True): - def is_differentiable_arg(arg): - if requires_grad: - return arg.requires_grad - else: - return arg.is_floating_point() or arg.is_complex() - - if is_iterable_of_tensors(arg): - if all(is_differentiable_arg(a) for a in arg): - return True - if all(not is_differentiable_arg(a) for a in arg): - return False - raise RuntimeError("NYI: The test runner can't handle this") - return isinstance(arg, Tensor) and is_differentiable_arg(arg) - - -# Version of autograd.grad with some differences: -# - pytree inputs is allowed (but leaves of the pytree have to all -# be tensors) -# - if an input is not used as part of derivatives, we will return a -# zero-filled tensor for the result -def _autograd_grad( - outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True -): - inputs, inputs_spec = tree_flatten(inputs) - diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) - if grad_outputs is None: - diff_outputs = tuple(out for out in outputs if out.requires_grad) - else: - diff_grad_outputs = [ - (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad - ] - if len(diff_grad_outputs) == 0: - diff_outputs, grad_outputs = (), () - else: - diff_outputs, grad_outputs = zip(*diff_grad_outputs) - grad_inputs = torch.autograd.grad( - diff_outputs, - diff_inputs, - grad_outputs, - retain_graph=retain_graph, - create_graph=create_graph, - allow_unused=True, - ) - result = [] - grad_inputs_iter = iter(grad_inputs) - for inp in inputs: - if inp.requires_grad: - grad_input = next(grad_inputs_iter) - if grad_input is None: - result.append(torch.zeros_like(inp)) - else: - result.append(grad_input) - else: - result.append(torch.zeros_like(inp)) - return tree_unflatten(result, inputs_spec) - - -def _as_tuple(val): - if isinstance(val, tuple): - return val - return (val,) - - -def ref_vjp_no_create(f, *primals): - result = f(*primals) - - def wrapped(cotangents): - return _autograd_grad( - _as_tuple(result), - primals, - _as_tuple(cotangents), - create_graph=False, - retain_graph=True, - ) - - return result, wrapped - - -dtype_precisions = { - torch.float16: (0.001, 1e-5), - torch.bfloat16: (0.016, 1e-4), - torch.float32: (1.3e-6, 1e-5), - torch.float64: (1e-7, 1e-7), - torch.complex32: (0.001, 1e-5), - torch.complex64: (1.3e-6, 1e-5), - torch.complex128: (1e-7, 1e-7), -} -# Returns the "default" rtol and atol for comparing scalars or -# tensors of the given dtypes. - - -def _getDefaultRtolAndAtol(dtype0, dtype1): - rtol = max( - dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] - ) - atol = max( - dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] - ) - return rtol, atol - - -def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): - assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" - if orig.numel() == 0 or decomp.numel() == 0: - assert orig.numel() == decomp.numel() - return - assert orig.shape == decomp.shape, f"{i} Operation: {op}" - tol_table = { - (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5, - (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, - (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3, - (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2, - (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, - (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, - (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, - (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, - (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5, - (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5, - (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4, - (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4, - (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7, - (torch.float16, torch.ops.aten.var_mean.correction): 5e-7, - (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7, - (torch.float16, torch.ops.aten.var_mean.dim): 5e-7, - (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2, - (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, - (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2, - (torch.float16, torch.ops.aten.nll_loss2d_backward.default): 1e-4, - (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1, - (torch.float16, torch.ops.aten.hardswish.default): 2e-7, - (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7, - (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2, - (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2, - (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, - (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2, - (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, - (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3, - (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, - (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, - (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3, - (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2, - # see https://github.com/pytorch/pytorch/pull/96264 - (torch.float16, torch.ops.aten.mv.default): 1e-5, - (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, - (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5, - (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7, - # XPU specific - ( - torch.float16, - torch.ops.aten._batch_norm_with_update.default, - ): 2e-7, # adjust tolerance for xpu - ( - torch.bfloat16, - torch.ops.aten._batch_norm_with_update.default, - ): 2e-7, # adjust tolerance for xpu - } - if ref.is_floating_point(): - orig_diff = (orig - ref).abs().max() - decomp_diff = (decomp - ref).abs().max() - atol = tol_table.get((test_dtype, op), 1e-7) - if decomp_diff > orig_diff + atol: - raise RuntimeError( - f"Difference from float64 is larger with decomposition {op.__name__}" - f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" - f"atol = {atol}\n" - f"args = {args}\n" - f"kwargs = {kwargs}" - ) - else: - test_case.assertEqual( - orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}" - ) - - -def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): - test_case.assertEqual( - orig.dtype, - decomp.dtype, - f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}", - ) - # Before adding an entry to this table, make sure your decomposition is right :) - tol_table = { - # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 - (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), - (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( - 1e-3, - 1e-3, - ), - (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6), - # This exceeds default tolerances only on CPU, on CUDA it's fine - (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5), - # Exceeds tolerances on CUDA, likely due to fma - (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5), - (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), - (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4), - (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4), - # The decomposition is TOO correct. It computes everything in int64, so sometimes - # there's an off-by-one error. See - # https://github.com/pytorch/pytorch/issues/81996 - # https://github.com/pytorch/pytorch/issues/82230 - (torch.int8, torch.ops.aten.linspace.default): (0, 1), - (torch.uint8, torch.ops.aten.linspace.default): (0, 1), - (torch.int16, torch.ops.aten.linspace.default): (0, 1), - (torch.int32, torch.ops.aten.linspace.default): (0, 1), - (torch.int64, torch.ops.aten.linspace.default): (0, 1), - (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), - (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), - (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), - (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), - (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), - (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), - (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), - (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), - (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), - (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), - (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), - (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), - (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), - (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), - (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), - } - if (decomp.dtype, op) in tol_table: - rtol, atol = tol_table[(decomp.dtype, op)] - else: - rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) - test_case.assertEqual( - orig, - decomp, - rtol=rtol, - atol=atol, - msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}", - ) - - -# Given f, returns an f' such that: -# - f' takes only positional arguments -# - All arguments to f' are floating-point Tensors -# - All outputs of f' are floating-point Tensors -def normalize_op_input_output2( - f, args, kwargs, output_process_fn_grad=None, requires_grad=True -): - flat_args, args_spec = tree_flatten(args) - diff_argnums = tuple( - i - for i, arg in enumerate(flat_args) - if diff_arg(arg, requires_grad=requires_grad) - ) - assert len(diff_argnums) > 0 - primals = tuple(flat_args[i] for i in diff_argnums) - - @functools.wraps(f) - def wrapped(*primals): - _args = list(flat_args) - for num, arg in zip(diff_argnums, primals): - _args[num] = arg - _args = tree_unflatten(_args, args_spec) - result = f(*_args, **kwargs) - if output_process_fn_grad is not None: - result = output_process_fn_grad(result) - if isinstance(result, tuple): - # TODO We should check that the integer outputs also agree - result = tuple( - r - for r in result - if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) - ) - assert len(result) > 0 - return result - - return wrapped, primals - - -# NB: This also upcasts dtype arguments -# TODO: handle complex correctly -def upcast_tensor(x, dtype=torch.float32): - if isinstance(x, Tensor) and x.dtype.is_floating_point: - return x.to(dtype=dtype) - elif isinstance(x, torch.dtype) and x in [ - torch.float16, - torch.bfloat16, - torch.float, - ]: - return dtype - else: - return x - - -def normalize_op_input_output(f, sample, requires_grad=True): - args = tuple([sample.input] + list(sample.args)) - return normalize_op_input_output2( - f, - args, - sample.kwargs, - sample.output_process_fn_grad, - requires_grad=requires_grad, - ) - - -CROSS_REF_EXCLUDE_SET = { - # CUBLAS_STATUS_NOT_SUPPORTED when calling - # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, - # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, - # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, - # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` - ("cuda", torch.bfloat16, "nn.functional.bilinear"), - # randomness - (None, None, "special.ndtr"), # aten.special_ndtr was not decomposed - (None, None, "new_empty"), - (None, None, "empty_like"), - (None, None, "empty"), - # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default. - (None, None, "item"), - # It's the only in-place op without an out-of-place equivalent in the Python API - # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`. - (None, None, "zero_"), - # No idea what's going on here - # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), []) - # in the test, but it seems to pass when tested locally and in the logsumexp test - (None, torch.float32, "masked.logsumexp"), - (None, torch.float64, "masked.logsumexp"), - # exp_vml_cpu not implemented for Half - (torch.cpu, torch.float16, "signal.windows.exponential"), - (torch.cpu, torch.float16, "signal.windows.gaussian"), - # sin_vml_cpu not implemented for Half - (torch.cpu, torch.float16, "signal.windows.cosine"), - # CompositeAutogradImplicit - # See https://github.com/pytorch/pytorch/issues/81669 - (None, None, "nn.functional.relu6"), - # This decomp runs before autograd. - (None, None, "nn.functional.rrelu"), - (None, None, "meshgrid"), - # Decomposition registered as Autograd - (None, None, "nn.functional.hardshrink"), - (None, None, "nn.functional.softshrink"), - # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) - (None, None, "diag"), - # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 - ("cpu", torch.bfloat16, "_softmax_backward_data"), - (None, None, "norm"), - # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise) - (None, None, "native_batch_norm"), - (None, None, "_upsample_bilinear2d_aa"), - (None, None, "empty_strided"), # aten.empty_strided was not decomposed - ( - None, - None, - "bernoulli", - ), # bernoulli is a function of randomness, so couldn't do cross-reference. - # XPU specific exclude cases - # ("xpu", None, "some_xpu_specific_op"), -} - -CROSS_REF_BACKWARD_EXCLUDE_SET = { - # Decomposed backward formula is not as precise - ("cpu", torch.bfloat16, "nn.functional.hardswish"), - ("cuda", torch.float16, "nn.functional.cross_entropy"), - ( - None, - None, - "bernoulli", - ), # bernoulli is a function of randomness, so couldn't do cross-reference. - # XPU specific backward exclude cases - # ("xpu", torch.float16, "nn.functional.some_op"), -} - -all_decomposed = set() -all_called = defaultdict(int) - -# Helpful snippet for testing coverage -""" -import atexit -def check_coverage(): - print("missing coverage:") - print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) -atexit.register(check_coverage) -""" - -# Helpful snippet for Horace to create his google sheet :) -""" -import atexit -def dump_ops(): - with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: - for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): - f.write(f'{op.__name__}\n') - g.write(f'{count}\n') - with open('run_decompositions.txt', 'w') as f: - for op in sorted([i.__name__ for i in all_decomposed]): - f.write(f'{op}\n') - -atexit.register(dump_ops) -""" - - -def any_unsupported(args, kwargs): - def test_unsupported(t): - if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: - # These are all things that we haven't coded decompositions - # to handle correctly. Maybe they should. - return any( - [ - t.is_sparse_csr, - t.is_sparse, - t.is_mkldnn, - t.is_quantized, - t.is_nested, - torch._is_functional_tensor(t), - ] - ) - elif torch.overrides.is_tensor_like(t): - # Decompositions will generally change the behavior of Tensor-like - # subclasses, so bypass tests in this case too - return True - else: - return False - - flat_args = pytree.arg_tree_leaves(*args, **kwargs) - return any(test_unsupported(x) for x in flat_args) - - -core_backward_failures = { - skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs - xfail("addcdiv"), - skip("addcmul"), # slow: fails with --timeout=360 secs - skip("deg2rad"), # slow: fails with --timeout=360 secs - skip("diag_embed"), # slow: fails with --timeout=360 secs - skip("frac"), # slow: fails with --timeout=360 secs - skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs - xfail("lerp"), - skip("logaddexp"), # slow: fails with --timeout=360 secs - skip("native_dropout_backward"), # slow: fails with --timeout=360 secs - xfail("nn.functional.binary_cross_entropy_with_logits"), - skip("nn.functional.glu"), # slow: fails with --timeout=360 secs - xfail("nn.functional.hardshrink"), - xfail("nn.functional.softshrink"), - skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs - xfail("norm"), - xfail("norm", "fro"), - xfail("norm", "inf"), - xfail("norm", "nuc"), - skip("rad2deg"), # slow: fails with --timeout=360 secs - skip("renorm"), # slow: fails with --timeout=360 secs - skip("rot90"), # slow: fails with --timeout=360 secs - skip("rsub"), # slow: fails with --timeout=360 secs - skip("sgn"), # slow: fails with --timeout=360 secs - skip("special.xlog1py"), # slow: fails with --timeout=360 secs - xfail("stack"), - skip("tril"), # slow: fails with --timeout=360 secs - skip("triu"), # slow: fails with --timeout=360 secs - skip("unfold_copy"), # slow: fails with --timeout=360 secs - skip("xlogy"), # slow: fails with --timeout=360 secs - xfail("zero_"), -} -if not TEST_WITH_SLOW: - core_backward_failures.update( - { - skip("addr"), # slow: takes 46 sec on A100 - skip("baddbmm"), # slow: takes 800+ sec on A100 - skip("clamp_min"), # slow: takes 800 sec on A100 - skip("clamp_max"), # slow: takes 800 sec on A100 - skip("logit"), # slow: takes 44 sec on A100 - skip("nn.functional.hardswish"), # slow: takes 60 sec on A100 - skip("std_mean"), # slow: takes 170 sec on A100 - skip("split", variant_name="list_args"), # slow: takes 118 sec on A100 - skip("transpose"), # slow: takes 50 sec on A100 - skip("unbind"), # slow: takes 70 sec on A100 - skip("unsafe_split"), # slow: takes 49 sec on A100 - } - ) - -comprehensive_failures = { - xfail( - "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,) - ), # off by one error - xfail( - "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,) - ), # off by one error - xfail( - "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) - ), # off by one error -} - - -@unMarkDynamoStrictTest -class TestDecomp(TestCase): - longMessage = True - - # NB: This actually overlaps with test_comprehensive, but it only - # runs on things that are definitely decomposed so it's a lot faster - # to run - @onlyNativeDeviceTypes - @skipIfCrossRef - @suppress_warnings - @ops(_decomp_test_ops) - def test_quick(self, device, dtype, op): - self.do_cross_ref(device, dtype, op, run_all=False) - - @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures) - @onlyNativeDeviceTypes - @skipIfCrossRef - @suppress_warnings - @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,)) - def test_quick_core_backward(self, device, dtype, op): - test_keys = [ - (torch.device(device).type, dtype, op.name), - (None, dtype, op.name), - (None, None, op.name), - ] - if any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys): - self.skipTest(f"{op.name} in {dtype} not supported") - for sample_input in op.sample_inputs(device, dtype, requires_grad=True): - aten_name = op.decomp_aten_name or op.aten_name - args = [sample_input.input] + list(sample_input.args) - kwargs = sample_input.kwargs - func = partial(op.get_op(), **kwargs) - with ( - self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all=False - ) as mode, - enable_python_dispatcher(), - ): - torch.autograd.gradcheck(func, args) - self.check_decomposed(aten_name, mode) - - @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @onlyNativeDeviceTypes - @skipIfCrossRef - @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures) - @suppress_warnings - @ops(op_db) - def test_comprehensive(self, device, dtype, op): - self.do_cross_ref(device, dtype, op, run_all=True) - - def test_uniform(self, device): - size = (2, 3, 4, 5) - dtype = torch.float32 - x = make_tensor(size, dtype=dtype, device=device) - low = 0.3 - high = 0.9 - - torch.manual_seed(123) - ref = torch.ops.aten.uniform(x, low, high) - torch.manual_seed(123) - res = torch._decomp.decompositions.uniform(x, low=low, high=high) - self.assertEqual(ref, res) - - def test_bernoulli_default(self, device): - p = 0.3 - p_t = p * torch.ones(5, 5) - torch.manual_seed(123) - ref = torch.ops.aten.bernoulli.default(p_t) - torch.manual_seed(123) - res = torch._decomp.decompositions.bernoulli(p_t) - ref_p = ref.sum() / torch.prod(torch.tensor(ref.size())) - res_p = res.sum() / torch.prod(torch.tensor(res.size())) - self.assertEqual(ref_p, res_p, atol=0.06 * p, rtol=0.06) - - def test_broadcasting_index_copy(self, device): - x = torch.zeros([1, 10], device=device) - xs = torch.ones([2, 10], device=device) - - def index_copy(xs, x): - torch._decomp.decompositions.index_copy_( - xs, 0, torch.tensor(0).to(device), x - ) - - index_copy(xs, x) - - xs_two = torch.ones([2, 10], device=device) - xs_two[0] = x - - self.assertEqual(xs, xs_two) - - def test_cat_single_input(self, device): - decomp_table = torch._inductor.decomposition.select_decomp_table() - cat_inductor = decomp_table[torch.ops.aten.cat.default] - - inp = torch.rand([2048, 2048], device=device) - inps = [inp for _ in range(10)] - - for dim in (-1, 0, 1): - self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim)) - - @suppress_warnings - @tf32_off() - # only tests RNNs since we have py dispsatcher decomps for them - @modules( - filter( - lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), - module_db, - ) - ) - def test_rnn_decomp_module(self, device, dtype, module_info, training): - module_cls = module_info.module_cls - module_inputs = module_info.module_inputs_func( - module_info, - device=device, - dtype=dtype, - requires_grad=True, - training=training, - ) - for module_input in module_inputs: - if module_input.forward_input is None: - continue - args, kwargs = ( - module_input.constructor_input.args, - module_input.constructor_input.kwargs, - ) - m = module_cls(*args, **kwargs) - m.to(device).to(dtype) - - args, kwargs = ( - module_input.forward_input.args, - module_input.forward_input.kwargs, - ) - with ( - self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all=True - ), - enable_python_dispatcher(), - ): - decomp_out = m(*args, **kwargs) - - non_decomp_out = m(*args, **kwargs) - # without this check, incorrect decomps at the python dispatcher level can still pass because - # they're checking aten decomps at the torch_dispatch level - self.assertEqual(decomp_out, non_decomp_out) - - def test_batch_norm_unflatten_weight_bias(self, device): - # https://github.com/pytorch/pytorch/issues/100970 - shape = (1, 3, 2, 2) - input = torch.randn(shape, device=device) - weight = torch.randn((3, 1, 1, 1), device=device) - bias = torch.randn(3, device=device) - mean = torch.randn(3, device=device) - var = torch.randn(3, device=device) - res = torch._decomp.decompositions.native_batch_norm( - input, weight, bias, mean, var, False, 1, 1e-05 - ) - self.assertEqual(shape, res[0].shape) - - def test_arange_graph(self, device): - from torch.fx.experimental.proxy_tensor import make_fx - - def func(x, start): - le = x.shape[-1] - if start is None: - a = torch.arange(le, dtype=torch.float32, device=x.device) - else: - a = torch.arange(start, le, dtype=torch.float32, device=x.device) - return a - - pattern = r", device = device\(.+\), requires_grad = False" - - cfunc = make_fx(func, decomposition_table=decomposition_table) - fx_g = cfunc(torch.rand(10, device=device), None) - fx_g_code = fx_g.code.strip() - # Remove device and requires_grad - fx_g_code = re.sub(pattern, "", fx_g_code) - self.assertExpectedInline( - fx_g_code, - """\ -def forward(self, x_1, start_1): - iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64) - mul = torch.ops.prims.mul.default(iota, 1); iota = None - add = torch.ops.prims.add.default(mul, 0); mul = None - convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None - return convert_element_type""", - ) - - fx_g = cfunc(torch.rand(10, device=device), 1) - fx_g_code = fx_g.code.strip() - # Remove device and requires_grad - fx_g_code = re.sub(pattern, "", fx_g_code) - self.assertExpectedInline( - fx_g_code, - """\ -def forward(self, x_1, start_1): - iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64) - mul = torch.ops.prims.mul.default(iota, 1); iota = None - add = torch.ops.prims.add.default(mul, 1); mul = None - convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None - return convert_element_type""", - ) - - def test_masked_fill(self, device): - from torch.fx.experimental.proxy_tensor import make_fx - - if torch.device(device).type not in [ - "xpu", - "cuda", - torch._C._get_privateuse1_backend_name(), - ]: - self.skipTest("only runs on XPU, CUDA and PrivateUse1.") - - def func(scores, mask, value): - return scores.masked_fill(mask, value) - - scores_t = torch.tensor([1, 2, 3, 4], device=device) - mask_t = torch.tensor([True, True, True, True], device=device) - value_t = torch.tensor(0, dtype=scores_t.dtype) - cfunc = make_fx(func, decomposition_table=decomposition_table) - fx_g = cfunc(scores_t, mask_t, value_t) - self.assertExpectedInline( - fx_g.code.strip(), - """\ -def forward(self, scores_1, mask_1, value_1): - where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None - return where""", - ) - - class DecompCrossRefMode(TorchDispatchMode): - def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all): - self.test_case = test_case - self.saved_precision = saved_precision - self.saved_rel_tol = saved_rel_tol - self.test_dtype = dtype - self.run_all = run_all - - # We check the correctness of each decomposition right after running it. - # So, when we encounter a decomposition, we run the function normally, and - # then run the decomposition, and ensure they're identical. - self.called = set() - self.decomposed = set() - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - self.test_case.precision = self.saved_precision - self.test_case.rel_tol = self.saved_rel_tol - - self.called.add(func) - all_called[func] += 1 - - # Stuff we shouldn't bother testing - # (TODO: remove detach from the decomp table?) - # N.b. Testing in-place ops would need dedicated logic - in_place = func.name()[-1] == "_" - ignored_ops = [ - torch.ops.aten.detach.default, - # non-deterministic ops - torch.ops.aten.empty.memory_format, - torch.ops.aten.empty_like.default, - torch.ops.aten.new_empty.default, - torch.ops.aten.empty_strided.default, - torch.ops.aten.new_empty_strided.default, - torch.ops.aten.randn.default, - torch.ops.aten.native_dropout.default, - ] - if ( - func not in decomposition_table - or func in ignored_ops - or torch.Tag.nondeterministic_seeded in func.tags - or any_unsupported(args, kwargs) - or in_place - ): - return func(*args, **kwargs) - - self.decomposed.add(func) - all_decomposed.add(func) - - # We take 2 main strategies for verifying correctness/numerical stability of decompositions - # The first one is simply tolerance checking between decomp_out and pytorch_out - # However, for fp16/bf16 and reductions, this becomes very - # finicky, as there are not many guarantees we can make. - # So, for fp16/bf16, we instead compare the difference of - # {decomp_out, pytorch_out_64} and {pytorch_out, - # pytorch_out_64}. In other words, we compare how far the - # decomposition and pytorch are from the "ground truth" (i.e. - # fp64). If the decomposition results in more error, we error - - # We also decompose the decomposition recursively for - # further coverage, as some paths not be exercised directly by - # OpInfos (sadly) but just by other ops - - decomposition = decomposition_table[func] - - do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16] - if self.run_all: - # Execute recursively via DFS, to find the root of a possible error first - with self: - decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) - else: - decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) - - # At this stage we should not be decomposing an in-place op - # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops - # because decompositions are run after functionalisation and we would not like them to - # de-functionalise the graph, as that would break AoTAutograd - # We run the real function *after* the decomposition to make sure that the - # decomposition does not modify any of the inputs in-place. If it does - # real_out should be different than decom_out so we should catch this - real_out_unflat = func(*args, **kwargs) - real_out = pytree.tree_leaves(real_out_unflat) - - assert len(real_out) == len(decomp_out) - - if do_relative_check: - device_arg = kwargs.get("device", None) - - def upcast(x): - if (isinstance(x, Tensor) and x.device.type == "mps") or ( - device_arg and torch.device(device_arg).type == "mps" - ): - return upcast_tensor(x, dtype=torch.float32) - else: - return upcast_tensor(x, dtype=torch.float64) - - real_out_double, _ = tree_flatten( - func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) - ) - for i, (orig, decomp, ref) in enumerate( - zip(real_out, decomp_out, real_out_double) - ): - if not isinstance(orig, torch.Tensor): - assert type(orig) == type(decomp) - assert orig == decomp - continue - op_assert_ref( - self.test_case, - func, - self.test_dtype, - i, - orig, - decomp, - ref, - args, - kwargs, - ) - else: - for orig, decomp in zip(real_out, decomp_out): - if not isinstance(orig, torch.Tensor): - assert type(orig) == type(decomp) - assert orig == decomp - continue - op_assert_equal( - self.test_case, - func, - self.test_dtype, - orig, - decomp, - args, - kwargs, - ) - - return real_out_unflat - - def check_decomposed(self, aten_name, mode): - self.assertTrue( - any(overload_to_aten_name(c) == aten_name for c in mode.decomposed), - msg=( - f"aten.{aten_name} was not decomposed, saw calls for: " - f"{', '.join(map(str, list(mode.called)))}. If your op is " - f"CompositeImplicitAutograd you should skip this test " - f"by updating CROSS_REF_EXCLUDE_SET." - ), - ) - - @skipIfTorchDynamo("Test does not work with TorchDynamo") - def do_cross_ref(self, device, dtype, op, *, run_all): - test_keys = [ - (torch.device(device).type, dtype, op.name), - (None, dtype, op.name), - (None, None, op.name), - ] - if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys): - self.skipTest(f"{op.name} in {dtype} not supported") - - skip_decomp_vjp = any( - key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys - ) - - requires_grad = ( - op.supports_autograd - and dtype in op.supported_backward_dtypes(torch.device(device).type) - # TODO: OpInfo really ought to error out for this case, but it's - # not exercised in test_ops_gradients atm. The problem is not - # complex32 per-se (which is supported by data movement only ops) - # but that when we do backwards we expect other ops like add to work - and not dtype == torch.complex32 - ) - samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) - - aten_name = op.decomp_aten_name or op.aten_name - - func = op.get_op() - - def run_without_python_dispatcher(mode): - return any( - isinstance(op, torch._ops.OpOverload) - and op.has_kernel_for_dispatch_key( - DispatchKey.CompositeImplicitAutograd - ) - for op in mode.decomposed.union([func]) - ) - - for sample_input in samples: - if requires_grad: - fn, primals = normalize_op_input_output(func, sample_input) - primals = tree_map( - lambda x: x if isinstance(x, torch.Tensor) else x, primals - ) - - # Once https://github.com/pytorch/pytorch/pull/75965/ I can - # store the called list on the mode object instance and no - # explicit clearing is necessary as I will create a fresh mode - # for each region - with ( - self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode, - enable_python_dispatcher(), - ): - decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) - if run_without_python_dispatcher(mode): - # without this check, incorrect decomps at the python dispatcher level can still pass because - # they're checking aten decomps at the torch_dispatch level. - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode: - decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) - if aten_name in decomposition_names: - self.check_decomposed(aten_name, mode) - - if not skip_decomp_vjp and ( - op.aten_backward_name in decomposition_names or run_all - ): - cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) - - with ( - self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode, - enable_python_dispatcher(), - ): - decomp_vjp_fn(cotangents) - if run_without_python_dispatcher(mode): - # without this check, incorrect decomps at the python dispatcher level can still pass because - # they're checking aten decomps at the torch_dispatch level. - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode: - decomp_vjp_fn(cotangents) - if not run_all: - self.check_decomposed(op.aten_backward_name, mode) - - elif aten_name in decomposition_names or run_all: - args = [sample_input.input] + list(sample_input.args) - kwargs = sample_input.kwargs - # A failure here might be because the decomposition for the op is wrong or because a - # decomposition used by the particular op is wrong. - with ( - self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode, - enable_python_dispatcher(), - ): - func(*args, **kwargs) - - if run_without_python_dispatcher(mode): - # without this check, incorrect decomps at the python dispatcher level can still pass because - # they're checking aten decomps at the torch_dispatch level. - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode: - func(*args, **kwargs) - - if not run_all: - self.check_decomposed(aten_name, mode) - else: - assert op.supports_autograd - self.skipTest( - "only backwards is decomposed, but dtype doesn't support AD" - ) - - -instantiate_device_type_tests(TestDecomp, globals(), only_for="xpu", allow_xpu=True) - - -class DecompOneOffTests(TestCase): - @onlyNativeDeviceTypes - @skipIfCrossRef - def test_contiguous_softmax(self, device): - size = (2, 4, 3, 3) - stride = (9, 18, 3, 1) - dtype = torch.float32 - - x = torch.randn(size, dtype=dtype, device=device) - x = torch.as_strided(x, size, stride) - - ref = torch.ops.aten._softmax(x, -1, False) - res = torch._decomp.decompositions._softmax(x, -1, False) - self.assertEqual(ref.stride(), res.stride()) - - @onlyNativeDeviceTypes - @skipIfCrossRef - def test_contiguous_log_softmax(self, device): - size = (2, 4, 3, 3) - stride = (9, 18, 3, 1) - - dtype = torch.float32 - x = torch.randn(size, dtype=dtype, device=device) - x = torch.as_strided(x, size, stride) - - ref = torch.ops.aten._log_softmax(x, -1, False) - res = torch._decomp.decompositions._log_softmax(x, -1, False) - self.assertEqual(ref.stride(), res.stride()) - - def test_exponential_non_inf(self, device): - inp = torch.empty((4, 400, 256), device=device) - - with torch._dynamo.utils.preserve_rng_state(): - exp_ref = inp.exponential_() - exp = torch._refs.exponential(inp) - - self.assertEqual(exp, exp_ref) - self.assertFalse(exp.isinf().any()) - - @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") - @skipIfCrossRef - def test_amp_batch_norm_backward(self): - device = device_type - grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) - x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) - weight = torch.randn((2,), dtype=torch.float32, device=device) - rmean = torch.randn((2,), dtype=torch.float32, device=device) - rvar = torch.randn((2,), dtype=torch.float32, device=device) - mean = torch.randn((0,), dtype=torch.float32, device=device) - - ref = torch.ops.aten.native_batch_norm_backward( - grad_out, - x, - weight, - rmean, - rvar, - mean, - mean, - False, - 1e-05, - [True, True, True], - ) - res = torch._decomp.decompositions.native_batch_norm_backward( - grad_out, - x, - weight, - rmean, - rvar, - mean, - mean, - False, - 1e-05, - [True, True, True], - ) - for a, b in zip(ref, res): - self.assertEqual(a.stride(), b.stride()) - self.assertEqual(a.dtype, b.dtype) - - @onlyNativeDeviceTypes - @skipIfCrossRef - def test_elu_backward(self, device): - size = (2, 4, 3, 3) - dtype = torch.float32 - grad_out = torch.randn(size, dtype=dtype, device=device) - out = torch.randn(size, dtype=dtype, device=device) - - ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out) - res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out) - self.assertEqual(ref, res) - - @onlyNativeDeviceTypes - @skipIfCrossRef - def test_threshold_backward_dtype(self, device): - grad = torch.randint(10, (4,), device=device) - input_tensor = torch.randint(10, (4,), device=device) - - ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1) - res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1) - self.assertEqual(ref.dtype, res.dtype) - - @onlyNativeDeviceTypes - @skipIfCrossRef - def test_weight_norm_interface(self, device): - g = torch.randn((3, 10, 10), device=device) - v = torch.randn((1, 1, 10), device=device) - - ref = torch.ops.aten._weight_norm_interface(g, v, 2) - res = torch._decomp.decompositions._weight_norm_interface(g, v, 2) - self.assertTrue(torch.allclose(ref[0], res[0])) - self.assertTrue(torch.allclose(ref[1], res[1])) - - inp = torch.rand([30, 10], device=device) - inp2 = torch.rand([30, 1], device=device) - - self.assertEqual( - torch.ops.aten._weight_norm_interface(inp, inp2), - torch._decomp.decompositions._weight_norm_interface(inp, inp2), - ) - - @onlyCPU - @skipIfCrossRef - @skipOps( - "DecompOneOffTests", - "test_sdpa", - [ - xfail( - "nn.functional.scaled_dot_product_attention", - dtypes=[torch.half], - ), - ], - ) - @ops(_sdpa_op_info) - def test_sdpa(self, device, dtype, op): - # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we - # add support for float16 over there we should update this test as well. - query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) - key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) - value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) - masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)] - - atol, rtol = dtype_precisions[dtype] - - for mask in masks: - is_causal = mask is None - decomposed_res = ( - torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( - query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask - ) - ) - actual_res = decomposed_res[0] - # Output has form (N, H, L, E), but should be continuous on (L, N, H, E) - # in order for subsequent view(L * N, H * E) to be valid. - # So permute(2, 0, 1, 3) before checking that tensor is contiguous - self.assertTrue(actual_res.permute(2, 0, 1, 3).is_contiguous()) - - eager_res = op( - query_layer, - key_layer, - value_layer, - attn_mask=mask, - dropout_p=0.0, - is_causal=is_causal, - ) - - self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol)) - - @onlyCPU - def test_native_layer_norm_cpu_decomp(self, device): - def f(x, w, b): - return torch.ops.aten.native_layer_norm.default(x, [1, 2, 3], w, b, eps=0.5) - - x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu") - w = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu") - b = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu") - out_ref = f(x, w, b) - - from torch._subclasses.fake_tensor import FakeTensorMode - - with enable_python_dispatcher(), FakeTensorMode(): - x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu") - w = torch.randn( - 1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu" - ) - b = torch.randn( - 1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu" - ) - out = f(x, w, b) - - for o_ref, o in zip(out_ref, out): - self.assertEqual(o_ref.dtype, o.dtype) - - @unittest.skipIf(not SM70OrLater, "triton") - def test_rms_norm_decomp_accelerator(self, device): - @torch.compile - def rms_norm_sinh(a, b, c): - output = torch.nn.functional.rms_norm(a, b, c) - return torch.sinh(output) - - normalized_shape_arg = (3, 3, 3) - input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) - weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) - - def forward_pass_fn(): - return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) - - model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( - forward_pass_fn - ) - - # check RMSNorm was fused with sinh - self.assertTrue( - "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] - ) - self.assertTrue( - "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] - ) - - -instantiate_device_type_tests( - DecompOneOffTests, globals(), only_for="xpu", allow_xpu=True -) - - -class HasDecompTest(TestCase): - def setUp(self): - super().setUp() - self.maxDiff = None - - @staticmethod - def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool: - has_tensor_arg = any( - "Tensor" in str(a.type) - for a in itertools.chain(op._schema.arguments, op._schema.returns) - ) - if not has_tensor_arg: - return False - - try: - # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions - return not _is_cia_op(op) - except RuntimeError as e: - # has_key fails for some jit-registered ops, which shouldn't be - # relevant here anyway - if "does not exist" in str(e): - return False - raise - - def test_has_decomposition(self): - def all_aten_overloads(): - for name in torch._C._dispatch_get_all_op_names(): - if not name.startswith("aten::"): - continue - - name = name[6:] - if "." in name: - packet_name, overload_name = name.split(".") - else: - packet_name, overload_name = name, "default" - - packet = getattr(aten, packet_name) - assert isinstance(packet, torch._ops.OpOverloadPacket) - op = getattr(packet, overload_name) - yield op - - # This is for operators that are only registered in some CI - # configurations, so would cause the test to fail - allow_list = {aten.get_gradients.default} - - overloads_wanting_decomp = { - op for op in all_aten_overloads() if self._can_appear_in_trace(op) - } - ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys() - ops_missing_decomp -= allow_list - self.assertExpected( - "".join(sorted(op.name() + "\n" for op in ops_missing_decomp)) - ) - - def test_aten_core_operators(self): - # If a decomposition isn't included in the core decompositions, - # then it must decompose a core ATen operator. - # - # See NOTE [Core ATen Ops] - # - # If this test fails then either: - # - Add the decomposition to torch._decomp.core_aten_decompositions, - # if decomposition should be used by inductor (not a core operator). - # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of - # core ATen operators (and inductor will not use the decomposition). - - # Some decompositions are registered for CompositeImplicitAutograd - # operators, which never appear in AOTAutograd's graph so are never used. - useful_decomps = { - op - for op in decomposition_table.keys() - if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op) - } - core_decomps = torch._decomp.core_aten_decompositions().keys() - core_aten_ops = useful_decomps - core_decomps - self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops))) - - -if __name__ == "__main__": - run_tests() diff --git a/test/xpu/test_decomp_xpu.py b/test/xpu/test_decomp_xpu.py index 5c091f6f34..dd60368eef 100644 --- a/test/xpu/test_decomp_xpu.py +++ b/test/xpu/test_decomp_xpu.py @@ -6,23 +6,193 @@ # # http://www.apache.org/licenses/LICENSE-2.0 -# Owner(s): ["module: intel"] +# Owner(s): ["module: decompositions"] -import torch -from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import run_tests -try: - from xpu_test_utils import XPUPatchForImport -except Exception as e: - from .xpu_test_utils import XPUPatchForImport +import functools +import itertools +import re +import unittest +from collections import defaultdict +from functools import partial + +import torch._inductor.decomposition +import torch.autograd +from torch import Tensor +from torch._decomp import core_aten_decompositions, decomposition_table +from torch._dispatch.python import enable_python_dispatcher +from torch._export.utils import _is_cia_op +from torch._ops import DispatchKey +from torch.testing import make_tensor +from torch.testing._internal.common_cuda import SM70OrLater, tf32_off +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCPU, + onlyNativeDeviceTypes, + ops, +) +from torch.testing._internal.common_methods_invocations import ( + op_db, + skip, + skipOps, + xfail, +) +from torch.testing._internal.common_modules import module_db, modules +from torch.testing._internal.common_utils import ( + is_iterable_of_tensors, + run_tests, + skipIfCrossRef, + skipIfTorchDynamo, + suppress_warnings, + TEST_WITH_ASAN, + TEST_WITH_SLOW, + TestCase, + unMarkDynamoStrictTest, +) +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) + +aten = torch.ops.aten + + +# TODO: this isn't going to work with non-aten namespaces +def overload_to_aten_name(op): + return op._schema.name.split("::")[1] + + +# All operators that can have decomp tests +decomposition_names = { + overload_to_aten_name(k) + for k in decomposition_table + if isinstance(k, torch._ops.OpOverload) +} +core_decomposition_names = { + overload_to_aten_name(k) + for k in core_aten_decompositions() + if isinstance(k, torch._ops.OpOverload) and not _is_cia_op(k) +} +_decomp_test_ops = [ + op + for op in op_db + if op.aten_name in decomposition_names + or op.aten_backward_name in decomposition_names +] +_decomp_test_ops_core_autograd = [ + op + for op in op_db + if op.aten_name in core_decomposition_names and op.supports_autograd +] +_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name] + + +def diff_arg(arg, requires_grad=True): + def is_differentiable_arg(arg): + if requires_grad: + return arg.requires_grad + else: + return arg.is_floating_point() or arg.is_complex() + + if is_iterable_of_tensors(arg): + if all(is_differentiable_arg(a) for a in arg): + return True + if all(not is_differentiable_arg(a) for a in arg): + return False + raise RuntimeError("NYI: The test runner can't handle this") + return isinstance(arg, Tensor) and is_differentiable_arg(arg) + + +# Version of autograd.grad with some differences: +# - pytree inputs is allowed (but leaves of the pytree have to all +# be tensors) +# - if an input is not used as part of derivatives, we will return a +# zero-filled tensor for the result +def _autograd_grad( + outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True +): + inputs, inputs_spec = tree_flatten(inputs) + diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) + if grad_outputs is None: + diff_outputs = tuple(out for out in outputs if out.requires_grad) + else: + diff_grad_outputs = [ + (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad + ] + if len(diff_grad_outputs) == 0: + diff_outputs, grad_outputs = (), () + else: + diff_outputs, grad_outputs = zip(*diff_grad_outputs) + grad_inputs = torch.autograd.grad( + diff_outputs, + diff_inputs, + grad_outputs, + retain_graph=retain_graph, + create_graph=create_graph, + allow_unused=True, + ) + result = [] + grad_inputs_iter = iter(grad_inputs) + for inp in inputs: + if inp.requires_grad: + grad_input = next(grad_inputs_iter) + if grad_input is None: + result.append(torch.zeros_like(inp)) + else: + result.append(grad_input) + else: + result.append(torch.zeros_like(inp)) + return tree_unflatten(result, inputs_spec) + + +def _as_tuple(val): + if isinstance(val, tuple): + return val + return (val,) + + +def ref_vjp_no_create(f, *primals): + result = f(*primals) + + def wrapped(cotangents): + return _autograd_grad( + _as_tuple(result), + primals, + _as_tuple(cotangents), + create_graph=False, + retain_graph=True, + ) + + return result, wrapped -with XPUPatchForImport(False): - import test_decomp - from test_decomp import DecompOneOffTests, TestDecomp +dtype_precisions = { + torch.float16: (0.001, 1e-5), + torch.bfloat16: (0.016, 1e-4), + torch.float32: (1.3e-6, 1e-5), + torch.float64: (1e-7, 1e-7), + torch.complex32: (0.001, 1e-5), + torch.complex64: (1.3e-6, 1e-5), + torch.complex128: (1e-7, 1e-7), +} +# Returns the "default" rtol and atol for comparing scalars or +# tensors of the given dtypes. -def _op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): + +def _getDefaultRtolAndAtol(dtype0, dtype1): + rtol = max( + dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] + ) + atol = max( + dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] + ) + return rtol, atol + + +def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" if orig.numel() == 0 or decomp.numel() == 0: assert orig.numel() == decomp.numel() @@ -66,14 +236,16 @@ def _op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs (torch.float16, torch.ops.aten.mv.default): 1e-5, (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5, + (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7, + # XPU specific ( torch.float16, torch.ops.aten._batch_norm_with_update.default, - ): 2e-7, # adjust tolerance for xpu, so hook this func + ): 2e-7, # adjust tolerance for xpu ( torch.bfloat16, torch.ops.aten._batch_norm_with_update.default, - ): 2e-7, # adjust tolerance for xpu, so hook this func + ): 2e-7, # adjust tolerance for xpu } if ref.is_floating_point(): orig_diff = (orig - ref).abs().max() @@ -93,13 +265,1108 @@ def _op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs ) -test_decomp.op_assert_ref = _op_assert_ref +def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): + test_case.assertEqual( + orig.dtype, + decomp.dtype, + f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}", + ) + # Before adding an entry to this table, make sure your decomposition is right :) + tol_table = { + # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 + (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), + (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( + 1e-3, + 1e-3, + ), + (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6), + # This exceeds default tolerances only on CPU, on CUDA it's fine + (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5), + # Exceeds tolerances on CUDA, likely due to fma + (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5), + (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), + (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4), + (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4), + # The decomposition is TOO correct. It computes everything in int64, so sometimes + # there's an off-by-one error. See + # https://github.com/pytorch/pytorch/issues/81996 + # https://github.com/pytorch/pytorch/issues/82230 + (torch.int8, torch.ops.aten.linspace.default): (0, 1), + (torch.uint8, torch.ops.aten.linspace.default): (0, 1), + (torch.int16, torch.ops.aten.linspace.default): (0, 1), + (torch.int32, torch.ops.aten.linspace.default): (0, 1), + (torch.int64, torch.ops.aten.linspace.default): (0, 1), + (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1), + (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1), + (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1), + } + if (decomp.dtype, op) in tol_table: + rtol, atol = tol_table[(decomp.dtype, op)] + else: + rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) + test_case.assertEqual( + orig, + decomp, + rtol=rtol, + atol=atol, + msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}", + ) + + +# Given f, returns an f' such that: +# - f' takes only positional arguments +# - All arguments to f' are floating-point Tensors +# - All outputs of f' are floating-point Tensors +def normalize_op_input_output2( + f, args, kwargs, output_process_fn_grad=None, requires_grad=True +): + flat_args, args_spec = tree_flatten(args) + diff_argnums = tuple( + i + for i, arg in enumerate(flat_args) + if diff_arg(arg, requires_grad=requires_grad) + ) + assert len(diff_argnums) > 0 + primals = tuple(flat_args[i] for i in diff_argnums) + + @functools.wraps(f) + def wrapped(*primals): + _args = list(flat_args) + for num, arg in zip(diff_argnums, primals): + _args[num] = arg + _args = tree_unflatten(_args, args_spec) + result = f(*_args, **kwargs) + if output_process_fn_grad is not None: + result = output_process_fn_grad(result) + if isinstance(result, tuple): + # TODO We should check that the integer outputs also agree + result = tuple( + r + for r in result + if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) + ) + assert len(result) > 0 + return result + + return wrapped, primals + + +# NB: This also upcasts dtype arguments +# TODO: handle complex correctly +def upcast_tensor(x, dtype=torch.float32): + if isinstance(x, Tensor) and x.dtype.is_floating_point: + return x.to(dtype=dtype) + elif isinstance(x, torch.dtype) and x in [ + torch.float16, + torch.bfloat16, + torch.float, + ]: + return dtype + else: + return x + + +def normalize_op_input_output(f, sample, requires_grad=True): + args = tuple([sample.input] + list(sample.args)) + return normalize_op_input_output2( + f, + args, + sample.kwargs, + sample.output_process_fn_grad, + requires_grad=requires_grad, + ) + + +CROSS_REF_EXCLUDE_SET = { + # CUBLAS_STATUS_NOT_SUPPORTED when calling + # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, + # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, + # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, + # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` + ("cuda", torch.bfloat16, "nn.functional.bilinear"), + # randomness + (None, None, "special.ndtr"), # aten.special_ndtr was not decomposed + (None, None, "new_empty"), + (None, None, "empty_like"), + (None, None, "empty"), + # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default. + (None, None, "item"), + # It's the only in-place op without an out-of-place equivalent in the Python API + # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`. + (None, None, "zero_"), + # No idea what's going on here + # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), []) + # in the test, but it seems to pass when tested locally and in the logsumexp test + (None, torch.float32, "masked.logsumexp"), + (None, torch.float64, "masked.logsumexp"), + # exp_vml_cpu not implemented for Half + (torch.cpu, torch.float16, "signal.windows.exponential"), + (torch.cpu, torch.float16, "signal.windows.gaussian"), + # sin_vml_cpu not implemented for Half + (torch.cpu, torch.float16, "signal.windows.cosine"), + # CompositeAutogradImplicit + # See https://github.com/pytorch/pytorch/issues/81669 + (None, None, "nn.functional.relu6"), + # This decomp runs before autograd. + (None, None, "nn.functional.rrelu"), + (None, None, "meshgrid"), + # Decomposition registered as Autograd + (None, None, "nn.functional.hardshrink"), + (None, None, "nn.functional.softshrink"), + # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) + (None, None, "diag"), + # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 + ("cpu", torch.bfloat16, "_softmax_backward_data"), + (None, None, "norm"), + # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise) + (None, None, "native_batch_norm"), + (None, None, "_upsample_bilinear2d_aa"), + (None, None, "empty_strided"), # aten.empty_strided was not decomposed + ( + None, + None, + "bernoulli", + ), # bernoulli is a function of randomness, so couldn't do cross-reference. + # XPU specific exclude cases + # ("xpu", None, "some_xpu_specific_op"), +} + +CROSS_REF_BACKWARD_EXCLUDE_SET = { + # Decomposed backward formula is not as precise + ("cpu", torch.bfloat16, "nn.functional.hardswish"), + ("cuda", torch.float16, "nn.functional.cross_entropy"), + ( + None, + None, + "bernoulli", + ), # bernoulli is a function of randomness, so couldn't do cross-reference. + # XPU specific backward exclude cases + # ("xpu", torch.float16, "nn.functional.some_op"), +} + +all_decomposed = set() +all_called = defaultdict(int) + +# Helpful snippet for testing coverage +""" +import atexit +def check_coverage(): + print("missing coverage:") + print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) +atexit.register(check_coverage) +""" + +# Helpful snippet for Horace to create his google sheet :) +""" +import atexit +def dump_ops(): + with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: + for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): + f.write(f'{op.__name__}\n') + g.write(f'{count}\n') + with open('run_decompositions.txt', 'w') as f: + for op in sorted([i.__name__ for i in all_decomposed]): + f.write(f'{op}\n') + +atexit.register(dump_ops) +""" + + +def any_unsupported(args, kwargs): + def test_unsupported(t): + if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: + # These are all things that we haven't coded decompositions + # to handle correctly. Maybe they should. + return any( + [ + t.is_sparse_csr, + t.is_sparse, + t.is_mkldnn, + t.is_quantized, + t.is_nested, + torch._is_functional_tensor(t), + ] + ) + elif torch.overrides.is_tensor_like(t): + # Decompositions will generally change the behavior of Tensor-like + # subclasses, so bypass tests in this case too + return True + else: + return False + + flat_args = pytree.arg_tree_leaves(*args, **kwargs) + return any(test_unsupported(x) for x in flat_args) + + +core_backward_failures = { + skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs + xfail("addcdiv"), + skip("addcmul"), # slow: fails with --timeout=360 secs + skip("deg2rad"), # slow: fails with --timeout=360 secs + skip("diag_embed"), # slow: fails with --timeout=360 secs + skip("frac"), # slow: fails with --timeout=360 secs + skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs + xfail("lerp"), + skip("logaddexp"), # slow: fails with --timeout=360 secs + skip("native_dropout_backward"), # slow: fails with --timeout=360 secs + xfail("nn.functional.binary_cross_entropy_with_logits"), + skip("nn.functional.glu"), # slow: fails with --timeout=360 secs + xfail("nn.functional.hardshrink"), + xfail("nn.functional.softshrink"), + skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs + xfail("norm"), + xfail("norm", "fro"), + xfail("norm", "inf"), + xfail("norm", "nuc"), + skip("rad2deg"), # slow: fails with --timeout=360 secs + skip("renorm"), # slow: fails with --timeout=360 secs + skip("rot90"), # slow: fails with --timeout=360 secs + skip("rsub"), # slow: fails with --timeout=360 secs + skip("sgn"), # slow: fails with --timeout=360 secs + skip("special.xlog1py"), # slow: fails with --timeout=360 secs + xfail("stack"), + skip("tril"), # slow: fails with --timeout=360 secs + skip("triu"), # slow: fails with --timeout=360 secs + skip("unfold_copy"), # slow: fails with --timeout=360 secs + skip("xlogy"), # slow: fails with --timeout=360 secs + xfail("zero_"), +} +if not TEST_WITH_SLOW: + core_backward_failures.update( + { + skip("addr"), # slow: takes 46 sec on A100 + skip("baddbmm"), # slow: takes 800+ sec on A100 + skip("clamp_min"), # slow: takes 800 sec on A100 + skip("clamp_max"), # slow: takes 800 sec on A100 + skip("logit"), # slow: takes 44 sec on A100 + skip("nn.functional.hardswish"), # slow: takes 60 sec on A100 + skip("std_mean"), # slow: takes 170 sec on A100 + skip("split", variant_name="list_args"), # slow: takes 118 sec on A100 + skip("transpose"), # slow: takes 50 sec on A100 + skip("unbind"), # slow: takes 70 sec on A100 + skip("unsafe_split"), # slow: takes 49 sec on A100 + } + ) + +comprehensive_failures = { + xfail( + "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,) + ), # off by one error + xfail( + "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,) + ), # off by one error + xfail( + "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) + ), # off by one error +} + + +@unMarkDynamoStrictTest +class TestDecomp(TestCase): + longMessage = True + + # NB: This actually overlaps with test_comprehensive, but it only + # runs on things that are definitely decomposed so it's a lot faster + # to run + @onlyNativeDeviceTypes + @skipIfCrossRef + @suppress_warnings + @ops(_decomp_test_ops) + def test_quick(self, device, dtype, op): + self.do_cross_ref(device, dtype, op, run_all=False) + + @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures) + @onlyNativeDeviceTypes + @skipIfCrossRef + @suppress_warnings + @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,)) + def test_quick_core_backward(self, device, dtype, op): + test_keys = [ + (torch.device(device).type, dtype, op.name), + (None, dtype, op.name), + (None, None, op.name), + ] + if any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys): + self.skipTest(f"{op.name} in {dtype} not supported") + for sample_input in op.sample_inputs(device, dtype, requires_grad=True): + aten_name = op.decomp_aten_name or op.aten_name + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + func = partial(op.get_op(), **kwargs) + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all=False + ) as mode, + enable_python_dispatcher(), + ): + torch.autograd.gradcheck(func, args) + self.check_decomposed(aten_name, mode) + + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @onlyNativeDeviceTypes + @skipIfCrossRef + @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures) + @suppress_warnings + @ops(op_db) + def test_comprehensive(self, device, dtype, op): + self.do_cross_ref(device, dtype, op, run_all=True) + + def test_uniform(self, device): + size = (2, 3, 4, 5) + dtype = torch.float32 + x = make_tensor(size, dtype=dtype, device=device) + low = 0.3 + high = 0.9 + + torch.manual_seed(123) + ref = torch.ops.aten.uniform(x, low, high) + torch.manual_seed(123) + res = torch._decomp.decompositions.uniform(x, low=low, high=high) + self.assertEqual(ref, res) + + def test_bernoulli_default(self, device): + p = 0.3 + p_t = p * torch.ones(5, 5) + torch.manual_seed(123) + ref = torch.ops.aten.bernoulli.default(p_t) + torch.manual_seed(123) + res = torch._decomp.decompositions.bernoulli(p_t) + ref_p = ref.sum() / torch.prod(torch.tensor(ref.size())) + res_p = res.sum() / torch.prod(torch.tensor(res.size())) + self.assertEqual(ref_p, res_p, atol=0.06 * p, rtol=0.06) + + def test_broadcasting_index_copy(self, device): + x = torch.zeros([1, 10], device=device) + xs = torch.ones([2, 10], device=device) + + def index_copy(xs, x): + torch._decomp.decompositions.index_copy_( + xs, 0, torch.tensor(0).to(device), x + ) + + index_copy(xs, x) + + xs_two = torch.ones([2, 10], device=device) + xs_two[0] = x + + self.assertEqual(xs, xs_two) + + def test_cat_single_input(self, device): + decomp_table = torch._inductor.decomposition.select_decomp_table() + cat_inductor = decomp_table[torch.ops.aten.cat.default] + + inp = torch.rand([2048, 2048], device=device) + inps = [inp for _ in range(10)] + + for dim in (-1, 0, 1): + self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim)) + + @suppress_warnings + @tf32_off() + # only tests RNNs since we have py dispsatcher decomps for them + @modules( + filter( + lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), + module_db, + ) + ) + def test_rnn_decomp_module(self, device, dtype, module_info, training): + module_cls = module_info.module_cls + module_inputs = module_info.module_inputs_func( + module_info, + device=device, + dtype=dtype, + requires_grad=True, + training=training, + ) + for module_input in module_inputs: + if module_input.forward_input is None: + continue + args, kwargs = ( + module_input.constructor_input.args, + module_input.constructor_input.kwargs, + ) + m = module_cls(*args, **kwargs) + m.to(device).to(dtype) + + args, kwargs = ( + module_input.forward_input.args, + module_input.forward_input.kwargs, + ) + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all=True + ), + enable_python_dispatcher(), + ): + decomp_out = m(*args, **kwargs) + + non_decomp_out = m(*args, **kwargs) + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level + self.assertEqual(decomp_out, non_decomp_out) + + def test_batch_norm_unflatten_weight_bias(self, device): + # https://github.com/pytorch/pytorch/issues/100970 + shape = (1, 3, 2, 2) + input = torch.randn(shape, device=device) + weight = torch.randn((3, 1, 1, 1), device=device) + bias = torch.randn(3, device=device) + mean = torch.randn(3, device=device) + var = torch.randn(3, device=device) + res = torch._decomp.decompositions.native_batch_norm( + input, weight, bias, mean, var, False, 1, 1e-05 + ) + self.assertEqual(shape, res[0].shape) + + def test_arange_graph(self, device): + from torch.fx.experimental.proxy_tensor import make_fx + + def func(x, start): + le = x.shape[-1] + if start is None: + a = torch.arange(le, dtype=torch.float32, device=x.device) + else: + a = torch.arange(start, le, dtype=torch.float32, device=x.device) + return a + + pattern = r", device = device\(.+\), requires_grad = False" + + cfunc = make_fx(func, decomposition_table=decomposition_table) + fx_g = cfunc(torch.rand(10, device=device), None) + fx_g_code = fx_g.code.strip() + # Remove device and requires_grad + fx_g_code = re.sub(pattern, "", fx_g_code) + self.assertExpectedInline( + fx_g_code, + """\ +def forward(self, x_1, start_1): + iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64) + mul = torch.ops.prims.mul.default(iota, 1); iota = None + add = torch.ops.prims.add.default(mul, 0); mul = None + convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None + return convert_element_type""", + ) + + fx_g = cfunc(torch.rand(10, device=device), 1) + fx_g_code = fx_g.code.strip() + # Remove device and requires_grad + fx_g_code = re.sub(pattern, "", fx_g_code) + self.assertExpectedInline( + fx_g_code, + """\ +def forward(self, x_1, start_1): + iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64) + mul = torch.ops.prims.mul.default(iota, 1); iota = None + add = torch.ops.prims.add.default(mul, 1); mul = None + convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None + return convert_element_type""", + ) + + def test_masked_fill(self, device): + from torch.fx.experimental.proxy_tensor import make_fx + + if torch.device(device).type not in [ + "xpu", + "cuda", + torch._C._get_privateuse1_backend_name(), + ]: + self.skipTest("only runs on XPU, CUDA and PrivateUse1.") + + def func(scores, mask, value): + return scores.masked_fill(mask, value) + + scores_t = torch.tensor([1, 2, 3, 4], device=device) + mask_t = torch.tensor([True, True, True, True], device=device) + value_t = torch.tensor(0, dtype=scores_t.dtype) + cfunc = make_fx(func, decomposition_table=decomposition_table) + fx_g = cfunc(scores_t, mask_t, value_t) + self.assertExpectedInline( + fx_g.code.strip(), + """\ +def forward(self, scores_1, mask_1, value_1): + where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None + return where""", + ) + + class DecompCrossRefMode(TorchDispatchMode): + def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all): + self.test_case = test_case + self.saved_precision = saved_precision + self.saved_rel_tol = saved_rel_tol + self.test_dtype = dtype + self.run_all = run_all + + # We check the correctness of each decomposition right after running it. + # So, when we encounter a decomposition, we run the function normally, and + # then run the decomposition, and ensure they're identical. + self.called = set() + self.decomposed = set() + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + self.test_case.precision = self.saved_precision + self.test_case.rel_tol = self.saved_rel_tol + + self.called.add(func) + all_called[func] += 1 + + # Stuff we shouldn't bother testing + # (TODO: remove detach from the decomp table?) + # N.b. Testing in-place ops would need dedicated logic + in_place = func.name()[-1] == "_" + ignored_ops = [ + torch.ops.aten.detach.default, + # non-deterministic ops + torch.ops.aten.empty.memory_format, + torch.ops.aten.empty_like.default, + torch.ops.aten.new_empty.default, + torch.ops.aten.empty_strided.default, + torch.ops.aten.new_empty_strided.default, + torch.ops.aten.randn.default, + torch.ops.aten.native_dropout.default, + ] + if ( + func not in decomposition_table + or func in ignored_ops + or torch.Tag.nondeterministic_seeded in func.tags + or any_unsupported(args, kwargs) + or in_place + ): + return func(*args, **kwargs) + + self.decomposed.add(func) + all_decomposed.add(func) + + # We take 2 main strategies for verifying correctness/numerical stability of decompositions + # The first one is simply tolerance checking between decomp_out and pytorch_out + # However, for fp16/bf16 and reductions, this becomes very + # finicky, as there are not many guarantees we can make. + # So, for fp16/bf16, we instead compare the difference of + # {decomp_out, pytorch_out_64} and {pytorch_out, + # pytorch_out_64}. In other words, we compare how far the + # decomposition and pytorch are from the "ground truth" (i.e. + # fp64). If the decomposition results in more error, we error + + # We also decompose the decomposition recursively for + # further coverage, as some paths not be exercised directly by + # OpInfos (sadly) but just by other ops + + decomposition = decomposition_table[func] + + do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16] + if self.run_all: + # Execute recursively via DFS, to find the root of a possible error first + with self: + decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) + else: + decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs)) + + # At this stage we should not be decomposing an in-place op + # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops + # because decompositions are run after functionalisation and we would not like them to + # de-functionalise the graph, as that would break AoTAutograd + # We run the real function *after* the decomposition to make sure that the + # decomposition does not modify any of the inputs in-place. If it does + # real_out should be different than decom_out so we should catch this + real_out_unflat = func(*args, **kwargs) + real_out = pytree.tree_leaves(real_out_unflat) + + assert len(real_out) == len(decomp_out) + + if do_relative_check: + device_arg = kwargs.get("device", None) + + def upcast(x): + if (isinstance(x, Tensor) and x.device.type == "mps") or ( + device_arg and torch.device(device_arg).type == "mps" + ): + return upcast_tensor(x, dtype=torch.float32) + else: + return upcast_tensor(x, dtype=torch.float64) + + real_out_double, _ = tree_flatten( + func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) + ) + for i, (orig, decomp, ref) in enumerate( + zip(real_out, decomp_out, real_out_double) + ): + if not isinstance(orig, torch.Tensor): + assert type(orig) == type(decomp) + assert orig == decomp + continue + op_assert_ref( + self.test_case, + func, + self.test_dtype, + i, + orig, + decomp, + ref, + args, + kwargs, + ) + else: + for orig, decomp in zip(real_out, decomp_out): + if not isinstance(orig, torch.Tensor): + assert type(orig) == type(decomp) + assert orig == decomp + continue + op_assert_equal( + self.test_case, + func, + self.test_dtype, + orig, + decomp, + args, + kwargs, + ) + + return real_out_unflat + + def check_decomposed(self, aten_name, mode): + self.assertTrue( + any(overload_to_aten_name(c) == aten_name for c in mode.decomposed), + msg=( + f"aten.{aten_name} was not decomposed, saw calls for: " + f"{', '.join(map(str, list(mode.called)))}. If your op is " + f"CompositeImplicitAutograd you should skip this test " + f"by updating CROSS_REF_EXCLUDE_SET." + ), + ) + + @skipIfTorchDynamo("Test does not work with TorchDynamo") + def do_cross_ref(self, device, dtype, op, *, run_all): + test_keys = [ + (torch.device(device).type, dtype, op.name), + (None, dtype, op.name), + (None, None, op.name), + ] + if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys): + self.skipTest(f"{op.name} in {dtype} not supported") + + skip_decomp_vjp = any( + key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys + ) + + requires_grad = ( + op.supports_autograd + and dtype in op.supported_backward_dtypes(torch.device(device).type) + # TODO: OpInfo really ought to error out for this case, but it's + # not exercised in test_ops_gradients atm. The problem is not + # complex32 per-se (which is supported by data movement only ops) + # but that when we do backwards we expect other ops like add to work + and not dtype == torch.complex32 + ) + samples = op.sample_inputs(device, dtype, requires_grad=requires_grad) + + aten_name = op.decomp_aten_name or op.aten_name + + func = op.get_op() + + def run_without_python_dispatcher(mode): + return any( + isinstance(op, torch._ops.OpOverload) + and op.has_kernel_for_dispatch_key( + DispatchKey.CompositeImplicitAutograd + ) + for op in mode.decomposed.union([func]) + ) + + for sample_input in samples: + if requires_grad: + fn, primals = normalize_op_input_output(func, sample_input) + primals = tree_map( + lambda x: x if isinstance(x, torch.Tensor) else x, primals + ) + + # Once https://github.com/pytorch/pytorch/pull/75965/ I can + # store the called list on the mode object instance and no + # explicit clearing is necessary as I will create a fresh mode + # for each region + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode, + enable_python_dispatcher(), + ): + decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode: + decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) + if aten_name in decomposition_names: + self.check_decomposed(aten_name, mode) + + if not skip_decomp_vjp and ( + op.aten_backward_name in decomposition_names or run_all + ): + cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) + + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode, + enable_python_dispatcher(), + ): + decomp_vjp_fn(cotangents) + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode: + decomp_vjp_fn(cotangents) + if not run_all: + self.check_decomposed(op.aten_backward_name, mode) + + elif aten_name in decomposition_names or run_all: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + # A failure here might be because the decomposition for the op is wrong or because a + # decomposition used by the particular op is wrong. + with ( + self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode, + enable_python_dispatcher(), + ): + func(*args, **kwargs) + + if run_without_python_dispatcher(mode): + # without this check, incorrect decomps at the python dispatcher level can still pass because + # they're checking aten decomps at the torch_dispatch level. + with self.DecompCrossRefMode( + self, self.precision, self.rel_tol, dtype, run_all + ) as mode: + func(*args, **kwargs) + + if not run_all: + self.check_decomposed(aten_name, mode) + else: + assert op.supports_autograd + self.skipTest( + "only backwards is decomposed, but dtype doesn't support AD" + ) + instantiate_device_type_tests(TestDecomp, globals(), only_for="xpu", allow_xpu=True) + + +class DecompOneOffTests(TestCase): + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_contiguous_softmax(self, device): + size = (2, 4, 3, 3) + stride = (9, 18, 3, 1) + dtype = torch.float32 + + x = torch.randn(size, dtype=dtype, device=device) + x = torch.as_strided(x, size, stride) + + ref = torch.ops.aten._softmax(x, -1, False) + res = torch._decomp.decompositions._softmax(x, -1, False) + self.assertEqual(ref.stride(), res.stride()) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_contiguous_log_softmax(self, device): + size = (2, 4, 3, 3) + stride = (9, 18, 3, 1) + + dtype = torch.float32 + x = torch.randn(size, dtype=dtype, device=device) + x = torch.as_strided(x, size, stride) + + ref = torch.ops.aten._log_softmax(x, -1, False) + res = torch._decomp.decompositions._log_softmax(x, -1, False) + self.assertEqual(ref.stride(), res.stride()) + + def test_exponential_non_inf(self, device): + inp = torch.empty((4, 400, 256), device=device) + + with torch._dynamo.utils.preserve_rng_state(): + exp_ref = inp.exponential_() + exp = torch._refs.exponential(inp) + + self.assertEqual(exp, exp_ref) + self.assertFalse(exp.isinf().any()) + + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @skipIfCrossRef + def test_amp_batch_norm_backward(self): + device = device_type + grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) + x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device) + weight = torch.randn((2,), dtype=torch.float32, device=device) + rmean = torch.randn((2,), dtype=torch.float32, device=device) + rvar = torch.randn((2,), dtype=torch.float32, device=device) + mean = torch.randn((0,), dtype=torch.float32, device=device) + + ref = torch.ops.aten.native_batch_norm_backward( + grad_out, + x, + weight, + rmean, + rvar, + mean, + mean, + False, + 1e-05, + [True, True, True], + ) + res = torch._decomp.decompositions.native_batch_norm_backward( + grad_out, + x, + weight, + rmean, + rvar, + mean, + mean, + False, + 1e-05, + [True, True, True], + ) + for a, b in zip(ref, res): + self.assertEqual(a.stride(), b.stride()) + self.assertEqual(a.dtype, b.dtype) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_elu_backward(self, device): + size = (2, 4, 3, 3) + dtype = torch.float32 + grad_out = torch.randn(size, dtype=dtype, device=device) + out = torch.randn(size, dtype=dtype, device=device) + + ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out) + res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out) + self.assertEqual(ref, res) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_threshold_backward_dtype(self, device): + grad = torch.randint(10, (4,), device=device) + input_tensor = torch.randint(10, (4,), device=device) + + ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1) + res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1) + self.assertEqual(ref.dtype, res.dtype) + + @onlyNativeDeviceTypes + @skipIfCrossRef + def test_weight_norm_interface(self, device): + g = torch.randn((3, 10, 10), device=device) + v = torch.randn((1, 1, 10), device=device) + + ref = torch.ops.aten._weight_norm_interface(g, v, 2) + res = torch._decomp.decompositions._weight_norm_interface(g, v, 2) + self.assertTrue(torch.allclose(ref[0], res[0])) + self.assertTrue(torch.allclose(ref[1], res[1])) + + inp = torch.rand([30, 10], device=device) + inp2 = torch.rand([30, 1], device=device) + + self.assertEqual( + torch.ops.aten._weight_norm_interface(inp, inp2), + torch._decomp.decompositions._weight_norm_interface(inp, inp2), + ) + + @onlyCPU + @skipIfCrossRef + @skipOps( + "DecompOneOffTests", + "test_sdpa", + [ + xfail( + "nn.functional.scaled_dot_product_attention", + dtypes=[torch.half], + ), + ], + ) + @ops(_sdpa_op_info) + def test_sdpa(self, device, dtype, op): + # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we + # add support for float16 over there we should update this test as well. + query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) + key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) + value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype) + masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)] + + atol, rtol = dtype_precisions[dtype] + + for mask in masks: + is_causal = mask is None + decomposed_res = ( + torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu( + query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask + ) + ) + actual_res = decomposed_res[0] + # Output has form (N, H, L, E), but should be continuous on (L, N, H, E) + # in order for subsequent view(L * N, H * E) to be valid. + # So permute(2, 0, 1, 3) before checking that tensor is contiguous + self.assertTrue(actual_res.permute(2, 0, 1, 3).is_contiguous()) + + eager_res = op( + query_layer, + key_layer, + value_layer, + attn_mask=mask, + dropout_p=0.0, + is_causal=is_causal, + ) + + self.assertTrue(torch.allclose(actual_res, eager_res, atol=atol, rtol=rtol)) + + @onlyCPU + def test_native_layer_norm_cpu_decomp(self, device): + def f(x, w, b): + return torch.ops.aten.native_layer_norm.default(x, [1, 2, 3], w, b, eps=0.5) + + x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu") + w = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu") + b = torch.randn(1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu") + out_ref = f(x, w, b) + + from torch._subclasses.fake_tensor import FakeTensorMode + + with enable_python_dispatcher(), FakeTensorMode(): + x = torch.randn(1, 2, 3, dtype=torch.bfloat16, device="cpu") + w = torch.randn( + 1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu" + ) + b = torch.randn( + 1, 2, 3, dtype=torch.bfloat16, requires_grad=True, device="cpu" + ) + out = f(x, w, b) + + for o_ref, o in zip(out_ref, out): + self.assertEqual(o_ref.dtype, o.dtype) + + @unittest.skipIf(not SM70OrLater, "triton") + def test_rms_norm_decomp_accelerator(self, device): + @torch.compile + def rms_norm_sinh(a, b, c): + output = torch.nn.functional.rms_norm(a, b, c) + return torch.sinh(output) + + normalized_shape_arg = (3, 3, 3) + input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + + def forward_pass_fn(): + return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) + + model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( + forward_pass_fn + ) + + # check RMSNorm was fused with sinh + self.assertTrue( + "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] + ) + self.assertTrue( + "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + ) + + instantiate_device_type_tests( DecompOneOffTests, globals(), only_for="xpu", allow_xpu=True ) +class HasDecompTest(TestCase): + def setUp(self): + super().setUp() + self.maxDiff = None + + @staticmethod + def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool: + has_tensor_arg = any( + "Tensor" in str(a.type) + for a in itertools.chain(op._schema.arguments, op._schema.returns) + ) + if not has_tensor_arg: + return False + + try: + # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions + return not _is_cia_op(op) + except RuntimeError as e: + # has_key fails for some jit-registered ops, which shouldn't be + # relevant here anyway + if "does not exist" in str(e): + return False + raise + + def test_has_decomposition(self): + def all_aten_overloads(): + for name in torch._C._dispatch_get_all_op_names(): + if not name.startswith("aten::"): + continue + + name = name[6:] + if "." in name: + packet_name, overload_name = name.split(".") + else: + packet_name, overload_name = name, "default" + + packet = getattr(aten, packet_name) + assert isinstance(packet, torch._ops.OpOverloadPacket) + op = getattr(packet, overload_name) + yield op + + # This is for operators that are only registered in some CI + # configurations, so would cause the test to fail + allow_list = {aten.get_gradients.default} + + overloads_wanting_decomp = { + op for op in all_aten_overloads() if self._can_appear_in_trace(op) + } + ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys() + ops_missing_decomp -= allow_list + self.assertExpected( + "".join(sorted(op.name() + "\n" for op in ops_missing_decomp)) + ) + + def test_aten_core_operators(self): + # If a decomposition isn't included in the core decompositions, + # then it must decompose a core ATen operator. + # + # See NOTE [Core ATen Ops] + # + # If this test fails then either: + # - Add the decomposition to torch._decomp.core_aten_decompositions, + # if decomposition should be used by inductor (not a core operator). + # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of + # core ATen operators (and inductor will not use the decomposition). + + # Some decompositions are registered for CompositeImplicitAutograd + # operators, which never appear in AOTAutograd's graph so are never used. + useful_decomps = { + op + for op in decomposition_table.keys() + if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op) + } + core_decomps = torch._decomp.core_aten_decompositions().keys() + core_aten_ops = useful_decomps - core_decomps + self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops))) + + if __name__ == "__main__": run_tests() diff --git a/test/xpu/windows_skip_cases.py b/test/xpu/windows_skip_dict.py similarity index 90% rename from test/xpu/windows_skip_cases.py rename to test/xpu/windows_skip_dict.py index c9fadd550c..531056cd8c 100644 --- a/test/xpu/windows_skip_cases.py +++ b/test/xpu/windows_skip_dict.py @@ -13,8 +13,8 @@ skip_dict = { # Windows: Skip entire files using *.py:: pattern - "test_decomp": [ - "test_decomp.py::", # Skip entire file on Windows + "test_decomp_xpu.py": [ + "test_decomp_xpu.py::", # Skip entire file on Windows ], # Files where Windows only needs to skip specific tests (will merge with Linux defaults) # "test_linalg": [ diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 7f863ebc4f..9c9188f151 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -1105,7 +1105,7 @@ def copy_tests( def launch_test(test_case, skip_list=None, exe_list=None): os.environ["PYTORCH_TEST_WITH_SLOW"] = "1" module_name = test_case.replace(".py", "").replace("/", ".").replace("\\", ".") - if skip_list is not None: + if skip_list and len(skip_list) > 0: skip_options = ' -k "not ' + skip_list[0] for skip_case in skip_list[1:]: skip_option = " and not " + skip_case @@ -1115,7 +1115,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): f"pytest --junit-xml=./op_ut_with_skip.{module_name}.xml " + test_case ) test_command += skip_options - elif exe_list is not None: + elif exe_list and len(exe_list) > 0: exe_options = ' -k "' + exe_list[0] for exe_case in exe_list[1:]: exe_option = " or " + exe_case diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index a3281791de..11bb54535f 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -7922,6 +7922,17 @@ SparseCsrXPU: angle_sparse_csr_out tags: pointwise +- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) + variants: function + dispatch: + XPU: _fill_mem_eff_dropout_mask_ + tags: nondeterministic_seeded + +- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + dispatch: + XPU: _scaled_dot_product_efficient_attention_xpu + tags: nondeterministic_seeded + - func: special_airy_ai(Tensor x) -> Tensor python_module: special structured_delegate: special_airy_ai.out