diff --git a/.github/workflows/build_zoom_backend.yml b/.github/workflows/build_zoom_backend.yml new file mode 100644 index 0000000000000..8550d088aeb63 --- /dev/null +++ b/.github/workflows/build_zoom_backend.yml @@ -0,0 +1,126 @@ +name: "Build PyTorch" + +on: + workflow_dispatch: + inputs: + force_debug_with_tmate: + type: boolean + description: 'Run the build with tmate session' + required: false + default: false + debug_with_tmate: + type: boolean + description: 'Run the build with a tmate session ONLY in case of failure' + required: false + default: false + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + build: + + strategy: + fail-fast: false + matrix: + include: + - name: "ubuntu-22.04" + runs-on: "mi300" + # container: "rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0" + # runs-on: "nod-ai-shared-cpubuilder-manylinux-x86_64" + + runs-on: ${{ matrix.runs-on }} + + name: ${{ matrix.name }} + + env: + CACHE_DIR: ${{ github.workspace }}/.container-cache + # either the PR number or `branch-N` where N always increments + CACHE_KEY: linux-build-test-cpp-asserts-manylinux-v2-${{ format('{0}-{1}', github.ref_name, github.run_number) }} + + defaults: + run: + shell: bash + + permissions: + id-token: write + contents: write + + container: + image: ${{ matrix.container }} + + steps: + - name: "Check out repository" + uses: actions/checkout@v4.2.2 + with: + submodules: recursive + + - name: Enable cache + uses: actions/cache/restore@v3 + with: + path: ${{ env.CACHE_DIR }} + key: ${{ env.CACHE_KEY }} + restore-keys: linux-build-test-cpp- + + - name: "Build PyTorch" + id: build + run: | + + export CCACHE_DIR="${{ env.CACHE_DIR }}" + export CMAKE_C_COMPILER_LAUNCHER=ccache + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export CCACHE_SLOPPINESS=include_file_ctime,include_file_mtime,time_macros + + python -m venv venv + source venv/bin/activate + pip install -r requirements.txt + chmod +x ./build.sh + ./build.sh + + - name: "Audit" + id: audit + run: | + + sudo apt install patchelf + python -m venv venv + source venv/bin/activate + pip install auditwheel + auditwheel repair -w dist --plat manylinux_2_39_x86_64 dist/torch* + + - name: Save cache + uses: actions/cache/save@v3 + if: ${{ !cancelled() }} + with: + path: ${{ env.CACHE_DIR }} + key: ${{ env.CACHE_KEY }} + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.name }}_artifact + path: dist + if-no-files-found: warn + + - name: Release current commit + uses: ncipollo/release-action@v1.12.0 + with: + artifacts: "dist/torch*.whl" + token: "${{ secrets.GITHUB_TOKEN }}" + tag: "latest" + name: "latest" + removeArtifacts: false + allowUpdates: true + replacesArtifacts: true + makeLatest: true + + - name: "Setup tmate session" + if: ${{ (failure() && inputs.debug_with_tmate) || inputs.force_debug_with_tmate }} + uses: mxschmitt/action-tmate@v3.18 + with: + limit-access-to-actor: true + install-dependencies: ${{ startsWith(matrix.runs-on, 'macos') || startsWith(matrix.runs-on, 'windows') }} \ No newline at end of file diff --git a/BUILD.bazel b/BUILD.bazel index 3f7e6327452c0..c30d8c3df9232 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -9,7 +9,7 @@ load("@pytorch//tools/config:defs.bzl", "if_cuda") load("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops") load(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets") load(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources") -load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources") +load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources", "aten_ufunc_generated_zoom_sources") load("//:tools/bazel.bzl", "rules") define_targets(rules = rules) @@ -104,6 +104,12 @@ generated_cuda_cpp = [ "aten/src/ATen/RegisterSparseCsrCUDA.cpp", ] +generated_zoom_cpp = [ + "aten/src/ATen/ZoomFunctions.h", + "aten/src/ATen/ZoomFunctions_inl.h", + "aten/src/ATen/RegisterPrivateUse1.cpp", +] + generate_aten( name = "generated_aten_cpp", srcs = aten_generation_srcs, @@ -112,7 +118,8 @@ generate_aten( generated_cuda_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") + aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") + - aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [ + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + + aten_ufunc_generated_zoom_sources("aten/src/ATen/{}") + [ "aten/src/ATen/Declarations.yaml", ] ), diff --git a/BuildingZoom.md b/BuildingZoom.md new file mode 100644 index 0000000000000..66918e2316281 --- /dev/null +++ b/BuildingZoom.md @@ -0,0 +1,136 @@ +# Setup Python Env + +To start out, we just need to follow the normal procedure to build PyTorch from source. For convenience I've included these steps here: + +```bash +conda create -n nod-pytorch python==3.10 +conda activate nod-pytorch +conda install cmake ninja +pip install -r requirements.txt +export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +python setup.py develop +``` + +# CMake Build + +Using the `USE_ZOOM` flag with CMake will enable building with HIP for ROCm without requiring any of the "HIPify" scripts in order to build. This will include HIP libraries and populate `torch.version.hip` appropriately. This flag is NOT yet entered into the `setup.py` script, so for now it needs to be added manually via `cmake` or `ccmake`. + +You'll need to set the `ROCM_PATH` and `HIP_ROOT_DIR` environment variables appropriately, by default on linux these should be `/opt/rocm/` and `/opt/rocm/hip` respectively. + +If you're running on Linux you can just use `build.sh` to build: +```bash +cd pytorch/ +source build.sh +``` + +Alternatively, if you want to manually setup your CMake build you can use the following commands: + +```bash +cd build/ +export PYTORCH_ROCM_ARCH=gfx90a +export ROCM_PATH=/opt/rocm +export HIP_ROOT_DIR=/opt/rocm/hip +cmake -DUSE_ZOOM=ON --build . --target install +``` + +# Running PyTorch with Zoom + +Programs using the zoom backend must be prefaced with this stub until we register a proper dispatch key in pytorch + +```python +import torch +import torch.zoom +torch.utils.rename_privateuse1_backend('zoom') +torch.utils.generate_methods_for_privateuse1_backend(unsupported_dtype=None) +``` + +# Installing Triton + +Since main Triton currently treats ROCm as if its masquerading as `torch.cuda`, we need a custom installation: + +```bash +git clone https://github.com/123epsilon/triton.git +cd triton/ +git checkout zoom +pip install pybind11 +pip install python/ +``` + +# Running LLama3 with Triton using LigerKernels and HuggingFace + +```bash +pip install liger-kernel +``` + +```python +# Run Llama 3 +import torch +from transformers import AutoTokenizer +from liger_kernel.transformers import AutoLigerKernelForCausalLM +from time import perf_counter as pf +torch.utils.rename_privateuse1_backend('zoom') + +# Set up the model and tokenizer +model_id = "meta-llama/Meta-Llama-3-8B" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoLigerKernelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="zoom" +) + +# Function to generate text +def generate_text(prompt, max_length=30): + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + outputs = model.generate(**inputs, max_new_tokens=max_length) + return tokenizer.decode(outputs[0], skip_special_tokens=True) + +# Example usage +prompt = "Hey, how are you doing today?" +s = pf() +response = generate_text(prompt) +e = pf() +print(f"Prompt: {prompt}") +print(f"Response: {response}") + +print(f"{e-s} seconds") +``` + +```python +# Or run the instruct-tuned variant +import torch +import transformers +from liger_kernel.transformers import apply_liger_kernel_to_llama +torch.utils.rename_privateuse1_backend('zoom') + +apply_liger_kernel_to_llama() +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + +pipeline = transformers.pipeline( + "text-generation", + model=model_id, + model_kwargs={"torch_dtype": torch.bfloat16}, + device_map="zoom", +) + +messages = [ + {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, + {"role": "user", "content": "Who are you?"}, +] + +terminators = [ + pipeline.tokenizer.eos_token_id, + pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") +] + +outputs = pipeline( + messages, + max_new_tokens=30, + eos_token_id=terminators, + do_sample=True, + temperature=0.6, + top_p=0.9, +) +print(outputs[0]["generated_text"][-1]) + +``` \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c6320e68d390..adfd8510e6ad0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,7 @@ option(USE_CPP_CODE_COVERAGE "Compile C/C++ with code coverage flags" OFF) option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) +option(USE_ZOOM "Use ZOOM HIP Backend" OFF) option(USE_CUDA "Use CUDA" ON) cmake_dependent_option( USE_XPU "Use XPU. Only available on Linux." ON @@ -231,12 +232,14 @@ option(USE_MAGMA "Use MAGMA" ON) option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF) option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF) option(USE_NATIVE_ARCH "Use -march=native" OFF) +option(ENABLE_ZOOM_BLAS "Use HIPBlas Kernels in the ZOOM backend" ON) +option(DISABLE_HIPBLASLT "Disable HIPBlasLt Kernels in the ZOOM backend" OFF) cmake_dependent_option( USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) cmake_dependent_option( USE_NCCL "Use NCCL" ON - "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) + "USE_CUDA OR USE_ROCM OR USE_ZOOM;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option( diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index bda6aea327062..f1753f50c32fd 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -30,10 +30,13 @@ set(ATen_CUDA_SRCS_W_SORT_BY_KEY) set(ATen_CUDA_TEST_SRCS) set(ATen_CUDA_INCLUDE) set(ATen_NVRTC_STUB_SRCS) +set(ATen_HIPRTC_STUB_SRCS) set(ATen_HIP_SRCS) +set(ATen_ZOOM_SRCS) set(ATen_HIP_SRCS_W_SORT_BY_KEY) set(ATen_HIP_TEST_SRCS) set(ATen_HIP_INCLUDE) +set(ATen_ZOOM_INCLUDE) set(ATen_MPS_SRCS) set(ATen_MPS_TEST_SRCS) set(ATen_XPU_SRCS) @@ -44,6 +47,7 @@ set(ATen_CPU_DEPENDENCY_LIBS) set(ATen_XPU_DEPENDENCY_LIBS) set(ATen_CUDA_DEPENDENCY_LIBS) set(ATen_HIP_DEPENDENCY_LIBS) +set(ATen_ZOOM_DEPENDENCY_LIBS) set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS) set(ATen_PUBLIC_HIP_DEPENDENCY_LIBS) set(ATEN_INSTALL_BIN_SUBDIR "bin" CACHE PATH "ATen install binary subdirectory") @@ -70,6 +74,17 @@ if(USE_ROCM) endif() endif() +if(USE_ZOOM) + include(LoadHIP) + if(NOT PYTORCH_FOUND_HIP) + message(WARNING "Could not load HIP, setting USE_ZOOM = OFF") + set(USE_ZOOM OFF) + else() + message(STATUS "Loaded HIP, Zoom Enabled") + endif() +endif() + + # Both CUDA and ROCM are enabled and found. Report an error. if(USE_CUDA AND USE_ROCM) message(FATAL_ERROR "Both CUDA and ROCm are enabled and found. PyTorch can only be built with either of them. Please turn one off by using either USE_CUDA=OFF or USE_ROCM=OFF.") @@ -116,12 +131,14 @@ set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) +set(ATen_ZOOM_SRCS ${ATen_ZOOM_SRCS} PARENT_SCOPE) set(ATen_MPS_SRCS ${ATen_MPS_SRCS} PARENT_SCOPE) set(ATen_MPS_TEST_SRCS ${ATen_MPS_TEST_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS_W_SORT_BY_KEY ${ATen_HIP_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) set(ATen_XPU_TEST_SRCS ${ATen_XPU_TEST_SRCS} PARENT_SCOPE) set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) +set(ATen_HIPRTC_STUB_SRCS ${ATen_HIPRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE) set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE) set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE) @@ -132,12 +149,14 @@ set(ATen_VEC_TEST_SRCS ${ATen_VEC_TEST_SRCS} PARENT_SCOPE) set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE) set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE) +set(ATen_ZOOM_INCLUDE ${ATen_ZOOM_INCLUDE} PARENT_SCOPE) set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_ZOOM_DEPENDENCY_LIBS ${ATen_ZOOM_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE) set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) diff --git a/aten/src/ATen/AccumulateType.cpp b/aten/src/ATen/AccumulateType.cpp index c4623cc08629c..55952a6c8ff91 100644 --- a/aten/src/ATen/AccumulateType.cpp +++ b/aten/src/ATen/AccumulateType.cpp @@ -2,17 +2,20 @@ namespace at { +// TODO(Arham): exchange keys c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) { switch (type) { -#define DEFINE_CASE(scalar_t, TypeNum) \ - case ScalarType::TypeNum: \ - switch (device) { \ - case DeviceType::CUDA: \ - return CppTypeToScalarType>::value; \ - case DeviceType::MPS: \ - return CppTypeToScalarType>::value; \ - default: \ - return CppTypeToScalarType>::value; \ +#define DEFINE_CASE(scalar_t, TypeNum) \ + case ScalarType::TypeNum: \ + switch (device) { \ + case DeviceType::CUDA: \ + return CppTypeToScalarType>::value; \ + case DeviceType::PrivateUse1: \ + return CppTypeToScalarType>::value; \ + case DeviceType::MPS: \ + return CppTypeToScalarType>::value; \ + default: \ + return CppTypeToScalarType>::value; \ } AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(DEFINE_CASE) @@ -23,7 +26,12 @@ c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) { } c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) { - return is_cuda ? toAccumulateType(type, c10::DeviceType::CUDA) : toAccumulateType(type, c10::DeviceType::CPU); + #ifndef USE_ZOOM + return is_cuda ? toAccumulateType(type, c10::DeviceType::CUDA) : toAccumulateType(type, c10::DeviceType::CPU); + #else + // TODO(Arham): exchange keys + return is_cuda ? toAccumulateType(type, c10::DeviceType::PrivateUse1) : toAccumulateType(type, c10::DeviceType::CPU); + #endif } } diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h index 0275ef099b03d..1cdd2423c050a 100644 --- a/aten/src/ATen/AccumulateType.h +++ b/aten/src/ATen/AccumulateType.h @@ -67,7 +67,12 @@ struct AccumulateType { template struct AccumulateType { - using type = typename AccumulateTypeDevice::type; + #ifndef USE_ZOOM + using type = typename AccumulateTypeDevice::type; + #else + // TODO(Arham): exchange keys + using type = typename AccumulateTypeDevice::type; + #endif }; template @@ -83,6 +88,8 @@ using acc_type = typename AccumulateType::type; }; #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS) #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) +// TODO(Arham): exchange keys +#define ZOOM_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::PrivateUse1) #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) MPS_ACC_TYPE(BFloat16, float); @@ -126,6 +133,28 @@ CUDA_ACC_TYPE(c10::complex, c10::complex); CUDA_ACC_TYPE(c10::complex, c10::complex); CUDA_ACC_TYPE(c10::complex, c10::complex); +#if defined(__HIPCC__) +ZOOM_ACC_TYPE(half, float); +#endif +ZOOM_ACC_TYPE(BFloat16, float); +ZOOM_ACC_TYPE(Half, float); +ZOOM_ACC_TYPE(Float8_e5m2, float); +ZOOM_ACC_TYPE(Float8_e4m3fn, float); +ZOOM_ACC_TYPE(Float8_e5m2fnuz, float); +ZOOM_ACC_TYPE(Float8_e4m3fnuz, float); +ZOOM_ACC_TYPE(float, float); +ZOOM_ACC_TYPE(double, double); +ZOOM_ACC_TYPE(int8_t, int64_t); +ZOOM_ACC_TYPE(uint8_t, int64_t); +ZOOM_ACC_TYPE(char, int64_t); +ZOOM_ACC_TYPE(int16_t, int64_t); +ZOOM_ACC_TYPE(int32_t, int64_t); +ZOOM_ACC_TYPE(int64_t, int64_t); +ZOOM_ACC_TYPE(bool, bool); +ZOOM_ACC_TYPE(c10::complex, c10::complex); +ZOOM_ACC_TYPE(c10::complex, c10::complex); +ZOOM_ACC_TYPE(c10::complex, c10::complex); + CPU_ACC_TYPE(BFloat16, float); CPU_ACC_TYPE(Half, float); CPU_ACC_TYPE(Float8_e5m2, float); diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9ec458fda45e4..38b94f40408b2 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -82,6 +82,12 @@ file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp") file(GLOB miopen_h "miopen/*.h") file(GLOB miopen_cpp "miopen/*.cpp") +file(GLOB zoom_h "zoom/*.h" "zoom/detail/*.h" "zoom/*.cuh" "zoom/detail/*.cuh" "zoom/tunable/*.h" "zoom/jit/*.cuh" "zoom/jit/*.h") +file(GLOB zoom_cpp "zoom/*.cpp" "zoom/detail/*.cpp" "zoom/tunable/*.cpp" "zoom/jit/*.cpp") +file(GLOB zoom_hip "zoom/*.cu" "zoom/detail/*.cu") +file(GLOB zoom_hiprtc_stub_h "zoom/hiprtc_stub/*.h") +file(GLOB zoom_hiprtc_stub_cpp "zoom/hiprtc_stub/*.cpp") + file(GLOB mkl_cpp "mkl/*.cpp") file(GLOB mkldnn_cpp "mkldnn/*.cpp") @@ -166,6 +172,13 @@ file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp") file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp") file(GLOB native_utils_cpp "native/utils/*.cpp") +file(GLOB native_zoom_hip "native/zoom/*.cu") +file(GLOB native_zoom_hip_h "native/zoom/*.cuh") +file(GLOB native_zoom_cpp "native/zoom/*.cpp") +file(GLOB native_zoom_linalg_cpp "native/zoom/linalg/*.cpp") +file(GLOB native_sparse_zoom_hip "native/sparse/zoom/*.cu") +file(GLOB native_sparse_zoom_cpp "native/sparse/zoom/*.cpp") + # flash_attention sources file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu") file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu") @@ -342,6 +355,26 @@ if(USE_ROCM) ) endif() +if(USE_ZOOM) + list(APPEND ATen_ZOOM_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/zoom) + list(APPEND ATen_ZOOM_SRCS + ${ATen_ZOOM_SRCS} + ${zoom_hip} + ${native_zoom_hip} + ${native_zoom_hip_h} + ${native_sparse_zoom_hip} + ) + list(APPEND all_zoom_cpp + ${native_sparse_zoom_cpp} + ${zoom_cpp} + ${native_zoom_cpp} + ${native_zoom_linalg_cpp} + ${zoom_generated_sources} + ${ATen_ZOOM_SRCS} + ${all_zoom_cpp} + ) +endif() + if(USE_XPU) list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/xpu) list(APPEND ATen_XPU_SRCS ${xpu_cpp}) @@ -546,6 +579,7 @@ endif() # Include CPU paths for CUDA/HIP as well list(APPEND ATen_CUDA_INCLUDE ${ATen_CPU_INCLUDE}) list(APPEND ATen_HIP_INCLUDE ${ATen_CPU_INCLUDE}) +list(APPEND ATen_ZOOM_INCLUDE ${ATen_CPU_INCLUDE}) list(APPEND ATen_VULKAN_INCLUDE ${ATen_CPU_INCLUDE}) # We have two libraries: libATen_cpu.so and libATen_cuda.so, @@ -576,6 +610,12 @@ if(USE_ROCM) # list(APPEND ATen_HIP_DEPENDENCY_LIBS ATEN_CUDA_FILES_GEN_LIB) endif() +if(USE_ZOOM) + set(ATen_ZOOM_SRCS ${all_zoom_cpp}) + set(ATen_HIPRTC_STUB_SRCS ${zoom_hiprtc_stub_cpp}) + list(APPEND ATen_ZOOM_DEPENDENCY_LIBS ATEN_ZOOM_FILES_GEN_LIB) +endif() + set(ATEN_INCLUDE_DIR "${CMAKE_INSTALL_PREFIX}/${AT_INSTALL_INCLUDE_DIR}") configure_file(ATenConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake") install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" @@ -583,7 +623,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) if(NOT INTERN_BUILD_MOBILE) - list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h}) + list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${zoom_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h}) # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) @@ -611,7 +651,7 @@ foreach(HEADER ${INSTALL_HEADERS}) endforeach() # TODO: Install hip_generated_headers when we have it -foreach(HEADER ${generated_headers} ${cuda_generated_headers}) +foreach(HEADER ${generated_headers} ${cuda_generated_headers} ${zoom_generated_headers}) # NB: Assumed to be flat install(FILES ${HEADER} DESTINATION ${AT_INSTALL_INCLUDE_DIR}/ATen) endforeach() @@ -652,7 +692,10 @@ set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE) set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) +set(ATen_HIPRTC_STUB_SRCS ${ATen_HIPRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) +set(ATen_ZOOM_SRCS ${ATen_ZOOM_SRCS} PARENT_SCOPE) +set(ATen_HIPRTC_STUB_SRCS ${ATen_HIPRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_MPS_SRCS ${ATen_MPS_SRCS} PARENT_SCOPE) set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) set(ATen_QUANTIZED_SRCS ${ATen_QUANTIZED_SRCS} PARENT_SCOPE) @@ -671,12 +714,14 @@ set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE) set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE) +set(ATen_ZOOM_INCLUDE ${ATen_ZOOM_INCLUDE} PARENT_SCOPE) set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_ZOOM_DEPENDENCY_LIBS ${ATen_ZOOM_DEPENDENCY_LIBS} PARENT_SCOPE) set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) set(ATen_ATTENTION_KERNEL_SRCS ${ATen_ATTENTION_KERNEL_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 7fd191ef3f38c..20679ab7ff5af 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -171,6 +171,7 @@ void Context::alertCuBLASConfigNotDeterministic() const { return; } + #ifndef USE_ZOOM auto msg = c10::str( "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ", "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ", @@ -180,6 +181,16 @@ void Context::alertCuBLASConfigNotDeterministic() const { cublas_config_var_name, "=", cublas_deterministic_configs[1], ". For more information, go to ", "https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility" ); + #else + auto msg = c10::str( + "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ", + "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ", + "it uses hipBLAS and you have atomic operations enabled. To enable deterministic behavior in this ", + "case, you must set an environment variable before running your PyTorch application: ", + "ROCBLAS_DEFAULT_ATOMICS_MODE = 0. For more information, go to ", + "https://github.com/ROCm/rocBLAS/blob/develop/docs/how-to/what-is-rocblas.rst#bitwise-reproducibility" + ); + #endif if (deterministicAlgorithmsWarnOnly()) { TORCH_WARN(msg); diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a922bcd5922fc..8a9d96eaae528 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -126,6 +127,9 @@ class TORCH_API Context { static bool hasCuBLASLt() { return detail::getCUDAHooks().hasCuBLASLt(); } + static bool hasZoom() { + return detail::getZoomHooks().hasROCM(); + } static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } @@ -163,14 +167,18 @@ class TORCH_API Context { } void lazyInitPrivateUse1() { c10::call_once(thp_init, [&] { - if (isPrivateUse1HooksRegistered()) { - at::GetPrivateUse1HooksInterface()->initPrivateUse1(); - } + // if (isPrivateUse1HooksRegistered()) { + // at::GetPrivateUse1HooksInterface()->initPrivateUse1(); + // } + detail::getZoomHooks().initPrivateUse1(); }); } static const at::cuda::NVRTC& getNVRTC() { return detail::getCUDAHooks().nvrtc(); } + static const at::zoom::HIPRTC& getHIPRTC() { + return detail::getZoomHooks().hiprtc(); + } static bool setFlushDenormal(bool on); diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index 1eb5c070b547c..8891874b764cc 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -18,6 +18,8 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) { // To properly support this, see https://github.com/pytorch/pytorch/issues/14560 if (at::globalContext().hasCUDA()) { return at::detail::getCUDAHooks().getPinnedMemoryAllocator(); + } else if (at::globalContext().hasZoom()) { + return at::detail::getZoomHooks().getPinnedMemoryAllocator(); } else if (at::globalContext().hasXPU()) { return at::detail::getXPUHooks().getPinnedMemoryAllocator(); } else if(at::isPrivateUse1HooksRegistered()) { diff --git a/aten/src/ATen/TensorIndexing.cpp b/aten/src/ATen/TensorIndexing.cpp index bd50282b46ec6..128298522d48f 100644 --- a/aten/src/ATen/TensorIndexing.cpp +++ b/aten/src/ATen/TensorIndexing.cpp @@ -50,9 +50,10 @@ static inline void set_item(const Tensor& self, ArrayRef indices, c at::Device self_device = self.device(); // TODO: This qint special case looks very suspicious... + // TODO(Arham): exchange keys if (isQIntType(self.scalar_type())) { value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU)); - } else if (self_device.is_cuda()) { + } else if (self_device.is_cuda() || self_device.is_privateuseone()) { value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU)); } else { value = at::indexing::scalarToTensor(v, self.options(), self_device); diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 2d01bdeca500b..8219fafb037b9 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -202,6 +202,44 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } +// TODO(Arham): exchange keys +TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) { + // lower_precision_fp +#define _KERNEL_ZOOM_LOW_PRECISION_FP(...) \ + KERNEL_ZOOM(__VA_ARGS__, lower_precision_fp) + + AT_FORALL_LOWER_PRECISION_FP(_KERNEL_ZOOM_LOW_PRECISION_FP) + + // fp32 +#define _KERNEL_ZOOM_FP32(...) KERNEL_ZOOM(__VA_ARGS__, fp32) + + AT_FORALL_FP32(_KERNEL_ZOOM_FP32) + + // fp32_set_opt_dtype +#define _KERNEL_ZOOM_FP32_SET_OPT_DTYPE(...) \ + KERNEL_ZOOM(__VA_ARGS__, fp32_set_opt_dtype) + + AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_ZOOM_FP32_SET_OPT_DTYPE) + + // fp32_append_dtype + // The fp32_append_dtype wrapper overrides implicit promotion behavior. + // norm does not implicitly promote, but be aware when adding new ops to this policy. + AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE( + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_ZOOM) + + // promote +#define _KERNEL_ZOOM_PROMOTE(...) KERNEL_ZOOM(__VA_ARGS__, promote) + + AT_FORALL_PROMOTE(_KERNEL_ZOOM_PROMOTE) + + m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), + TORCH_FN((&at::autocast::binary_cross_entropy_banned))); +} + TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index c36030db5b048..2f897715d03b6 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -708,6 +708,25 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REDISPATCH_SIGNATURE, \ POLICY) +// KERNEL_ZOOM/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_ZOOM +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastZOOM +// TODO(Arham): exchange keys +#define KERNEL_ZOOM(...) KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_ZOOM( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::PrivateUse1, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + // KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU #define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__) diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp index b801eb2fa5211..2bac27929b9e1 100644 --- a/aten/src/ATen/core/VariableFallbackKernel.cpp +++ b/aten/src/ATen/core/VariableFallbackKernel.cpp @@ -66,6 +66,11 @@ TORCH_LIBRARY_IMPL(_, AutogradCUDA, m) { m.fallback(AUTOGRAD_FALLBACK); } +// TODO(Arham): replace with zoom key +TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) { + m.fallback(AUTOGRAD_FALLBACK); +} + TORCH_LIBRARY_IMPL(_, AutogradXLA, m) { m.fallback(AUTOGRAD_FALLBACK); } diff --git a/aten/src/ATen/detail/ZoomHooksInterface.cpp b/aten/src/ATen/detail/ZoomHooksInterface.cpp new file mode 100644 index 0000000000000..f23de3c899c16 --- /dev/null +++ b/aten/src/ATen/detail/ZoomHooksInterface.cpp @@ -0,0 +1,48 @@ +#include + +#include + +#include + +namespace at { +namespace detail { + +// NB: We purposely leak the CUDA hooks object. This is because under some +// situations, we may need to reference the CUDA hooks while running destructors +// of objects which were constructed *prior* to the first invocation of +// getZoomHooks. The example which precipitated this change was the fused +// kernel cache in the JIT. The kernel cache is a global variable which caches +// both CPU and CUDA kernels; CUDA kernels must interact with CUDA hooks on +// destruction. Because the kernel cache handles CPU kernels too, it can be +// constructed before we initialize CUDA; if it contains CUDA kernels at program +// destruction time, you will destruct the CUDA kernels after CUDA hooks has +// been unloaded. In principle, we could have also fixed the kernel cache store +// CUDA kernels in a separate global variable, but this solution is much +// simpler. +// +// CUDAHooks doesn't actually contain any data, so leaking it is very benign; +// you're probably losing only a word (the vptr in the allocated object.) +static ZoomHooksInterface* zoom_hooks = nullptr; + +// init and register extension hooks +void initZoomHooks() { + static c10::once_flag once; + c10::call_once(once, [] { + zoom_hooks = PrivateUse1HooksRegistry()->Create("ZoomHooks", ZoomHooksArgs{}).release(); + if (!zoom_hooks) { + zoom_hooks = new ZoomHooksInterface(); + } + RegisterPrivateUse1HooksInterface(zoom_hooks); + }); +} + +const ZoomHooksInterface& getZoomHooks() { + initZoomHooks(); + return *zoom_hooks; +} + +} // namespace detail + +C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, ZoomHooksInterface, ZoomHooksArgs) + +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/detail/ZoomHooksInterface.h b/aten/src/ATen/detail/ZoomHooksInterface.h new file mode 100644 index 0000000000000..02bdd94ff1dad --- /dev/null +++ b/aten/src/ATen/detail/ZoomHooksInterface.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include +#include + +#include + +// Forward-declares at::Generator and at::zoom::NVRTC +namespace at { +struct Generator; +namespace zoom { +struct HIPRTC; +} // namespace zoom +} // namespace at + +// NB: Class must live in `at` due to limitations of Registry.h. +namespace at { + +// #ifdef _MSC_VER +// constexpr const char* ZOOM_HELP = +// "PyTorch splits its backend into two shared libraries: a CPU library " +// "and a CUDA library; this error has occurred because you are trying " +// "to use some CUDA functionality, but the CUDA library has not been " +// "loaded by the dynamic linker for some reason. The CUDA library MUST " +// "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! " +// "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ " +// "in your link arguments; many dynamic linkers will delete dynamic library " +// "dependencies if you don't depend on any of their symbols. You can check " +// "if this has occurred by using link on your binary to see if there is a " +// "dependency on *_cuda.dll library."; +// #else +constexpr const char* ZOOM_HELP = + "PyTorch splits its backend into two shared libraries: a CPU library " + "and a ZOOM library; this error has occurred because you are trying " + "to use some ZOOM functionality, but the ZOOM library has not been " + "loaded by the dynamic linker for some reason. The ZOOM library MUST " + "be loaded, EVEN IF you don't directly use any symbols from the ZOOM library! " + "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many " + "dynamic linkers will delete dynamic library dependencies if you don't " + "depend on any of their symbols. You can check if this has occurred by " + "using ldd on your binary to see if there is a dependency on *_cuda.so " + "library."; +// #endif + +// The ZoomHooksInterface is an omnibus interface for any ZOOM functionality +// which we may want to call into from CPU code (and thus must be dynamically +// dispatched, to allow for separate compilation of ZOOM code). How do I +// decide if a function should live in this class? There are two tests: +// +// 1. Does the *implementation* of this function require linking against +// ZOOM libraries? +// +// 2. Is this function *called* from non-ZOOM ATen code? +// +// (2) should filter out many ostensible use-cases, since many times a ZOOM +// function provided by ATen is only really ever used by actual ZOOM code. +// +// TODO: Consider putting the stub definitions in another class, so that one +// never forgets to implement each virtual function in the real implementation +// in ZOOMHooks. This probably doesn't buy us much though. +struct TORCH_API ZoomHooksInterface : PrivateUse1HooksInterface { + // This should never actually be implemented, but it is used to + // squelch -Werror=non-virtual-dtor + virtual ~ZoomHooksInterface() override = default; + + // Initialize THCState and, transitively, the ZOOM state + virtual void initZoom() const { + TORCH_CHECK(false, "Cannot initialize ZOOM without torch_zoom library. ", ZOOM_HELP); + } + + virtual void initPrivateUse1() const override { + initZoom(); + } + + virtual const Generator& getDefaultZoomGenerator(C10_UNUSED DeviceIndex device_index = -1) const { + TORCH_CHECK(false, "Cannot get default ZOOM generator without torch_zoom library. ", ZOOM_HELP); + } + + virtual const Generator& getDefaultGenerator(DeviceIndex device_index) override { return getDefaultZoomGenerator(device_index); }; + + virtual Device getDeviceFromPtr(void* /*data*/) const override { + TORCH_CHECK(false, "Cannot get device of pointer on ZOOM without torch_zoom library. ", ZOOM_HELP); + } + + virtual bool isPinnedPtr(const void* /*data*/) const { + return false; + } + + virtual bool hasROCM() const { + return false; + } + + virtual const at::zoom::HIPRTC& hiprtc() const { + TORCH_CHECK(false, "HIPRTC requires Zoom. ", ZOOM_HELP); + } + + virtual bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without torch_zoom library. ", ZOOM_HELP); + } + + virtual DeviceIndex current_device() const { + return -1; + } + + virtual Allocator* getPinnedMemoryAllocator() const override { + TORCH_CHECK(false, "Pinned memory requires ZOOM. ", ZOOM_HELP); + } + + virtual Allocator* getZoomDeviceAllocator() const { + TORCH_CHECK(false, "ZoomDeviceAllocator requires ZOOM. ", ZOOM_HELP); + } + + virtual std::string showConfig() const { + TORCH_CHECK(false, "Cannot query detailed ZOOM version without torch_zoom library. ", ZOOM_HELP); + } + + virtual int getNumGPUs() const { + return 0; + } + + virtual void deviceSynchronize(DeviceIndex /*device_index*/) const { + TORCH_CHECK(false, "Cannot synchronize ZOOM device without torch_zoom library. ", ZOOM_HELP); + } +}; + +// NB: dummy argument to suppress "ISO C++11 requires at least one argument +// for the "..." in a variadic macro" +struct TORCH_API ZoomHooksArgs {}; + +TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, ZoomHooksInterface, ZoomHooksArgs); +#define REGISTER_PRIVATEUSE1_HOOKS(clsname) \ + C10_REGISTER_CLASS(PrivateUse1HooksRegistry, clsname, clsname) + +namespace detail { +TORCH_API void initZoomHooks(); +TORCH_API const ZoomHooksInterface& getZoomHooks(); +} // namespace detail +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index a0141f974923e..be525a961d9d6 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -787,6 +787,18 @@ Tensor log_sigmoid_backward_cuda(const Tensor& grad_output, const Tensor& input, return iter.output(); } +Tensor log_sigmoid_backward_zoom(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) { + auto grad_input = at::empty_like(grad_output); + // NOTE: buffer is only used by CPU dispatch, we just ignore it here + auto iter = at::TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(input) + .add_const_input(grad_output) + .build(); + log_sigmoid_backward_stub(kPrivateUse1, iter); + return iter.output(); +} + Tensor log_sigmoid_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) { auto grad_input = at::empty_like(grad_output); auto iter = at::TensorIteratorConfig() @@ -810,6 +822,17 @@ Tensor& log_sigmoid_backward_cuda_out(const Tensor& grad_output, const Tensor& i return grad_input; } +Tensor& log_sigmoid_backward_zoom_out(const Tensor& grad_output, const Tensor& input, + const Tensor& buffer, Tensor& grad_input) { +auto iter = TensorIteratorConfig() +.add_output(grad_input) +.add_const_input(input) +.add_const_input(grad_output) +.build(); +log_sigmoid_backward_stub(kPrivateUse1, iter); +return grad_input; +} + Tensor& log_sigmoid_backward_cpu_out(const Tensor& grad_output, const Tensor& input, const Tensor& buffer, diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index ecedc73579d66..f9ad9a4f725a0 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -598,7 +598,8 @@ struct ConvParams { // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of // a depthwise multiplier) bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const { - return input.is_cuda() && + //TODO(Arham): exchange keys + return (input.is_cuda() || input.is_privateuseone()) && !transposed && (input.ndimension() == 4 || input.ndimension() == 5) && at::symint::size(input, 1) == groups && @@ -1254,7 +1255,7 @@ ConvBackend _select_conv_backend( !params.is_dilated()) { // fast path for grouped conv3d return ConvBackend::Slow3d; - } else if (input.device().is_cpu() || input.is_cuda()) { + } else if (input.device().is_cpu() || input.is_cuda() || input.is_privateuseone()) { // backends without support for groups if (params.transposed) { if (input.ndimension() == 4) { @@ -1277,7 +1278,7 @@ ConvBackend _select_conv_backend( return ConvBackend::Slow2d; } } - } else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) { + } else if (input.ndimension() == 5 && (input.is_cuda() || input.is_privateuseone() || params.is_dilated())) { return ConvBackend::SlowDilated3d; } else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */ /* CPU implementation has specialized MM kernels @@ -1767,14 +1768,14 @@ std::tuple _convolution_double_backward( const std::option Tensor ggO; if (input.numel() != 0) { if (ggI.defined()) { - if (weight.is_cuda()) { + if (weight.is_cuda() || weight.is_privateuseone()) { weight = weight.contiguous(); } ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups); } if (ggW.defined()) { - if (ggW.is_cuda()) { + if (ggW.is_cuda() || ggW.is_privateuseone()) { ggW = ggW.contiguous(); } auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups); @@ -1826,7 +1827,7 @@ std::tuple _convolution_double_backward( const std::option if (input.numel() != 0) { if (groups == 1) { - if (gOt.is_cuda()) { + if (gOt.is_cuda() || gOt.is_privateuseone()) { gOt = gOt.contiguous(); } // Compute conv @@ -1841,7 +1842,7 @@ std::tuple _convolution_double_backward( const std::option for (const auto g : c10::irange(groups)) { auto ggIt_g = subvariable(ggIt, 0, groups, g); auto gOt_g = subvariable(gOt, 0, groups, g); - if (gOt_g.is_cuda()) { + if (gOt_g.is_cuda() || gOt_g.is_privateuseone()) { gOt_g = gOt_g.contiguous(); } @@ -1883,7 +1884,7 @@ std::tuple _convolution_double_backward( const std::option gi_conv_params.transposed = !params.transposed; if (params.transposed) { - if (gO.is_cuda()) { + if (gO.is_cuda() || gO.is_privateuseone()) { gO = gO.contiguous(); } gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups); @@ -1917,7 +1918,7 @@ std::tuple _convolution_double_backward( const std::option } } - if (gO.is_cuda()) { + if (gO.is_cuda() || gO.is_privateuseone()) { gO = gO.contiguous(); } diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index c5f81e98906dd..416a607d5c262 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -130,7 +130,8 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) { // (e.g. XLA) may be supported by overriding copy_ and _copy_from. bool is_supported_device(Device device) { DeviceType device_type = device.type(); - return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS; + // TODO(Arham): exchange keys + return device_type == kPrivateUse1 || device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS; } } // namespace @@ -288,6 +289,9 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) } else if (iter.device_type(1) == kMPS) { device_type = kMPS; } + else if (iter.device_type(1) == kPrivateUse1) { + device_type = kPrivateUse1; + } // TODO: if we need to, we can also enable this path for quantized tensor if (device_type == kCPU && copy_transpose_valid(self, src) && !self.is_quantized()) { diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index 942461c7612c1..06e8f4b4fc091 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -101,9 +101,10 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, std // See Note [cdist relies on cdist_impl redispatching] // Keep this condition in sync with the condition at the Note + // TODO(Arham): replace keys below if (!(p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25))))) { - TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1); - TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2); + TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kPrivateUse1, "cdist only supports CPU, CUDA, and HIP devices, X1 got: ", device1); + TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kPrivateUse1, "cdist only supports CPU, CUDA, and HIP devices, X2 got: ", device2); } auto dim1 = x1.dim(); @@ -228,9 +229,10 @@ Tensor _cdist_backward(const Tensor& _grad, const Tensor& _x1, const Tensor& _x2 int64_t n = x1.size(-2); int64_t m = x1.size(-1); auto device1 = x1.device().type(); - TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X1 got: ", device1); + //TODO(Arham): exchange keys below + TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kPrivateUse1, "_cdist_backward only supports CPU, CUDA, and HIP devices, X1 got: ", device1); auto device2 = x2.device().type(); - TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2); + TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kPrivateUse1, "_cdist_backward only supports CPU, CUDA, and HIP devices, X2 got: ", device2); Tensor grad_x1 = at::empty({batch_product, n, m}, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); @@ -244,7 +246,8 @@ Tensor _cdist_backward(const Tensor& _grad, const Tensor& _x1, const Tensor& _x2 Tensor _pdist_forward(const Tensor& self, const double p) { TORCH_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input"); auto device = self.device().type(); - TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device); + // TODO(Arham): exchange keys below + TORCH_CHECK(device == kCPU || device == kCUDA || device == kPrivateUse1, "_pdist_forward only supports CPU, CUDA, and HIP devices, got: ", device); Tensor result = at::empty({0}, self.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); if (self.size(0) <= 1) { result.resize_({0}); @@ -265,7 +268,8 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c TORCH_CHECK(self.is_contiguous(), "_pdist_backward requires self to be contiguous"); TORCH_CHECK(pdist.is_contiguous(), "_pdist_backward requires pdist to be contiguous"); auto device = self.device().type(); - TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device); + // TODO(Arham): exchange keys below + TORCH_CHECK(device == kCPU || device == kCUDA || device == kPrivateUse1, "_pdist_backward only supports CPU, CUDA, and HIP devices, got: ", device); Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); pdist_backward_stub(device, result, grad, self, p, pdist); return result; diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h index 2c334157eba9f..5476777d72f8a 100644 --- a/aten/src/ATen/native/Distributions.h +++ b/aten/src/ATen/native/Distributions.h @@ -17,7 +17,11 @@ #define compat_abs c10::cuda::compat::abs #define compat_log1p c10::cuda::compat::log1p #elif defined(__HIPCC__) -#include + #ifdef USE_ZOOM + #include + #else + #include + #endif #define compat_exp c10::hip::compat::exp #define compat_ceil c10::hip::compat::ceil #define compat_floor c10::hip::compat::floor diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index 5b7167ee93dd2..9cdf5df112d71 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -11,8 +11,13 @@ #include #include #elif defined(__HIPCC__) -#include -#include + #ifdef USE_ZOOM + #include + #include + #else + #include + #include + #endif #endif #if defined(__CUDACC__) || defined(__HIPCC__) #include @@ -56,7 +61,11 @@ inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) { #include #define compat_pow c10::cuda::compat::pow #elif defined(__HIPCC__) -#include +#ifdef USE_ZOOM + #include + #else + #include + #endif #define compat_pow c10::hip::compat::pow #else #define compat_pow std::pow diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index 3188479b931f3..fd2e8e282ad1d 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -452,7 +452,7 @@ static Tensor softmax(const Tensor& input_, const int64_t dim_) { Tensor softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ + if ((input_.is_cuda() || input_.is_privateuseone()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ return at::_softmax(input_, dim_, true); } else { Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_; @@ -469,7 +469,7 @@ Tensor& softmax_out( std::optional dtype, Tensor& output_) { Tensor output_temp; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && + if ((input_.is_cuda() || input_.is_privateuseone()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float) { if (!output_.is_contiguous()) { auto options = @@ -517,7 +517,7 @@ static Tensor log_softmax(const Tensor& input_, const int64_t dim_) { Tensor log_softmax(const Tensor& input_, const int64_t dim_, std::optional dtype) { auto result = [&]() { NoNamesGuard guard; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ + if ((input_.is_cuda() || input_.is_privateuseone()) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ return at::_log_softmax(input_, dim_, true); } else { Tensor converted = dtype.has_value()? input_.toType(dtype.value()) : input_; @@ -534,7 +534,7 @@ Tensor& log_softmax_out( std::optional dtype, Tensor& output_) { Tensor output_temp; - if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && + if (((input_.is_cuda() || input_.is_privateuseone())) && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float) { if (!output_.is_contiguous()) { auto options = diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 395af8e5ef139..12c37abd37dec 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -660,7 +660,9 @@ Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const b // See note [Writing Nondeterministic Operations] // Nondeterministic when index contains duplicate entries and we do not accumulate // If we accumulate on GPU, we use atomicGPUAdd, which is non-deterministic - if (!accumulate || (accumulate && self.device().type() == DeviceType::CUDA)) { + // TODO(Arham): replace PU1 with Zoom key + bool non_deterministic_device = self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::PrivateUse1; + if (!accumulate || (accumulate && non_deterministic_device)) { at::globalContext().alertNotDeterministic("put_"); } @@ -735,7 +737,9 @@ Tensor & _index_put_impl_(Tensor & self, const torch::List at::assert_no_overlap(self, *index); } } - if (self.device().type() == DeviceType::CUDA && (accumulate || globalContext().deterministicAlgorithms())) { + // TODO(Arham): replace PU1 with Zoom key + bool non_deterministic_device = ( self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::PrivateUse1 ); + if ( non_deterministic_device && (accumulate || globalContext().deterministicAlgorithms())) { TORCH_CHECK(value_.device() == self.device(), "expected device ", self.device(), " but got device ", value_.device(), " for value tensor"); index_put_with_sort_stub(self.device().type(), self, indices, value_, accumulate, unsafe); @@ -797,7 +801,8 @@ TORCH_IMPL_FUNC(index_copy_out) if (!result.is_same(self)) result.copy_(self); // See Note [Enabling Deterministic Operations] - if (result.is_cuda() && globalContext().deterministicAlgorithms()){ + // TODO(Arham): exchange keys + if ((result.is_cuda() || result.is_privateuseone()) && globalContext().deterministicAlgorithms()){ torch::List> indices; indices.reserve(dim + 1); for (const auto i: c10::irange(dim)) { @@ -1732,7 +1737,8 @@ void scatter_impl( if (index.numel() == 0) return; auto op = ReductionType::SUM; - bool deterministic = globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA; + // TODO(Arham): replace PU1 with Zoom key + bool deterministic = globalContext().deterministicAlgorithms() && (self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::PrivateUse1); if (reduce.has_value()) { op = get_operator_enum(reduce.value(), use_new_options); @@ -1826,7 +1832,8 @@ TORCH_IMPL_FUNC(scatter_add) // See Note [Enabling Deterministic Operations] // Avoid gpuAtomicAdd for CUDA if deterministic mode is turned on - if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA) { + // TODO(Arham): replace PU1 with Zoom key + if (globalContext().deterministicAlgorithms() && (self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::PrivateUse1)) { _scatter_via_index_put(self, dim, index, src, mut_out, /*accumulate*/true); } else { if (can_use_expanded_index_path(mut_out, dim, index, src, /*is_scatter_like*/true)) { diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 974ad302ca0c8..7233665684236 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -585,8 +585,9 @@ std::tuple mode(const Tensor& self, int64_t dim, bool keepdim) { std::tuple mode_out(const Tensor& self, int64_t dim, bool keepdim, Tensor& values, Tensor& indices) { - TORCH_CHECK(self.device().is_cpu() || self.is_cuda(), - "mode only supports CPU AND CUDA device type, got: ", self.device().type()); + // TODO(Arham): exchange keys + TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_privateuseone(), + "mode only supports CPU, CUDA, and Zoom device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "mode only supports strided layout, got: ", self.layout()); TORCH_CHECK(self.device() == values.device(), diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 10d8b1ad79cad..d4ad8a63126b3 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -169,11 +169,13 @@ dispatch: CPU: _assert_async_cpu CUDA: _assert_async_cuda + PrivateUse1: _assert_async_zoom - func: _assert_async.msg(Tensor self, str assert_msg) -> () dispatch: CPU: _assert_async_msg_cpu CUDA: _assert_async_msg_cuda + PrivateUse1: _assert_async_msg_zoom - func: _assert_scalar(Scalar self, str assert_msg) -> () dispatch: @@ -271,6 +273,7 @@ variants: function dispatch: CUDA: fused_dropout_cuda + PrivateUse1: fused_dropout_zoom tags: nondeterministic_seeded autogen: _fused_dropout.out @@ -278,6 +281,7 @@ variants: function dispatch: CUDA: masked_scale_cuda + PrivateUse1: masked_scale_zoom autogen: _masked_scale.out - func: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) @@ -285,6 +289,7 @@ dispatch: CPU: native_dropout_cpu CUDA: native_dropout_cuda + PrivateUse1: native_dropout_zoom NestedTensorCPU, NestedTensorCUDA: native_dropout_nested tags: [nondeterministic_seeded, core] autogen: native_dropout.out @@ -293,6 +298,7 @@ dispatch: CPU, NestedTensorCPU, NestedTensorCUDA: native_dropout_backward CUDA: native_dropout_backward_cuda + PrivateUse1: native_dropout_backward_zoom autogen: native_dropout_backward.out tags: pointwise @@ -354,7 +360,7 @@ - func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: abs_out + CPU, CUDA, PrivateUse1: abs_out MPS: abs_out_mps SparseCPU, SparseCUDA: abs_sparse_out SparseCsrCPU, SparseCsrCUDA: abs_sparse_csr_out @@ -399,26 +405,26 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: angle + CPU, CUDA, PrivateUse1: angle SparseCsrCPU, SparseCsrCUDA: angle_sparse_csr tags: pointwise - func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: angle_out + CPU, CUDA, PrivateUse1: angle_out SparseCsrCPU, SparseCsrCUDA: angle_sparse_csr_out tags: pointwise - func: view_as_real(Tensor(a) self) -> Tensor(a) variants: function dispatch: - CPU, CUDA, MPS, Meta: view_as_real + CPU, CUDA, PrivateUse1, MPS, Meta: view_as_real - func: view_as_complex(Tensor(a) self) -> Tensor(a) variants: function dispatch: - CPU, CUDA, MPS, Meta: view_as_complex + CPU, CUDA, PrivateUse1, MPS, Meta: view_as_complex - func: sgn(Tensor self) -> Tensor variants: function, method @@ -442,7 +448,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sgn_out + CPU, CUDA, PrivateUse1: sgn_out MPS: sgn_out_mps SparseCPU, SparseCUDA: sgn_sparse_out SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_out @@ -481,7 +487,7 @@ - func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: conj_physical_out + CPU, CUDA, PrivateUse1: conj_physical_out MPS: conj_physical_out_mps SparseCPU, SparseCUDA: conj_physical_out_sparse SparseCsrCPU, SparseCsrCUDA: conj_physical_sparse_csr_out @@ -522,7 +528,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: acos_out + CPU, CUDA, PrivateUse1: acos_out MPS: acos_out_mps tags: pointwise @@ -638,6 +644,7 @@ dispatch: CPU: addmv_out_cpu CUDA: addmv_out_cuda + PrivateUse1: addmv_out_hip MPS: addmv_out_mps SparseCsrCPU: addmv_out_sparse_compressed SparseCsrCUDA: addmv_out_sparse_compressed_cuda @@ -645,7 +652,7 @@ - func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor variants: function, method dispatch: - CPU, CUDA: addr + CPU, CUDA, PrivateUse1: addr MPS: addr_mps CompositeExplicitAutograd: math_addr @@ -656,7 +663,7 @@ - func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: addr_out + CPU, CUDA, PrivateUse1: addr_out MPS: addr_out_mps CompositeExplicitAutograd: math_addr_out @@ -707,14 +714,14 @@ device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: all_out + CPU, CUDA, PrivateUse1: all_out MPS: all_out_mps - func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: all_dims_out + CPU, CUDA, PrivateUse1: all_dims_out CompositeExplicitAutograd: all_dims_out_default cpp_no_default_args: ['dim'] @@ -750,14 +757,14 @@ device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: any_out + CPU, CUDA, PrivateUse1: any_out MPS: any_out_mps - func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: any_dims_out + CPU, CUDA, PrivateUse1: any_dims_out CompositeExplicitAutograd: any_dims_out_default cpp_no_default_args: ['dim'] @@ -797,6 +804,7 @@ dispatch: CPU, Meta: arange_out CUDA: arange_cuda_out + PrivateUse1: arange_zoom_out MPS: arange_mps_out cpp_no_default_args: ['step'] @@ -816,7 +824,7 @@ - func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: argmax_out + CPU, CUDA, PrivateUse1: argmax_out MPS: argmax_out_mps - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor @@ -828,7 +836,7 @@ - func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: argmin_out + CPU, CUDA, PrivateUse1: argmin_out MPS: argmin_out_mps - func: acosh(Tensor self) -> Tensor @@ -845,7 +853,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: acosh_out + CPU, CUDA, PrivateUse1: acosh_out MPS: acosh_out_mps tags: pointwise # arccosh, alias for acosh @@ -878,7 +886,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: asinh_out + CPU, CUDA, PrivateUse1: asinh_out MPS: asinh_out_mps SparseCPU, SparseCUDA: asinh_sparse_out SparseCsrCPU, SparseCsrCUDA: asinh_sparse_csr_out @@ -913,7 +921,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: atanh_out + CPU, CUDA, PrivateUse1: atanh_out MPS: atanh_out_mps SparseCPU, SparseCUDA: atanh_sparse_out SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr_out @@ -931,7 +939,7 @@ - func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) variants: function, method dispatch: - ZeroTensor, CPU, CUDA: as_strided_tensorimpl + ZeroTensor, CPU, CUDA, PrivateUse1: as_strided_tensorimpl Meta: as_strided_tensorimpl_meta_symint MPS: as_strided_tensorimpl_mps QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl @@ -971,7 +979,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: asin_out + CPU, CUDA, PrivateUse1: asin_out MPS: asin_out_mps SparseCPU, SparseCUDA: asin_sparse_out SparseCsrCPU, SparseCsrCUDA: asin_sparse_csr_out @@ -1009,7 +1017,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: atan_out + CPU, CUDA, PrivateUse1: atan_out MPS: atan_out_mps SparseCPU, SparseCUDA: atan_sparse_out SparseCsrCPU, SparseCsrCUDA: atan_sparse_csr_out @@ -1055,6 +1063,7 @@ dispatch: CPU: baddbmm_out_cpu CUDA: baddbmm_out_cuda + PrivateUse1: baddbmm_out_hip MPS: baddbmm_out_mps SparseCsrCUDA: baddbmm_out_sparse_csr_cuda @@ -1092,7 +1101,7 @@ variants: function tags: nondeterministic_seeded dispatch: - CPU, CUDA: bernoulli_out + CPU, CUDA, PrivateUse1: bernoulli_out MPS: bernoulli_out_mps - func: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) @@ -1100,7 +1109,7 @@ variants: method tags: nondeterministic_seeded dispatch: - CPU, CUDA: bernoulli_ + CPU, CUDA, PrivateUse1: bernoulli_ MPS: bernoulli_mps_ autogen: bernoulli.Tensor, bernoulli.Tensor_out @@ -1109,7 +1118,7 @@ variants: method tags: nondeterministic_seeded dispatch: - CPU, CUDA: bernoulli_ + CPU, CUDA, PrivateUse1: bernoulli_ MPS: bernoulli_mps_ autogen: bernoulli.float_out @@ -1134,6 +1143,7 @@ dispatch: CPU: binary_cross_entropy_cpu CUDA: binary_cross_entropy_cuda + PrivateUse1: binary_cross_entropy_zoom MPS: binary_cross_entropy_mps - func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) @@ -1143,6 +1153,7 @@ dispatch: CPU: binary_cross_entropy_out_cpu CUDA: binary_cross_entropy_out_cuda + PrivateUse1: binary_cross_entropy_out_zoom MPS: binary_cross_entropy_out_mps - func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor @@ -1151,6 +1162,7 @@ dispatch: CPU: binary_cross_entropy_backward_cpu CUDA: binary_cross_entropy_backward_cuda + PrivateUse1: binary_cross_entropy_backward_zoom MPS: binary_cross_entropy_backward_mps - func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -1159,6 +1171,7 @@ dispatch: CPU: binary_cross_entropy_backward_out_cpu CUDA: binary_cross_entropy_backward_out_cuda + PrivateUse1: binary_cross_entropy_backward_out_zoom MPS: binary_cross_entropy_backward_out_mps - func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor @@ -1173,6 +1186,7 @@ dispatch: CPU: _bincount_cpu CUDA: _bincount_cuda + PrivateUse1: _bincount_zoom MPS: _bincount_mps tags: dynamic_output_shape autogen: bincount.out @@ -1194,7 +1208,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: bitwise_not_out + CPU, CUDA, PrivateUse1: bitwise_not_out MPS: bitwise_not_out_mps tags: pointwise @@ -1203,7 +1217,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: copysign_out + CPU, CUDA, PrivateUse1, MPS: copysign_out tags: pointwise - func: copysign.Tensor(Tensor self, Tensor other) -> Tensor @@ -1259,7 +1273,7 @@ - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_not_out + CPU, CUDA, PrivateUse1: logical_not_out MPS: logical_not_out_mps tags: pointwise @@ -1280,7 +1294,7 @@ - func: logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_xor_out + CPU, CUDA, PrivateUse1: logical_xor_out MPS: logical_xor_out_mps tags: pointwise @@ -1301,7 +1315,7 @@ - func: logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_and_out + CPU, CUDA, PrivateUse1: logical_and_out MPS: logical_and_out_mps tags: pointwise @@ -1322,7 +1336,7 @@ - func: logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: logical_or_out + CPU, CUDA, PrivateUse1: logical_or_out MPS: logical_or_out_mps tags: pointwise @@ -1352,6 +1366,7 @@ dispatch: CPU: bmm_out_cpu CUDA: bmm_out_cuda + PrivateUse1: bmm_out_hip MPS: bmm_out_mps SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda @@ -1386,6 +1401,7 @@ dispatch: CPU: cat_out_cpu CUDA: cat_out_cuda + PrivateUse1: cat_out_zoom MPS: cat_out_mps QuantizedCPU: cat_out_quantized_cpu @@ -1440,7 +1456,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: ceil_out + CPU, CUDA, PrivateUse1: ceil_out MPS: ceil_out_mps SparseCPU, SparseCUDA: ceil_sparse_out SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr_out @@ -1511,7 +1527,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_out + CPU, CUDA, PrivateUse1: clamp_out MPS: clamp_out_mps tags: pointwise @@ -1520,7 +1536,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_Tensor_out + CPU, CUDA, PrivateUse1: clamp_Tensor_out MPS: clamp_Tensor_out_mps tags: pointwise @@ -1551,7 +1567,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_max_out + CPU, CUDA, PrivateUse1: clamp_max_out MPS: clamp_max_out_mps tags: pointwise @@ -1560,7 +1576,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_max_Tensor_out + CPU, CUDA, PrivateUse1: clamp_max_Tensor_out MPS: clamp_max_Tensor_out_mps tags: pointwise @@ -1591,7 +1607,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_min_out + CPU, CUDA, PrivateUse1: clamp_min_out MPS: clamp_min_out_mps tags: pointwise @@ -1600,7 +1616,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_min_Tensor_out + CPU, CUDA, PrivateUse1: clamp_min_Tensor_out MPS: clamp_min_Tensor_out_mps tags: pointwise @@ -1640,7 +1656,7 @@ - func: complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: complex_out + CPU, CUDA, PrivateUse1: complex_out MPS: complex_out_mps - func: polar(Tensor abs, Tensor angle) -> Tensor @@ -1650,7 +1666,7 @@ - func: polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: polar_out + CPU, CUDA, PrivateUse1: polar_out MPS: polar_out_mps - func: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor @@ -1797,7 +1813,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: cos_out + CPU, CUDA, PrivateUse1: cos_out MPS: cos_out_mps tags: pointwise @@ -1818,7 +1834,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: cosh_out + CPU, CUDA, PrivateUse1: cosh_out MPS: cosh_out_mps tags: pointwise @@ -1933,6 +1949,7 @@ dispatch: CPU: cummax_helper_cpu CUDA: cummax_helper_cuda + PrivateUse1: cummax_helper_zoom - func: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) device_check: NoCheck # TensorIterator @@ -1957,6 +1974,7 @@ dispatch: CPU: cummin_helper_cpu CUDA: cummin_helper_cuda + PrivateUse1: cummin_helper_zoom - func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor variants: function @@ -1976,7 +1994,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: cumprod_out + CPU, CUDA, PrivateUse1: cumprod_out MPS: cumprod_out_mps - func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor @@ -2008,7 +2026,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: cumsum_out + CPU, CUDA, PrivateUse1: cumsum_out MPS: cumsum_out_mps - func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor @@ -2034,13 +2052,14 @@ dispatch: CPU: ctc_loss_cpu CUDA: ctc_loss_gpu + PrivateUse1: ctc_loss_gpu Meta: ctc_loss_meta autogen: _ctc_loss.out tags: dynamic_output_shape # the shape of second output is data dependent - func: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) dispatch: - CPU, CUDA: ctc_loss_tensor + CPU, CUDA, PrivateUse1: ctc_loss_tensor autogen: _ctc_loss.Tensor_out tags: dynamic_output_shape # the shape of second output is data dependent @@ -2048,11 +2067,12 @@ dispatch: CPU: ctc_loss_backward_cpu CUDA: ctc_loss_backward_gpu + PrivateUse1: ctc_loss_backward_gpu autogen: _ctc_loss_backward.out - func: _ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor dispatch: - CPU, CUDA: ctc_loss_backward_tensor + CPU, CUDA, PrivateUse1: ctc_loss_backward_tensor - func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor variants: function, method @@ -2137,7 +2157,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: div_out + CPU, CUDA, PrivateUse1: div_out MPS: div_out_mps SparseCPU, SparseCUDA: div_out_sparse_zerodim tags: pointwise @@ -2163,7 +2183,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: div_out_mode + CPU, CUDA, PrivateUse1: div_out_mode MPS: div_out_mode_mps SparseCPU, SparseCUDA: div_out_sparse_zerodim tags: pointwise @@ -2253,6 +2273,7 @@ dispatch: CPU: dot CUDA: dot_cuda + PrivateUse1: dot_hip MPS: dot_mps - func: dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) @@ -2264,6 +2285,7 @@ dispatch: CPU: vdot CUDA: vdot_cuda + PrivateUse1: vdot_hip - func: vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -2286,6 +2308,7 @@ dispatch: CPU: embedding_dense_backward_cpu CUDA: embedding_dense_backward_cuda + PrivateUse1: embedding_dense_backward_zoom MPS: embedding_dense_backward_mps autogen: embedding_dense_backward.out tags: core @@ -2293,7 +2316,8 @@ - func: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) dispatch: CPU: embedding_renorm_cpu_ - CUDA: embedding_renorm_cuda_ + CUDA: embedding_renorm_cuda + PrivateUse1: embedding_renorm_zoom_ autogen: embedding_renorm, embedding_renorm.out - func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor @@ -2312,6 +2336,7 @@ dispatch: CPU: _embedding_bag_forward_only_cpu CUDA: _embedding_bag_forward_only_cuda + PrivateUse1: _embedding_bag_forward_only_zoom autogen: _embedding_bag_forward_only.out - func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) @@ -2333,6 +2358,7 @@ dispatch: CPU: _embedding_bag_cpu CUDA: _embedding_bag_cuda + PrivateUse1: _embedding_bag_zoom autogen: _embedding_bag.out tags: core @@ -2348,12 +2374,14 @@ dispatch: CPU: _embedding_bag_dense_backward_cpu CUDA: _embedding_bag_dense_backward_cuda + PrivateUse1: _embedding_bag_dense_backward_zoom autogen: _embedding_bag_dense_backward.out - func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor dispatch: CPU: _embedding_bag_per_sample_weights_backward_cpu CUDA: _embedding_bag_per_sample_weights_backward_cuda + PrivateUse1: _embedding_bag_per_sample_weights_backward_zoom autogen: _embedding_bag_per_sample_weights_backward.out - func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor @@ -2367,6 +2395,7 @@ dispatch: CPU: empty_cpu CUDA: empty_cuda + PrivateUse1: empty_zoom MPS: empty_mps Meta: empty_meta_symint MkldnnCPU: empty_mkldnn @@ -2444,6 +2473,7 @@ Meta: resize__symint CPU: resize_ CUDA: resize_cuda_ + PrivateUse1: resize_zoom_ MPS: resize_mps_ QuantizedCPU: quantized_resize_cpu_ SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_ @@ -2485,6 +2515,7 @@ dispatch: CPU: empty_strided_cpu CUDA: empty_strided_cuda + PrivateUse1: empty_strided_zoom MPS: empty_strided_mps Meta: empty_strided_meta_symint QuantizedCPU, QuantizedCUDA: empty_strided_unknown_quantized @@ -2514,7 +2545,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: erf_out + CPU, CUDA, PrivateUse1: erf_out MPS: erf_out_mps SparseCPU, SparseCUDA: erf_sparse_out SparseCsrCPU, SparseCsrCUDA: erf_sparse_csr_out @@ -2537,7 +2568,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: erfc_out + CPU, CUDA, PrivateUse1: erfc_out tags: pointwise - func: exp(Tensor self) -> Tensor @@ -2557,7 +2588,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: exp_out + CPU, CUDA, PrivateUse1: exp_out MPS: exp_out_mps tags: pointwise @@ -2575,7 +2606,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: exp2_out + CPU, CUDA, PrivateUse1: exp2_out MPS: exp2_out_mps tags: pointwise @@ -2602,7 +2633,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: expm1_out + CPU, CUDA, PrivateUse1: expm1_out MPS: expm1_out_mps SparseCPU, SparseCUDA: expm1_sparse_out SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_out @@ -2634,12 +2665,14 @@ dispatch: CPU, Meta: eye_out_cpu CUDA: eye_out_cuda + PrivateUse1: eye_out_zoom MPS: eye_out_mps - func: eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, Meta: eye_out_cpu CUDA: eye_out_cuda + PrivateUse1: eye_out_zoom MPS: eye_out_mps - func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) @@ -2679,7 +2712,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: fill_ + CPU, CUDA, PrivateUse1: fill_ MPS: fill_scalar_mps QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ @@ -2691,7 +2724,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: fill_ + CPU, CUDA, PrivateUse1: fill_ MPS: fill_tensor_mps_ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ @@ -2721,7 +2754,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: floor_out + CPU, CUDA, PrivateUse1: floor_out MPS: floor_out_mps SparseCPU, SparseCUDA: floor_sparse_out SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr_out @@ -2731,7 +2764,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: floor_divide + CPU, CUDA, PrivateUse1: floor_divide MPS: floor_divide_mps SparseCPU, SparseCUDA: floor_divide_sparse @@ -2739,14 +2772,14 @@ device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: floor_divide_ + CPU, CUDA, PrivateUse1: floor_divide_ MPS: floor_divide_mps_ SparseCPU, SparseCUDA: floor_divide_sparse_ - func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: floor_divide_out + CPU, CUDA, PrivateUse1: floor_divide_out MPS: floor_divide_out_mps SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim @@ -2786,7 +2819,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: frac_out + CPU, CUDA, PrivateUse1: frac_out MPS: frac_out_mps SparseCPU, SparseCUDA: frac_sparse_out SparseCsrCPU, SparseCsrCUDA: frac_sparse_csr_out @@ -2824,7 +2857,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: gcd_out + CPU, CUDA, PrivateUse1: gcd_out tags: pointwise - func: gcd(Tensor self, Tensor other) -> Tensor @@ -2840,7 +2873,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: lcm_out + CPU, CUDA, PrivateUse1: lcm_out tags: pointwise - func: lcm(Tensor self, Tensor other) -> Tensor @@ -2875,6 +2908,7 @@ dispatch: CPU, QuantizedCPU: grid_sampler_2d_cpu CUDA: grid_sampler_2d_cuda + PrivateUse1: grid_sampler_2d_zoom MPS: grid_sampler_2d_mps autogen: grid_sampler_2d.out tags: core @@ -2886,6 +2920,7 @@ dispatch: CPU: grid_sampler_2d_backward_cpu CUDA: grid_sampler_2d_backward_cuda + PrivateUse1: grid_sampler_2d_backward_zoom autogen: grid_sampler_2d_backward.out # See NOTE [ grid_sample CPU fallback ] @@ -2900,6 +2935,7 @@ dispatch: CPU: grid_sampler_3d_cpu CUDA: grid_sampler_3d_cuda + PrivateUse1: grid_sampler_3d_zoom autogen: grid_sampler_3d.out # `grid_sampler_3d_backward` takes in `output_mask` to optimize performance for @@ -2909,6 +2945,7 @@ dispatch: CPU: grid_sampler_3d_backward_cpu CUDA: grid_sampler_3d_backward_cuda + PrivateUse1: grid_sampler_3d_backward_zoom autogen: grid_sampler_3d_backward.out - func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -2962,14 +2999,14 @@ - func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) dispatch: - CPU, CUDA: native_group_norm + CPU, CUDA, PrivateUse1: native_group_norm CompositeExplicitAutograd: math_group_norm autogen: native_group_norm.out tags: core - func: native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) dispatch: - CPU, CUDA: native_group_norm_backward + CPU, CUDA, PrivateUse1: native_group_norm_backward autogen: native_group_norm_backward.out tags: core @@ -3024,6 +3061,7 @@ dispatch: CPU: _validate_compressed_sparse_indices_cpu CUDA: _validate_compressed_sparse_indices_cuda + PrivateUse1: _validate_compressed_sparse_indices_zoom - func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int @@ -3052,7 +3090,7 @@ precomputed: - indices -> DimVector sizes, DimVector strides dispatch: - CPU, CUDA, MPS: index_out + CPU, CUDA, MPS, PrivateUse1: index_out # Used by inductor to signal indexing without bounds checks # Note that we don't support boolean indexing, to avoid dynamic output shapes @@ -3067,7 +3105,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: index_copy_out + CPU, CUDA, PrivateUse1: index_copy_out - func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) variants: method @@ -3112,7 +3150,7 @@ device_check: NoCheck # TensorIterator variants: function dispatch: - CPU, CUDA, MPS: _index_put_impl_ + CPU, CUDA, PrivateUse1, MPS: _index_put_impl_ QuantizedCPU: _index_put_impl_quantized_cpu_ QuantizedCUDA: _index_put_impl_quantized_cuda_ autogen: _index_put_impl, _index_put_impl.out @@ -3127,7 +3165,7 @@ variants: function structured: True dispatch: - CPU, CUDA: isin_Tensor_Tensor_out + CPU, CUDA, PrivateUse1: isin_Tensor_Tensor_out MPS: isin_Tensor_Tensor_out_mps - func: isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor @@ -3138,7 +3176,7 @@ variants: function structured: True dispatch: - CPU, CUDA: isin_Tensor_Scalar_out + CPU, CUDA, PrivateUse1: isin_Tensor_Scalar_out - func: isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor variants: function @@ -3148,7 +3186,7 @@ variants: function structured: True dispatch: - CPU, CUDA: isin_Scalar_Tensor_out + CPU, CUDA, PrivateUse1: isin_Scalar_Tensor_out - func: isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor variants: function @@ -3159,7 +3197,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, MPS: isnan + CPU, CUDA, MPS, PrivateUse1: isnan SparseCPU, SparseCUDA: isnan_sparse SparseCsrCPU, SparseCsrCUDA: isnan_sparse_csr autogen: isnan.out @@ -3241,6 +3279,7 @@ dispatch: CPU: kthvalue_out_cpu CUDA: kthvalue_out_cuda + PrivateUse1: kthvalue_out_zoom - func: kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method @@ -3255,6 +3294,7 @@ dispatch: CPU: layer_norm_cpu CUDA: layer_norm_cuda + PrivateUse1: layer_norm_zoom MPS: layer_norm_mps CompositeExplicitAutograd: math_native_layer_norm NestedTensorCPU, NestedTensorCUDA: nested_layer_norm @@ -3265,6 +3305,7 @@ dispatch: CPU: layer_norm_backward_cpu CUDA: layer_norm_backward_cuda + PrivateUse1: layer_norm_backward_zoom MPS: layer_norm_backward_mps NestedTensorCPU, NestedTensorCUDA: layer_norm_backward_nested autogen: native_layer_norm_backward.out @@ -3288,7 +3329,7 @@ - func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: nan_to_num_out + CPU, CUDA, PrivateUse1: nan_to_num_out MPS: nan_to_num_out_mps SparseCPU, SparseCUDA: nan_to_num_sparse_out tags: pointwise @@ -3422,6 +3463,7 @@ dispatch: CPU, Meta: linspace_out CUDA: linspace_cuda_out + PrivateUse1: linspace_zoom_out MPS: linspace_out_mps - func: linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) @@ -3456,7 +3498,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: log_out + CPU, CUDA, PrivateUse1: log_out MPS: log_out_mps tags: pointwise @@ -3477,7 +3519,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: log10_out + CPU, CUDA, PrivateUse1: log10_out MPS: log10_out_mps tags: pointwise @@ -3504,7 +3546,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: log1p_out + CPU, CUDA, PrivateUse1: log1p_out MPS: log1p_out_mps SparseCPU, SparseCUDA: log1p_sparse_out SparseCsrCPU, SparseCsrCUDA: log1p_sparse_csr_out @@ -3527,7 +3569,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: log2_out + CPU, CUDA, PrivateUse1: log2_out MPS: log2_out_mps tags: pointwise @@ -3535,7 +3577,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: logaddexp_out + CPU, CUDA, PrivateUse1: logaddexp_out MPS: logaddexp_out_mps tags: pointwise @@ -3548,7 +3590,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: logaddexp2_out + CPU, CUDA, PrivateUse1: logaddexp2_out MPS: logaddexp2_out_mps tags: pointwise @@ -3597,7 +3639,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: xlogy_out + CPU, CUDA, PrivateUse1: xlogy_out MPS: xlogy_out_mps tags: pointwise @@ -3638,6 +3680,7 @@ dispatch: CPU, Meta: logspace_out CUDA: logspace_cuda_out + PrivateUse1: logspace_zoom_out - func: logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) category_override: factory @@ -3675,6 +3718,7 @@ dispatch: CPU: log_softmax_cpu_out CUDA: log_softmax_cuda_out + PrivateUse1: log_softmax_zoom_out MPS: log_softmax_mps_out - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor @@ -3685,17 +3729,20 @@ dispatch: CPU: log_softmax_backward_cpu_out CUDA: log_softmax_backward_cuda_out + PrivateUse1: log_softmax_backward_zoom_out MPS: log_softmax_backward_mps_out - func: _logcumsumexp(Tensor self, int dim) -> Tensor dispatch: CPU: _logcumsumexp_cpu CUDA: _logcumsumexp_cuda + PrivateUse1: _logcumsumexp_zoom - func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _logcumsumexp_out_cpu CUDA: _logcumsumexp_out_cuda + PrivateUse1: _logcumsumexp_out_zoom - func: logcumsumexp(Tensor self, int dim) -> Tensor variants: function, method @@ -3783,16 +3830,16 @@ device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: aminmax_out + CPU, CUDA, PrivateUse1: aminmax_out MPS: aminmax_out_mps - func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor dispatch: - CPU, CUDA: _compute_linear_combination + CPU, CUDA, PrivateUse1: _compute_linear_combination - func: _compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: _compute_linear_combination_out + CPU, CUDA, PrivateUse1: _compute_linear_combination_out - func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) device_check: NoCheck # TensorIterator @@ -3808,7 +3855,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: max_out + CPU, CUDA, PrivateUse1: max_out MPS: max_out_mps - func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -3833,7 +3880,7 @@ - func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: amax_out + CPU, CUDA, PrivateUse1: amax_out MPS: amax_out_mps # Return: (Tensor output, Tensor indices) @@ -3917,7 +3964,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: mean_out + CPU, CUDA, PrivateUse1: mean_out MPS: mean_out_mps QuantizedCPU: mean_out_quantized_cpu @@ -3940,6 +3987,7 @@ dispatch: CPU: median_cpu CUDA: median_cuda + PrivateUse1: median_zoom MPS: median_mps autogen: median.out @@ -3952,6 +4000,7 @@ dispatch: CPU: median_out_cpu CUDA: median_out_cuda + PrivateUse1: median_out_zoom MPS: median_out_mps - func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -3964,6 +4013,7 @@ dispatch: CPU: nanmedian_cpu CUDA: nanmedian_cuda + PrivateUse1: nanmedian_zoom autogen: nanmedian.out - func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -3975,6 +4025,7 @@ dispatch: CPU: nanmedian_out_cpu CUDA: nanmedian_out_cuda + PrivateUse1: nanmedian_out_zoom - func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method @@ -3995,7 +4046,7 @@ precomputed: - dim -> int dim dispatch: - CPU, CUDA: min_out + CPU, CUDA, PrivateUse1: min_out MPS: min_out_mps - func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -4013,7 +4064,7 @@ - func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: amin_out + CPU, CUDA, PrivateUse1: amin_out MPS: amin_out_mps # TODO: Add this function to MPS dispatch key so that we avoid declaring it in @@ -4103,6 +4154,7 @@ dispatch: CPU: mm_out_cpu CUDA: mm_out_cuda + PrivateUse1: mm_out_hip MPS: mm_out_mps SparseCPU, SparseCUDA: _sparse_mm_out SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm_out @@ -4111,11 +4163,13 @@ dispatch: CPU: _int_mm_cpu CUDA: _int_mm_cuda + PrivateUse1: _int_mm_hip - func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: _int_mm_out_cpu CUDA: _int_mm_out_cuda + PrivateUse1: _int_mm_out_hip - func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor dispatch: @@ -4148,7 +4202,7 @@ - func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method dispatch: - CPU, CUDA: mode + CPU, CUDA, PrivateUse1: mode - func: mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) dispatch: @@ -4187,7 +4241,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: mul_out + CPU, CUDA, PrivateUse1: mul_out MPS: mul_out_mps SparseCPU: mul_out_sparse_cpu SparseCUDA: mul_out_sparse_cuda @@ -4290,12 +4344,14 @@ dispatch: CPU: batch_norm_cpu CUDA: batch_norm_cuda + PrivateUse1: batch_norm_zoom MPS: batch_norm_mps MkldnnCPU: mkldnn_batch_norm - func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) dispatch: CUDA: batch_norm_cuda_out + PrivateUse1: batch_norm_zoom_out MPS: batch_norm_mps_out CPU: batch_norm_cpu_out @@ -4304,6 +4360,7 @@ dispatch: CPU: _batch_norm_legit_cpu CUDA: _batch_norm_legit_cuda + PrivateUse1: _batch_norm_legit_zoom MPS: _batch_norm_legit_mps MkldnnCPU: _mkldnn_batch_norm_legit autogen: _native_batch_norm_legit_functional @@ -4322,12 +4379,14 @@ dispatch: CPU: _batch_norm_legit_cpu_out CUDA: _batch_norm_legit_cuda_out + PrivateUse1: _batch_norm_legit_zoom_out MPS: _batch_norm_legit_mps_out - func: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) dispatch: CPU: _batch_norm_legit_no_stats_cpu CUDA: _batch_norm_legit_no_stats_cuda + PrivateUse1: _batch_norm_legit_no_stats_zoom MPS: _batch_norm_legit_no_stats_mps MkldnnCPU: _mkldnn_batch_norm_legit_no_stats tags: core @@ -4336,36 +4395,43 @@ dispatch: CPU: _batch_norm_legit_no_stats_cpu_out CUDA: _batch_norm_legit_no_stats_cuda_out + PrivateUse1: _batch_norm_legit_no_stats_zoom_out MPS: _batch_norm_legit_no_stats_mps_out - func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) dispatch: CUDA: batch_norm_stats_cuda + PrivateUse1: batch_norm_stats_zoom autogen: batch_norm_stats.out - func: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor dispatch: CUDA: batch_norm_elemt_cuda + PrivateUse1: batch_norm_elemt_zoom - func: batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) dispatch: CUDA: batch_norm_elemt_cuda_out + PrivateUse1: batch_norm_elemt_zoom_out # for backward compatibility - func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor) dispatch: CUDA: batch_norm_gather_stats_cuda + PrivateUse1: batch_norm_gather_stats_zoom autogen: batch_norm_gather_stats.out - func: batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor) dispatch: CUDA: batch_norm_gather_stats_with_counts_cuda + PrivateUse1: batch_norm_gather_stats_with_counts_zoom autogen: batch_norm_gather_stats_with_counts.out - func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) dispatch: CPU: batch_norm_backward_cpu CUDA: batch_norm_backward_cuda + PrivateUse1: batch_norm_backward_zoom MPS: batch_norm_backward_mps MkldnnCPU: mkldnn_batch_norm_backward autogen: native_batch_norm_backward.out @@ -4373,17 +4439,20 @@ - func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: batch_norm_backward_reduce_cuda + PrivateUse1: batch_norm_backward_reduce_zoom autogen: batch_norm_backward_reduce.out - func: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor dispatch: CUDA: batch_norm_backward_elemt_cuda + PrivateUse1: batch_norm_backward_elemt_zoom autogen: batch_norm_backward_elemt.out - func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor) dispatch: CPU: batch_norm_update_stats_cpu CUDA: batch_norm_update_stats_cuda + PrivateUse1: batch_norm_update_stats_zoom autogen: batch_norm_update_stats.out - func: is_vulkan_available() -> bool @@ -4430,27 +4499,27 @@ - func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor dispatch: - CPU, CUDA: _cdist_forward + CPU, CUDA, PrivateUse1: _cdist_forward MPS: _cdist_forward_mps autogen: _cdist_forward.out tags: core - func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor dispatch: - CPU, CUDA: _cdist_backward + CPU, CUDA, PrivateUse1: _cdist_backward autogen: _cdist_backward.out - func: pdist(Tensor self, float p=2) -> Tensor - func: _pdist_forward(Tensor self, float p=2) -> Tensor dispatch: - CPU, CUDA: _pdist_forward + CPU, CUDA, PrivateUse1: _pdist_forward autogen: _pdist_forward.out tags: core - func: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor dispatch: - CPU, CUDA: _pdist_backward + CPU, CUDA, PrivateUse1: _pdist_backward autogen: _pdist_backward.out - func: cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor @@ -4531,6 +4600,7 @@ variants: method dispatch: NestedTensorCUDA, CUDA: is_pinned_cuda + PrivateUse1: is_pinned_zoom MPS: is_pinned_mps CompositeExplicitAutograd: is_pinned_default @@ -4543,6 +4613,7 @@ - func: _pin_memory(Tensor self, Device? device=None) -> Tensor dispatch: CUDA: _pin_memory_cuda + PrivateUse1: _pin_memory_zoom MPS: _pin_memory_mps NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested autogen: _pin_memory.out @@ -4760,6 +4831,7 @@ dispatch: CPU: randperm_out_cpu CUDA: randperm_out_cuda + PrivateUse1: randperm_out_zoom MPS: randperm_out_mps - func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -4778,6 +4850,7 @@ dispatch: CPU, Meta: range_out CUDA: range_cuda_out + PrivateUse1: range_zoom_out MPS: range_mps_out cpp_no_default_args: ['step'] @@ -4801,7 +4874,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: reciprocal_out + CPU, CUDA, PrivateUse1: reciprocal_out MPS: reciprocal_out_mps tags: pointwise @@ -4830,7 +4903,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: neg_out + CPU, CUDA, PrivateUse1: neg_out MPS: neg_out_mps SparseCPU, SparseCUDA: neg_out_sparse SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_out @@ -4858,6 +4931,7 @@ dispatch: CPU: repeat_interleave_cpu CUDA: repeat_interleave_cuda + PrivateUse1: repeat_interleave_zoom MPS: repeat_interleave_mps tags: dynamic_output_shape autogen: repeat_interleave.Tensor_out @@ -4893,7 +4967,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias + CPU, CUDA, PrivateUse1, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias # We don't need to support mkldnn since this is handled explicitly by the reshape operator. - func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor @@ -4935,7 +5009,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU: round_out - CUDA: round_out + CUDA, PrivateUse1: round_out MPS: round_out_mps SparseCPU, SparseCUDA: round_sparse_out SparseCsrCPU, SparseCsrCUDA: round_sparse_csr_out @@ -4959,7 +5033,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU: round_decimals_out - CUDA: round_decimals_out + CUDA, PrivateUse1: round_decimals_out tags: pointwise - func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor @@ -4974,7 +5048,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: relu + CPU, CUDA, PrivateUse1: relu MPS: relu_mps MkldnnCPU: mkldnn_relu QuantizedCPU: relu_quantized_cpu @@ -4988,7 +5062,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: relu_ + CPU, CUDA, PrivateUse1: relu_ MPS: relu_mps_ MkldnnCPU: mkldnn_relu_ QuantizedCPU: relu_quantized_cpu_ @@ -5011,14 +5085,14 @@ - func: _prelu_kernel(Tensor self, Tensor weight) -> Tensor dispatch: - CPU, CUDA: _prelu_kernel + CPU, CUDA, PrivateUse1: _prelu_kernel QuantizedCPU: _prelu_kernel_quantized_cpu MkldnnCPU: mkldnn_prelu MPS: prelu_mps - func: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) dispatch: - CPU, CUDA: _prelu_kernel_backward + CPU, CUDA, PrivateUse1: _prelu_kernel_backward MkldnnCPU: mkldnn_prelu_backward MPS: prelu_backward_mps @@ -5030,6 +5104,7 @@ dispatch: CPU: gelu_out_cpu CUDA: gelu_out_cuda + PrivateUse1: gelu_out_zoom MPS: gelu_out_mps - func: gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!) @@ -5058,6 +5133,7 @@ dispatch: CPU: gelu_backward_out_cpu CUDA: gelu_backward_out_cuda + PrivateUse1: gelu_backward_out_zoom MPS: gelu_backward_out_mps - func: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor @@ -5079,7 +5155,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: hardshrink_out + CPU, CUDA, PrivateUse1: hardshrink_out - func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor structured_delegate: hardshrink.out @@ -5090,7 +5166,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: hardshrink_backward_out + CPU, CUDA, PrivateUse1: hardshrink_backward_out - func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor structured_delegate: hardshrink_backward.grad_input @@ -5113,7 +5189,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: rsqrt_out + CPU, CUDA, PrivateUse1: rsqrt_out MPS: rsqrt_out_mps tags: pointwise @@ -5183,7 +5259,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: silu_out + CPU, CUDA, PrivateUse1: silu_out MPS: silu_out_mps tags: pointwise @@ -5192,7 +5268,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: silu_backward_out + CPU, CUDA, PrivateUse1: silu_backward_out MPS: silu_backward_out_mps tags: pointwise @@ -5217,13 +5293,13 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: mish_out + CPU, CUDA, PrivateUse1: mish_out MPS: mish_out_mps - func: mish_backward(Tensor grad_output, Tensor self) -> Tensor python_module: nn dispatch: - CPU, CUDA: mish_backward + CPU, CUDA, PrivateUse1: mish_backward MPS: mish_backward_mps CompositeImplicitAutograd: math_mish_backward @@ -5249,26 +5325,26 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sigmoid_out + CPU, CUDA, PrivateUse1: sigmoid_out MPS: sigmoid_out_mps tags: pointwise - func: logit(Tensor self, float? eps=None) -> Tensor variants: function, method dispatch: - CPU, CUDA: logit + CPU, CUDA, PrivateUse1: logit MPS: logit_mps tags: pointwise - func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) variants: function, method dispatch: - CPU, CUDA: logit_ + CPU, CUDA, PrivateUse1: logit_ tags: pointwise - func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: logit_out + CPU, CUDA, PrivateUse1: logit_out MPS: logit_out_mps tags: pointwise @@ -5296,7 +5372,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sin_out + CPU, CUDA, PrivateUse1: sin_out MPS: sin_out_mps SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_out SparseCPU, SparseCUDA: sin_sparse_out @@ -5316,7 +5392,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sinc_out + CPU, CUDA, PrivateUse1: sinc_out tags: pointwise - func: sinh(Tensor self) -> Tensor @@ -5342,7 +5418,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sinh_out + CPU, CUDA, PrivateUse1: sinh_out MPS: sinh_out_mps SparseCPU, SparseCUDA: sinh_sparse_out SparseCsrCPU, SparseCsrCUDA: sinh_sparse_csr_out @@ -5501,6 +5577,7 @@ dispatch: CPU: softmax_cpu_out CUDA: softmax_cuda_out + PrivateUse1: softmax_zoom_out MPS: softmax_mps_out - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor @@ -5513,6 +5590,7 @@ dispatch: CPU: softmax_backward_cpu_out CUDA: softmax_backward_cuda_out + PrivateUse1: softmax_backward_zoom_out MPS: softmax_backward_mps_out - func: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] @@ -5743,7 +5821,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: sum_out + CPU, CUDA, PrivateUse1: sum_out MPS: sum_out_mps - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) @@ -5757,12 +5835,12 @@ - func: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method dispatch: - CPU, CUDA: nansum + CPU, CUDA, PrivateUse1: nansum MPS: nansum_mps - func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: nansum_out + CPU, CUDA, PrivateUse1: nansum_out MPS: nansum_out_mps - func: sum_to_size(Tensor self, SymInt[] size) -> Tensor @@ -5795,7 +5873,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sqrt_out + CPU, CUDA, PrivateUse1: sqrt_out MPS: sqrt_out_mps SparseCPU, SparseCUDA: sqrt_sparse_out SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr_out @@ -5828,7 +5906,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: std + CPU, CUDA, PrivateUse1: std MPS: std_mps QuantizedCPU: std_quantized_cpu @@ -5846,7 +5924,7 @@ device_check: NoCheck # TensorIterator variants: function dispatch: - CPU, CUDA: std_mean + CPU, CUDA, PrivateUse1: std_mean MPS: std_mean_mps autogen: std_mean.correction_out @@ -5866,7 +5944,7 @@ - func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: std_out + CPU, CUDA, PrivateUse1: std_out QuantizedCPU: std_out_quantized_cpu - func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor @@ -5890,7 +5968,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: prod + CPU, CUDA, PrivateUse1: prod MPS: prod_mps autogen: prod.out tags: core @@ -5905,7 +5983,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: prod_out + CPU, CUDA, PrivateUse1: prod_out MPS: prod_out_mps - func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor @@ -5953,7 +6031,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: tan_out + CPU, CUDA, PrivateUse1: tan_out MPS: tan_out_mps SparseCPU, SparseCUDA: tan_sparse_out SparseCsrCPU, SparseCsrCUDA: tan_sparse_csr_out @@ -5987,7 +6065,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: tanh_out + CPU, CUDA, PrivateUse1: tanh_out MPS: tanh_out_mps SparseCPU, SparseCUDA: tanh_sparse_out SparseCsrCPU, SparseCsrCUDA: tanh_sparse_csr_out @@ -6017,14 +6095,14 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: threshold_out + CPU, CUDA, PrivateUse1: threshold_out MPS: threshold_out_mps - func: threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: threshold_backward_out + CPU, CUDA, PrivateUse1: threshold_backward_out MPS: threshold_backward_out_mps SparseCPU, SparseCUDA: threshold_backward_sparse_out SparseCsrCPU, SparseCsrCUDA: threshold_backward_sparse_compressed_out @@ -6086,7 +6164,7 @@ - func: flip(Tensor self, int[] dims) -> Tensor variants: function, method dispatch: - CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip + CPU, QuantizedCPU, CUDA, QuantizedCUDA, PrivateUse1: flip MPS: flip_mps autogen: flip.out tags: core @@ -6102,6 +6180,7 @@ dispatch: CPU, MPS: roll CUDA: roll_cuda + PrivateUse1: roll_zoom autogen: roll.out # default int[] value [0,1] should not add space after comma, since codegen parser uses ', ' to split args @@ -6268,7 +6347,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: trunc_out + CPU, CUDA, PrivateUse1: trunc_out MPS: trunc_out_mps SparseCPU, SparseCUDA: trunc_sparse_out SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr_out @@ -6294,6 +6373,7 @@ dispatch: CPU: _unique_cpu CUDA: _unique_cuda + PrivateUse1: _unique_zoom autogen: _unique.out - func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) @@ -6301,6 +6381,7 @@ dispatch: CPU: unique_dim_cpu CUDA: unique_dim_cuda + PrivateUse1: unique_dim_zoom tags: dynamic_output_shape autogen: unique_dim.out @@ -6309,6 +6390,7 @@ dispatch: CPU: unique_consecutive_cpu CUDA: unique_consecutive_cuda + PrivateUse1: unique_consecutive_zoom MPS: unique_consecutive_mps tags: dynamic_output_shape autogen: unique_consecutive.out @@ -6318,6 +6400,7 @@ dispatch: CPU: unique_dim_consecutive_cpu CUDA: unique_dim_consecutive_cuda + PrivateUse1: unique_dim_consecutive_zoom MPS: unique_dim_consecutive_mps tags: dynamic_output_shape autogen: unique_dim_consecutive.out @@ -6331,6 +6414,7 @@ dispatch: CPU: _unique2_cpu CUDA: _unique2_cuda + PrivateUse1: _unique2_zoom MPS: _unique2_mps tags: dynamic_output_shape autogen: _unique2.out @@ -6376,7 +6460,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: var + CPU, CUDA, PrivateUse1: var MPS: var_mps tags: core @@ -6387,7 +6471,7 @@ - func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: var_out + CPU, CUDA, PrivateUse1: var_out - func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor device_check: NoCheck # TensorIterator @@ -6420,7 +6504,7 @@ device_check: NoCheck # TensorIterator variants: function dispatch: - CPU, CUDA: var_mean + CPU, CUDA, PrivateUse1: var_mean MPS: var_mean_mps autogen: var_mean.correction_out @@ -6442,13 +6526,13 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA, MPS: where + CPU, CUDA, MPS, PrivateUse1: where tags: [core, pointwise] - func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MPS: where_self_out + CPU, CUDA, MPS, PrivateUse1: where_self_out - func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor variants: function @@ -6476,6 +6560,7 @@ dispatch: CPU: weight_norm_cpu CUDA: weight_norm_cuda + PrivateUse1: weight_norm_zoom MPS: weight_norm_mps autogen: _weight_norm_interface.out @@ -6484,6 +6569,7 @@ dispatch: CPU: weight_norm_backward_cpu CUDA: weight_norm_backward_cuda + PrivateUse1: weight_norm_backward_zoom MPS: weight_norm_backward_mps autogen: _weight_norm_interface_backward.out @@ -6501,6 +6587,7 @@ dispatch: CPU: _efficientzerotensor CUDA: _efficientzerotensor_cuda + PrivateUse1: _efficientzerotensor_zoom MPS: _efficientzerotensor_mps Meta: _efficientzerotensor_meta_symint autogen: _efficientzerotensor.out @@ -6583,6 +6670,7 @@ dispatch: CPU: _batch_norm_with_update_cpu CUDA: _batch_norm_with_update_cuda + PrivateUse1: _batch_norm_with_update_zoom MPS: _batch_norm_with_update_mps MkldnnCPU: _batch_norm_with_update_mkldnn autogen: _batch_norm_with_update_functional @@ -6591,6 +6679,7 @@ dispatch: CPU: _batch_norm_with_update_cpu_out CUDA: _batch_norm_with_update_cuda_out + PrivateUse1: _batch_norm_with_update_zoom_out MPS: _batch_norm_with_update_mps_out - func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) @@ -6602,6 +6691,7 @@ dispatch: CPU: _new_batch_norm_backward_cpu CUDA: _new_batch_norm_backward_cuda + PrivateUse1: _new_batch_norm_backward_zoom MPS: _new_batch_norm_backward_mps MkldnnCPU: _new_batch_norm_backward_mkldnn @@ -6722,7 +6812,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: norm_out + CPU, CUDA, PrivateUse1: norm_out MPS: norm_out_mps # These four redispatch in their implementation, so OK to be CompositeImplicitAutograd @@ -6748,7 +6838,7 @@ - func: frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) dispatch: - CPU, CUDA: frexp_out + CPU, CUDA, PrivateUse1: frexp_out tags: pointwise # Deprecated (v.1.12) @@ -6811,7 +6901,7 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: zero_ + CPU, CUDA, PrivateUse1: zero_ MPS: zero_mps_ Meta: zero_meta_ SparseCPU, SparseCUDA, SparseMeta: zero_sparse_ @@ -6825,7 +6915,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sub_out + CPU, CUDA, PrivateUse1: sub_out MPS: sub_out_mps SparseCPU, SparseCUDA: sub_out_sparse tags: pointwise @@ -6892,7 +6982,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: heaviside_out + CPU, CUDA, PrivateUse1: heaviside_out tags: pointwise - func: heaviside(Tensor self, Tensor values) -> Tensor @@ -6950,6 +7040,7 @@ dispatch: CPU: addmm_out_cpu CUDA: addmm_out_cuda + PrivateUse1: addmm_out_hip MPS: addmm_out_mps SparseCPU: addmm_out_sparse_dense_cpu SparseCUDA: addmm_out_sparse_dense_cuda @@ -6979,6 +7070,7 @@ dispatch: CPU: addmm_activation_out_cpu CUDA: addmm_activation_out_cuda + PrivateUse1: addmm_activation_out_hip - func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor structured_delegate: _addmm_activation.out @@ -6988,11 +7080,13 @@ variants: function dispatch: CUDA: _scaled_mm_cuda + PrivateUse1: _scaled_mm_hip - func: _scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!)) variants: function dispatch: CUDA: _scaled_mm_out_cuda + PrivateUse1: _scaled_mm_out_hip # NOTE [ Sparse: autograd and API ] # @@ -7726,6 +7820,7 @@ dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda + PrivateUse1: _local_scalar_dense_zoom MPS: _local_scalar_dense_mps variants: function @@ -7747,6 +7842,7 @@ - func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_lstm_cell_cuda + PrivateUse1: _thnn_fused_lstm_cell_zoom autogen: _thnn_fused_lstm_cell.out # NB: The composite version of this function below is a simple wrapper that duplicates some of the outputs @@ -7755,6 +7851,7 @@ - func: _thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_lstm_cell_backward_impl_cuda + PrivateUse1: _thnn_fused_lstm_cell_backward_impl_zoom autogen: _thnn_fused_lstm_cell_backward_impl.out - func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) @@ -7764,11 +7861,13 @@ - func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) dispatch: CUDA: _thnn_fused_gru_cell_cuda + PrivateUse1: _thnn_fused_gru_cell_zoom autogen: _thnn_fused_gru_cell.out - func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) dispatch: CUDA: _thnn_fused_gru_cell_backward_cuda + PrivateUse1: _thnn_fused_gru_cell_backward_zoom autogen: _thnn_fused_gru_cell_backward.out - func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) @@ -7851,7 +7950,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, MPS: set_ + CPU, CUDA, Meta, MPS, PrivateUse1: set_ autogen: set.source_Storage, set.source_Storage_out tags: inplace_view @@ -7863,6 +7962,7 @@ CPU: set_storage_cpu_ Meta: set_storage_meta__symint CUDA: set_storage_cuda_ + PrivateUse1: set_storage_zoom_ MPS: set_storage_mps_ QuantizedCPU, QuantizedCUDA: set_storage_quantized_ autogen: set.source_Storage_storage_offset, set.source_Storage_storage_offset_out @@ -7881,7 +7981,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, MPS: set_tensor_ + CPU, CUDA, Meta, MPS, PrivateUse1: set_tensor_ autogen: set.source_Tensor, set.source_Tensor_out tags: inplace_view @@ -7890,6 +7990,7 @@ dispatch: CPU: set_cpu_ CUDA: set_cuda_ + PrivateUse1: set_zoom_ Meta: set_meta_ MPS: set_mps_ autogen: set, set.out @@ -7926,7 +8027,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, MPS: is_set_to + CPU, CUDA, PrivateUse1, MPS: is_set_to - func: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -7934,6 +8035,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + PrivateUse1: masked_fill__zoom QuantizedCPU: masked_fill__quantized_cpu QuantizedCUDA: masked_fill__quantized_cuda MPS: masked_fill__mps @@ -7953,6 +8055,7 @@ dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda + PrivateUse1: masked_fill__zoom QuantizedCPU: masked_fill__quantized_cpu QuantizedCUDA: masked_fill__quantized_cuda MPS: masked_fill__mps @@ -7969,6 +8072,7 @@ dispatch: CPU: masked_scatter__cpu CUDA: masked_scatter__cuda + PrivateUse1: masked_scatter__zoom MPS: masked_scatter__mps autogen: masked_scatter.out @@ -7984,12 +8088,14 @@ - func: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor dispatch: CUDA: masked_softmax_cuda + PrivateUse1: masked_softmax_zoom CPU: masked_softmax_cpu autogen: _masked_softmax.out - func: _masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor dispatch: CUDA: masked_softmax_backward_cuda + PrivateUse1: masked_softmax_backward_zoom CPU: masked_softmax_backward_cpu autogen: _masked_softmax_backward.out @@ -7998,7 +8104,7 @@ device_check: NoCheck device_guard: False dispatch: - ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view + ZeroTensor, Meta, CPU, CUDA, PrivateUse1, QuantizedCPU, QuantizedCUDA, MPS: view MkldnnCPU: mkldnn_view NestedTensorCPU, NestedTensorCUDA: view_nested tags: core @@ -8019,7 +8125,7 @@ - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) variants: method dispatch: - CPU, CUDA: put_ + CPU, CUDA, PrivateUse1: put_ autogen: put.out - func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor @@ -8035,6 +8141,7 @@ dispatch: CPU: index_add_cpu_out CUDA: index_add_cuda_out + PrivateUse1: index_add_zoom_out MPS: index_add_mps_out - func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) @@ -8056,6 +8163,7 @@ dispatch: CPU: index_reduce_cpu_out CUDA: index_reduce_cuda_out + PrivateUse1: index_reduce_zoom_out - func: index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) structured_delegate: index_reduce.out @@ -8071,6 +8179,7 @@ dispatch: CPU: index_fill_ CUDA: index_fill_ + PrivateUse1: index_fill_ MPS: index_fill_mps_ autogen: index_fill.int_Scalar_out @@ -8084,7 +8193,7 @@ device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: index_fill_ + CPU, CUDA, PrivateUse1: index_fill_ MPS: index_fill_mps_ autogen: index_fill.int_Tensor_out @@ -8123,7 +8232,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_src_out + CPU, CUDA, PrivateUse1: scatter_src_out MPS: scatter_src_out_mps - func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor @@ -8139,7 +8248,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_value_out + CPU, CUDA, PrivateUse1: scatter_value_out MPS: scatter_value_out_mps - func: scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor @@ -8154,7 +8263,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_reduce_out + CPU, CUDA, PrivateUse1: scatter_reduce_out MPS: scatter_reduce_out_mps - func: scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor @@ -8169,7 +8278,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_value_reduce_out + CPU, CUDA, PrivateUse1: scatter_value_reduce_out MPS: scatter_value_reduce_out_mps - func: scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor @@ -8191,7 +8300,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_add + CPU, CUDA, PrivateUse1: scatter_add MPS: scatter_add_mps_out - func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor @@ -8210,7 +8319,7 @@ structured: True variants: function dispatch: - CPU, CUDA: scatter_reduce_two + CPU, CUDA, PrivateUse1: scatter_reduce_two - func: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) structured_delegate: eq.Scalar_out @@ -8228,7 +8337,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_and_out + CPU, CUDA, PrivateUse1: bitwise_and_out MPS: bitwise_and_out_mps tags: pointwise @@ -8295,7 +8404,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_or_out + CPU, CUDA, PrivateUse1: bitwise_or_out MPS: bitwise_or_out_mps tags: pointwise @@ -8362,7 +8471,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: bitwise_xor_out + CPU, CUDA, PrivateUse1: bitwise_xor_out MPS: bitwise_xor_out_mps tags: pointwise @@ -8431,21 +8540,21 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: __lshift__ + CPU, CUDA, PrivateUse1: __lshift__ tags: pointwise - func: __lshift__.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: __lshift__ + CPU, CUDA, PrivateUse1: __lshift__ tags: pointwise - func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: __ilshift__ + CPU, CUDA, PrivateUse1: __ilshift__ autogen: __lshift__.Scalar_out tags: pointwise @@ -8453,7 +8562,7 @@ device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: __ilshift__ + CPU, CUDA, PrivateUse1: __ilshift__ autogen: __lshift__.Tensor_out tags: pointwise @@ -8474,7 +8583,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: bitwise_left_shift_out + CPU, CUDA, PrivateUse1: bitwise_left_shift_out tags: pointwise - func: bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor @@ -8510,28 +8619,28 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: __rshift__ + CPU, CUDA, PrivateUse1: __rshift__ tags: pointwise - func: __rshift__.Tensor(Tensor self, Tensor other) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: __rshift__ + CPU, CUDA, PrivateUse1: __rshift__ tags: pointwise - func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: __irshift__ + CPU, CUDA, PrivateUse1: __irshift__ autogen: __rshift__.Scalar_out - func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method dispatch: - CPU, CUDA: __irshift__ + CPU, CUDA, PrivateUse1: __irshift__ autogen: __rshift__.Tensor_out - func: bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor @@ -8551,7 +8660,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: bitwise_right_shift_out + CPU, CUDA, PrivateUse1: bitwise_right_shift_out tags: pointwise - func: bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor @@ -8612,18 +8721,18 @@ - func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) variants: method dispatch: - CPU, CUDA: addbmm_ + CPU, CUDA, PrivateUse1: addbmm_ MPS: addbmm_mps_ - func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: addbmm_out + CPU, CUDA, PrivateUse1: addbmm_out MPS: addbmm_out_mps - func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor variants: method, function dispatch: - CPU, CUDA: addbmm + CPU, CUDA, PrivateUse1: addbmm MPS: addbmm_mps - func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) @@ -8631,7 +8740,7 @@ variants: method tags: nondeterministic_seeded dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ Meta: random_meta_ MPS: random_mps_ autogen: random.from, random.from_out @@ -8641,7 +8750,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ Meta: random_meta_ MPS: random_mps_ autogen: random.to, random.to_out @@ -8651,7 +8760,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ MPS: random_mps_ Meta: random_meta_ autogen: random, random.out @@ -8661,7 +8770,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: uniform_ + CPU, CUDA, PrivateUse1: uniform_ MPS: uniform_mps_ Meta: uniform_meta_ autogen: uniform, uniform.out @@ -8671,7 +8780,7 @@ variants: method tags: nondeterministic_seeded dispatch: - CPU, CUDA: cauchy_ + CPU, CUDA, PrivateUse1: cauchy_ autogen: cauchy, cauchy.out - func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) @@ -8679,7 +8788,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: log_normal_ + CPU, CUDA, PrivateUse1: log_normal_ autogen: log_normal, log_normal.out - func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) @@ -8687,7 +8796,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: exponential_ + CPU, CUDA, PrivateUse1: exponential_ MPS: exponential_mps_ autogen: exponential, exponential.out @@ -8696,7 +8805,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: geometric_ + CPU, CUDA, PrivateUse1: geometric_ # wrappers for TH functions autogen: geometric, geometric.out @@ -8716,6 +8825,7 @@ dispatch: CPU: triu_cpu CUDA: triu_cuda + PrivateUse1: triu_zoom MPS: triu_mps_out - func: triu(Tensor self, int diagonal=0) -> Tensor @@ -8727,6 +8837,7 @@ dispatch: CPU: tril_cpu CUDA: tril_cuda + PrivateUse1: tril_zoom MPS: tril_mps_out - func: tril(Tensor self, int diagonal=0) -> Tensor @@ -8750,6 +8861,7 @@ dispatch: CPU: trace_cpu CUDA: trace_cuda + PrivateUse1: trace_zoom MPS: trace_mps autogen: trace.out @@ -8765,7 +8877,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ne_Scalar_out + CPU, CUDA, PrivateUse1: ne_Scalar_out MPS: ne_scalar_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -8783,7 +8895,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ne_Tensor_out + CPU, CUDA, PrivateUse1: ne_Tensor_out MPS: ne_tensor_out_mps QuantizedCPU: ne_out_quantized_cpu tags: pointwise @@ -8828,7 +8940,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Scalar_out + CPU, CUDA, PrivateUse1: eq_Scalar_out MPS: eq_scalar_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -8847,7 +8959,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Tensor_out + CPU, CUDA, PrivateUse1: eq_Tensor_out MPS: eq_tensor_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -8865,7 +8977,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ge_Scalar_out + CPU, CUDA, PrivateUse1: ge_Scalar_out MPS: ge_scalar_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -8884,7 +8996,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: ge_Tensor_out + CPU, CUDA, PrivateUse1: ge_Tensor_out MPS: ge_tensor_out_mps QuantizedCPU: ge_out_quantized_cpu tags: pointwise @@ -8929,7 +9041,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: le_Scalar_out + CPU, CUDA, PrivateUse1: le_Scalar_out MPS: le_scalar_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -8947,7 +9059,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: le_Tensor_out + CPU, CUDA, PrivateUse1: le_Tensor_out MPS: le_tensor_out_mps QuantizedCPU: le_out_quantized_cpu tags: pointwise @@ -8992,7 +9104,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: gt_Scalar_out + CPU, CUDA, PrivateUse1: gt_Scalar_out MPS: gt_scalar_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9011,7 +9123,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: gt_Tensor_out + CPU, CUDA, PrivateUse1: gt_Tensor_out MPS: gt_tensor_out_mps QuantizedCPU: gt_out_quantized_cpu tags: pointwise @@ -9056,7 +9168,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: lt_Scalar_out + CPU, CUDA, PrivateUse1: lt_Scalar_out MPS: lt_scalar_out_mps QuantizedCPU: lt_out_quantized_cpu tags: pointwise @@ -9074,7 +9186,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: lt_Tensor_out + CPU, CUDA, PrivateUse1: lt_Tensor_out MPS: lt_tensor_out_mps QuantizedCPU: lt_out_quantized_cpu tags: pointwise @@ -9116,12 +9228,12 @@ - func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: take_out + CPU, CUDA, PrivateUse1: take_out - func: take(Tensor self, Tensor index) -> Tensor variants: method, function dispatch: - CPU, CUDA: take + CPU, CUDA, PrivateUse1: take - func: take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) @@ -9132,6 +9244,7 @@ dispatch: CPU, QuantizedCPU: index_select_out_cpu_ CUDA, QuantizedCUDA: index_select_out_cuda + PrivateUse1: index_select_out_zoom MPS: index_select_out_mps - func: index_select(Tensor self, int dim, Tensor index) -> Tensor @@ -9141,6 +9254,7 @@ QuantizedCPU: index_select_quantized_cpu_ CUDA: index_select_cuda QuantizedCUDA: index_select_quantized_cuda + PrivateUse1: index_select_zoom SparseCPU: index_select_sparse_cpu SparseCUDA: index_select_sparse_cuda MPS: index_select_mps @@ -9162,6 +9276,7 @@ dispatch: CPU: masked_select_out_cpu CUDA: masked_select_out_cuda + PrivateUse1: masked_select_out_zoom MPS: masked_select_out_mps tags: dynamic_output_shape @@ -9170,6 +9285,7 @@ dispatch: CPU: masked_select_cpu CUDA: masked_select_cuda + PrivateUse1: masked_select_zoom MPS: masked_select_mps tags: dynamic_output_shape @@ -9182,6 +9298,7 @@ dispatch: CPU: nonzero_out_cpu CUDA: nonzero_out_cuda + PrivateUse1: nonzero_out_zoom MPS: nonzero_out_mps tags: dynamic_output_shape @@ -9190,6 +9307,7 @@ dispatch: CPU: nonzero_cpu CUDA: nonzero_cuda + PrivateUse1: nonzero_zoom MPS: nonzero_mps tags: [dynamic_output_shape, core] @@ -9212,7 +9330,7 @@ - func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA: gather_out + CPU, CUDA, PrivateUse1: gather_out MPS: gather_out_mps - func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor @@ -9237,7 +9355,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: addcmul_out + CPU, CUDA, PrivateUse1: addcmul_out MPS: addcmul_out_mps tags: pointwise @@ -9258,7 +9376,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: addcdiv_out + CPU, CUDA, PrivateUse1: addcdiv_out MPS: addcdiv_out_mps tags: pointwise @@ -9428,13 +9546,13 @@ - func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) tags: nondeterministic_seeded dispatch: - CPU, CUDA: multinomial_out + CPU, CUDA, PrivateUse1: multinomial_out MPS: multinomial_out_mps - func: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor variants: method, function dispatch: - CPU, CUDA: multinomial + CPU, CUDA, PrivateUse1: multinomial MPS: multinomial_mps tags: nondeterministic_seeded @@ -9443,7 +9561,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: lgamma_out + CPU, CUDA, PrivateUse1: lgamma_out MPS: lgamma_out_mps tags: pointwise @@ -9464,7 +9582,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: digamma_out + CPU, CUDA, PrivateUse1: digamma_out MPS: digamma_out_mps tags: pointwise @@ -9479,7 +9597,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: polygamma_out + CPU, CUDA, PrivateUse1: polygamma_out MPS: polygamma_out_mps tags: pointwise @@ -9519,7 +9637,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: erfinv_out + CPU, CUDA, PrivateUse1: erfinv_out MPS: erfinv_out_mps SparseCPU, SparseCUDA: erfinv_sparse_out SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr_out @@ -9539,7 +9657,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: i0_out + CPU, CUDA, PrivateUse1: i0_out tags: pointwise - func: sign(Tensor self) -> Tensor @@ -9565,7 +9683,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sign_out + CPU, CUDA, PrivateUse1: sign_out MPS: sign_out_mps SparseCPU, SparseCUDA: sign_sparse_out SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr_out @@ -9585,6 +9703,7 @@ dispatch: CPU: signbit_out CUDA: signbit_out + PrivateUse1: signbit_out MPS: signbit_out_mps SparseCPU, SparseCUDA: signbit_sparse_out SparseCsrCPU, SparseCsrCUDA: signbit_sparse_csr_out @@ -9602,7 +9721,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: atan2_out + CPU, CUDA, PrivateUse1: atan2_out MPS: atan2_out_mps tags: [core, pointwise] @@ -9633,7 +9752,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: lerp_Scalar + CPU, CUDA, PrivateUse1: lerp_Scalar MPS: lerp_Scalar_mps tags: pointwise @@ -9642,7 +9761,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: lerp_Tensor + CPU, CUDA, PrivateUse1: lerp_Tensor MPS: lerp_Tensor_mps tags: pointwise @@ -9662,12 +9781,14 @@ dispatch: CPU, MPS: histogram_histc_out CUDA: _histc_out_cuda + PrivateUse1: _histc_out_zoom - func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor variants: method, function dispatch: CPU, MPS: histogram_histc CUDA: _histc_cuda + PrivateUse1: _histc_zoom - func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) dispatch: @@ -9733,7 +9854,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: fmod_out + CPU, CUDA, PrivateUse1: fmod_out MPS: fmod_mps_out tags: pointwise @@ -9753,7 +9874,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: hypot_out + CPU, CUDA, PrivateUse1: hypot_out MPS: hypot_out_mps tags: pointwise @@ -9771,7 +9892,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: igamma_out + CPU, CUDA, PrivateUse1: igamma_out tags: pointwise - func: igamma(Tensor self, Tensor other) -> Tensor @@ -9788,7 +9909,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: igammac_out + CPU, CUDA, PrivateUse1: igammac_out tags: pointwise - func: igammac(Tensor self, Tensor other) -> Tensor @@ -9805,7 +9926,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA, MPS: nextafter_out + CPU, CUDA, PrivateUse1, MPS: nextafter_out tags: pointwise - func: nextafter(Tensor self, Tensor other) -> Tensor @@ -9840,7 +9961,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: remainder_out + CPU, CUDA, PrivateUse1: remainder_out MPS: remainder_out_mps tags: pointwise @@ -9868,14 +9989,14 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: min + CPU, CUDA, PrivateUse1: min MPS: min_mps QuantizedCPU: min_quantized_cpu - func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: min_unary_out + CPU, CUDA, PrivateUse1: min_unary_out QuantizedCPU: min_quantized_unary_out - func: fmin(Tensor self, Tensor other) -> Tensor @@ -9889,14 +10010,14 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MPS: fmin_out + CPU, CUDA, PrivateUse1, MPS: fmin_out tags: pointwise - func: max(Tensor self) -> Tensor device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: max + CPU, CUDA, PrivateUse1: max MPS: max_mps QuantizedCPU: max_quantized_cpu @@ -9911,7 +10032,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MPS: fmax_out + CPU, CUDA, PrivateUse1, MPS: fmax_out tags: pointwise - func: maximum(Tensor self, Tensor other) -> Tensor @@ -9925,7 +10046,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: maximum_out + CPU, CUDA, PrivateUse1: maximum_out MPS: maximum_out_mps tags: pointwise @@ -9943,7 +10064,7 @@ - func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: max_unary_out + CPU, CUDA, PrivateUse1: max_unary_out QuantizedCPU: max_quantized_unary_out - func: minimum(Tensor self, Tensor other) -> Tensor @@ -9957,7 +10078,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: minimum_out + CPU, CUDA, PrivateUse1: minimum_out MPS: minimum_out_mps tags: pointwise @@ -10000,7 +10121,7 @@ - func: sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) structured: True dispatch: - CPU, CUDA: sort_stable_out + CPU, CUDA, PrivateUse1: sort_stable_out MPS: sort_stable_out_mps - func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) @@ -10050,6 +10171,7 @@ dispatch: CPU: topk_out_cpu CUDA: topk_out_cuda + PrivateUse1: topk_out_zoom MPS: topk_out_mps - func: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) @@ -10068,7 +10190,7 @@ device_check: NoCheck structured: True dispatch: - CPU, CUDA: all_all_out + CPU, CUDA, PrivateUse1: all_all_out MPS: all_all_out_mps - func: any(Tensor self) -> Tensor @@ -10083,14 +10205,14 @@ device_check: NoCheck structured: True dispatch: - CPU, CUDA: any_all_out + CPU, CUDA, PrivateUse1: any_all_out MPS: any_all_out_mps - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA: renorm_out + CPU, CUDA, PrivateUse1: renorm_out MPS: renorm_out_mps - func: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor @@ -10108,13 +10230,13 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, MPS: unfold + CPU, CUDA, PrivateUse1, Meta, MPS: unfold QuantizedCPU, QuantizedCUDA: unfold - func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor variants: function dispatch: - CPU, CUDA: unfold_backward + CPU, CUDA, PrivateUse1: unfold_backward autogen: unfold_backward.out - func: equal(Tensor self, Tensor other) -> bool @@ -10123,6 +10245,7 @@ dispatch: CPU: cpu_equal CUDA: cuda_equal + PrivateUse1: zoom_equal MPS: mps_equal QuantizedCPU: equal_quantized_cpu @@ -10131,7 +10254,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: pow_Tensor_Tensor_out + CPU, CUDA, PrivateUse1: pow_Tensor_Tensor_out MPS: pow_tensor_tensor_out_mps tags: pointwise @@ -10159,7 +10282,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: pow_Tensor_Scalar_out + CPU, CUDA, PrivateUse1: pow_Tensor_Scalar_out SparseCPU, SparseCUDA: pow_out_sparse_scalar MPS: pow_tensor_scalar_out_mps tags: pointwise @@ -10217,7 +10340,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: normal_ + CPU, CUDA, PrivateUse1: normal_ MPS: normal_mps_ Meta: normal_meta_ SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_ @@ -10236,41 +10359,41 @@ - func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) tags: nondeterministic_seeded dispatch: - CPU, CUDA: normal_out + CPU, CUDA, PrivateUse1: normal_out MPS: normal_mps_out Meta: normal_out_meta - func: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor dispatch: - CPU, CUDA: normal + CPU, CUDA, PrivateUse1: normal MPS: normal_mps Meta: normal_meta tags: nondeterministic_seeded - func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: normal_out + CPU, CUDA, PrivateUse1: normal_out Meta: normal_out_meta MPS: normal_mps_out tags: nondeterministic_seeded - func: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor dispatch: - CPU, CUDA: normal + CPU, CUDA, PrivateUse1: normal MPS: normal_mps Meta: normal_meta tags: nondeterministic_seeded - func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: normal_out + CPU, CUDA, PrivateUse1: normal_out Meta: normal_out_meta MPS: normal_mps_out tags: nondeterministic_seeded - func: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor dispatch: - CPU, CUDA: normal + CPU, CUDA, PrivateUse1: normal MPS: normal_mps Meta: normal_meta tags: nondeterministic_seeded @@ -10296,6 +10419,7 @@ variants: function dispatch: CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ + PrivateUse1: _amp_foreach_non_finite_check_and_unscale_zoom_ CPU: _amp_foreach_non_finite_check_and_unscale_cpu_ autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out @@ -10303,6 +10427,7 @@ variants: function dispatch: CUDA: _amp_update_scale_cuda_ + PrivateUse1: _amp_update_scale_zoom_ CPU: _amp_update_scale_cpu_ autogen: _amp_update_scale, _amp_update_scale.out @@ -10325,6 +10450,7 @@ dispatch: CPU: foreach_tensor_add_scalar_kernel_slow CUDA: foreach_tensor_add_scalar_kernel_cuda + PrivateUse1: foreach_tensor_add_scalar_kernel_zoom - func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10332,6 +10458,7 @@ dispatch: CPU: foreach_tensor_add_scalar_kernel_slow_ CUDA: foreach_tensor_add_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_add_scalar_kernel_zoom_ autogen: _foreach_add.Scalar_out - func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] @@ -10340,6 +10467,7 @@ dispatch: CPU: foreach_tensor_add_list_kernel_slow CUDA: foreach_tensor_add_list_kernel_cuda + PrivateUse1: foreach_tensor_add_list_kernel_zoom - func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10347,6 +10475,7 @@ dispatch: CPU: foreach_tensor_add_list_kernel_slow_ CUDA: foreach_tensor_add_list_kernel_cuda_ + PrivateUse1: foreach_tensor_add_list_kernel_zoom_ autogen: _foreach_add.List_out - func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10355,6 +10484,7 @@ dispatch: CPU: foreach_tensor_add_scalarlist_kernel_slow CUDA: foreach_tensor_add_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_add_scalarlist_kernel_zoom - func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10362,6 +10492,7 @@ dispatch: CPU: foreach_tensor_add_scalarlist_kernel_slow_ CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_add_scalarlist_kernel_zoom_ autogen: _foreach_add.ScalarList_out - func: _foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[] @@ -10370,6 +10501,7 @@ dispatch: CPU: foreach_tensor_add_tensor_kernel_slow CUDA: foreach_tensor_add_tensor_kernel_cuda + PrivateUse1: foreach_tensor_add_tensor_kernel_zoom - func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10377,6 +10509,7 @@ dispatch: CPU: foreach_tensor_add_tensor_kernel_slow_ CUDA: foreach_tensor_add_tensor_kernel_cuda_ + PrivateUse1: foreach_tensor_add_tensor_kernel_zoom_ autogen: _foreach_add.Tensor_out - func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10385,6 +10518,7 @@ dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow CUDA: foreach_tensor_sub_scalar_kernel_cuda + PrivateUse1: foreach_tensor_sub_scalar_kernel_zoom - func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10392,6 +10526,7 @@ dispatch: CPU: foreach_tensor_sub_scalar_kernel_slow_ CUDA: foreach_tensor_sub_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_sub_scalar_kernel_zoom_ autogen: _foreach_sub.Scalar_out - func: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] @@ -10400,6 +10535,7 @@ dispatch: CPU: foreach_tensor_sub_list_kernel_slow CUDA: foreach_tensor_sub_list_kernel_cuda + PrivateUse1: foreach_tensor_sub_list_kernel_zoom - func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10407,6 +10543,7 @@ dispatch: CPU: foreach_tensor_sub_list_kernel_slow_ CUDA: foreach_tensor_sub_list_kernel_cuda_ + PrivateUse1: foreach_tensor_sub_list_kernel_zoom_ autogen: _foreach_sub.List_out - func: _foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10415,6 +10552,7 @@ dispatch: CPU: foreach_tensor_sub_scalarlist_kernel_slow CUDA: foreach_tensor_sub_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_sub_scalarlist_kernel_zoom - func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10422,6 +10560,7 @@ dispatch: CPU: foreach_tensor_sub_scalarlist_kernel_slow_ CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_sub_scalarlist_kernel_zoom_ autogen: _foreach_sub.ScalarList_out - func: _foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10430,6 +10569,7 @@ dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow CUDA: foreach_tensor_mul_scalar_kernel_cuda + PrivateUse1: foreach_tensor_mul_scalar_kernel_zoom - func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10437,6 +10577,7 @@ dispatch: CPU: foreach_tensor_mul_scalar_kernel_slow_ CUDA: foreach_tensor_mul_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_scalar_kernel_zoom_ autogen: _foreach_mul.Scalar_out - func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10445,6 +10586,7 @@ dispatch: CPU: foreach_tensor_mul_list_kernel_slow CUDA: foreach_tensor_mul_list_kernel_cuda + PrivateUse1: foreach_tensor_mul_list_kernel_zoom - func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10452,6 +10594,7 @@ dispatch: CPU: foreach_tensor_mul_list_kernel_slow_ CUDA: foreach_tensor_mul_list_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_list_kernel_zoom_ autogen: _foreach_mul.List_out - func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10460,6 +10603,7 @@ dispatch: CPU: foreach_tensor_mul_scalarlist_kernel_slow CUDA: foreach_tensor_mul_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_mul_scalarlist_kernel_zoom - func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10467,6 +10611,7 @@ dispatch: CPU: foreach_tensor_mul_scalarlist_kernel_slow_ CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_scalarlist_kernel_zoom_ autogen: _foreach_mul.ScalarList_out - func: _foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[] @@ -10475,6 +10620,7 @@ dispatch: CPU: foreach_tensor_mul_tensor_kernel_slow CUDA: foreach_tensor_mul_tensor_kernel_cuda + PrivateUse1: foreach_tensor_mul_tensor_kernel_zoom - func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10482,6 +10628,7 @@ dispatch: CPU: foreach_tensor_mul_tensor_kernel_slow_ CUDA: foreach_tensor_mul_tensor_kernel_cuda_ + PrivateUse1: foreach_tensor_mul_tensor_kernel_zoom_ autogen: _foreach_mul.Tensor_out - func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10490,6 +10637,7 @@ dispatch: CPU: foreach_tensor_div_scalar_kernel_slow CUDA: foreach_tensor_div_scalar_kernel_cuda + PrivateUse1: foreach_tensor_div_scalar_kernel_zoom - func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10497,6 +10645,7 @@ dispatch: CPU: foreach_tensor_div_scalar_kernel_slow_ CUDA: foreach_tensor_div_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_div_scalar_kernel_zoom_ autogen: _foreach_div.Scalar_out - func: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10505,6 +10654,7 @@ dispatch: CPU: foreach_tensor_div_list_kernel_slow CUDA: foreach_tensor_div_list_kernel_cuda + PrivateUse1: foreach_tensor_div_list_kernel_zoom - func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10512,6 +10662,7 @@ dispatch: CPU: foreach_tensor_div_list_kernel_slow_ CUDA: foreach_tensor_div_list_kernel_cuda_ + PrivateUse1: foreach_tensor_div_list_kernel_zoom_ autogen: _foreach_div.List_out - func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10520,6 +10671,7 @@ dispatch: CPU: foreach_tensor_div_scalarlist_kernel_slow CUDA: foreach_tensor_div_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_div_scalarlist_kernel_zoom - func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10527,6 +10679,7 @@ dispatch: CPU: foreach_tensor_div_scalarlist_kernel_slow_ CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_div_scalarlist_kernel_zoom_ autogen: _foreach_div.ScalarList_out - func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[] @@ -10535,6 +10688,7 @@ dispatch: CPU: foreach_tensor_div_tensor_kernel_slow CUDA: foreach_tensor_div_tensor_kernel_cuda + PrivateUse1: foreach_tensor_div_tensor_kernel_zoom - func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10542,6 +10696,7 @@ dispatch: CPU: foreach_tensor_div_tensor_kernel_slow_ CUDA: foreach_tensor_div_tensor_kernel_cuda_ + PrivateUse1: foreach_tensor_div_tensor_kernel_zoom_ autogen: _foreach_div.Tensor_out - func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10550,6 +10705,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom - func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10557,6 +10713,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom_ autogen: _foreach_clamp_max.Scalar_out - func: _foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10565,6 +10722,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom - func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10572,6 +10730,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom_ autogen: _foreach_clamp_max.List_out - func: _foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10580,6 +10739,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom - func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10587,6 +10747,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom_ autogen: _foreach_clamp_max.ScalarList_out - func: _foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10595,6 +10756,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom - func: _foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10602,6 +10764,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom_ autogen: _foreach_clamp_min.Scalar_out - func: _foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10610,6 +10773,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom - func: _foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10617,6 +10781,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom_ autogen: _foreach_clamp_min.List_out - func: _foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10625,6 +10790,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom - func: _foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10632,6 +10798,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom_ autogen: _foreach_clamp_min.ScalarList_out # foreach_minimum/maximum dispatches to clamp_max/min @@ -10641,6 +10808,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom - func: _foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10648,6 +10816,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalar_kernel_zoom_ autogen: _foreach_maximum.Scalar_out # foreach_minimum/maximum dispatches to clamp_max/min @@ -10657,6 +10826,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow CUDA: foreach_tensor_clamp_min_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom - func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10664,6 +10834,7 @@ dispatch: CPU: foreach_tensor_clamp_min_list_kernel_slow_ CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_list_kernel_zoom_ autogen: _foreach_maximum.List_out # foreach_minimum/maximum dispatches to clamp_max/min @@ -10673,6 +10844,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom - func: _foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10680,6 +10852,7 @@ dispatch: CPU: foreach_tensor_clamp_min_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_min_scalarlist_kernel_zoom_ autogen: _foreach_maximum.ScalarList_out - func: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] @@ -10688,6 +10861,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom - func: _foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10695,6 +10869,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalar_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalar_kernel_zoom_ autogen: _foreach_minimum.Scalar_out - func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] @@ -10703,6 +10878,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow CUDA: foreach_tensor_clamp_max_list_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom - func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10710,6 +10886,7 @@ dispatch: CPU: foreach_tensor_clamp_max_list_kernel_slow_ CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_list_kernel_zoom_ autogen: _foreach_minimum.List_out - func: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] @@ -10718,6 +10895,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom - func: _foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10725,6 +10903,7 @@ dispatch: CPU: foreach_tensor_clamp_max_scalarlist_kernel_slow_ CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_clamp_max_scalarlist_kernel_zoom_ autogen: _foreach_minimum.ScalarList_out - func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] @@ -10733,6 +10912,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalar_slow CUDA: foreach_tensor_addcdiv_scalar_cuda + PrivateUse1: foreach_tensor_addcdiv_scalar_zoom - func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10740,6 +10920,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalarlist_slow CUDA: foreach_tensor_addcdiv_scalarlist_cuda + PrivateUse1: foreach_tensor_addcdiv_scalarlist_zoom - func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10747,6 +10928,7 @@ dispatch: CPU: foreach_tensor_addcdiv_tensor_slow CUDA: foreach_tensor_addcdiv_tensor_cuda + PrivateUse1: foreach_tensor_addcdiv_tensor_zoom - func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10754,6 +10936,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalar_slow_ CUDA: foreach_tensor_addcdiv_scalar_cuda_ + PrivateUse1: foreach_tensor_addcdiv_scalar_zoom_ autogen: _foreach_addcdiv.Scalar_out - func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () @@ -10762,6 +10945,7 @@ dispatch: CPU: foreach_tensor_addcdiv_scalarlist_slow_ CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ + PrivateUse1: foreach_tensor_addcdiv_scalarlist_zoom_ autogen: _foreach_addcdiv.ScalarList_out - func: _foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () @@ -10770,6 +10954,7 @@ dispatch: CPU: foreach_tensor_addcdiv_tensor_slow_ CUDA: foreach_tensor_addcdiv_tensor_cuda_ + PrivateUse1: foreach_tensor_addcdiv_tensor_zoom_ autogen: _foreach_addcdiv.Tensor_out - func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] @@ -10778,6 +10963,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalar_slow CUDA: foreach_tensor_addcmul_scalar_cuda + PrivateUse1: foreach_tensor_addcmul_scalar_zoom - func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10785,6 +10971,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalarlist_slow CUDA: foreach_tensor_addcmul_scalarlist_cuda + PrivateUse1: foreach_tensor_addcmul_scalarlist_zoom - func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10792,6 +10979,7 @@ dispatch: CPU: foreach_tensor_addcmul_tensor_slow CUDA: foreach_tensor_addcmul_tensor_cuda + PrivateUse1: foreach_tensor_addcmul_tensor_zoom - func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10799,6 +10987,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalar_slow_ CUDA: foreach_tensor_addcmul_scalar_cuda_ + PrivateUse1: foreach_tensor_addcmul_scalar_zoom_ autogen: _foreach_addcmul.Scalar_out - func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () @@ -10807,6 +10996,7 @@ dispatch: CPU: foreach_tensor_addcmul_scalarlist_slow_ CUDA: foreach_tensor_addcmul_scalarlist_cuda_ + PrivateUse1: foreach_tensor_addcmul_scalarlist_zoom_ autogen: _foreach_addcmul.ScalarList_out - func: _foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () @@ -10815,6 +11005,7 @@ dispatch: CPU: foreach_tensor_addcmul_tensor_slow_ CUDA: foreach_tensor_addcmul_tensor_cuda_ + PrivateUse1: foreach_tensor_addcmul_tensor_zoom_ autogen: _foreach_addcmul.Tensor_out - func: _foreach_abs(Tensor[] self) -> Tensor[] @@ -10823,6 +11014,7 @@ dispatch: CPU: foreach_tensor_abs_slow CUDA: foreach_tensor_abs_cuda + PrivateUse1: foreach_tensor_abs_zoom - func: _foreach_abs_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10830,6 +11022,7 @@ dispatch: CPU: foreach_tensor_abs_slow_ CUDA: foreach_tensor_abs_cuda_ + PrivateUse1: foreach_tensor_abs_zoom_ autogen: _foreach_abs.out - func: _foreach_acos(Tensor[] self) -> Tensor[] @@ -10838,6 +11031,7 @@ dispatch: CPU: foreach_tensor_acos_slow CUDA: foreach_tensor_acos_cuda + PrivateUse1: foreach_tensor_acos_zoom - func: _foreach_acos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10845,6 +11039,7 @@ dispatch: CPU: foreach_tensor_acos_slow_ CUDA: foreach_tensor_acos_cuda_ + PrivateUse1: foreach_tensor_acos_zoom_ autogen: _foreach_acos.out - func: _foreach_asin(Tensor[] self) -> Tensor[] @@ -10853,6 +11048,7 @@ dispatch: CPU: foreach_tensor_asin_slow CUDA: foreach_tensor_asin_cuda + PrivateUse1: foreach_tensor_asin_zoom - func: _foreach_asin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10860,6 +11056,7 @@ dispatch: CPU: foreach_tensor_asin_slow_ CUDA: foreach_tensor_asin_cuda_ + PrivateUse1: foreach_tensor_asin_zoom_ autogen: _foreach_asin.out - func: _foreach_atan(Tensor[] self) -> Tensor[] @@ -10868,6 +11065,7 @@ dispatch: CPU: foreach_tensor_atan_slow CUDA: foreach_tensor_atan_cuda + PrivateUse1: foreach_tensor_atan_zoom - func: _foreach_atan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10875,6 +11073,7 @@ dispatch: CPU: foreach_tensor_atan_slow_ CUDA: foreach_tensor_atan_cuda_ + PrivateUse1: foreach_tensor_atan_zoom_ autogen: _foreach_atan.out - func: _foreach_ceil(Tensor[] self) -> Tensor[] @@ -10883,6 +11082,7 @@ dispatch: CPU: foreach_tensor_ceil_slow CUDA: foreach_tensor_ceil_cuda + PrivateUse1: foreach_tensor_ceil_zoom - func: _foreach_ceil_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10890,6 +11090,7 @@ dispatch: CPU: foreach_tensor_ceil_slow_ CUDA: foreach_tensor_ceil_cuda_ + PrivateUse1: foreach_tensor_ceil_zoom_ autogen: _foreach_ceil.out - func: _foreach_cos(Tensor[] self) -> Tensor[] @@ -10898,6 +11099,7 @@ dispatch: CPU: foreach_tensor_cos_slow CUDA: foreach_tensor_cos_cuda + PrivateUse1: foreach_tensor_cos_zoom - func: _foreach_cos_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10905,6 +11107,7 @@ dispatch: CPU: foreach_tensor_cos_slow_ CUDA: foreach_tensor_cos_cuda_ + PrivateUse1: foreach_tensor_cos_zoom_ autogen: _foreach_cos.out - func: _foreach_cosh(Tensor[] self) -> Tensor[] @@ -10913,6 +11116,7 @@ dispatch: CPU: foreach_tensor_cosh_slow CUDA: foreach_tensor_cosh_cuda + PrivateUse1: foreach_tensor_cosh_zoom - func: _foreach_cosh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10920,6 +11124,7 @@ dispatch: CPU: foreach_tensor_cosh_slow_ CUDA: foreach_tensor_cosh_cuda_ + PrivateUse1: foreach_tensor_cosh_zoom_ autogen: _foreach_cosh.out - func: _foreach_erf(Tensor[] self) -> Tensor[] @@ -10928,6 +11133,7 @@ dispatch: CPU: foreach_tensor_erf_slow CUDA: foreach_tensor_erf_cuda + PrivateUse1: foreach_tensor_erf_zoom - func: _foreach_erf_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10935,6 +11141,7 @@ dispatch: CPU: foreach_tensor_erf_slow_ CUDA: foreach_tensor_erf_cuda_ + PrivateUse1: foreach_tensor_erf_zoom_ autogen: _foreach_erf.out - func: _foreach_erfc(Tensor[] self) -> Tensor[] @@ -10943,6 +11150,7 @@ dispatch: CPU: foreach_tensor_erfc_slow CUDA: foreach_tensor_erfc_cuda + PrivateUse1: foreach_tensor_erfc_zoom - func: _foreach_erfc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10950,6 +11158,7 @@ dispatch: CPU: foreach_tensor_erfc_slow_ CUDA: foreach_tensor_erfc_cuda_ + PrivateUse1: foreach_tensor_erfc_zoom_ autogen: _foreach_erfc.out - func: _foreach_exp(Tensor[] self) -> Tensor[] @@ -10958,6 +11167,7 @@ dispatch: CPU: foreach_tensor_exp_slow CUDA: foreach_tensor_exp_cuda + PrivateUse1: foreach_tensor_exp_zoom - func: _foreach_exp_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10965,6 +11175,7 @@ dispatch: CPU: foreach_tensor_exp_slow_ CUDA: foreach_tensor_exp_cuda_ + PrivateUse1: foreach_tensor_exp_zoom_ autogen: _foreach_exp.out - func: _foreach_expm1(Tensor[] self) -> Tensor[] @@ -10973,6 +11184,7 @@ dispatch: CPU: foreach_tensor_expm1_slow CUDA: foreach_tensor_expm1_cuda + PrivateUse1: foreach_tensor_expm1_zoom - func: _foreach_expm1_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10980,6 +11192,7 @@ dispatch: CPU: foreach_tensor_expm1_slow_ CUDA: foreach_tensor_expm1_cuda_ + PrivateUse1: foreach_tensor_expm1_zoom_ autogen: _foreach_expm1.out - func: _foreach_floor(Tensor[] self) -> Tensor[] @@ -10988,6 +11201,7 @@ dispatch: CPU: foreach_tensor_floor_slow CUDA: foreach_tensor_floor_cuda + PrivateUse1: foreach_tensor_floor_zoom - func: _foreach_floor_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -10995,6 +11209,7 @@ dispatch: CPU: foreach_tensor_floor_slow_ CUDA: foreach_tensor_floor_cuda_ + PrivateUse1: foreach_tensor_floor_zoom_ autogen: _foreach_floor.out - func: _foreach_frac(Tensor[] self) -> Tensor[] @@ -11003,6 +11218,7 @@ dispatch: CPU: foreach_tensor_frac_slow CUDA: foreach_tensor_frac_cuda + PrivateUse1: foreach_tensor_frac_zoom - func: _foreach_frac_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11010,6 +11226,7 @@ dispatch: CPU: foreach_tensor_frac_slow_ CUDA: foreach_tensor_frac_cuda_ + PrivateUse1: foreach_tensor_frac_zoom_ autogen: _foreach_frac.out - func: _foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[] @@ -11018,6 +11235,7 @@ dispatch: CPU: foreach_tensor_ternary_lerp_slow CUDA: foreach_tensor_lerp_ternary_cuda + PrivateUse1: foreach_tensor_lerp_ternary_zoom autogen: _foreach_lerp.List_out - func: _foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> () @@ -11026,6 +11244,7 @@ dispatch: CPU: foreach_tensor_ternary_lerp_slow_ CUDA: foreach_tensor_lerp_ternary_cuda_ + PrivateUse1: foreach_tensor_lerp_ternary_zoom_ autogen: _foreach_lerp.List_out - func: _foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[] @@ -11034,6 +11253,7 @@ dispatch: CPU: foreach_tensor_lerp_list_kernel_slow CUDA: foreach_tensor_lerp_list_cuda + PrivateUse1: foreach_tensor_lerp_list_zoom autogen: _foreach_lerp.Scalar_out - func: _foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> () @@ -11042,6 +11262,7 @@ dispatch: CPU: foreach_tensor_lerp_list_kernel_slow_ CUDA: foreach_tensor_lerp_list_cuda_ + PrivateUse1: foreach_tensor_lerp_list_zoom_ autogen: _foreach_lerp.Scalar_out - func: _foreach_lgamma(Tensor[] self) -> Tensor[] @@ -11050,6 +11271,7 @@ dispatch: CPU: foreach_tensor_lgamma_slow CUDA: foreach_tensor_lgamma_cuda + PrivateUse1: foreach_tensor_lgamma_zoom - func: _foreach_lgamma_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11057,6 +11279,7 @@ dispatch: CPU: foreach_tensor_lgamma_slow_ CUDA: foreach_tensor_lgamma_cuda_ + PrivateUse1: foreach_tensor_lgamma_zoom_ autogen: _foreach_lgamma.out - func: _foreach_log(Tensor[] self) -> Tensor[] @@ -11065,6 +11288,7 @@ dispatch: CPU: foreach_tensor_log_slow CUDA: foreach_tensor_log_cuda + PrivateUse1: foreach_tensor_log_zoom - func: _foreach_log_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11072,6 +11296,7 @@ dispatch: CPU: foreach_tensor_log_slow_ CUDA: foreach_tensor_log_cuda_ + PrivateUse1: foreach_tensor_log_zoom_ autogen: _foreach_log.out - func: _foreach_log10(Tensor[] self) -> Tensor[] @@ -11080,6 +11305,7 @@ dispatch: CPU: foreach_tensor_log10_slow CUDA: foreach_tensor_log10_cuda + PrivateUse1: foreach_tensor_log10_zoom - func: _foreach_log10_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11087,6 +11313,7 @@ dispatch: CPU: foreach_tensor_log10_slow_ CUDA: foreach_tensor_log10_cuda_ + PrivateUse1: foreach_tensor_log10_zoom_ autogen: _foreach_log10.out - func: _foreach_log1p(Tensor[] self) -> Tensor[] @@ -11095,6 +11322,7 @@ dispatch: CPU: foreach_tensor_log1p_slow CUDA: foreach_tensor_log1p_cuda + PrivateUse1: foreach_tensor_log1p_zoom - func: _foreach_log1p_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11102,6 +11330,7 @@ dispatch: CPU: foreach_tensor_log1p_slow_ CUDA: foreach_tensor_log1p_cuda_ + PrivateUse1: foreach_tensor_log1p_zoom_ autogen: _foreach_log1p.out - func: _foreach_log2(Tensor[] self) -> Tensor[] @@ -11110,6 +11339,7 @@ dispatch: CPU: foreach_tensor_log2_slow CUDA: foreach_tensor_log2_cuda + PrivateUse1: foreach_tensor_log2_zoom - func: _foreach_log2_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11117,6 +11347,7 @@ dispatch: CPU: foreach_tensor_log2_slow_ CUDA: foreach_tensor_log2_cuda_ + PrivateUse1: foreach_tensor_log2_zoom_ autogen: _foreach_log2.out - func: _foreach_neg(Tensor[] self) -> Tensor[] @@ -11125,6 +11356,7 @@ dispatch: CPU: foreach_tensor_neg_slow CUDA: foreach_tensor_neg_cuda + PrivateUse1: foreach_tensor_neg_zoom - func: _foreach_neg_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11132,6 +11364,7 @@ dispatch: CPU: foreach_tensor_neg_slow_ CUDA: foreach_tensor_neg_cuda_ + PrivateUse1: foreach_tensor_neg_zoom_ autogen: _foreach_neg.out - func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2) -> Tensor[] @@ -11140,6 +11373,7 @@ dispatch: CPU: foreach_tensor_norm_slow CUDA: foreach_tensor_norm_cuda + PrivateUse1: foreach_tensor_norm_zoom autogen: _foreach_norm.Scalar_out - func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] @@ -11148,6 +11382,7 @@ dispatch: CPU: foreach_tensor_pow_list_kernel_slow CUDA: foreach_tensor_pow_list_kernel_cuda + PrivateUse1: foreach_tensor_pow_list_kernel_zoom - func: _foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11155,6 +11390,7 @@ dispatch: CPU: foreach_tensor_pow_scalar_kernel_slow CUDA: foreach_tensor_pow_scalar_kernel_cuda + PrivateUse1: foreach_tensor_pow_scalar_kernel_zoom - func: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11162,6 +11398,7 @@ dispatch: CPU: foreach_tensor_pow_scalarlist_kernel_slow CUDA: foreach_tensor_pow_scalarlist_kernel_cuda + PrivateUse1: foreach_tensor_pow_scalarlist_kernel_zoom - func: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11169,6 +11406,7 @@ dispatch: CPU: foreach_scalar_pow_list_kernel_slow CUDA: foreach_scalar_pow_list_kernel_cuda + PrivateUse1: foreach_scalar_pow_list_kernel_zoom - func: _foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () device_check: NoCheck @@ -11176,6 +11414,7 @@ dispatch: CPU: foreach_tensor_pow_list_kernel_slow_ CUDA: foreach_tensor_pow_list_kernel_cuda_ + PrivateUse1: foreach_tensor_pow_list_kernel_zoom_ autogen: _foreach_pow.List_out - func: _foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> () @@ -11184,6 +11423,7 @@ dispatch: CPU: foreach_tensor_pow_scalar_kernel_slow_ CUDA: foreach_tensor_pow_scalar_kernel_cuda_ + PrivateUse1: foreach_tensor_pow_scalar_kernel_zoom_ autogen: _foreach_pow.Scalar_out - func: _foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> () @@ -11192,6 +11432,7 @@ dispatch: CPU: foreach_tensor_pow_scalarlist_kernel_slow_ CUDA: foreach_tensor_pow_scalarlist_kernel_cuda_ + PrivateUse1: foreach_tensor_pow_scalarlist_kernel_zoom_ autogen: _foreach_pow.ScalarList_out - func: _foreach_reciprocal(Tensor[] self) -> Tensor[] @@ -11200,6 +11441,7 @@ dispatch: CPU: foreach_tensor_reciprocal_slow CUDA: foreach_tensor_reciprocal_cuda + PrivateUse1: foreach_tensor_reciprocal_zoom - func: _foreach_reciprocal_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11207,6 +11449,7 @@ dispatch: CPU: foreach_tensor_reciprocal_slow_ CUDA: foreach_tensor_reciprocal_cuda_ + PrivateUse1: foreach_tensor_reciprocal_zoom_ autogen: _foreach_reciprocal.out - func: _foreach_round(Tensor[] self) -> Tensor[] @@ -11215,6 +11458,7 @@ dispatch: CPU: foreach_tensor_round_slow CUDA: foreach_tensor_round_cuda + PrivateUse1: foreach_tensor_round_zoom - func: _foreach_round_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11222,6 +11466,7 @@ dispatch: CPU: foreach_tensor_round_slow_ CUDA: foreach_tensor_round_cuda_ + PrivateUse1: foreach_tensor_round_zoom_ autogen: _foreach_round.out - func: _foreach_sigmoid(Tensor[] self) -> Tensor[] @@ -11230,6 +11475,7 @@ dispatch: CPU: foreach_tensor_sigmoid_slow CUDA: foreach_tensor_sigmoid_cuda + PrivateUse1: foreach_tensor_sigmoid_zoom - func: _foreach_sigmoid_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11237,6 +11483,7 @@ dispatch: CPU: foreach_tensor_sigmoid_slow_ CUDA: foreach_tensor_sigmoid_cuda_ + PrivateUse1: foreach_tensor_sigmoid_zoom_ autogen: _foreach_sigmoid.out - func: _foreach_sign(Tensor[] self) -> Tensor[] @@ -11245,6 +11492,7 @@ dispatch: CPU: foreach_tensor_sign_slow CUDA: foreach_tensor_sign_cuda + PrivateUse1: foreach_tensor_sign_zoom - func: _foreach_sign_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11252,6 +11500,7 @@ dispatch: CPU: foreach_tensor_sign_slow_ CUDA: foreach_tensor_sign_cuda_ + PrivateUse1: foreach_tensor_sign_zoom_ autogen: _foreach_sign.out - func: _foreach_sin(Tensor[] self) -> Tensor[] @@ -11260,6 +11509,7 @@ dispatch: CPU: foreach_tensor_sin_slow CUDA: foreach_tensor_sin_cuda + PrivateUse1: foreach_tensor_sin_zoom - func: _foreach_sin_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11267,6 +11517,7 @@ dispatch: CPU: foreach_tensor_sin_slow_ CUDA: foreach_tensor_sin_cuda_ + PrivateUse1: foreach_tensor_sin_zoom_ autogen: _foreach_sin.out - func: _foreach_sinh(Tensor[] self) -> Tensor[] @@ -11275,6 +11526,7 @@ dispatch: CPU: foreach_tensor_sinh_slow CUDA: foreach_tensor_sinh_cuda + PrivateUse1: foreach_tensor_sinh_zoom - func: _foreach_sinh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11282,6 +11534,7 @@ dispatch: CPU: foreach_tensor_sinh_slow_ CUDA: foreach_tensor_sinh_cuda_ + PrivateUse1: foreach_tensor_sinh_zoom_ autogen: _foreach_sinh.out - func: _foreach_sqrt(Tensor[] self) -> Tensor[] @@ -11290,6 +11543,7 @@ dispatch: CPU: foreach_tensor_sqrt_slow CUDA: foreach_tensor_sqrt_cuda + PrivateUse1: foreach_tensor_sqrt_zoom - func: _foreach_sqrt_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11297,6 +11551,7 @@ dispatch: CPU: foreach_tensor_sqrt_slow_ CUDA: foreach_tensor_sqrt_cuda_ + PrivateUse1: foreach_tensor_sqrt_zoom_ autogen: _foreach_sqrt.out - func: _foreach_tan(Tensor[] self) -> Tensor[] @@ -11305,6 +11560,7 @@ dispatch: CPU: foreach_tensor_tan_slow CUDA: foreach_tensor_tan_cuda + PrivateUse1: foreach_tensor_tan_zoom - func: _foreach_tan_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11312,6 +11568,7 @@ dispatch: CPU: foreach_tensor_tan_slow_ CUDA: foreach_tensor_tan_cuda_ + PrivateUse1: foreach_tensor_tan_zoom_ autogen: _foreach_tan.out - func: _foreach_tanh(Tensor[] self) -> Tensor[] @@ -11320,6 +11577,7 @@ dispatch: CPU: foreach_tensor_tanh_slow CUDA: foreach_tensor_tanh_cuda + PrivateUse1: foreach_tensor_tanh_zoom - func: _foreach_tanh_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11327,6 +11585,7 @@ dispatch: CPU: foreach_tensor_tanh_slow_ CUDA: foreach_tensor_tanh_cuda_ + PrivateUse1: foreach_tensor_tanh_zoom_ autogen: _foreach_tanh.out - func: _foreach_trunc(Tensor[] self) -> Tensor[] @@ -11335,6 +11594,7 @@ dispatch: CPU: foreach_tensor_trunc_slow CUDA: foreach_tensor_trunc_cuda + PrivateUse1: foreach_tensor_trunc_zoom - func: _foreach_trunc_(Tensor(a!)[] self) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices @@ -11342,6 +11602,7 @@ dispatch: CPU: foreach_tensor_trunc_slow_ CUDA: foreach_tensor_trunc_cuda_ + PrivateUse1: foreach_tensor_trunc_zoom_ autogen: _foreach_trunc.out - func: _foreach_zero_(Tensor(a!)[] self) -> () @@ -11350,6 +11611,7 @@ dispatch: CPU: foreach_tensor_zero_slow_ CUDA: foreach_tensor_zero_cuda_ + PrivateUse1: foreach_tensor_zero_zoom_ autogen: _foreach_zero, _foreach_zero.out - func: _foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> () @@ -11358,6 +11620,7 @@ dispatch: CPU: foreach_tensor_copy_list_kernel_slow_ CUDA: foreach_tensor_copy_list_kernel_cuda_ + PrivateUse1: foreach_tensor_copy_list_kernel_zoom_ autogen: _foreach_copy.out - func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out @@ -11370,18 +11633,21 @@ dispatch: CPU: bucketize_cpu CUDA: bucketize_cuda + PrivateUse1: bucketize_zoom MPS: bucketize_mps - func: bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: bucketize_out_cpu CUDA: bucketize_out_cuda + PrivateUse1: bucketize_out_zoom MPS: bucketize_out_mps - func: bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor dispatch: CPU: bucketize_cpu CUDA: bucketize_cuda + PrivateUse1: bucketize_zoom MPS: bucketize_mps autogen: bucketize.Scalar_out @@ -11389,24 +11655,28 @@ dispatch: CPU: searchsorted_cpu CUDA: searchsorted_cuda + PrivateUse1: searchsorted_zoom MPS: searchsorted_mps - func: searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: searchsorted_out_cpu CUDA: searchsorted_out_cuda + PrivateUse1: searchsorted_out_zoom MPS: searchsorted_out_mps - func: searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor dispatch: CPU: searchsorted_cpu CUDA: searchsorted_cuda + PrivateUse1: searchsorted_zoom MPS: searchsorted_mps - func: searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: searchsorted_out_cpu CUDA: searchsorted_out_cuda + PrivateUse1: searchsorted_out_zoom MPS: searchsorted_out_mps - func: _convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor @@ -11435,7 +11705,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: mse_loss_out + CPU, CUDA, PrivateUse1: mse_loss_out MPS: mse_loss_out_mps - func: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor @@ -11446,13 +11716,13 @@ - func: mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: - CPU, CUDA: mse_loss_backward_out + CPU, CUDA, PrivateUse1: mse_loss_backward_out MPS: mse_loss_backward_out_mps - func: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor python_module: nn dispatch: - CPU, CUDA: mse_loss_backward + CPU, CUDA, PrivateUse1: mse_loss_backward MPS: mse_loss_backward_mps - func: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor @@ -11462,25 +11732,29 @@ python_module: nn dispatch: CPU: multi_margin_loss_cpu_out - CUDA: multi_margin_loss_cuda_out + CUDA: multi_margin_loss_cuda + PrivateUse1: multi_margin_loss_zoom_out - func: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor python_module: nn dispatch: CPU: multi_margin_loss_cpu CUDA: multi_margin_loss_cuda + PrivateUse1: multi_margin_loss_zoom - func: multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: multi_margin_loss_cpu_backward_out - CUDA: multi_margin_loss_cuda_backward_out + CUDA: multi_margin_loss_cuda + PrivateUse1: multi_margin_loss_zoom_backward_out - func: multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor python_module: nn dispatch: CPU: multi_margin_loss_cpu_backward - CUDA: multi_margin_loss_cuda_backward + CUDA: multi_margin_loss_cuda + PrivateUse1: multi_margin_loss_zoom_backward - func: multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -11493,24 +11767,28 @@ dispatch: CPU: multilabel_margin_loss_forward_out_cpu CUDA: multilabel_margin_loss_forward_out_cuda + PrivateUse1: multilabel_margin_loss_forward_out_zoom - func: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) python_module: nn dispatch: CPU: multilabel_margin_loss_forward_cpu CUDA: multilabel_margin_loss_forward_cuda + PrivateUse1: multilabel_margin_loss_forward_zoom - func: multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: CPU: multilabel_margin_loss_backward_cpu_out - CUDA: multilabel_margin_loss_backward_cuda_out + CUDA: multilabel_margin_loss_backward_cuda + PrivateUse1: multilabel_margin_loss_backward_zoom_out - func: multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor python_module: nn dispatch: CPU: multilabel_margin_loss_backward_cpu CUDA: multilabel_margin_loss_backward_cuda + PrivateUse1: multilabel_margin_loss_backward_zoom - func: nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -11531,6 +11809,7 @@ dispatch: CPU: nll_loss_forward_out_cpu CUDA: nll_loss_forward_out_cuda + PrivateUse1: nll_loss_forward_out_zoom MPS: nll_loss_forward_out_mps - func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) @@ -11543,6 +11822,7 @@ dispatch: CPU: nll_loss_backward_out_cpu CUDA: nll_loss_backward_out_cuda + PrivateUse1: nll_loss_backward_out_zoom MPS: nll_loss_backward_out_mps - func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor @@ -11562,6 +11842,7 @@ dispatch: CPU: nll_loss2d_forward_out_cpu CUDA: nll_loss2d_forward_out_cuda + PrivateUse1: nll_loss2d_forward_out_zoom MPS: nll_loss2d_forward_out_mps - func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) @@ -11569,6 +11850,7 @@ dispatch: CPU: nll_loss2d_forward_cpu CUDA: nll_loss2d_forward_cuda + PrivateUse1: nll_loss2d_forward_zoom MPS: nll_loss2d_forward_mps - func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -11576,6 +11858,7 @@ dispatch: CPU: nll_loss2d_backward_out_cpu CUDA: nll_loss2d_backward_out_cuda + PrivateUse1: nll_loss2d_backward_out_zoom MPS: nll_loss2d_backward_out_mps - func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor @@ -11583,6 +11866,7 @@ dispatch: CPU: nll_loss2d_backward_cpu CUDA: nll_loss2d_backward_cuda + PrivateUse1: nll_loss2d_backward_zoom MPS: nll_loss2d_backward_mps - func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) @@ -11591,7 +11875,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: smooth_l1_loss_out + CPU, CUDA, PrivateUse1: smooth_l1_loss_out MPS: smooth_l1_loss_out_mps - func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor @@ -11604,6 +11888,7 @@ dispatch: CPU: smooth_l1_loss_backward_out CUDA: smooth_l1_loss_backward_out + PrivateUse1: smooth_l1_loss_backward_out MPS: smooth_l1_loss_backward_out_mps - func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor @@ -11614,19 +11899,19 @@ - func: huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: - CPU, CUDA: huber_loss_out + CPU, CUDA, PrivateUse1: huber_loss_out MPS: huber_loss_out_mps - func: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor python_module: nn dispatch: - CPU, CUDA: huber_loss + CPU, CUDA, PrivateUse1: huber_loss MPS: huber_loss_mps - func: huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: - CPU, CUDA: huber_loss_backward_out + CPU, CUDA, PrivateUse1: huber_loss_backward_out MPS: huber_loss_backward_out_mps - func: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor @@ -11660,7 +11945,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: elu_out + CPU, CUDA, PrivateUse1: elu_out MPS: elu_out_mps - func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor @@ -11673,7 +11958,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: elu_backward_out + CPU, CUDA, PrivateUse1: elu_backward_out MPS: elu_backward_out_mps - func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor @@ -11690,7 +11975,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: glu_out + CPU, CUDA, PrivateUse1: glu_out MPS: glu_out_mps - func: glu(Tensor self, int dim=-1) -> Tensor @@ -11703,6 +11988,7 @@ dispatch: CPU: glu_backward_cpu_out CUDA: glu_backward_cuda_out + PrivateUse1: glu_backward_zoom_out MPS: glu_backward_mps_out - func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor @@ -11710,18 +11996,19 @@ dispatch: CPU: glu_backward_cpu CUDA: glu_backward_cuda + PrivateUse1: glu_backward_zoom MPS: glu_backward_mps - func: glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor python_module: nn dispatch: - CPU, CUDA: glu_jvp + CPU, CUDA, PrivateUse1: glu_jvp autogen: glu_jvp.out - func: glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor python_module: nn dispatch: - CPU, CUDA: glu_backward_jvp + CPU, CUDA, PrivateUse1: glu_backward_jvp autogen: glu_backward_jvp.out - func: hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -11730,7 +12017,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardsigmoid_out + CPU, CUDA, PrivateUse1: hardsigmoid_out MPS: hardsigmoid_out_mps QuantizedCPU: hardsigmoid_out_quantized_cpu @@ -11751,7 +12038,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: hardsigmoid_backward_out + CPU, CUDA, PrivateUse1: hardsigmoid_backward_out MPS: hardsigmoid_backward_out_mps - func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor @@ -11762,61 +12049,61 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA, MPS: hardtanh_out + CPU, CUDA, PrivateUse1, MPS: hardtanh_out QuantizedCPU: hardtanh_out_quantized_cpu - func: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA, MPS: hardtanh + CPU, CUDA, PrivateUse1, MPS: hardtanh QuantizedCPU: hardtanh_quantized_cpu tags: core - func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn dispatch: - CPU, CUDA: hardtanh_backward_out + CPU, CUDA, PrivateUse1: hardtanh_backward_out MPS: hardtanh_backward_out_mps - func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor python_module: nn dispatch: - CPU, CUDA: hardtanh_backward + CPU, CUDA, PrivateUse1: hardtanh_backward MPS: hardtanh_backward_mps - func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA, MPS: hardtanh_ + CPU, CUDA, PrivateUse1, MPS: hardtanh_ QuantizedCPU: hardtanh_quantized_cpu_ - func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish_out + CPU, CUDA, PrivateUse1: hardswish_out MPS: hardswish_out_mps - func: hardswish(Tensor self) -> Tensor device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish + CPU, CUDA, PrivateUse1: hardswish MPS: hardswish_mps - func: hardswish_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: hardswish_ + CPU, CUDA, PrivateUse1: hardswish_ MPS: hardswish_mps_ - func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor python_module: nn dispatch: - CPU, CUDA: hardswish_backward + CPU, CUDA, PrivateUse1: hardswish_backward MPS: hardswish_backward_mps autogen: hardswish_backward.out @@ -11826,7 +12113,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: leaky_relu_out + CPU, CUDA, PrivateUse1: leaky_relu_out MPS: leaky_relu_out_mps QuantizedCPU: leaky_relu_out_quantized_cpu @@ -11843,7 +12130,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: leaky_relu_backward_out + CPU, CUDA, PrivateUse1: leaky_relu_backward_out MPS: leaky_relu_backward_out_mps - func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor @@ -11871,6 +12158,7 @@ dispatch: CPU: log_sigmoid_forward_out_cpu CUDA: log_sigmoid_forward_out_cuda + PrivateUse1: log_sigmoid_forward_out_zoom MPS: log_sigmoid_forward_out_mps - func: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) @@ -11879,6 +12167,7 @@ dispatch: CPU: log_sigmoid_forward_cpu CUDA: log_sigmoid_forward_cuda + PrivateUse1: log_sigmoid_forward_zoom MPS: log_sigmoid_forward_mps - func: log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -11886,6 +12175,7 @@ dispatch: CPU: log_sigmoid_backward_cpu_out CUDA: log_sigmoid_backward_cuda_out + PrivateUse1: log_sigmoid_backward_zoom_out MPS: log_sigmoid_backward_mps_out - func: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor @@ -11893,6 +12183,7 @@ dispatch: CPU: log_sigmoid_backward_cpu CUDA: log_sigmoid_backward_cuda + PrivateUse1: log_sigmoid_backward_zoom MPS: log_sigmoid_backward_mps - func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) @@ -11901,12 +12192,14 @@ dispatch: CPU: rrelu_with_noise_out_cpu CUDA: rrelu_with_noise_out_cuda + PrivateUse1: rrelu_with_noise_out_zoom - func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor python_module: nn dispatch: CPU: rrelu_with_noise_cpu CUDA: rrelu_with_noise_cuda + PrivateUse1: rrelu_with_noise_zoom tags: nondeterministic_seeded - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor @@ -11921,6 +12214,7 @@ dispatch: CPU: rrelu_with_noise_cpu_ CUDA: rrelu_with_noise_cuda_ + PrivateUse1: rrelu_with_noise_zoom_ - func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -11928,7 +12222,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: softplus_out + CPU, CUDA, PrivateUse1: softplus_out MPS: softplus_out_mps - func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor @@ -11941,7 +12235,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: softplus_backward_out + CPU, CUDA, PrivateUse1: softplus_backward_out MPS: softplus_backward_out_mps - func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor @@ -11967,7 +12261,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: softshrink_backward_out + CPU, CUDA, PrivateUse1: softshrink_backward_out MPS: softshrink_backward_out_mps - func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor @@ -11979,6 +12273,7 @@ dispatch: CPU: adaptive_avg_pool2d_out_cpu CUDA: adaptive_avg_pool2d_out_cuda + PrivateUse1: adaptive_avg_pool2d_out_zoom MPS: adaptive_avg_pool2d_out_mps MkldnnCPU: mkldnn_adaptive_avg_pool2d_out_stub @@ -12004,6 +12299,7 @@ dispatch: CPU: adaptive_avg_pool2d_cpu CUDA: adaptive_avg_pool2d_cuda + PrivateUse1: adaptive_avg_pool2d_zoom MPS: adaptive_avg_pool2d_mps QuantizedCPU: adaptive_avg_pool2d_quantized_cpu QuantizedCUDA: adaptive_avg_pool2d_quantized_cuda @@ -12015,6 +12311,7 @@ dispatch: CPU: adaptive_avg_pool2d_backward_cpu CUDA: adaptive_avg_pool2d_backward_cuda + PrivateUse1: adaptive_avg_pool2d_backward_zoom MPS: adaptive_avg_pool2d_backward_mps autogen: _adaptive_avg_pool2d_backward.out tags: core @@ -12024,6 +12321,7 @@ dispatch: CPU: adaptive_avg_pool3d_out_cpu CUDA: adaptive_avg_pool3d_out_cuda + PrivateUse1: adaptive_avg_pool3d_out_zoom QuantizedCPU: adaptive_avg_pool3d_out_quantized_cpu - func: adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor @@ -12035,6 +12333,7 @@ dispatch: CPU: adaptive_avg_pool3d_cpu CUDA: adaptive_avg_pool3d_cuda + PrivateUse1: adaptive_avg_pool3d_zoom QuantizedCPU: adaptive_avg_pool3d_quantized_cpu autogen: _adaptive_avg_pool3d.out tags: core @@ -12044,12 +12343,14 @@ dispatch: CPU: adaptive_avg_pool3d_backward_out_cpu CUDA: adaptive_avg_pool3d_backward_out_cuda + PrivateUse1: adaptive_avg_pool3d_backward_out_zoom - func: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor python_module: nn dispatch: CPU: adaptive_avg_pool3d_backward_cpu CUDA: adaptive_avg_pool3d_backward_cuda + PrivateUse1: adaptive_avg_pool3d_backward_zoom autogen: _adaptive_avg_pool3d_backward.out # Return: (Tensor output, Tensor indices) @@ -12059,6 +12360,7 @@ dispatch: CPU: adaptive_max_pool2d_out_cpu CUDA: adaptive_max_pool2d_out_cuda + PrivateUse1: adaptive_max_pool2d_out_zoom MPS: adaptive_max_pool2d_out_mps # Return: (Tensor output, Tensor indices) @@ -12072,6 +12374,7 @@ dispatch: CPU: adaptive_max_pool2d_backward_out_cpu CUDA: adaptive_max_pool2d_backward_out_cuda + PrivateUse1: adaptive_max_pool2d_backward_out_zoom MPS: adaptive_max_pool2d_backward_out_mps - func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor @@ -12085,6 +12388,7 @@ dispatch: CPU: adaptive_max_pool3d_out_cpu CUDA: adaptive_max_pool3d_out_cuda + PrivateUse1: adaptive_max_pool3d_out_zoom # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) @@ -12097,6 +12401,7 @@ dispatch: CPU: adaptive_max_pool3d_backward_out_cpu CUDA: adaptive_max_pool3d_backward_out_cuda + PrivateUse1: adaptive_max_pool3d_backward_out_zoom - func: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor python_module: nn @@ -12112,6 +12417,7 @@ dispatch: CPU: avg_pool2d_out_cpu CUDA: avg_pool2d_out_cuda + PrivateUse1: avg_pool2d_out_zoom MPS: avg_pool2d_out_mps MkldnnCPU: mkldnn_avg_pool2d_out @@ -12129,6 +12435,7 @@ dispatch: CPU: avg_pool2d_backward_out_cpu CUDA: avg_pool2d_backward_out_cuda + PrivateUse1: avg_pool2d_backward_out_zoom MPS: avg_pool2d_backward_out_mps MkldnnCPU: mkldnn_avg_pool2d_backward_out @@ -12145,6 +12452,7 @@ dispatch: CPU: avg_pool3d_out_cpu CUDA: avg_pool3d_out_cuda + PrivateUse1: avg_pool3d_out_zoom MkldnnCPU: mkldnn_avg_pool3d_out - func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor @@ -12161,6 +12469,7 @@ dispatch: CPU: avg_pool3d_backward_out_cpu CUDA: avg_pool3d_backward_out_cuda + PrivateUse1: avg_pool3d_backward_out_zoom MkldnnCPU: mkldnn_avg_pool3d_backward_out - func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor @@ -12176,6 +12485,7 @@ dispatch: CPU: fractional_max_pool2d_out_cpu CUDA: fractional_max_pool2d_out_cuda + PrivateUse1: fractional_max_pool2d_out_zoom # Return: (Tensor output, Tensor indices) - func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) @@ -12188,6 +12498,7 @@ dispatch: CPU: fractional_max_pool2d_backward_cpu CUDA: fractional_max_pool2d_backward_cuda + PrivateUse1: fractional_max_pool2d_backward_zoom - func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor python_module: nn @@ -12204,6 +12515,7 @@ dispatch: CPU: fractional_max_pool3d_out_cpu CUDA: fractional_max_pool3d_out_cuda + PrivateUse1: fractional_max_pool3d_out_zoom # Return: (Tensor output, Tensor indices) - func: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) @@ -12215,12 +12527,14 @@ dispatch: CPU: fractional_max_pool3d_backward_out_cpu CUDA: fractional_max_pool3d_backward_out_cuda + PrivateUse1: fractional_max_pool3d_backward_out_zoom - func: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor python_module: nn dispatch: CPU: fractional_max_pool3d_backward_cpu CUDA: fractional_max_pool3d_backward_cuda + PrivateUse1: fractional_max_pool3d_backward_zoom # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) @@ -12229,6 +12543,7 @@ dispatch: CPU: max_pool2d_with_indices_out_cpu CUDA: max_pool2d_with_indices_out_cuda + PrivateUse1: max_pool2d_with_indices_out_zoom MPS: max_pool2d_with_indices_out_mps # Return: (Tensor output, Tensor indices) @@ -12243,6 +12558,7 @@ dispatch: CPU: max_pool2d_with_indices_backward_out_cpu CUDA: max_pool2d_with_indices_backward_out_cuda + PrivateUse1: max_pool2d_with_indices_backward_out_zoom MPS: max_pool2d_with_indices_backward_out_mps - func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor @@ -12256,6 +12572,7 @@ dispatch: CPU: max_pool3d_with_indices_out_cpu CUDA: max_pool3d_with_indices_out_cuda + PrivateUse1: max_pool3d_with_indices_out_zoom # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) @@ -12263,6 +12580,7 @@ dispatch: CPU: max_pool3d_with_indices_cpu CUDA: max_pool3d_with_indices_cuda + PrivateUse1: max_pool3d_with_indices_zoom tags: core - func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -12270,36 +12588,42 @@ dispatch: CPU: max_pool3d_with_indices_backward_out_cpu CUDA: max_pool3d_with_indices_backward_out_cuda + PrivateUse1: max_pool3d_with_indices_backward_out_zoom - func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor python_module: nn dispatch: CPU: max_pool3d_with_indices_backward_cpu CUDA: max_pool3d_with_indices_backward_cuda + PrivateUse1: max_pool3d_with_indices_backward_zoom - func: max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: max_unpooling2d_forward_out_cpu CUDA: max_unpooling2d_forward_out_cuda + PrivateUse1: max_unpooling2d_forward_out_zoom - func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor python_module: nn dispatch: CPU: max_unpooling2d_forward_cpu CUDA: max_unpooling2d_forward_cuda + PrivateUse1: max_unpooling2d_forward_zoom - func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: max_unpooling3d_forward_out_cpu CUDA: max_unpooling3d_forward_out_cuda + PrivateUse1: max_unpooling3d_forward_out_zoom - func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor python_module: nn dispatch: CPU: max_unpooling3d_forward_cpu CUDA: max_unpooling3d_forward_cuda + PrivateUse1: max_unpooling3d_forward_zoom - func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -12308,6 +12632,7 @@ CPU: reflection_pad1d_out_cpu QuantizedCPU: reflection_pad1d_out_quantized_cpu CUDA: reflection_pad1d_out_cuda + PrivateUse1: reflection_pad1d_out_zoom MPS: reflection_pad1d_out_mps - func: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor @@ -12321,6 +12646,7 @@ dispatch: CPU: reflection_pad1d_backward_out_cpu CUDA: reflection_pad1d_backward_out_cuda + PrivateUse1: reflection_pad1d_backward_out_zoom MPS: reflection_pad1d_backward_out_mps - func: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor @@ -12332,6 +12658,7 @@ dispatch: CPU, QuantizedCPU: reflection_pad2d_out_cpu CUDA: reflection_pad2d_out_cuda + PrivateUse1: reflection_pad2d_out_zoom MPS: reflection_pad2d_out_mps - func: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor @@ -12340,6 +12667,7 @@ CPU: reflection_pad2d_cpu QuantizedCPU: reflection_pad2d_quantized_cpu CUDA: reflection_pad2d_cuda + PrivateUse1: reflection_pad2d_zoom MPS: reflection_pad2d_mps tags: core @@ -12348,6 +12676,7 @@ dispatch: CPU: reflection_pad2d_backward_out_cpu CUDA: reflection_pad2d_backward_out_cuda + PrivateUse1: reflection_pad2d_backward_out_zoom MPS: reflection_pad2d_backward_out_mps - func: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor @@ -12355,6 +12684,7 @@ dispatch: CPU: reflection_pad2d_backward_cpu CUDA: reflection_pad2d_backward_cuda + PrivateUse1: reflection_pad2d_backward_zoom MPS: reflection_pad2d_backward_mps - func: reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) @@ -12363,6 +12693,7 @@ dispatch: CPU: reflection_pad3d_out_cpu CUDA: reflection_pad3d_out_cuda + PrivateUse1: reflection_pad3d_out_zoom MPS: reflection_pad3d_out_mps - func: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor @@ -12376,6 +12707,7 @@ dispatch: CPU: reflection_pad3d_backward_out_cpu CUDA: reflection_pad3d_backward_out_cuda + PrivateUse1: reflection_pad3d_backward_out_zoom MPS: reflection_pad3d_backward_out_mps - func: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor @@ -12388,6 +12720,7 @@ dispatch: CPU: replication_pad1d_out_cpu CUDA: replication_pad1d_out_cuda + PrivateUse1: replication_pad1d_out_zoom MPS: replication_pad1d_out_mps - func: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor @@ -12400,6 +12733,7 @@ dispatch: CPU: replication_pad1d_backward_out_cpu CUDA: replication_pad1d_backward_out_cuda + PrivateUse1: replication_pad1d_backward_out_zoom MPS: replication_pad1d_backward_out_mps - func: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor @@ -12412,6 +12746,7 @@ dispatch: CPU: replication_pad2d_out_cpu CUDA: replication_pad2d_out_cuda + PrivateUse1: replication_pad2d_out_zoom MPS: replication_pad2d_out_mps - func: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor @@ -12424,6 +12759,7 @@ dispatch: CPU: replication_pad2d_backward_out_cpu CUDA: replication_pad2d_backward_out_cuda + PrivateUse1: replication_pad2d_backward_out_zoom MPS: replication_pad2d_backward_out_mps - func: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor @@ -12431,6 +12767,7 @@ dispatch: CPU: replication_pad2d_backward_cpu CUDA: replication_pad2d_backward_cuda + PrivateUse1: replication_pad2d_backward_zoom MPS: replication_pad2d_backward_mps - func: replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) @@ -12439,6 +12776,7 @@ dispatch: CPU: replication_pad3d_out_cpu CUDA: replication_pad3d_out_cuda + PrivateUse1: replication_pad3d_out_zoom MPS: replication_pad3d_out_mps - func: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor @@ -12452,6 +12790,7 @@ dispatch: CPU: replication_pad3d_backward_out_cpu CUDA: replication_pad3d_backward_out_cuda + PrivateUse1: replication_pad3d_backward_out_zoom MPS: replication_pad3d_backward_out_mps - func: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor @@ -12459,6 +12798,7 @@ dispatch: CPU: replication_pad3d_backward_cpu CUDA: replication_pad3d_backward_cuda + PrivateUse1: replication_pad3d_backward_zoom MPS: replication_pad3d_backward_mps - func: _pad_circular(Tensor self, SymInt[] pad) -> Tensor @@ -12533,6 +12873,7 @@ dispatch: CPU: upsample_linear1d_out_cpu CUDA: upsample_linear1d_out_cuda + PrivateUse1: upsample_linear1d_out_zoom MPS: upsample_linear1d_out_mps - func: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor @@ -12545,6 +12886,7 @@ dispatch: CPU: upsample_linear1d_backward_out_cpu CUDA: upsample_linear1d_backward_out_cuda + PrivateUse1: upsample_linear1d_backward_out_zoom MPS: upsample_linear1d_backward_out_mps - func: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor @@ -12557,6 +12899,7 @@ dispatch: CPU: upsample_bilinear2d_out_cpu CUDA: upsample_bilinear2d_out_cuda + PrivateUse1: upsample_bilinear2d_out_zoom MPS: upsample_bilinear2d_out_mps - func: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor @@ -12571,6 +12914,7 @@ dispatch: CPU: upsample_bilinear2d_backward_out_cpu CUDA: upsample_bilinear2d_backward_out_cuda + PrivateUse1: upsample_bilinear2d_backward_out_zoom MPS: upsample_bilinear2d_backward_out_mps - func: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor @@ -12583,6 +12927,7 @@ dispatch: CPU: _upsample_bilinear2d_aa_out_cpu CUDA: _upsample_bilinear2d_aa_out_cuda + PrivateUse1: _upsample_bilinear2d_aa_out_zoom - func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12594,6 +12939,7 @@ dispatch: CPU: _upsample_bilinear2d_aa_backward_out_cpu CUDA: _upsample_bilinear2d_aa_backward_out_cuda + PrivateUse1: _upsample_bilinear2d_aa_backward_out_zoom - func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12605,6 +12951,7 @@ dispatch: CPU: upsample_bicubic2d_out_cpu CUDA: upsample_bicubic2d_out_cuda + PrivateUse1: upsample_bicubic2d_out_zoom - func: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12616,6 +12963,7 @@ dispatch: CPU: upsample_bicubic2d_backward_out_cpu CUDA: upsample_bicubic2d_backward_out_cuda + PrivateUse1: upsample_bicubic2d_backward_out_zoom - func: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12627,6 +12975,7 @@ dispatch: CPU: _upsample_bicubic2d_aa_out_cpu CUDA: _upsample_bicubic2d_aa_out_cuda + PrivateUse1: _upsample_bicubic2d_aa_out_zoom - func: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12638,6 +12987,7 @@ dispatch: CPU: _upsample_bicubic2d_aa_backward_out_cpu CUDA: _upsample_bicubic2d_aa_backward_out_cuda + PrivateUse1: _upsample_bicubic2d_aa_backward_out_zoom - func: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12649,6 +12999,7 @@ dispatch: CPU: upsample_trilinear3d_out_cpu CUDA: upsample_trilinear3d_out_cuda + PrivateUse1: upsample_trilinear3d_out_zoom - func: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12660,6 +13011,7 @@ dispatch: CPU: upsample_trilinear3d_backward_out_cpu CUDA: upsample_trilinear3d_backward_out_cuda + PrivateUse1: upsample_trilinear3d_backward_out_zoom - func: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12671,6 +13023,7 @@ dispatch: CPU: upsample_nearest1d_out_cpu CUDA: upsample_nearest1d_out_cuda + PrivateUse1: upsample_nearest1d_out_zoom MPS: upsample_nearest1d_out_mps - func: _upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) @@ -12679,6 +13032,7 @@ dispatch: CPU: _upsample_nearest_exact1d_out_cpu CUDA: _upsample_nearest_exact1d_out_cuda + PrivateUse1: _upsample_nearest_exact1d_out_zoom MPS: _upsample_nearest_exact1d_out_mps - func: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor @@ -12695,6 +13049,7 @@ dispatch: CPU: upsample_nearest1d_backward_out_cpu CUDA: upsample_nearest1d_backward_out_cuda + PrivateUse1: upsample_nearest1d_backward_out_zoom MPS: upsample_nearest1d_backward_out_mps - func: _upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -12703,6 +13058,7 @@ dispatch: CPU: _upsample_nearest_exact1d_backward_out_cpu CUDA: _upsample_nearest_exact1d_backward_out_cuda + PrivateUse1: _upsample_nearest_exact1d_backward_out_zoom MPS: _upsample_nearest_exact1d_backward_out_mps - func: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor @@ -12719,6 +13075,7 @@ dispatch: CPU: upsample_nearest2d_out_cpu CUDA: upsample_nearest2d_out_cuda + PrivateUse1: upsample_nearest2d_out_zoom MPS: upsample_nearest2d_out_mps - func: _upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) @@ -12727,6 +13084,7 @@ dispatch: CPU: _upsample_nearest_exact2d_out_cpu CUDA: _upsample_nearest_exact2d_out_cuda + PrivateUse1: _upsample_nearest_exact2d_out_zoom MPS: _upsample_nearest_exact2d_out_mps - func: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor @@ -12747,6 +13105,7 @@ dispatch: CPU: upsample_nearest2d_backward_out_cpu CUDA: upsample_nearest2d_backward_out_cuda + PrivateUse1: upsample_nearest2d_backward_out_zoom MPS: upsample_nearest2d_backward_out_mps - func: _upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -12755,6 +13114,7 @@ dispatch: CPU: _upsample_nearest_exact2d_backward_out_cpu CUDA: _upsample_nearest_exact2d_backward_out_cuda + PrivateUse1: _upsample_nearest_exact2d_backward_out_zoom MPS: _upsample_nearest_exact2d_backward_out_mps - func: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor @@ -12771,6 +13131,7 @@ dispatch: CPU: upsample_nearest3d_out_cpu CUDA: upsample_nearest3d_out_cuda + PrivateUse1: upsample_nearest3d_out_zoom - func: _upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -12778,6 +13139,7 @@ dispatch: CPU: _upsample_nearest_exact3d_out_cpu CUDA: _upsample_nearest_exact3d_out_cuda + PrivateUse1: _upsample_nearest_exact3d_out_zoom - func: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12797,6 +13159,7 @@ dispatch: CPU: upsample_nearest3d_backward_out_cpu CUDA: upsample_nearest3d_backward_out_cuda + PrivateUse1: upsample_nearest3d_backward_out_zoom - func: _upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) python_module: nn @@ -12804,6 +13167,7 @@ dispatch: CPU: _upsample_nearest_exact3d_backward_out_cpu CUDA: _upsample_nearest_exact3d_backward_out_cuda + PrivateUse1: _upsample_nearest_exact3d_backward_out_zoom - func: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn @@ -12818,7 +13182,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: sigmoid_backward_out + CPU, CUDA, PrivateUse1: sigmoid_backward_out MPS: sigmoid_backward_out_mps tags: pointwise @@ -12832,7 +13196,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: logit_backward_out + CPU, CUDA, PrivateUse1: logit_backward_out MPS: logit_backward_out_mps tags: pointwise @@ -12846,7 +13210,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: tanh_backward_out + CPU, CUDA, PrivateUse1: tanh_backward_out MPS: tanh_backward_out_mps tags: pointwise @@ -12879,6 +13243,7 @@ dispatch: CPU: slow_conv_transpose2d_structured_cpu CUDA: slow_conv_transpose2d_structured_cuda + PrivateUse1: slow_conv_transpose2d_structured_zoom - func: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor python_module: nn @@ -12889,12 +13254,14 @@ dispatch: CPU: slow_conv_transpose3d_out_cpu CUDA: slow_conv_transpose3d_out_cuda + PrivateUse1: slow_conv_transpose3d_out_zoom - func: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor python_module: nn dispatch: CPU: slow_conv_transpose3d_cpu CUDA: slow_conv_transpose3d_cuda + PrivateUse1: slow_conv_transpose3d_zoom - func: thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) python_module: nn @@ -12907,41 +13274,48 @@ dispatch: CPU: slow_conv2d_forward_out_cpu CUDA: slow_conv2d_forward_out_cuda + PrivateUse1: slow_conv2d_forward_out_zoom - func: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor python_module: nn dispatch: CPU: slow_conv2d_forward_cpu CUDA: slow_conv2d_forward_cuda + PrivateUse1: slow_conv2d_forward_zoom - func: _slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) python_module: nn dispatch: CPU: slow_conv2d_backward_out_cpu CUDA: slow_conv2d_backward_out_cuda + PrivateUse1: slow_conv2d_backward_out_zoom - func: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) python_module: nn dispatch: CPU: slow_conv2d_backward_cpu CUDA: slow_conv2d_backward_cuda + PrivateUse1: slow_conv2d_backward_zoom autogen: _slow_conv2d_backward.output_mask_out - func: _conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) use_const_ref_for_mutable_tensors: True python_module: nn dispatch: - CUDA: conv_depthwise2d_cuda_out + CUDA: conv_depthwise2d_cuda + PrivateUse1: conv_depthwise2d_zoom_out - func: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor python_module: nn dispatch: CUDA: conv_depthwise2d_cuda + PrivateUse1: conv_depthwise2d_zoom - func: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor python_module: nn dispatch: CUDA: conv_depthwise3d_cuda + PrivateUse1: conv_depthwise3d_zoom autogen: conv_depthwise3d.out - func: slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) @@ -12965,6 +13339,7 @@ dispatch: CPU: slow_conv_dilated2d_cpu CUDA: slow_conv_dilated2d_cuda + PrivateUse1: slow_conv_dilated2d_zoom autogen: slow_conv_dilated2d.out - func: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor @@ -12972,6 +13347,7 @@ dispatch: CPU: slow_conv_dilated3d_cpu CUDA: slow_conv_dilated3d_cuda + PrivateUse1: slow_conv_dilated3d_zoom autogen: slow_conv_dilated3d.out - func: col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) @@ -12979,12 +13355,14 @@ dispatch: CPU: col2im_out_cpu CUDA: col2im_out_cuda + PrivateUse1: col2im_out_zoom - func: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor python_module: nn dispatch: CPU: col2im_cpu CUDA: col2im_cuda + PrivateUse1: col2im_zoom tags: core - func: column_stack(Tensor[] tensors) -> Tensor @@ -12996,12 +13374,14 @@ dispatch: CPU: im2col_out_cpu CUDA: im2col_out_cuda + PrivateUse1: im2col_out_zoom - func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor python_module: nn dispatch: CPU: im2col_cpu CUDA: im2col_cuda + PrivateUse1: im2col_zoom - func: isfinite(Tensor self) -> Tensor variants: function, method @@ -13024,6 +13404,7 @@ variants: method dispatch: CUDA: record_stream_cuda + PrivateUse1: record_stream_zoom - func: isposinf(Tensor self) -> Tensor variants: function, method @@ -13037,7 +13418,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: isposinf_out + CPU, CUDA, PrivateUse1: isposinf_out SparseCPU, SparseCUDA: isposinf_sparse_out SparseCsrCPU, SparseCsrCUDA: isposinf_sparse_csr_out tags: pointwise @@ -13054,7 +13435,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: isneginf_out + CPU, CUDA, PrivateUse1: isneginf_out SparseCPU, SparseCUDA: isneginf_sparse_out SparseCsrCPU, SparseCsrCUDA: isneginf_sparse_csr_out tags: pointwise @@ -13089,7 +13470,7 @@ python_module: special variants: function dispatch: - CPU, CUDA: special_entr_out + CPU, CUDA, PrivateUse1: special_entr_out tags: pointwise - func: special_ndtri(Tensor self) -> Tensor @@ -13104,7 +13485,7 @@ python_module: special variants: function dispatch: - CPU, CUDA: special_ndtri_out + CPU, CUDA, PrivateUse1: special_ndtri_out tags: pointwise - func: special_log_ndtr(Tensor self) -> Tensor @@ -13119,7 +13500,7 @@ python_module: special variants: function dispatch: - CPU, CUDA: special_log_ndtr_out + CPU, CUDA, PrivateUse1: special_log_ndtr_out tags: pointwise - func: special_expm1(Tensor self) -> Tensor @@ -13188,7 +13569,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: special_erfcx_out + CPU, CUDA, PrivateUse1: special_erfcx_out tags: pointwise - func: special_erfinv(Tensor self) -> Tensor @@ -13236,7 +13617,7 @@ python_module: special variants: function dispatch: - CPU, CUDA: special_xlog1py_out + CPU, CUDA, PrivateUse1: special_xlog1py_out tags: pointwise - func: special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -13315,7 +13696,7 @@ python_module: special variants: function dispatch: - CPU, CUDA: special_zeta_out + CPU, CUDA, PrivateUse1: special_zeta_out tags: pointwise - func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -13353,7 +13734,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: special_i0e_out + CPU, CUDA, PrivateUse1: special_i0e_out tags: pointwise - func: special_i1(Tensor self) -> Tensor @@ -13367,7 +13748,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: special_i1_out + CPU, CUDA, PrivateUse1: special_i1_out tags: pointwise - func: special_i1e(Tensor self) -> Tensor @@ -13381,7 +13762,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: special_i1e_out + CPU, CUDA, PrivateUse1: special_i1e_out tags: pointwise - func: special_logit(Tensor self, float? eps=None) -> Tensor @@ -13773,7 +14154,7 @@ python_module: linalg structured: True dispatch: - CPU, CUDA, MPS: linalg_cross_out + CPU, CUDA, PrivateUse1, MPS: linalg_cross_out # linalg.lu_factor - func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) @@ -14053,7 +14434,7 @@ python_module: linalg structured: True dispatch: - CPU, CUDA: linalg_vector_norm_out + CPU, CUDA, PrivateUse1: linalg_vector_norm_out MPS: linalg_vector_norm_out_mps - func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor @@ -14335,13 +14716,13 @@ - func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor variants: function dispatch: - CPU, CUDA: segment_reduce_kernel + CPU, CUDA, PrivateUse1: segment_reduce_kernel autogen: segment_reduce.out - func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor variants: function dispatch: - CPU, CUDA: _segment_reduce_backward_kernel + CPU, CUDA, PrivateUse1: _segment_reduce_backward_kernel autogen: _segment_reduce_backward.out - func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor @@ -14772,7 +15153,7 @@ - func: special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_airy_ai_out + CPU, CUDA, PrivateUse1: special_airy_ai_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -14787,7 +15168,7 @@ - func: special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_bessel_j0_out + CPU, CUDA, PrivateUse1: special_bessel_j0_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -14802,7 +15183,7 @@ - func: special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_bessel_j1_out + CPU, CUDA, PrivateUse1: special_bessel_j1_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -14817,7 +15198,7 @@ - func: special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_bessel_y0_out + CPU, CUDA, PrivateUse1: special_bessel_y0_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -14832,7 +15213,7 @@ - func: special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_bessel_y1_out + CPU, CUDA, PrivateUse1: special_bessel_y1_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -14865,7 +15246,7 @@ - func: special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_chebyshev_polynomial_t_out + CPU, CUDA, PrivateUse1: special_chebyshev_polynomial_t_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -14914,7 +15295,7 @@ - func: special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_chebyshev_polynomial_u_out + CPU, CUDA, PrivateUse1: special_chebyshev_polynomial_u_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -14963,7 +15344,7 @@ - func: special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_chebyshev_polynomial_v_out + CPU, CUDA, PrivateUse1: special_chebyshev_polynomial_v_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15012,7 +15393,7 @@ - func: special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_chebyshev_polynomial_w_out + CPU, CUDA, PrivateUse1: special_chebyshev_polynomial_w_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15061,7 +15442,7 @@ - func: special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_hermite_polynomial_h_out + CPU, CUDA, PrivateUse1: special_hermite_polynomial_h_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15110,7 +15491,7 @@ - func: special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_hermite_polynomial_he_out + CPU, CUDA, PrivateUse1: special_hermite_polynomial_he_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15159,7 +15540,7 @@ - func: special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_laguerre_polynomial_l_out + CPU, CUDA, PrivateUse1: special_laguerre_polynomial_l_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15208,7 +15589,7 @@ - func: special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_legendre_polynomial_p_out + CPU, CUDA, PrivateUse1: special_legendre_polynomial_p_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15239,7 +15620,7 @@ - func: special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_modified_bessel_i0_out + CPU, CUDA, PrivateUse1: special_modified_bessel_i0_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15254,7 +15635,7 @@ - func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_modified_bessel_i1_out + CPU, CUDA, PrivateUse1: special_modified_bessel_i1_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15269,7 +15650,7 @@ - func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_modified_bessel_k0_out + CPU, CUDA, PrivateUse1: special_modified_bessel_k0_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15284,7 +15665,7 @@ - func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_modified_bessel_k1_out + CPU, CUDA, PrivateUse1: special_modified_bessel_k1_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15299,7 +15680,7 @@ - func: special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_scaled_modified_bessel_k0_out + CPU, CUDA, PrivateUse1: special_scaled_modified_bessel_k0_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15314,7 +15695,7 @@ - func: special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_scaled_modified_bessel_k1_out + CPU, CUDA, PrivateUse1: special_scaled_modified_bessel_k1_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15347,7 +15728,7 @@ - func: special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_t_out + CPU, CUDA, PrivateUse1: special_shifted_chebyshev_polynomial_t_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15396,7 +15777,7 @@ - func: special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_u_out + CPU, CUDA, PrivateUse1: special_shifted_chebyshev_polynomial_u_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15445,7 +15826,7 @@ - func: special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_v_out + CPU, CUDA, PrivateUse1: special_shifted_chebyshev_polynomial_v_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15494,7 +15875,7 @@ - func: special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_shifted_chebyshev_polynomial_w_out + CPU, CUDA, PrivateUse1: special_shifted_chebyshev_polynomial_w_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15525,7 +15906,7 @@ - func: special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: special_spherical_bessel_j0_out + CPU, CUDA, PrivateUse1: special_spherical_bessel_j0_out python_module: special structured_inherits: TensorIteratorBase structured: True @@ -15544,7 +15925,8 @@ variants: function dispatch: CPU: _fused_adam_kernel_cpu_ - CUDA: _fused_adam_kernel_cuda_ + CUDA: _fused_adam_kernel_cuda + PrivateUse1: _fused_adam_kernel_zoom_ autogen: _fused_adam, _fused_adam.out - func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15554,7 +15936,8 @@ variants: function dispatch: CPU: _fused_adam_kernel_cpu_ - CUDA: _fused_adam_kernel_cuda_ + CUDA: _fused_adam_kernel_cuda + PrivateUse1: _fused_adam_kernel_zoom_ autogen: _fused_adam.tensor_lr, _fused_adam.tensor_lr_out - func: _fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15562,7 +15945,8 @@ variants: function dispatch: CPU: _fused_adamw_kernel_cpu_ - CUDA: _fused_adamw_kernel_cuda_ + CUDA: _fused_adamw_kernel_cuda + PrivateUse1: _fused_adamw_kernel_zoom_ autogen: _fused_adamw, _fused_adamw.out - func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15572,7 +15956,8 @@ variants: function dispatch: CPU: _fused_adamw_kernel_cpu_ - CUDA: _fused_adamw_kernel_cuda_ + CUDA: _fused_adamw_kernel_cuda + PrivateUse1: _fused_adamw_kernel_zoom_ autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out - func: _fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15580,7 +15965,8 @@ variants: function dispatch: CPU: _fused_sgd_kernel_cpu_ - CUDA: _fused_sgd_kernel_cuda_ + CUDA: _fused_sgd_kernel_cuda + PrivateUse1: _fused_sgd_kernel_zoom_ autogen: _fused_sgd, _fused_sgd.out - func: _fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () @@ -15590,7 +15976,8 @@ variants: function dispatch: CPU: _fused_sgd_kernel_cpu_ - CUDA: _fused_sgd_kernel_cuda_ + CUDA: _fused_sgd_kernel_cuda + PrivateUse1: _fused_sgd_kernel_zoom_ autogen: _fused_sgd.tensor_lr, _fused_sgd.tensor_lr_out - func: _fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () diff --git a/aten/src/ATen/native/zoom/AbsKernel.cu b/aten/src/ATen/native/zoom/AbsKernel.cu new file mode 100644 index 0000000000000..dd6dc56f646bf --- /dev/null +++ b/aten/src/ATen/native/zoom/AbsKernel.cu @@ -0,0 +1,42 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + +namespace at::native { + + +CONSTEXPR_EXCEPT_WIN_CUDA constexpr char abs_name[] = "abs_kernel"; +void abs_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.dtype(); + static const auto abs_string = jiterator_stringify( + template T abs_kernel(T x) { return std::abs(x); }); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "abs_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/abs_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, abs_string); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, + ScalarType::BFloat16, + ScalarType::Bool, + iter.dtype(), + "abs_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/abs_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, abs_string); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Activation.cpp b/aten/src/ATen/native/zoom/Activation.cpp new file mode 100644 index 0000000000000..039585b1e7160 --- /dev/null +++ b/aten/src/ATen/native/zoom/Activation.cpp @@ -0,0 +1,108 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +// ----------------------------------- +// glu backward +// ----------------------------------- + +Tensor& glu_backward_zoom_out(const Tensor& grad_output, const Tensor& input, + int64_t dim, Tensor& grad_input) { + TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors"); + auto wrap_dim = maybe_wrap_dim(dim, input.dim()); + auto input_sizes = input.sizes(); + const int64_t nIn = input_sizes[wrap_dim]; + TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ", + wrap_dim, " is size ", nIn); + + resize_output(grad_input, input_sizes); + + DimVector iter_shape(input_sizes); + const auto dim_size = nIn / 2; + iter_shape[wrap_dim] = dim_size; + TORCH_CHECK(grad_output.sizes() == IntArrayRef{iter_shape}); + + const auto iter = at::TensorIteratorConfig() + .add_output(grad_input) + .add_const_input(input) + .add_const_input(grad_output) + .resize_outputs(false) + .declare_static_shape(iter_shape) + .build(); + + if (iter.numel() == 0) { + return grad_input; + } + + const auto I_stride = input.strides()[wrap_dim] * dim_size; + const auto gI_stride = grad_input.strides()[wrap_dim] * dim_size; + + if (iter.can_use_32bit_indexing()) { + launch_glu_backward_kernel(iter, gI_stride, I_stride); + } else { + for (const auto& sub_iter: iter.with_32bit_indexing()) { + launch_glu_backward_kernel(sub_iter, gI_stride, I_stride); + } + } + return grad_input; +} + +Tensor glu_backward_zoom(const Tensor& grad_output, const Tensor& input, int64_t dim) { + auto grad_input = at::empty({0}, input.options()); + return glu_backward_zoom_out(grad_output, input, dim, grad_input); +} + +// ----------------------------------- +// log_sigmoid forward +// ----------------------------------- + +std::tuple log_sigmoid_forward_out_zoom(const Tensor& input, Tensor& result, Tensor& buffer) { + // NOTE: buffer is only used by CPU dispatch, we just ignore it here + auto iter = TensorIteratorConfig() + .add_output(result) + .add_const_input(input) + .build(); + launch_log_sigmoid_forward_kernel(iter); + return std::forward_as_tuple(result, buffer); +} + +std::tuple log_sigmoid_forward_zoom(const Tensor& input) { + auto result = at::empty_like(input); + auto buffer = at::empty({0}, input.options()); + log_sigmoid_forward_out_zoom(input, result, buffer); + return std::forward_as_tuple(result, buffer); +} + +TORCH_IMPL_FUNC(gelu_out_zoom) ( + const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/ +) { + GeluZoomKernelImpl(*this, get_gelutype_enum(approximate)); +} + +TORCH_IMPL_FUNC(gelu_backward_out_zoom) ( + const Tensor& /*grad*/, const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*grad_input*/ +) { + GeluBackwardZoomKernelImpl(*this, get_gelutype_enum(approximate)); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Activation.h b/aten/src/ATen/native/zoom/Activation.h new file mode 100644 index 0000000000000..309d316bd5fd7 --- /dev/null +++ b/aten/src/ATen/native/zoom/Activation.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace at { +struct TensorIteratorBase; +class TensorBase; +} + +namespace at { namespace native { + +void launch_glu_backward_kernel(const TensorIteratorBase& iter, + int64_t gI_stride, int64_t I_stride); + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter); + +void GeluZoomKernelImpl(TensorIteratorBase& it, GeluType approximate); +void GeluBackwardZoomKernelImpl(TensorIteratorBase& it, GeluType approximate); + +}} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationEluKernel.cu b/aten/src/ATen/native/zoom/ActivationEluKernel.cu new file mode 100644 index 0000000000000..e3f296a2a0ed8 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationEluKernel.cu @@ -0,0 +1,86 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void elu_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > 0 ? aop * poscoef + : std::expm1(aop * negiptcoef) * negcoef; + }); + }); +} + +void elu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef, is_result] GPU_LAMBDA( + scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + + if (is_result) { + return bop <= 0 ? aop * negiptcoef * (bop + negcoef) + : aop * poscoef; + } else { + return bop <= 0 + ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef) + : aop * poscoef; + } + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(elu_stub, &elu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(elu_backward_stub, &elu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationGeluKernel.cu b/aten/src/ATen/native/zoom/ActivationGeluKernel.cu new file mode 100644 index 0000000000000..7da8acc5b7ab1 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationGeluKernel.cu @@ -0,0 +1,88 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +void GeluZoomKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::hip::compat::tanh(inner)); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kAlpha = M_SQRT1_2; + return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + }); + }); + } +} + +void GeluBackwardZoomKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_sq = static_cast(x) * static_cast(x); + auto x_cube = x_sq * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + auto tanh_inner = c10::hip::compat::tanh(inner); + + auto left = opmath_t(0.5) * static_cast(x); + auto right = opmath_t(1) + tanh_inner; + + auto left_derivative = opmath_t(0.5) * right; + + auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; + auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); + auto right_derivative = left * tanh_derivative * inner_derivative; + + return static_cast(dy) * (left_derivative + right_derivative); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardZoomKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); + constexpr opmath_t kAlpha = M_SQRT1_2; + const opmath_t cdf = + opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + const opmath_t pdf = + c10::hip::compat::exp( + opmath_t(-0.5) * static_cast(x) * static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); + }); + }); + } +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationGluKernel.cu b/aten/src/ATen/native/zoom/ActivationGluKernel.cu new file mode 100644 index 0000000000000..c98794cf016a0 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationGluKernel.cu @@ -0,0 +1,141 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// ----------------------------------- +// glu forward +// ----------------------------------- +void glu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_zoom", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a_, scalar_t b_) -> scalar_t { + const opmath_t a = a_; + const opmath_t b = b_; + const opmath_t one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + return a * sigmoid; + }); + }); +} + +// ----------------------------------- +// glu forward ad +// ----------------------------------- +void glu_jvp_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_zoom", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, + [] GPU_LAMBDA( + scalar_t res_, scalar_t b_, scalar_t da_, scalar_t db_) + -> scalar_t { + const opmath_t res = res_; + const opmath_t b = b_; + const opmath_t da = da_; + const opmath_t db = db_; + const opmath_t one = opmath_t(1); + + const opmath_t sig_b = one / (one + std::exp(-b)); + return (da * sig_b + res * (db - sig_b * db)); + }); + }); +} + +// ----------------------------------- +// glu backward +// ----------------------------------- + +// Byte offsets don't require multiplication by sizeof(T), so are slightly +// cheaper. For fixed offsets, this removes all penalty from 64-bit indexing. +template +__device__ T* byte_offset(T* ptr, int64_t offset) { + using byte_ptr_t = typename std:: + conditional::value, const char*, char*>::type; + return reinterpret_cast(reinterpret_cast(ptr) + offset); +} + +template +__global__ void glu_backward_kernel( + int numel, + scalar_t* gI, + const scalar_t* I, + const scalar_t* gO, + OffsetCalc offset_calculator, + int64_t gI_byte_offset, + int64_t I_byte_offset) { + using opmath_t = at::opmath_type; + + const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + if (linear_index >= numel) { + return; + } + const auto offsets = offset_calculator.get(linear_index); + + // We explicitly iterate over the first half of the input tensor, and + // gI_byte_offset and I_byte_offset are the offsets to access the + // corresponding index in the second half of the tensor. + const opmath_t a = I[offsets[1]]; + const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset); + const opmath_t gO_val = gO[offsets[2]]; + + const auto one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + + auto* gA = gI + offsets[0]; + *gA = sigmoid * gO_val; + + auto* gB = byte_offset(gA, gI_byte_offset); + *gB = (one - sigmoid) * sigmoid * gO_val * a; +} + +void launch_glu_backward_kernel( + const TensorIteratorBase& iter, + int64_t gI_stride, + int64_t I_stride) { + const auto N = iter.numel(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + N > 0 && N <= std::numeric_limits::max()); + const auto offset_calculator = make_element_offset_calculator<3>(iter); + constexpr int64_t block_size = 256; + const int64_t grid = (N + block_size - 1) / block_size; + const auto stream = c10::zoom::getCurrentZoomStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "glu_backward_zoom", [&] { + auto gI = static_cast(iter.data_ptr(0)); + auto I = static_cast(iter.data_ptr(1)); + auto gO = static_cast(iter.data_ptr(2)); + glu_backward_kernel<<>>( + N, + gI, + I, + gO, + offset_calculator, + gI_stride * sizeof(scalar_t), + I_stride * sizeof(scalar_t)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(glu_stub, &glu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(glu_jvp_stub, &glu_jvp_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/zoom/ActivationHardshrinkKernel.cu new file mode 100644 index 0000000000000..cb581dbc9d661 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardshrinkKernel.cu @@ -0,0 +1,39 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardshrink_zoom", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return (a >= -lambd && a <= lambd) ? scalar_t(0) : a; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardshrink_stub, &hardshrink_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/zoom/ActivationHardsigmoidKernel.cu new file mode 100644 index 0000000000000..3af90e876b6e8 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardsigmoidKernel.cu @@ -0,0 +1,74 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardsigmoid_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_zoom", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel( + iter, + [zero, one_sixth, three, six] GPU_LAMBDA( + scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_sixth(1.0f / 6.0f); + gpu_kernel( + iter, + [zero, three, neg_three, one_sixth] GPU_LAMBDA( + scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + return (self_val > neg_three && self_val < three) + ? grad_val * one_sixth + : zero; + }); + }); +} + +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardswishKernel.cu b/aten/src/ATen/native/zoom/ActivationHardswishKernel.cu new file mode 100644 index 0000000000000..5b4704cbf85ab --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardswishKernel.cu @@ -0,0 +1,63 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardswish_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_zoom", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return x * std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardswish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_zoom", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_half(0.5f); + gpu_kernel( + iter, + [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + if (self_val < neg_three) { + return zero; + } else if (self_val <= three) { + return grad_val * ((self_val / three) + one_half); + } else { + return grad_val; + } + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardswish_stub, &hardswish_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationHardtanhKernel.cu b/aten/src/ATen/native/zoom/ActivationHardtanhKernel.cu new file mode 100644 index 0000000000000..ecd11f23e87fa --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationHardtanhKernel.cu @@ -0,0 +1,45 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void hardtanh_backward_kernel( + TensorIterator& iter, + const Scalar& min, + const Scalar& max) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + iter.dtype(), "hardtanh_backward_zoom", [&]() { + using opmath_t = at::opmath_type; + auto min_val = min.to(); + auto max_val = max.to(); + gpu_kernel( + iter, + [min_val, max_val] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/zoom/ActivationLeakyReluKernel.cu new file mode 100644 index 0000000000000..94a9a8168c2b0 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationLeakyReluKernel.cu @@ -0,0 +1,62 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel(iter, [negval] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > opmath_t(0) ? aop : aop * negval; + }); + }); +} + +void leaky_relu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel( + iter, [negval] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return aop > opmath_t(0) ? bop : bop * negval; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(leaky_relu_stub, &leaky_relu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/zoom/ActivationLogSigmoidKernel.cu new file mode 100644 index 0000000000000..79bad5edc99db --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationLogSigmoidKernel.cu @@ -0,0 +1,64 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// ----------------------------------- +// log_sigmoid forward +// ----------------------------------- + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_zoom", [&] { + using opmath_t = at::opmath_type; + + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t { + const opmath_t in = in_; + const auto min = std::min(opmath_t(0), in); + const auto z = std::exp(-std::abs(in)); + return min - std::log1p(z); + }); + }); +} + +namespace { +// ----------------------------------- +// log_sigmoid backward +// ----------------------------------- +void log_sigmoid_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_backward_zoom", [&] { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t in_, scalar_t grad_out_) -> scalar_t { + const opmath_t in = in_; + const opmath_t grad_out = grad_out_; + + auto in_negative = in < opmath_t(0); + auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0); + auto sign = in_negative ? opmath_t(1) : -opmath_t(1); + const auto z = std::exp(-std::abs(in)); + return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z))); + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationMishKernel.cu b/aten/src/ATen/native/zoom/ActivationMishKernel.cu new file mode 100644 index 0000000000000..75d69dd119185 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationMishKernel.cu @@ -0,0 +1,64 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void mish_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc * + c10::hip::compat::tanh( + c10::hip::compat::log1p(c10::hip::compat::exp(x_acc))); + }); + }); +} + +void mish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_backward_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::hip::compat::exp(-x_acc)); + const opmath_t t_acc = c10::hip::compat::tanh( + c10::hip::compat::log1p(c10::hip::compat::exp(x_acc))); + return dy_acc * + (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(mish_stub, &mish_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(mish_backward_stub, &mish_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationPreluKernel.cu b/aten/src/ATen/native/zoom/ActivationPreluKernel.cu new file mode 100644 index 0000000000000..512cc7224c5c8 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationPreluKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +// ----------------------------------- +// prelu +// ----------------------------------- +void prelu_kernel(TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_zoom", [&] { + gpu_kernel(iter, + [] GPU_LAMBDA (scalar_t input, scalar_t weight) -> scalar_t { + return (input > 0) ? input : weight * input; + }); + }); +} + +void prelu_backward_kernel(TensorIterator &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "prelu_backward_zoom", [&] { + gpu_kernel_multiple_outputs(iter, + [] GPU_LAMBDA (scalar_t input, scalar_t weight, scalar_t grad) -> thrust::tuple { + auto mask = input > 0; + auto grad_input = mask ? grad : weight * grad; + auto grad_weight = mask ? scalar_t{0} : input * grad; + return {grad_input, grad_weight}; + }); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(prelu_stub, &prelu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(prelu_backward_stub, &prelu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationSiluKernel.cu b/aten/src/ATen/native/zoom/ActivationSiluKernel.cu new file mode 100644 index 0000000000000..04f7d204a3a97 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationSiluKernel.cu @@ -0,0 +1,60 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void silu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc / (opmath_t(1) + ::exp(-x_acc)); + }); + }); +} + +void silu_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_backward_zoom", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::hip::compat::exp(-x_acc)); + return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc)); + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(silu_stub, &silu_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(silu_backward_stub, &silu_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationSoftplusKernel.cu b/aten/src/ATen/native/zoom/ActivationSoftplusKernel.cu new file mode 100644 index 0000000000000..ed3358d225af7 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationSoftplusKernel.cu @@ -0,0 +1,74 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void softplus_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel(iter, [beta, threshold] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return (aop * beta) > threshold + ? aop + : (::log1p(std::exp(aop * beta))) / beta; + }); + }); +} + +void softplus_backward_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_backward_zoom", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel( + iter, + [beta, threshold] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + opmath_t z = std::exp(bop * beta); + return (bop * beta) > threshold ? aop + : aop * z / (z + opmath_t(1.)); + }); + }); +} + +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(softplus_stub, &softplus_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/zoom/ActivationSoftshrinkKernel.cu new file mode 100644 index 0000000000000..69e27e22b477f --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationSoftshrinkKernel.cu @@ -0,0 +1,58 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softshrink_zoom", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)); + }); + }); +} + +void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "shrink_backward_zoom", + [&]() { + auto lambd = value.to(); + gpu_kernel( + iter, + [lambd] GPU_LAMBDA( + scalar_t grad_val, scalar_t self_val) -> scalar_t { + return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) + : grad_val; + }); + }); +} +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(softshrink_stub, &softshrink_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(shrink_backward_stub, &shrink_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ActivationThresholdKernel.cu b/aten/src/ATen/native/zoom/ActivationThresholdKernel.cu new file mode 100644 index 0000000000000..0d6a1c7e15f80 --- /dev/null +++ b/aten/src/ATen/native/zoom/ActivationThresholdKernel.cu @@ -0,0 +1,52 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +namespace { + +template +void threshold_kernel_impl( + TensorIteratorBase& iter, + scalar_t threshold, + scalar_t value) { + gpu_kernel_with_scalars( + iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t { + return x <= threshold ? value : other; + }); +} + +static void threshold_kernel_zoom( + TensorIteratorBase& iter, + const Scalar& threshold, + const Scalar& value) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "threshold_zoom", + [&] { + threshold_kernel_impl( + iter, threshold.to(), value.to()); + }); +} + +} // namespace + +REGISTER_PRIVATEUSE1_DISPATCH(threshold_stub, &threshold_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/AdaptiveAveragePooling.cu b/aten/src/ATen/native/zoom/AdaptiveAveragePooling.cu new file mode 100644 index 0000000000000..a3a8fa6e33c98 --- /dev/null +++ b/aten/src/ATen/native/zoom/AdaptiveAveragePooling.cu @@ -0,0 +1,822 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +#include +#include +#include + +#define START_IND(a,b,c) ((int64_t)((a / b) * c + ((a % b) * c) / b)) +#define END_IND(a,b,c) (1 + ((int64_t)(a + 1) * c - 1) / b) + +#define START_IND_INT(a,b,c) ((a * c) / b) +#define END_IND_INT(a,b,c) (((a + 1) * c + b - 1) / b) +// #define START_IND(a,b,c) a * c / b +// #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0 + +#define HIP_MAX_THREADS 1024 // this is safe, in reality 256 is our limit +#define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched + +namespace at::native { + +namespace { + + // 4d tensor B x D x H x W + // All kernels view batch dim B and feature dim D as collapsed. + + /* + * Description: + * this function adaptively average pools an input 4D tensor along dimensions 2 and 3 + * 4D input, 4D output + */ + template + __global__ void adaptive_average_pool(const scalar_t *input, scalar_t *output, + int isizeH, int isizeW, + int osizeH, int osizeW, + int64_t istrideD, int64_t istrideH, int64_t istrideW) + { + using opmath_t = at::opmath_type; + // iterators on output pixels + int oh, ow; + + // select input/output plane based on thread/block ID + int o_plane = blockIdx.x; + int i_plane = o_plane; + + output = output + o_plane*osizeH*osizeW; + input = input + i_plane*istrideD; + + int ostartH = blockDim.y*blockIdx.y + threadIdx.y; + int oendH = osizeH; + const int ostepH = blockDim.y*gridDim.y; + + int ostartW = threadIdx.x; + int oendW = osizeW; + const int ostepW = blockDim.x; + + // For all output pixels... + for(oh = ostartH; oh < oendH; oh += ostepH) { + + int istartH = START_IND(oh, osizeH, isizeH); + int iendH = END_IND(oh, osizeH, isizeH); + int kH = iendH - istartH; + + for(ow = ostartW; ow < oendW; ow += ostepW) { + + int istartW = START_IND(ow, osizeW, isizeW); + int iendW = END_IND(ow, osizeW, isizeW); + int kW = iendW - istartW; + + // Compute the average pooling over corresponding input pixels + const scalar_t *ptr_input = input + istartH*istrideH + istartW*istrideW; + scalar_t *ptr_output = output + oh*osizeW + ow; + opmath_t sum = static_cast(0); + int ih, iw; + for(ih = 0; ih < kH; ++ih) { + for(iw = 0; iw < kW; ++iw) { + scalar_t val = ptr_input[iw*istrideW]; + sum += val; + } + ptr_input += istrideH; // next input line + } + // Update output + *ptr_output = sum / kH / kW; + } + } + } + + /* + * Description: + * this function computes the gradInput from gradOutput + */ + template + __global__ void adaptive_average_gradinput( + T *gradInput, const T *gradOutput, + int isizeH, int isizeW, int osizeH, int osizeW + ) + { + // iterators on input pixels + int ih, iw; + + // select input/output plane based on thread/block ID + int i_plane = blockIdx.x; + int o_plane = i_plane; + + gradOutput = gradOutput + o_plane*osizeH*osizeW; + gradInput = gradInput + i_plane*isizeH*isizeW; + + int istartH = blockDim.y*blockIdx.y + threadIdx.y; + int iendH = isizeH; + int istepH = blockDim.y*gridDim.y; + + int istartW = threadIdx.x; + int iendW = isizeW; + int istepW = blockDim.x; + + // compute gradInput + for(ih = istartH; ih < iendH; ih += istepH) { + + int ostartH = START_IND(ih, isizeH, osizeH); + int oendH = END_IND(ih, isizeH, osizeH); + + for(iw = istartW; iw < iendW; iw += istepW) { + + int ostartW = START_IND(iw, isizeW, osizeW); + int oendW = END_IND(iw, isizeW, osizeW); + + // Compute the gradients over corresponding output pixels + T *ptr_gradInput = gradInput + ih*isizeW + iw; + + int oh, ow; + for(oh = ostartH; oh < oendH; ++oh) { + int kH = START_IND(oh, osizeH, isizeH) - END_IND(oh, osizeH, isizeH); + for(ow = ostartW; ow < oendW; ++ow) { + int kW = START_IND(ow, osizeW, isizeW) - END_IND(ow, osizeW, isizeW); + T grad_delta = gradOutput[ow + oh*osizeW] / kH / kW; + *ptr_gradInput += grad_delta; + } + } + } + } + } + + /* + * Description: + * this function computes the gradInput from gradOutput + * (uses atomic add) + */ + template + __global__ void atomic_adaptive_average_gradinput( + T *gradInput, const T *gradOutput, + int isizeH, int isizeW, int osizeH, int osizeW + ) + { + // iterators on output indices + int oh, ow; + + // select input/output plane based on thread/block ID + int o_plane = blockIdx.x; + int i_plane = o_plane; + + gradOutput = gradOutput + o_plane*osizeW*osizeH; + gradInput = gradInput + i_plane*isizeW*isizeH; + + int ostartH = blockDim.y*blockIdx.y + threadIdx.y; + int oendH = osizeH; + int ostepH = blockDim.y*gridDim.y; + + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + // For all output pixels... + for(oh = ostartH; oh < oendH; oh += ostepH) { + + int istartH = START_IND(oh, osizeH, isizeH); + int iendH = END_IND(oh, osizeH, isizeH); + int kH = iendH - istartH; + + for(ow = ostartW; ow < oendW; ow += ostepW) { + + int istartW = START_IND(ow, osizeW, isizeW); + int iendW = END_IND(ow, osizeW, isizeW); + int kW = iendW - istartW; + + // Compute the gradients for over corresponding input pixels + T *ptr_gradInput = gradInput + istartH*isizeW + istartW; + const T *ptr_gradOutput = gradOutput + oh*osizeW + ow; + T grad_delta = *ptr_gradOutput / kW / kH; + + int ih, iw; + for(ih = 0; ih < kH; ++ih) { + for(iw = 0; iw < kW; ++iw) { + // atomic add since different threads could update same variable + gpuAtomicAddNoReturn(&(ptr_gradInput[iw]), grad_delta); + } + ptr_gradInput += isizeW; // next input line + } + } + } + } + + /* + * Description: + * this function adaptively average pools an input 4D tensor along dimensions 2 and 3 + * NHWC layout for both input and output tensor + * 4D input, 4D output + */ + template + C10_LAUNCH_BOUNDS_1(HIP_MAX_THREADS) + __global__ void adaptive_average_pool_nhwc(const scalar_t* __restrict__ input, scalar_t* __restrict__ output, + int sizeB, int sizeC, + int isizeH, int isizeW, + int osizeH, int osizeW, + int kernel_stride_C, int kernel_size_C, + index_t istrideB, index_t istrideC, + index_t istrideH, index_t istrideW) + { + using opmath_t = at::opmath_type; + extern __shared__ int smem[]; + opmath_t *out_cached = reinterpret_cast(smem); + + // flattening cta for pre-computation & smem initialization; + int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); + int block_size = blockDim.x * blockDim.y * blockDim.z; + + // use shared memory to store temporary output value. This is simply to + // reduce register usage. + for (index_t i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) { + out_cached[i] = opmath_t(0.0); + } + + __syncthreads(); + + // each CTA handles a portion of a single slice on batch dimension; + int batch_id = blockIdx.x % sizeB; + int channel_id = blockIdx.x / sizeB; + int channel_offset = threadIdx.x + channel_id * blockDim.x; + + // each CTA handles a single slice on batch dimension; + // We use gridDim.x to handle striding on C as well. + output = output + batch_id * osizeH * osizeW * sizeC; + input = input + batch_id * istrideB; + + // split out_cached and exclusively it assigned to each thread; + out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C * blockDim.x]; + + // iterate on output H & W. + // Each CTA handles a consecutive H & W section (TILE); Do NOT stride CTA on + // tile so there's a better chance to hit L1 cache. + index_t oH = (osizeH + gridDim.z-1) / gridDim.z; + index_t oW = (osizeW + gridDim.y-1) / gridDim.y; + index_t ostartH = threadIdx.z + blockIdx.z*oH; + index_t oendH = ::min(ostartH+oH, osizeH); + index_t ostartW = threadIdx.y + blockIdx.y*oW; + index_t oendW = ::min(ostartW+oW, osizeW); + + // Stride for threads, each warp can reuse L1 as they go. So theoretically + // better chance to survive cache eviction. + for (int oh = ostartH; oh < oendH; oh+=blockDim.z) { + int istartH = START_IND_INT(oh, osizeH, isizeH); + int iendH = END_IND_INT(oh, osizeH, isizeH); + for (int ow = ostartW; ow < oendW; ow+=blockDim.y) { + int istartW = START_IND_INT(ow, osizeW, isizeW); + int iendW = END_IND_INT(ow, osizeW, isizeW); + scalar_t factor = scalar_t(1.0) / ((iendH-istartH) * (iendW-istartW)); + + // loop on input: hierarchy h->w->c, use shared memory here hopefully + // would not stall global memory read; + for (index_t ih = istartH; ih < iendH; ih++) { + for (index_t iw = istartW; iw < iendW; iw++) { + int cached_index = threadIdx.x; + const scalar_t *ptr_input = input + ih*istrideH + iw*istrideW; + for (index_t c = channel_offset; + c < sizeC; + c += blockDim.x*kernel_stride_C) { + out_cached[cached_index] += ptr_input[c*istrideC]; + cached_index += blockDim.x; + } + } + } + scalar_t *ptr_output = output + (oh * osizeW + ow) * sizeC; + + int cached_index = threadIdx.x; + // write accumulated output to global memory; + for (index_t c = channel_offset; + c < sizeC; + c += blockDim.x*kernel_stride_C) { + // This causes numerical issueptr when unit test with NCHW kernel; + // switch to could verify the correctness; + // output[c] = out_cached[c] / (iendH-istartH) / (iendW-istartW); + ptr_output[c] = out_cached[cached_index] * factor; + out_cached[cached_index] = opmath_t(0.0); + cached_index += blockDim.x; + } + // no need to __syncthreads() since out_cached is not shared. + } + } + } + + /* + * Description: + * this function computes the gradInput from gradOutput + * NHWC layout for both input and output tensor + * 4D input, 4D output + */ + template + C10_LAUNCH_BOUNDS_1(HIP_MAX_THREADS) + __global__ void adaptive_average_gradinput_nhwc(scalar_t* __restrict__ gradInput, const scalar_t* __restrict__ gradOutput, + int sizeB, int sizeC, + int isizeH, int isizeW, + int osizeH, int osizeW, + int kernel_stride_C, int kernel_size_C, + index_t ostrideB, index_t ostrideC, + index_t ostrideH, index_t ostrideW) + { + extern __shared__ int smem[]; + index_t *ostartW_cached = smem; + index_t *oendW_cached = &ostartW_cached[isizeW]; + + // be careful with alignment, in case scalar_t is fp16, we want to assign + // int pointers first. + scalar_t *r_kW_cached = reinterpret_cast(&oendW_cached[isizeW]); + scalar_t *r_kH_cached = &r_kW_cached[osizeW]; + scalar_t *out_cached = &r_kH_cached[osizeH]; + + // flattening cta for pre-computation & smem initialization; + int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); + int block_size = blockDim.x * blockDim.y * blockDim.z; + + // Precompute output start/end index per input index on width dimension; + // Not doing this for height dimension, as that's our out-most loop. + for (index_t i = thread_id; i < isizeW; i+= block_size) { + ostartW_cached[i] = START_IND_INT(i, isizeW, osizeW); + oendW_cached[i] = END_IND_INT(i, isizeW, osizeW); + } + + // Precompute pooling height/weight factor for each output element; + // This is used to weight output gradient when accumulate them on input + // gradient. + // Technically we don't have to compute it for the whole `osizeH`, since + // each cta only covers a consecutive portion of the entire output. But it's + // not going to save us from code divergence, and shared memory save is not + // an issue neither, so just leave it as is for now. + for (index_t i = thread_id; i < osizeH; i+= block_size) { + r_kH_cached[i] = scalar_t(1.0) / (END_IND_INT(i, osizeH, isizeH) - START_IND_INT(i, osizeH, isizeH)); + } + for (index_t i = thread_id; i < osizeW; i+= block_size) { + r_kW_cached[i] = scalar_t(1.0) / (END_IND_INT(i, osizeW, isizeW) - START_IND_INT(i, osizeW, isizeW)); + } + + // each CTA handles a portion of a single slice on batch dimension; + int batch_id = blockIdx.x % sizeB; + int channel_id = blockIdx.x / sizeB; + int channel_offset = threadIdx.x + channel_id * blockDim.x; + + // use shared memory to store temporary output value. This is simply to + // reduce register usage. + for (index_t i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) { + out_cached[i] = scalar_t(0.0); + } + + __syncthreads(); + + // each CTA handles a portion of a single slice on batch dimension; + // We use gridDim.x to handle striding on C as well. + gradInput = gradInput + batch_id * isizeH * isizeW * sizeC; + gradOutput = gradOutput + batch_id * ostrideB; + + // split out_cached and exclusively it assigned to each thread; + out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x * kernel_size_C]; + + // iterate on input H & W. + // Each CTA handles a consecutive H & W section (TILE); Do NOT stride CTA on + // tile so there's a better chance to hit L1 cache. + index_t iH = (isizeH + gridDim.z-1) / gridDim.z; + index_t iW = (isizeW + gridDim.y-1) / gridDim.y; + index_t istartH = threadIdx.z + blockIdx.z*iH; + index_t iendH = ::min(istartH+iH, isizeH); + index_t istartW = threadIdx.y + blockIdx.y*iW; + index_t iendW = ::min(istartW+iW, isizeW); + + // Stride for threads, each warp can reuse L1 as they go. So theoretically + // better chance to survive cache eviction. + for (index_t ih = istartH; ih < iendH; ih+=blockDim.z) { + index_t ostartH = START_IND_INT(ih, isizeH, osizeH); + index_t oendH = END_IND_INT(ih, isizeH, osizeH); + for (index_t iw = istartW; iw < iendW; iw+=blockDim.y) { + // loop on output: hierarchy h->w->c, so we could reuse weight factor f + // because it remains the same for given oh & ow + for(index_t oh = ostartH; oh < oendH; ++oh) { + for(index_t ow = ostartW_cached[iw]; ow < oendW_cached[iw]; ++ow) { + scalar_t f = r_kW_cached[ow] * r_kH_cached[oh]; + const scalar_t* ptr_gradOutput = gradOutput + oh*ostrideH + ow*ostrideW; + int cached_index = threadIdx.x; + for (index_t c = channel_offset; + c < sizeC; + c += blockDim.x*kernel_stride_C) { + out_cached[cached_index] += ptr_gradOutput[c*ostrideC] * f; + cached_index += blockDim.x; + } + } + } + scalar_t *ptr_gradInput = gradInput + (ih * isizeW + iw) * sizeC; + int cached_index = threadIdx.x; + // write accumulated gradIput to global memory; + for (index_t c = channel_offset; + c < sizeC; + c += blockDim.x*kernel_stride_C) { + ptr_gradInput[c] = out_cached[cached_index]; + out_cached[cached_index] = scalar_t(0.0); + cached_index += blockDim.x; + } + // no need to __syncthreads() since out_cached is not shared. + } + } + } + + // 4d tensor B x D x H x W + + void adaptive_avg_pool2d_out_zoom_template( + Tensor& output, + const Tensor& input, + IntArrayRef output_size) + { + TensorArg input_arg{ input, "input", 1 }, + output_arg{ output, "output", 2 }; + checkAllSameGPU(__func__, {input_arg, output_arg}); + + TORCH_CHECK(output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2"); + int64_t ndim = input.dim(); + TORCH_CHECK((ndim == 3 || ndim == 4), + "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ", input.sizes()); + for (const auto i : {-2, -1}) { + TORCH_CHECK(input.size(i) > 0, + "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, " + "but input has sizes ", input.sizes(), " with dimension ", i + ndim, " being " + "empty"); + } + + Tensor input_ = input; + switch (input.suggest_memory_format()) { + case at::MemoryFormat::ChannelsLast: { + // special case for tensor memory format in channels_last + TORCH_CHECK(input.ndimension() == 4, + "adaptive_avg_pool2d(): Expected 4D tensor, but got ", + input.sizes()); + + int sizeB = input_.size(0); + int sizeC = input_.size(1); + int isizeH = input_.size(2); + int isizeW = input_.size(3); + + int64_t istrideB = input_.stride(0); + int64_t istrideC = input_.stride(1); + int64_t istrideH = input_.stride(2); + int64_t istrideW = input_.stride(3); + + int osizeH = output_size[0]; + int osizeW = output_size[1]; + // preserve channels_last stride on output tensor; + if (!output.is_contiguous(at::MemoryFormat::ChannelsLast)) { + // TODO: modify this after resize_ added `memory_format` tag + output.resize_({sizeB, sizeC, osizeH, osizeW}).as_strided_({sizeB, sizeC, osizeH, osizeW}, {sizeC*osizeH*osizeW, 1, osizeW*sizeC, sizeC}); + } + + if (output.numel() == 0) { + return; + } + + const int max_threads = std::min( + at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock, HIP_MAX_THREADS); + int* maxThreadsDim = at::zoom::getCurrentDeviceProperties()->maxThreadsDim; + int* maxGridSize = at::zoom::getCurrentDeviceProperties()->maxGridSize; + size_t sharedMemPerBlock = at::zoom::getCurrentDeviceProperties()->sharedMemPerBlock; + + // Launch kernel on output tensor elements. Logic behind launch config: + // output tensor size NCHW, strides NHWC; + // Launch on: + // N -> grid.x + // H -> grid.z * block.z + // W -> grid.y * block.y + // C -> block.x + // encourage larger block_y & block_z for better cache hit while maintain + // reasonable block_x for coalesced memory access; + int block_x = std::min( + maxThreadsDim[0], std::min(lastPow2(sizeC), at::zoom::warp_size())); + int block_y = std::min( + maxThreadsDim[1], std::min(lastPow2(osizeW), max_threads / block_x)); + int block_z = std::min( + maxThreadsDim[2], std::min(lastPow2(osizeH), max_threads / block_x / block_y)); + block_x = std::min( + maxThreadsDim[0], std::min(lastPow2(sizeC), max_threads / block_y / block_z)); + const dim3 block(block_x, block_y, block_z); + int kernel_stride_C = ceil_div(sizeC, block_x * 4); + int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C); + + // Do NOT clip grid_x, striding on Batch dimension is not in the kernel, + // although it could be easily implemented given current kernel. + int grid_x = sizeB*kernel_stride_C; + // it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel; + int grid_y = std::min( + maxGridSize[1], ceil_div(osizeW, block_y*BLOCK_STRIDE)); + int grid_z = std::min( + maxGridSize[2], ceil_div(osizeH, block_z*BLOCK_STRIDE)); + const dim3 grid(grid_x, grid_y, grid_z); + + + // we are dealing with packed tensor here. max index is the same as numel. + // TODO: to really support input tensor large enought to go beyond int32, + // we will need to restrict out shared memory usage and adjust the launch + // config; + AT_ASSERT(input_.numel() < std::numeric_limits::max()); + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input_.scalar_type(), "adaptive_avg_pool2d_nhwc_zoom", [&] { + using opmath_t = at::opmath_type; + size_t shmem_size = (kernel_size_C * block_x * block_y * block_z) * sizeof(opmath_t); + AT_ASSERT(shmem_size <= sharedMemPerBlock); + adaptive_average_pool_nhwc<<>> ( + input_.const_data_ptr(), + output.mutable_data_ptr(), + sizeB, sizeC, isizeH, isizeW, osizeH, osizeW, + kernel_stride_C, kernel_size_C, + istrideB, istrideC, istrideH, istrideW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + ); + break; + } + case at::MemoryFormat::Contiguous: { + int64_t grid_x = input.size(-3); + if (input.ndimension() == 4) { + input_ = input.contiguous(); + grid_x *= input_.size(-4); + } + int64_t sizeD = input_.size(-3); + int64_t isizeH = input_.size(-2); + int64_t isizeW = input_.size(-1); + + int64_t istrideD = input_.stride(-3); + int64_t istrideH = input_.stride(-2); + int64_t istrideW = input_.stride(-1); + + int64_t osizeH = output_size[0]; + int64_t osizeW = output_size[1]; + if (input.ndimension() == 4) { + output.resize_({input_.size(-4), sizeD, osizeH, osizeW}); + } else { + output.resize_({sizeD, osizeH, osizeW}); + } + if (output.numel() == 0) { + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input_.scalar_type(), "adaptive_avg_pool2d_zoom", [&] { + const scalar_t *input_data = input_.const_data_ptr(); + scalar_t *output_data = output.mutable_data_ptr(); + + // cuda blocks & threads: + int blocksH = std::max((int)(16L / sizeD), 1); + dim3 blocks(grid_x, blocksH); + dim3 threads(32, 8); + + // run averagepool kernel + adaptive_average_pool <<>> ( + input_data, output_data, + isizeH, isizeW, osizeH, osizeW, + istrideD, istrideH, istrideW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + ); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } + } + + void adaptive_avg_pool2d_backward_out_zoom_template( + Tensor& gradInput, + const Tensor& gradOutput_, + const Tensor& input) + { + TensorArg grad_input_arg{ gradInput, "gradInput", 1 }, + grad_output_arg{ gradOutput_, "gradOutput_", 2 }, + input_arg{ input, "input", 3 }; + + adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool2d_backward"); + + checkAllSameGPU(__func__, {grad_input_arg, grad_output_arg, input_arg}); + + switch (input.suggest_memory_format()) { + case at::MemoryFormat::ChannelsLast: { + // special case for tensor memory format in channels_last + TORCH_CHECK(input.ndimension() == 4, + "adaptive_avg_pool2d_backward_zoom(): Expected 4D tensor, but got ", input.ndimension()); + + int sizeB = input.size(0); + int sizeC = input.size(1); + int isizeH = input.size(2); + int isizeW = input.size(3); + + Tensor gradOutput = gradOutput_; + + int64_t ostrideB = gradOutput.stride(0); + int64_t ostrideC = gradOutput.stride(1); + int64_t ostrideH = gradOutput.stride(2); + int64_t ostrideW = gradOutput.stride(3); + + int osizeH = gradOutput.size(-2); + int osizeW = gradOutput.size(-1); + + // preserve channels_last stride on input tensor; + if (!gradInput.is_contiguous(at::MemoryFormat::ChannelsLast)) { + gradInput.as_strided_( + {sizeB, sizeC, isizeH, isizeW}, + {sizeC*isizeH*isizeW, 1, isizeW*sizeC, sizeC}); + } + + int max_threads = std::min( + at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock, HIP_MAX_THREADS); + int* maxThreadsDim = at::zoom::getCurrentDeviceProperties()->maxThreadsDim; + int* maxGridSize = at::zoom::getCurrentDeviceProperties()->maxGridSize; + size_t sharedMemPerBlock = at::zoom::getCurrentDeviceProperties()->sharedMemPerBlock; + + // Launch kernel on input tensor elements. Logic behind launch config: + // input tensor size NCHW, strides NHWC; + // Launch on: + // N(C) -> grid.x (striding on C to reduce sh_mem usage) + // H -> grid.z * block.z + // W -> grid.y * block.y + // C -> block.x + // encourage larger block_y & block_z for better cache hit while maintain + // reasonable block_x for coalesced memory access; + bool done = false; + do { + int block_x = std::max(std::min( + maxThreadsDim[0], std::min(lastPow2(sizeC), at::zoom::warp_size())), 1); + int block_y = std::max(std::min( + maxThreadsDim[1], std::min(lastPow2(isizeW), max_threads / block_x)), 1); + int block_z = std::max(std::min( + maxThreadsDim[2], std::min(lastPow2(isizeH), max_threads / block_x / block_y)), 1); + block_x = std::max(std::min( + maxThreadsDim[0], std::min(lastPow2(sizeC), max_threads / block_y / block_z)), 1); + const dim3 block(block_x, block_y, block_z); + int kernel_stride_C = ceil_div(sizeC, block_x * 4); + int kernel_size_C = ceil_div(sizeC, block_x * kernel_stride_C); + + // Do NOT clip grid_x, striding on Batch dimension is not in the kernel, + // although it could be easily implemented given current kernel. + int grid_x = sizeB*kernel_stride_C; + // it's OK to clip grid_y & grid_z, as we block the two dimensions in the kernel; + int grid_y = std::min( + maxGridSize[1], ceil_div(isizeW, block_y*BLOCK_STRIDE)); + int grid_z = std::min( + maxGridSize[2], ceil_div(isizeH, block_z*BLOCK_STRIDE)); + const dim3 grid(grid_x, grid_y, grid_z); + + // we are dealing with packed tensor here. max index is the same as numel. + // TODO: to really support input tensor large enought to go beyond int32, + // we will need to restrict out shared memory usage and adjust the launch + // config; + AT_ASSERT(input.numel() < std::numeric_limits::max()); + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), "adaptive_avg_pool2d_backward_nhwc_zoom", [&] { + size_t shmem_size = (kernel_size_C * block_x * block_y * block_z + osizeH + osizeW) * sizeof(scalar_t) + 2 * isizeW * sizeof(int32_t); + if (shmem_size <= sharedMemPerBlock) { + adaptive_average_gradinput_nhwc<<>> ( + gradInput.mutable_data_ptr(), + gradOutput.const_data_ptr(), + sizeB, sizeC, isizeH, isizeW, osizeH, osizeW, + kernel_stride_C, kernel_size_C, + ostrideB, ostrideC, ostrideH, ostrideW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + done = true; + } else { + TORCH_WARN_ONCE("Requested shmem_size exceeds sharedMemPerBlock limit! Reducing max_threads..."); + max_threads /= 2; + } + } + ); + } while (!done && max_threads); + if (!done) { + TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accomodate sharedMemPerBlock limit"); + } + break; + } + case at::MemoryFormat::Contiguous: { + bool atomic = true; // suboptimal, but without atomic it doesn't pass the tests + + Tensor gradOutput = gradOutput_.contiguous(); + + int64_t sizeD = input.size(-3); + int64_t isizeH = input.size(-2); + int64_t isizeW = input.size(-1); + + int64_t osizeH = gradOutput.size(-2); + int64_t osizeW = gradOutput.size(-1); + + int64_t grid_x = sizeD; + if (input.ndimension() == 4) grid_x *= input.size(-4); + + //bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0); + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), "adaptive_avg_pool2d_backward_zoom", [&] { + const scalar_t *gradOutput_data = gradOutput.const_data_ptr(); + scalar_t *gradInput_data = gradInput.mutable_data_ptr(); + + // cuda blocks & threads: + int blocksH = std::max((int)(16L / sizeD), 1); + dim3 blocks(grid_x, blocksH); + dim3 threads(32, 8); + + if(atomic) + { + // run updateGradInput kernel, accumulate gradients atomically + atomic_adaptive_average_gradinput <<>> ( + gradInput_data, gradOutput_data, + isizeH, isizeW, osizeH, osizeW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + else + { + // run updateGradInput kernel + adaptive_average_gradinput <<>> ( + gradInput_data, gradOutput_data, + isizeH, isizeW, osizeH, osizeW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } + ); + break; + } + default: + TORCH_CHECK( + false, + "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + + } + } + +} // namespace + + Tensor& adaptive_avg_pool2d_out_zoom( + const Tensor& input, + IntArrayRef output_size, + Tensor& output) + { + adaptive_avg_pool2d_out_zoom_template( + output, input, output_size); + return output; + } + + Tensor adaptive_avg_pool2d_zoom( + at::Tensor const& input, + IntArrayRef output_size) + { + auto output = at::empty({0}, input.options()); + adaptive_avg_pool2d_out_zoom_template( + output, input, output_size); + return output; + } + + Tensor& adaptive_avg_pool2d_backward_out_zoom( + Tensor& gradInput, + const Tensor& gradOutput, + const Tensor& input) + { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("adaptive_avg_pool2d_backward_out_zoom"); + gradInput.resize_as_(input); + if (gradInput.numel() != 0) { + adaptive_avg_pool2d_backward_out_zoom_template( + gradInput, gradOutput, input); + } + return gradInput; + } + + Tensor adaptive_avg_pool2d_backward_zoom( + const Tensor& gradOutput, + const Tensor& input) + { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("adaptive_avg_pool2d_backward_zoom"); + auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + if (gradInput.numel() != 0) { + adaptive_avg_pool2d_backward_out_zoom_template( + gradInput, gradOutput, input); + } + return gradInput; + } + +} // namespace at::native + +#undef BLOCK_STRIDE +#undef HIP_MAX_THREADS +#undef START_IND +#undef END_IND diff --git a/aten/src/ATen/native/zoom/AdaptiveAveragePooling3d.cu b/aten/src/ATen/native/zoom/AdaptiveAveragePooling3d.cu new file mode 100644 index 0000000000000..2253c9a215144 --- /dev/null +++ b/aten/src/ATen/native/zoom/AdaptiveAveragePooling3d.cu @@ -0,0 +1,545 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +#include +#include +#include + + +namespace at::native { + +namespace { + +__device__ inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (a / b) * c + ((a % b) * c) / b; +} + +__device__ inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return 1 + ((a + 1) * c - 1) / b; +} + +// 5d tensor B x D x T x H x W +// All kernels view batch dim B and dim D as collapsed. + +/* + * Description: + * this function adaptively average pools an input 5D tensor along dimensions + * 2, 3, and 4 5D input, 5D output + * + * gridDim.y blocks work together on a single 2D output plane specified by + * (blockIdx.x + offsetZ). + */ +template +__global__ void adaptiveaveragepool( + const scalar_t *input, scalar_t *output, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t istrideD, + int64_t istrideT, int64_t istrideH, int64_t istrideW, + int64_t offsetZ) { + // iterates on output pixels + int ot, oh, ow; + + // compute offsets based on thread/block ID + int ostartH = blockIdx.y * blockDim.y + threadIdx.y; + int oendH = osizeH; + int ostepH = gridDim.y * blockDim.y; + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + // select output plane + int64_t o_plane = blockIdx.x + offsetZ; + ot = o_plane % osizeT; // output frame/time + int d = o_plane / osizeT; // slice/feature + + // input frame/time range is fixed. + int istartT = start_index(ot, osizeT, isizeT); + int iendT = end_index(ot, osizeT, isizeT); + int kT = iendT - istartT; + + // input offset by slice/feature and earliest relevant frame/time + const scalar_t *input_dt = input + d*istrideD + istartT*istrideT; + // output offset by slice/feature and frame/time + scalar_t *output_dt = output + o_plane*osizeH*osizeW; + + // For all output pixels... + for (oh = ostartH; oh < oendH; oh += ostepH) { + int istartH = start_index(oh, osizeH, isizeH); + int iendH = end_index(oh, osizeH, isizeH); + int kH = iendH - istartH; + + for (ow = ostartW; ow < oendW; ow += ostepW) { + int istartW = start_index(ow, osizeW, isizeW); + int iendW = end_index(ow, osizeW, isizeW); + int kW = iendW - istartW; + + // Compute the average pooling from corresponding input pixels + const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW; + scalar_t *ptr_output = output_dt + oh*osizeW + ow; + accscalar_t sum = static_cast(0); + + int it, ih, iw; + for (it = 0; it < kT; ++it) { + for (ih = 0; ih < kH; ++ih) { + for (iw = 0; iw < kW; ++iw) { + scalar_t val = ptr_input[ih*istrideH + iw*istrideW]; + sum += static_cast(val); + } + } + ptr_input += istrideT; // next input frame + } + // Update output + const accscalar_t divide_factor = static_cast(kT * kH * kW); + *ptr_output = static_cast(sum / divide_factor); + } + } +} + +template +void adaptiveaveragepool_loop( + const scalar_t *input_data, scalar_t *output_data, + int64_t totalZ, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t istrideD, int64_t istrideT, int64_t istrideH, int64_t istrideW) { + int64_t offsetZ = 0; + dim3 threads(32, 8); + // each H*W plane is processed by blocksH thread blocks + int blocksH = std::max((int)(16L / totalZ), 1); + while (totalZ > 0) { + dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH); + adaptiveaveragepool + <<>>( + input_data, output_data, + isizeT, isizeH, isizeW, + osizeT, osizeH, osizeW, + istrideD, + istrideT, istrideH, istrideW, + offsetZ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + totalZ -= 65535; + offsetZ += 65535; + } +} + +/* + * Description: + * This function computes the gradInput from gradOutput. + * + * gridDim.y blocks work together on a single 2D output plane specified by + * (blockIdx.x + offsetZ). + */ +template +__global__ void adaptiveaveragegradinput( + scalar_t *gradInput, const scalar_t *gradOutput, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t offsetZ) +{ + // iterators on input pixels + int it, ih, iw; + + // compute offsets based on thread/block ID + int istartH = blockIdx.y * blockDim.y + threadIdx.y; + int iendH = isizeH; + int istepH = gridDim.y * blockDim.y; + int istartW = threadIdx.x; + int iendW = isizeW; + int istepW = blockDim.x; + + // select input plane + int64_t i_plane = blockIdx.x + offsetZ; + it = i_plane % isizeT; // output frame/time + int d = i_plane / isizeT; // slice/feature + + // output frame/time range is fixed. + int ostartT = start_index(it, isizeT, osizeT); + int oendT = end_index(it, isizeT, osizeT); + + // gradInput offset by slice/feature and frame/time. + scalar_t *gradInput_dt = gradInput + i_plane*isizeH*isizeW; + // gradOutput offset by slice/feature and earliest relevant frame/time + const scalar_t *gradOutput_dt = gradOutput + (d*osizeT + ostartT)*osizeH*osizeW; + + // For all input pixels... + for (ih = istartH; ih < iendH; ih += istepH) { + int ostartH = start_index(ih, isizeH, osizeH); + int oendH = end_index(ih, isizeH, osizeH); + + for (iw = istartW; iw < iendW; iw += istepW) { + int ostartW = start_index(iw, isizeW, osizeW); + int oendW = end_index(iw, isizeW, osizeW); + + // Compute the gradients from corresponding output pixels + scalar_t *ptr_gradInput = gradInput_dt + ih*isizeW + iw; + const scalar_t *ptr_gradOutput = gradOutput_dt; + + // for all relevant output pixels + int ot, oh, ow; + for (ot = ostartT; ot < oendT; ++ot) { + int kT = end_index(ot, osizeT, isizeT) - start_index(ot, osizeT, isizeT); + for (oh = ostartH; oh < oendH; ++oh) { + int kH = end_index(oh, osizeH, isizeH) - start_index(oh, osizeH, isizeH); + for (ow = ostartW; ow < oendW; ++ow) { + int kW = end_index(ow, osizeW, isizeW) - start_index(ow, osizeW, isizeW); + const accscalar_t divide_factor = kW * kH * kT; + accscalar_t grad_delta = static_cast(ptr_gradOutput[oh*osizeW + ow] / divide_factor); + *ptr_gradInput += static_cast(grad_delta); + } + } + ptr_gradOutput += osizeH*osizeW; // next output frame + } + } + } +} + +template +void adaptiveaveragegradinput_loop( + scalar_t *gradInput_data, const scalar_t *gradOutput_data, + int64_t totalZ, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW) { + int64_t offsetZ = 0; + dim3 threads(32, 8); + // each H*W plane is processed by blocksH thread blocks + int blocksH = std::max((int)(16L / totalZ), 1); + while (totalZ > 0) { + dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH); + adaptiveaveragegradinput + <<>>( + gradInput_data, gradOutput_data, + isizeT, isizeH, isizeW, + osizeT, osizeH, osizeW, + offsetZ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + totalZ -= 65535; + offsetZ += 65535; + } +} + +/* + * Description: + * This function computes the gradInput from gradOutput. + * + * gridDim.y blocks work together on a single 2D output plane specified by + * (blockIdx.x + offsetZ). + * + * (uses atomic add) + * + */ +template +__global__ void atomicadaptiveaveragegradinput( + scalar_t *gradInput, const scalar_t *gradOutput, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t offsetZ) +{ + // iterators on output pixels + int ot, oh, ow; + + // compute offsets based on thread/block ID + int ostartH = blockIdx.y * blockDim.y + threadIdx.y; + int oendH = osizeH; + int ostepH = gridDim.y * blockDim.y; + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + // select output plane + int64_t o_plane = blockIdx.x + offsetZ; + ot = o_plane % osizeT; // output frame/time + int d = o_plane / osizeT; // output slice/feature + + // input frame/time range is fixed. + int istartT = start_index(ot, osizeT, isizeT); + int iendT = end_index(ot, osizeT, isizeT); + int kT = iendT - istartT; + + // gradInput offset by slice/feature and earliest relevant frame/time + scalar_t *gradInput_nt = gradInput + (d*isizeT + istartT)*isizeH*isizeW; + // gradOutput offset by slice/feature and frame/time + const scalar_t *gradOutput_nt = gradOutput + o_plane*osizeH*osizeW; + + // For all output pixels... + for (oh = ostartH; oh < oendH; oh += ostepH) { + int istartH = start_index(oh, osizeH, isizeH); + int iendH = end_index(oh, osizeH, isizeH); + int kH = iendH - istartH; + + for (ow = ostartW; ow < oendW; ow += ostepW) { + int istartW = start_index(ow, osizeW, isizeW); + int iendW = end_index(ow, osizeW, isizeW); + int kW = iendW - istartW; + + // Compute the gradients from corresponding input pixels + scalar_t *ptr_gradInput = gradInput_nt + istartH*isizeW + istartW; + const scalar_t *ptr_gradOutput = gradOutput_nt + oh*osizeW + ow; + scalar_t grad_delta = *ptr_gradOutput / kT / kH / kW; + + int it, ih, iw; + for (it = 0; it < kT; ++it) { + for (ih = 0; ih < kH; ++ih) { + for (iw = 0; iw < kW; ++iw) { + gpuAtomicAddNoReturn(&(ptr_gradInput[ih*isizeW + iw]), grad_delta); + } + } + ptr_gradInput += isizeH*isizeW; // next input frame + } + } + } +} + +template +void atomicadaptiveaveragegradinput_loop( + scalar_t* gradInput_data, const scalar_t* gradOutput_data, + int64_t totalZ, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW) { + int64_t offsetZ = 0; + dim3 threads(32, 8); + int blocksH = std::max((int)(16L / totalZ), 1); + while (totalZ > 0) { + dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH); + atomicadaptiveaveragegradinput<<>>( + gradInput_data, gradOutput_data, + isizeT, isizeH, isizeW, + osizeT, osizeH, osizeW, + offsetZ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + totalZ -= 65535; + offsetZ += 65535; + } +} + +// 5D tensor B x D x T x H x w + +void adaptive_avg_pool3d_out_zoom_template( + Tensor& output, + const Tensor& input_, + IntArrayRef& output_size) { + TensorArg output_arg{output, "output", 1}; + TensorArg input_arg{input_, "input_", 2}; + + checkAllSameGPU("adaptive_avg_pool3d_zoom", {output_arg, input_arg}); + + for (int64_t i = 1; i < input_.ndimension(); i++) { + TORCH_CHECK( + input_.size(i) > 0, + "adaptive_avg_pool3d_zoom(): Expected input to have non-zero size for non-batch dimensions, " + "but input has sizes ", input_.sizes(), + " with dimension ", i, " being empty"); + } + + TORCH_CHECK( + (input_.ndimension() == 4 || input_.ndimension() == 5), + "adaptive_avg_pool3d_zoom(): Expected 4D or 5D tensor, but got ", input_.sizes()); + + // the jit sometimes passes output_size.size() == 1 + TORCH_CHECK( + output_size.size() == 1 || output_size.size() == 3, + "adaptive_avg_pool3d: internal error: output_size.size() must be 1 or 3"); + + int64_t osizeT = output_size[0]; + int64_t osizeH = output_size[1]; + int64_t osizeW = output_size[2]; + + int64_t sizeD, isizeT, isizeH, isizeW; + int64_t istrideD, istrideT, istrideH, istrideW; + int64_t totalZ; + + const Tensor& input = input_.ndimension() == 4 ? input_ : input_.contiguous(); + + if (input.ndimension() == 4) { + sizeD = input.size(0); + isizeT = input.size(1); + isizeH = input.size(2); + isizeW = input.size(3); + + istrideD = input.stride(0); + istrideT = input.stride(1); + istrideH = input.stride(2); + istrideW = input.stride(3); + + output.resize_({sizeD, osizeT, osizeH, osizeW}); + + totalZ = sizeD * osizeT; + } else { + int64_t sizeB = input.size(0); + sizeD = input.size(1); + isizeT = input.size(2); + isizeH = input.size(3); + isizeW = input.size(4); + + istrideD = input.stride(1); + istrideT = input.stride(2); + istrideH = input.stride(3); + istrideW = input.stride(4); + + output.resize_({sizeB, sizeD, osizeT, osizeH, osizeW}); + + totalZ = sizeB * sizeD * osizeT; + } + + if (output.numel() == 0) { + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), "adaptive_avg_pool3d_zoom", [&] { + using accscalar_t = at::acc_type; + const scalar_t* input_data = input.const_data_ptr(); + scalar_t* output_data = output.mutable_data_ptr(); + + adaptiveaveragepool_loop( + input_data, output_data, + totalZ, + isizeT, isizeH, isizeW, + osizeT, osizeH, osizeW, + istrideD, istrideT, istrideH, istrideW); + }); +} + +void adaptive_avg_pool3d_backward_out_zoom_template( + Tensor& gradInput, + const Tensor& gradOutput_, + const Tensor& input) { + TensorArg grad_input_arg{gradInput, "gradInput", 1}; + TensorArg grad_output_arg{gradOutput_, "gradOutput_", 2}; + TensorArg input_arg{input, "input", 3}; + + adaptive_pool_empty_output_check(gradOutput_, "adaptive_avg_pool3d_backward"); + + checkAllSameGPU( + "adaptive_avg_pool3d_out_zoom", + {grad_input_arg, grad_output_arg, input_arg}); + + const Tensor gradOutput = gradOutput_.contiguous(); + + gradInput.resize_as_(input); + if (gradInput.numel() == 0) { + return; + } + + gradInput.zero_(); + + int64_t sizeD, isizeT, isizeH, isizeW; + int64_t osizeT, osizeH, osizeW; + int64_t totalZ; + + if (input.ndimension() == 4) { + sizeD = input.size(0); + isizeT = input.size(1); + isizeH = input.size(2); + isizeW = input.size(3); + + osizeT = gradOutput.size(1); + osizeH = gradOutput.size(2); + osizeW = gradOutput.size(3); + } else { + sizeD = input.size(1); + isizeT = input.size(2); + isizeH = input.size(3); + isizeW = input.size(4); + + osizeT = gradOutput.size(2); + osizeH = gradOutput.size(3); + osizeW = gradOutput.size(4); + } + + bool atomic = (isizeW%osizeW != 0) || (isizeH%osizeH != 0) || (isizeT%osizeT != 0); + + if (input.ndimension() == 4) { + totalZ = atomic ? sizeD * osizeT : sizeD * isizeT; + } else { + int sizeB = input.size(0); + totalZ = atomic ? sizeB * sizeD * osizeT : sizeB * sizeD * isizeT; + } + + if (atomic) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), "adaptive_avg_pool3d_backward_zoom", [&] { + scalar_t* gradInput_data = gradInput.mutable_data_ptr(); + const scalar_t* gradOutput_data = gradOutput.const_data_ptr(); + + atomicadaptiveaveragegradinput_loop( + gradInput_data, gradOutput_data, + totalZ, + isizeT, isizeH, isizeW, + osizeT, osizeH, osizeW); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), "adaptive_avg_pool3d_backward_zoom", [&] { + using accscalar_t = at::acc_type; + + scalar_t* gradInput_data = gradInput.mutable_data_ptr(); + const scalar_t* gradOutput_data = gradOutput.const_data_ptr(); + + adaptiveaveragegradinput_loop( + gradInput_data, gradOutput_data, + totalZ, + isizeT, isizeH, isizeW, + osizeT, osizeH, osizeW); + }); + } +} + +} // namespace + +Tensor& adaptive_avg_pool3d_out_zoom(const Tensor& input, + IntArrayRef output_size, + Tensor& output) { + adaptive_avg_pool3d_out_zoom_template(output, input, output_size); + return output; +} + +Tensor adaptive_avg_pool3d_zoom( + const Tensor& input, + IntArrayRef output_size) { + auto output = at::empty({0}, input.options()); + adaptive_avg_pool3d_out_zoom_template(output, input, output_size); + return output; +} + +Tensor& adaptive_avg_pool3d_backward_out_zoom(const Tensor& gradOutput_, + const Tensor& input, + Tensor& gradInput) { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("adaptive_avg_pool3d_backward_out_zoom"); + adaptive_avg_pool3d_backward_out_zoom_template(gradInput, gradOutput_, input); + return gradInput; +} + +Tensor adaptive_avg_pool3d_backward_zoom( + const Tensor& gradOutput_, + const Tensor& input) { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("adaptive_avg_pool3d_backward_zoom"); + auto gradInput = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + adaptive_avg_pool3d_backward_out_zoom_template(gradInput, gradOutput_, input); + return gradInput; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/AdaptiveMaxPooling2d.cu b/aten/src/ATen/native/zoom/AdaptiveMaxPooling2d.cu new file mode 100644 index 0000000000000..737c63f1f3083 --- /dev/null +++ b/aten/src/ATen/native/zoom/AdaptiveMaxPooling2d.cu @@ -0,0 +1,478 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +#include +#include +#include + + +namespace at::native { + +namespace { + +__device__ inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (a / b) * c + ((a % b) * c) / b; +} + +__device__ inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return 1 + ((a + 1) * c - 1) / b; +} + +// 4d tensor B x D x H x W + +/* + * Description: + * this function adaptively maxpools an input 4D tensor along dimensions 2 and 3 + * 4D input, 4D output, 4D argmax x and y + */ + template +__global__ void adaptivemaxpool(const T *input, T *output, int64_t *indices, + int isizeH, int isizeW, + int osizeH, int osizeW, + int64_t istrideD, int64_t istrideH, int64_t istrideW) +{ + // iterators + int oh, ow; + + // compute offsets based on thread/block ID + int o_plane = blockIdx.x; + int i_plane = o_plane; + + int ostartW = threadIdx.x; + int oendW = osizeW; + const int ostepW = blockDim.x; + + int ostartH = blockDim.y*blockIdx.y + threadIdx.y; + int oendH = osizeH; + const int ostepH = blockDim.y*gridDim.y; + // select input/output plane + output = output + o_plane*osizeH*osizeW; + input = input + i_plane*istrideD; + indices = indices + o_plane*osizeH*osizeW; + + // For all output pixels... + for(oh = ostartH; oh < oendH; oh += ostepH) { + + int istartH = start_index(oh, osizeH, isizeH); + int iendH = end_index(oh, osizeH, isizeH); + int kH = iendH - istartH; + + for(ow = ostartW; ow < oendW; ow += ostepW) { + int istartW = start_index(ow, osizeW, isizeW); + int iendW = end_index(ow, osizeW, isizeW); + + int kW = iendW - istartW; + + // Compute the mean of the input image... + const T *ptr_input = input + istartH*istrideH + istartW*istrideW; + T *ptr_output = output + oh*osizeW + ow; + int64_t *ptr_ind = indices + oh*osizeW + ow; + int argmax = istartH * isizeW + istartW; + T max = at::numeric_limits::lower_bound(); // -Infinity + int ih, iw; + for(ih = 0; ih < kH; ih++) { + for(iw = 0; iw < kW; iw++) { + T val = ptr_input[iw*istrideW]; + if ((val > max) || at::_isnan(val)) { + max = val; + argmax = (ih+istartH)*isizeW + iw+istartW; + } + } + ptr_input += istrideH; // next input line + } + // Update output and argmax + *ptr_output = max; + *ptr_ind = argmax; + } + } +} + +/* + * Description: + * this function computes the gradInput from weight and gradOutput + */ + template +__global__ void adaptivemaxgradinput(T *gradInput, const T *gradOutput, const int64_t *indices, + int isizeH, int isizeW, + int osizeH, int osizeW) +{ + // iterators + int oh, ow; + + // compute offsets based on thread/block ID + int o_plane = blockIdx.x; + int i_plane = o_plane; + //int k = blockIdx.x % sizeD; + + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + int ostartH = blockDim.y*blockIdx.y + threadIdx.y; + int oendH = osizeH; + int ostepH = blockDim.y*gridDim.y; + + // select input/output plane + gradOutput = gradOutput + o_plane*osizeH*osizeW; + gradInput = gradInput + i_plane*isizeH*isizeW; + indices = indices + o_plane*osizeH*osizeW; + + // compute gradInput + for(oh = ostartH; oh < oendH; oh += ostepH) { + + for(ow = ostartW; ow < oendW; ow += ostepW) { + + const T *ptr_gradOutput = gradOutput + oh*osizeW + ow; + const int64_t *ptr_ind = indices + oh*osizeW + ow; + T z = *ptr_gradOutput; + + int argmax = (*ptr_ind); + + gradInput[argmax] += z; + } + } +} + +/* + * Description: + * this function computes the gradInput from weight and gradOutput + * when kH != dH or kW != dW (uses atomic add) + */ + template +__global__ void atomicadaptivemaxgradinput( + T *gradInput, const T *gradOutput, const int64_t *indices, + int isizeH, int isizeW, int osizeH, int osizeW +) +{ + // iterators + int oh, ow; + + // compute offsets based on thread/block ID + int o_plane = blockIdx.x; + int i_plane = o_plane; + + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + int ostartH = blockDim.y*blockIdx.y + threadIdx.y; + int oendH = osizeH; + int ostepH = blockDim.y*gridDim.y; + + // select input/output plane + gradOutput = gradOutput + o_plane*osizeH*osizeW; + gradInput = gradInput + i_plane*isizeH*isizeW; + indices = indices + o_plane*osizeH*osizeW; + + // compute gradInput + for(oh = ostartH; oh < oendH; oh += ostepH) { + + for(ow = ostartW; ow < oendW; ow += ostepW) { + + const T *ptr_gradOutput = gradOutput + oh*osizeW + ow; + const int64_t *ptr_ind = indices + oh*osizeW + ow; + T z = *ptr_gradOutput; + + int argmax = (*ptr_ind); + + // atomic add since different threads could update same variable + gpuAtomicAddNoReturn(&(gradInput[argmax]), z); + } + } +} +} // namespace + +// 4d tensor B x D x H x W + +TORCH_IMPL_FUNC(adaptive_max_pool2d_out_zoom) +(const Tensor& input, +IntArrayRef output_size, +const Tensor& output, +const Tensor& indices) { + TensorArg output_arg{output, "output", 1}; + TensorArg indices_arg{indices, "indices", 2}; + TensorArg input_arg{input, "input", 3}; + + checkAllSameGPU( + __func__, {output_arg, indices_arg, input_arg}); + if (input.numel() == 0) { + return; + } + + int64_t osizeH = output_size[0]; + int64_t osizeW = output_size[1]; + + const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options()); + const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options()); + + if (input.ndimension() == 3) { + int64_t sizeD = input.size(0); + int64_t isizeH = input.size(1); + int64_t isizeW = input.size(2); + + int64_t istrideD = input.stride(0); + int64_t istrideH = input.stride(1); + int64_t istrideW = input.stride(2); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_zoom", [&] { + const scalar_t* input_data = input.const_data_ptr(); + scalar_t* output_data = output_c.mutable_data_ptr(); + int64_t* indices_data = indices_c.mutable_data_ptr(); + + // cuda blocks & threads: + int blocksH = (int)(16L / sizeD); + blocksH = blocksH < 1 ? 1 : blocksH; + dim3 blocks(sizeD, blocksH); + dim3 threads(32, 8); + + // run maxpool kernel + adaptivemaxpool<<< + blocks, + threads, + 0, + c10::zoom::getCurrentZoomStream()>>>( + input_data, + output_data, + indices_data, + isizeH, + isizeW, + osizeH, + osizeW, + istrideD, + istrideH, + istrideW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } else { + Tensor input_ = input.contiguous(); + int64_t sizeB = input_.size(0); + int64_t sizeD = input_.size(1); + int64_t isizeH = input_.size(2); + int64_t isizeW = input_.size(3); + + // In the kernel, the batch and channel dimensions are treated as if they + // are flattened and istrideD is used as the stride of this flattened dim + // Handle the edge case where input_.size(1) == 1, where despite passing the + // contiguity check the stride might not be H * W + int64_t istrideD = isizeH * isizeW; + int64_t istrideH = input_.stride(2); + int64_t istrideW = input_.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input_.scalar_type(), + "adaptive_max_pool2d_zoom", + [&] { + const scalar_t* input_data = input_.const_data_ptr(); + scalar_t* output_data = output_c.mutable_data_ptr(); + int64_t* indices_data = indices_c.mutable_data_ptr(); + + // cuda blocks & threads: + int blocksH = (int)(16L / sizeD); + blocksH = blocksH < 1 ? 1 : blocksH; + dim3 blocks(sizeB * sizeD, blocksH); + dim3 threads(32, 8); + + // run maxpool kernel + adaptivemaxpool<<< + blocks, + threads, + 0, + c10::zoom::getCurrentZoomStream()>>>( + input_data, + output_data, + indices_data, + isizeH, + isizeW, + osizeH, + osizeW, + istrideD, + istrideH, + istrideW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } + + if (!output.is_contiguous()) { + output.copy_(output_c); + } + if (!indices.is_contiguous()) { + indices.copy_(indices_c); + } +} + +TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_zoom) +(const Tensor& gradOutput, + const Tensor& input, + const Tensor& indices, + const Tensor& gradInput) { + globalContext().alertNotDeterministic( + "adaptive_max_pool2d_backward_zoom"); + + TensorArg grad_input_arg{gradInput, "gradInput", 1}; + TensorArg grad_output_arg{gradOutput, "gradOutput", 2}; + TensorArg input_arg{input, "input", 3}; + TensorArg indices_arg{indices, "indices", 4}; + + checkAllSameGPU( + __func__, + {grad_input_arg, grad_output_arg, input_arg, indices_arg}); + + if (gradOutput.numel() == 0) { + return; + } + + bool atomic = + true; // suboptimal, but without atomic it doesn't pass the tests + + const at::Tensor gradOutput_ = gradOutput.contiguous(); + const at::Tensor indices_ = indices.contiguous(); + const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options()); + + if (input.ndimension() == 3) { + int64_t sizeD = input.size(0); + int64_t isizeH = input.size(1); + int64_t isizeW = input.size(2); + + int64_t osizeH = gradOutput_.size(1); + int64_t osizeW = gradOutput_.size(2); + + // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0); + + gradInput_c.zero_(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "adaptive_max_pool2d_backward_zoom", + [&] { + scalar_t* gradInput_data = gradInput_c.mutable_data_ptr(); + const scalar_t* gradOutput_data = gradOutput_.const_data_ptr(); + const int64_t* indices_data = indices_.const_data_ptr(); + + // cuda blocks & threads: + int blocksH = (int)(16L / sizeD); + blocksH = blocksH < 1 ? 1 : blocksH; + dim3 blocks(sizeD, blocksH); + dim3 threads(32, 8); + + if (atomic) { + // run updateGradInput kernel, accumulate gradients atomically + atomicadaptivemaxgradinput<<< + blocks, + threads, + 0, + c10::zoom::getCurrentZoomStream()>>>( + gradInput_data, + gradOutput_data, + indices_data, + isizeH, + isizeW, + osizeH, + osizeW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + // run updateGradInput kernel + atomicadaptivemaxgradinput<<< + blocks, + threads, + 0, + c10::zoom::getCurrentZoomStream()>>>( + gradInput_data, + gradOutput_data, + indices_data, + isizeH, + isizeW, + osizeH, + osizeW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); + } else { + int64_t sizeB = input.size(0); + int64_t sizeD = input.size(1); + int64_t isizeH = input.size(2); + int64_t isizeW = input.size(3); + + int64_t osizeH = gradOutput_.size(2); + int64_t osizeW = gradOutput_.size(3); + + gradInput_c.zero_(); + + // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "adaptive_max_pool2d_backward_zoom", + [&] { + scalar_t* gradInput_data = gradInput_c.mutable_data_ptr(); + const scalar_t* gradOutput_data = gradOutput_.const_data_ptr(); + const int64_t* indices_data = indices_.const_data_ptr(); + + // cuda blocks & threads: + int blocksH = (int)(16L / sizeD); + blocksH = blocksH < 1 ? 1 : blocksH; + dim3 blocks(sizeB * sizeD, blocksH); + dim3 threads(32, 8); + + if (atomic) { + // run updateGradInput kernel, accumulate gradients atomically + atomicadaptivemaxgradinput<<< + blocks, + threads, + 0, + c10::zoom::getCurrentZoomStream()>>>( + gradInput_data, + gradOutput_data, + indices_data, + isizeH, + isizeW, + osizeH, + osizeW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + // run updateGradInput kernel, accumulate gradients atomically + adaptivemaxgradinput<<< + blocks, + threads, + 0, + c10::zoom::getCurrentZoomStream()>>>( + gradInput_data, + gradOutput_data, + indices_data, + isizeH, + isizeW, + osizeH, + osizeW); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); + } + + if (!gradInput.is_contiguous()) { + gradInput.copy_(gradInput_c); + } + } +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/AdaptiveMaxPooling3d.cu b/aten/src/ATen/native/zoom/AdaptiveMaxPooling3d.cu new file mode 100644 index 0000000000000..022053ced042a --- /dev/null +++ b/aten/src/ATen/native/zoom/AdaptiveMaxPooling3d.cu @@ -0,0 +1,488 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +#include +#include +#include + + +namespace at::native { + +namespace { + +__device__ inline int64_t start_index(int64_t a, int64_t b, int64_t c) { + return (a / b) * c + ((a % b) * c) / b; +} + +__device__ inline int64_t end_index(int64_t a, int64_t b, int64_t c) { + return 1 + ((a + 1) * c - 1) / b; +} + +// 5d tensor B x D x T x H x W + +/* + * Description: + * this function adaptively maxpools an input 4D tensor along dimensions 2 and 3 + * 4D input, 4D output, 4D argmax x and y + */ + template +__global__ void adaptivemaxpool( + const T *input, T *output, int64_t *indices, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t istrideD, + int64_t istrideT, int64_t istrideH, int64_t istrideW, + int64_t offsetZ) +{ + // iterators on output pixels + int ot, oh, ow; + + // compute offsets based on thread/block ID + int ostartH = blockIdx.y * blockDim.y + threadIdx.y; + int oendH = osizeH; + int ostepH = gridDim.y * blockDim.y; + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + // select output plane + int64_t o_plane = blockIdx.x + offsetZ; + ot = o_plane % osizeT; // output frame/time + int d = o_plane / osizeT; // slice/feature + + // input frame/time ramge is fixed. + int istartT = start_index(ot, osizeT, isizeT); + int iendT = end_index(ot, osizeT, isizeT); + int kT = iendT - istartT; + + // input offset by slice/feature and earliest relevant frame/time + const T *input_dt = input + d*istrideD + istartT*istrideT; + // output offset by slice/feature and frame/time + T *output_dt = output + o_plane*osizeH*osizeW; + // indices offset by slice/feature and frame/time + int64_t *indices_dt = indices + o_plane*osizeH*osizeW; + + // For all output pixels... + for(oh = ostartH; oh < oendH; oh += ostepH) { + + int istartH = start_index(oh, osizeH, isizeH); + int iendH = end_index(oh, osizeH, isizeH); + int kH = iendH - istartH; + + for(ow = ostartW; ow < oendW; ow += ostepW) { + + int istartW = start_index(ow, osizeW, isizeW); + int iendW = end_index(ow, osizeW, isizeW); + int kW = iendW - istartW; + + // Compute the average pooling from corresponding input pixels + const T *ptr_input = input_dt + istartH*istrideH + istartW*istrideW; + T *ptr_output = output_dt + oh*osizeW + ow; + int64_t *ptr_ind = indices_dt + oh*osizeW + ow; + int64_t argmax = istartT*isizeH*isizeW + istartH*isizeW + istartW; + T max = at::numeric_limits::lower_bound(); // -Infinity + + int it, ih, iw; + for(it = 0; it < kT; ++it) { + for(ih = 0; ih < kH; ++ih) { + for(iw = 0; iw < kW; ++iw) { + T val = ptr_input[ih*istrideH + iw*istrideW]; + if ((val > max) || at::_isnan(val)) { + max = val; + argmax = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + iw+istartW; + } + } + } + ptr_input += istrideT; // next input frame + } + // Update output and argmax + *ptr_output = max; + *ptr_ind = argmax; + } + } +} + +template +void adaptivemaxpool_loop( + const scalar_t *input_data, + scalar_t *output_data, + int64_t *indices_data, + int64_t totalZ, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t istrideD, + int64_t istrideT, int64_t istrideH, int64_t istrideW) +{ + int64_t offsetZ = 0; + dim3 threads(32, 8); + // each H*W plane is processed by blocksH thread blocks + int blocksH = std::max((int)(16L / totalZ), 1); + while (totalZ > 0) { + dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH); + adaptivemaxpool<<>>( + input_data, output_data, indices_data, isizeT, isizeH, isizeW, + osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW, offsetZ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + totalZ -= 65535; + offsetZ += 65535; + } +} + +/* + * Description: + * This function computes the gradInput from gradOutput. + * + * gridDim.y blocks work together on a single 2D output plane specified by + * (blockIdx.x + offsetZ). + * + * Assumes that input size can be perfectly divided by output size, i.e. + * each input pixel can only be argmax of one output pixel. + */ + template +__global__ void adaptivemaxgradinput( + T *gradInput, const T *gradOutput, const int64_t *indices, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t offsetZ +) +{ + // iterators on output pixels + int oh, ow; + + // compute offsets based on thread/block ID + int ostartH = blockIdx.y * blockDim.y + threadIdx.y; + int oendH = osizeH; + int ostepH = gridDim.y * blockDim.y; + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + // select output plane + int64_t o_plane = blockIdx.x + offsetZ; + int d = o_plane / osizeT; // output slice/feature + + // gradInput offset by slice/feature + T *gradInput_d = gradInput + d*isizeT*isizeH*isizeW; + // gradOutput offset by slice/feature and frame/otme + const T *gradOutput_dt = gradOutput + o_plane*osizeH*osizeW; + // indices offset by slice/feature and frame/otme + const int64_t *indices_dt = indices + o_plane*osizeH*osizeW; + + // For all output pixels... + for(oh = ostartH; oh < oendH; oh += ostepH) { + for(ow = ostartW; ow < oendW; ow += ostepW) { + // Compute the gradients for the argmax input pixel + const T *ptr_gradOutput = gradOutput_dt + oh*osizeW + ow; + const int64_t *ptr_ind = indices_dt + oh*osizeW + ow; + T grad_delta = *ptr_gradOutput; + int argmax = (*ptr_ind); + gradInput_d[argmax] += grad_delta; + } + } +} + +template +void adaptivemaxgradinput_loop( + scalar_t *gradInput_data, + const scalar_t *gradOutput_data, + const int64_t *indices_data, + int64_t totalZ, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW) +{ + int64_t offsetZ = 0; + dim3 threads(32, 8); + // each H*W plane is processed by blocksH thread blocks + int blocksH = std::max((int)(16L / totalZ), 1); + while (totalZ > 0) { + dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH); + adaptivemaxgradinput<<>>( + gradInput_data, gradOutput_data, indices_data, + isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + totalZ -= 65535; + offsetZ += 65535; + } +} + +/* + * Description: + * This function computes the gradInput from gradOutput. + * + * gridDim.y blocks work together on a single 2D output plane specified by + * (blockIdx.x + offsetZ). + * + * Uses atomic add. + */ + template +__global__ void atomicadaptivemaxgradinput( + T *gradInput, const T *gradOutput, const int64_t *indices, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW, + int64_t offsetZ +) +{ + // iterators on output pixels + int oh, ow; + + // compute offsets based on thread/block ID + int ostartH = blockIdx.y * blockDim.y + threadIdx.y; + int oendH = osizeH; + int ostepH = gridDim.y * blockDim.y; + int ostartW = threadIdx.x; + int oendW = osizeW; + int ostepW = blockDim.x; + + // select output plane + int64_t o_plane = blockIdx.x + offsetZ; + int d = o_plane / osizeT; // output slice/feature + + // gradInput offset by slice/feature + T *gradInput_d = gradInput + d*isizeT*isizeH*isizeW; + // gradOutput offset by slice/feature and frame/otme + const T *gradOutput_dt = gradOutput + o_plane*osizeH*osizeW; + // indices offset by slice/feature and frame/otme + const int64_t *indices_dt = indices + o_plane*osizeH*osizeW; + + // For all output pixels... + for(oh = ostartH; oh < oendH; oh += ostepH) { + for(ow = ostartW; ow < oendW; ow += ostepW) { + // Compute the gradients for the argmax input pixel + const T *ptr_gradOutput = gradOutput_dt + oh*osizeW + ow; + const int64_t *ptr_ind = indices_dt + oh*osizeW + ow; + T grad_delta = *ptr_gradOutput; + int64_t argmax = (*ptr_ind); + gpuAtomicAddNoReturn(&(gradInput_d[argmax]), grad_delta); + } + } +} + +template +void atomicadaptivemaxgradinput_loop( + scalar_t *gradInput_data, + const scalar_t *gradOutput_data, + const int64_t *indices_data, + int64_t totalZ, + int isizeT, int isizeH, int isizeW, + int osizeT, int osizeH, int osizeW) +{ + int64_t offsetZ = 0; + dim3 threads(32, 8); + // each H*W plane is processed by blocksH thread blocks + int blocksH = std::max((int)(16L / totalZ), 1); + while (totalZ > 0) { + dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH); + atomicadaptivemaxgradinput<<>>( + gradInput_data, gradOutput_data, indices_data, + isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + totalZ -= 65535; + offsetZ += 65535; + } +} +} // namespace + +// 5d tensor B x D x T x H x W + +TORCH_IMPL_FUNC(adaptive_max_pool3d_out_zoom) +(const Tensor& input, + IntArrayRef output_size, + const Tensor& output, + const Tensor& indices) { + TensorArg output_arg{output, "output", 1}; + TensorArg indices_arg{indices, "indices", 2}; + TensorArg input_arg{input, "input", 3}; + + checkAllSameGPU( + __func__, {output_arg, indices_arg, input_arg}); + if (input.numel() == 0 || output.numel() == 0) { + return; + } + + int64_t osizeT = output_size[0]; + int64_t osizeH = output_size[1]; + int64_t osizeW = output_size[2]; + + int64_t sizeD, isizeT, isizeH, isizeW; + int64_t istrideD, istrideT, istrideH, istrideW; + int64_t totalZ; + + const Tensor& input_ = input.ndimension() == 4 ? input : input.contiguous(); + + if (input_.ndimension() == 4) { + sizeD = input_.size(0); + isizeT = input_.size(1); + isizeH = input_.size(2); + isizeW = input_.size(3); + + istrideD = input_.stride(0); + istrideT = input_.stride(1); + istrideH = input_.stride(2); + istrideW = input_.stride(3); + + totalZ = sizeD * osizeT; + } else { + int64_t sizeB = input_.size(0); + sizeD = input_.size(1); + isizeT = input_.size(2); + isizeH = input_.size(3); + isizeW = input_.size(4); + + // In the kernel, the batch and channel dimensions are treated as if they + // are flattened and istrideD is used as the stride of this flattened dim + // Handle the edge case where input_.size(1) == 1, where despite passing the + // contiguity check the stride might not be T * H * W + istrideD = isizeT * isizeH * isizeW; + istrideT = input_.stride(2); + istrideH = input_.stride(3); + istrideW = input_.stride(4); + + totalZ = sizeB * sizeD * osizeT; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, input_.scalar_type(), "adaptive_max_pool3d_zoom", [&] { + const scalar_t* input_data = input_.const_data_ptr(); + scalar_t* output_data = output.mutable_data_ptr(); + int64_t* indices_data = indices.mutable_data_ptr(); + + adaptivemaxpool_loop( + input_data, + output_data, + indices_data, + totalZ, + isizeT, + isizeH, + isizeW, + osizeT, + osizeH, + osizeW, + istrideD, + istrideT, + istrideH, + istrideW); + }); +} + +TORCH_IMPL_FUNC(adaptive_max_pool3d_backward_out_zoom) +(const Tensor& gradOutput, + const Tensor& input, + const Tensor& indices, + const Tensor& gradInput) { + TensorArg grad_input_arg{gradInput, "gradInput", 1}; + TensorArg grad_output_arg{gradOutput, "gradOutput", 2}; + TensorArg input_arg{input, "input", 3}; + TensorArg indices_arg{indices, "indices", 4}; + + checkAllSameGPU( + __func__, + {grad_input_arg, grad_output_arg, input_arg, indices_arg}); + if (gradOutput.numel() == 0) { + return; + } + + const Tensor gradOutput_ = gradOutput.contiguous(); + + gradInput.zero_(); + + int64_t sizeD, isizeT, isizeH, isizeW; + int64_t osizeT, osizeH, osizeW; + int64_t totalZ; + + if (input.ndimension() == 4) { + sizeD = input.size(0); + isizeT = input.size(1); + isizeH = input.size(2); + isizeW = input.size(3); + + osizeT = gradOutput_.size(1); + osizeH = gradOutput_.size(2); + osizeW = gradOutput_.size(3); + } else { + sizeD = input.size(1); + isizeT = input.size(2); + isizeH = input.size(3); + isizeW = input.size(4); + + osizeT = gradOutput_.size(2); + osizeH = gradOutput_.size(3); + osizeW = gradOutput_.size(4); + } + + bool atomic = (isizeW % osizeW != 0) || (isizeH % osizeH != 0) || + (isizeT % osizeT != 0); + + if (input.ndimension() == 4) { + totalZ = sizeD * osizeT; + } else { + int sizeB = input.size(0); + totalZ = sizeB * sizeD * osizeT; + } + + if (atomic) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "adaptive_max_pool3d_backward_zoom", + [&] { + scalar_t* gradInput_data = gradInput.mutable_data_ptr(); + const scalar_t* gradOutput_data = gradOutput_.const_data_ptr(); + const int64_t* indices_data = indices.const_data_ptr(); + + atomicadaptivemaxgradinput_loop( + gradInput_data, + gradOutput_data, + indices_data, + totalZ, + isizeT, + isizeH, + isizeW, + osizeT, + osizeH, + osizeW); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "adaptive_max_pool3d_backward_zoom", + [&] { + scalar_t* gradInput_data = gradInput.mutable_data_ptr(); + const scalar_t* gradOutput_data = gradOutput_.const_data_ptr(); + const int64_t* indices_data = indices.const_data_ptr(); + + adaptivemaxgradinput_loop( + gradInput_data, + gradOutput_data, + indices_data, + totalZ, + isizeT, + isizeH, + isizeW, + osizeT, + osizeH, + osizeW); + }); + } + } +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/AmpKernels.cu b/aten/src/ATen/native/zoom/AmpKernels.cu new file mode 100644 index 0000000000000..14fa799fd6d28 --- /dev/null +++ b/aten/src/ATen/native/zoom/AmpKernels.cu @@ -0,0 +1,252 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace { +// Thin wrapper around https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE_1g57a3c8313f570282a1a7bcc78743b08e, +// to ensure the Cuda math library's isfinite is actually what gets called in +// _amp_non_finite_check_and_unscale_cuda_'s gpu_kernel lambda. +// +// isfinite_ensure_cuda_math is defined outside at::native because: +// - A bare call to "isfinite(val)" inside at::native causes nvcc to prefer the unrelated +// Tensor at::native::isfinite(const Tensor&), resulting in an error: +// "no suitable constructor exists to convert from "float" to "at::Tensor"" +// - Unfortunately, the Cuda math library documentation doesn't say how (or if) you can provide a full namespace path +// to ensure that its version of a particular function is invoked. It only shows bare (not-namespaced) +// calls to its routines inside kernel or device functions. +// - "std::isfinite(val)" in the gpu_kernel lambda causes an "unspecified launch failure" at runtime with cuda 9 on Windows. +// +// isfinite_ensure_cuda_math, declared at file scope outside the at::native region, uses isfinite as math library docs +// suggest and allows disambiguated usage in the lambda within the at::native region. +// GPU_LAMBDA is defined as __host__ __device__ (see Loops.cuh), so I need the __host__ keyword or else nvcc complains that +// "calling a __device__ function("isfinite_ensure_cuda_math") from a __host__ __device__ function("operator()") is not allowed." +static __host__ __device__ __forceinline__ int isfinite_ensure_zoom_math(float val) { + return isfinite(val); +} +} + +namespace at::native { + +namespace { +// Single-tensor fallback for _amp_foreach_non_finite_check_and_unscale_zoom_. +// Handles individual tensors that are acceptable to unscale but not MTA-safe. +void _amp_non_finite_check_and_unscale_zoom_(Tensor& scaled_grad, + Tensor& found_inf, + const Tensor& inv_scale) +{ + // The only way we reach this function is through _amp_foreach_non_finite_check_and_unscale_zoom_, so no input checks. + + // It's not obvious gpu_kernel always guards onto its argument. Guarding here just in case. + const OptionalDeviceGuard device_guard(device_of(scaled_grad)); + + // Acts on scaled_grad in place. + auto iter = TensorIterator::unary_op(scaled_grad, scaled_grad); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + iter.dtype(), + "_amp_non_finite_check_and_unscale_zoom", + [&iter, &found_inf, &inv_scale] { + auto* found_inf_ptr = found_inf.mutable_data_ptr(); + auto* inv_scale_ptr = inv_scale.const_data_ptr(); + + using opmath_t = at::opmath_type; + + gpu_kernel(iter, + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (scalar_t val_in) -> scalar_t { + auto val = static_cast(val_in); + if (!isfinite_ensure_zoom_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); + }); +} +} // anonymous namespace + + +// Multiplies each tensor in scaled_grads by inv_scale in-place. +// If any element of any tensor in scaled_grads is inf or NaN, sets found_inf to 1.0. +// Uses multi tensor apply (MTA) to process all MTA-safe tensors. +// +// Args: +// scaled_grads: A TensorList of scaled gradient tensors. May contain infs or NaNs. +// found_inf: A single-element float tensor to which 1.0 will be written if any gradient contain infs/nans. +// Pre-zeroing found_inf, if appropriate, is the responsibility of the caller. +// inv_scale: The inverse of the scale factor by which scaled_grads are currently multiplied. +void _amp_foreach_non_finite_check_and_unscale_zoom_(TensorList scaled_grads, + Tensor& found_inf, + const Tensor& inv_scale) +{ + if (scaled_grads.size() == 0) { + return; + } + + TORCH_CHECK(inv_scale.is_privateuseone(), "inv_scale must be a Zoom tensor."); + TORCH_CHECK(found_inf.is_privateuseone(), "found_inf must be a Zoom tensor."); + TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."); + TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); + TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor."); + TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); + + // Ensures client code (GradScaler) filtered scaled_grads by dtype. + check_foreach_api_restrictions(scaled_grads); + + std::vector> tensor_lists; + + // is_non_overlapping_and_dense() is not available in Python. + // GradScaler can't filter for it. We need to filter here. + if (can_use_fast_route(scaled_grads)) { + // Hopefully common case. + // can_use_fast_route is true, which confirms: + // - all scaled_grads are strided + // - all scaled_grads are non overlapping and dense + // - all scaled_grads are on the same device + // - all scaled_grads are of the same dtype + TORCH_CHECK(scaled_grads[0].is_privateuseone(), "scaled_grads must be Zoom tensors."); + // Sets up MTA launch to use scaled_grads as-is. + tensor_lists.emplace_back(scaled_grads.vec()); + } else { + // Hopefully uncommon case. + // can_use_fast_route is an all-or-nothing check. In this path it was false, + // so any of the above confirmations could have gone wrong. + // We filter MTA-safe tensors into an MTA-able list. + // If a tensor is acceptable but not MTA-safe, we fall back to the TensorIterator kernel. + // If a tensor is unacceptable, we throw an error to blame GradScaler. + tensor_lists.resize(1); + tensor_lists[0].reserve(scaled_grads.size()); + auto expected_device = scaled_grads[0].device(); + const auto expected_dtype = scaled_grads[0].scalar_type(); + for (const Tensor& t : scaled_grads) { + // Ensures GradScaler filtered scaled_grads by device. + TORCH_CHECK(t.is_privateuseone(), "one of scaled_grads was not a Zoom tensor."); + TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); + TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); + if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) { + // t is acceptable but not MTA-safe. Falls back to single-tensor TensorIterator kernel. + _amp_non_finite_check_and_unscale_zoom_(const_cast(t), + found_inf, + inv_scale); + } else { + tensor_lists[0].push_back(t); + } + } + if (tensor_lists[0].size() == 0) { + return; + } + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + tensor_lists[0][0].scalar_type(), + "_amp_foreach_non_finite_check_and_unscale_zoom", + [&tensor_lists, &found_inf, &inv_scale] { + auto* found_inf_ptr = found_inf.mutable_data_ptr(); + auto* inv_scale_ptr = inv_scale.const_data_ptr(); + + using opmath_t = at::opmath_type; + + // multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly. + multi_tensor_apply<1>(tensor_lists, + UnaryOpFunctor(), + [found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t { + // There is a slight asymmetry here with the TensorIterator kernel above. + // MTA Functors ensure val comes in as opmath_t rather than scalar_t. + if (!isfinite_ensure_zoom_math(val)) { + *found_inf_ptr = 1.f; + } + // Every thread accesses inv_scale, but it will hit in cache. + const auto inv_scale_val = *inv_scale_ptr; + return static_cast(inv_scale_val == 1.f ? val : val * inv_scale_val); + }); + }); +} + + +// amp_update_scale_zoom_kernel is launched with a single thread to compute the new scale. +// The scale factor is maintained and updated on the GPU to avoid synchronization. +__global__ void amp_update_scale_zoom_kernel(float* current_scale, + int* growth_tracker, + const float* found_inf, + double growth_factor, + double backoff_factor, + int growth_interval) +{ + if (*found_inf) { + *current_scale = (*current_scale)*backoff_factor; + *growth_tracker = 0; + } else { + // Entering this branch means we just carried out a successful step, + // so growth_tracker is incremented before comparing to growth_interval. + auto successful = (*growth_tracker) + 1; + if (successful == growth_interval) { + auto new_scale = static_cast((*current_scale)*growth_factor); + // Do not grow the scale past fp32 bounds to inf. + if (isfinite_ensure_zoom_math(new_scale)) { + *current_scale = new_scale; + } + *growth_tracker = 0; + } else { + *growth_tracker = successful; + } + } +} + + +// _amp_update_scale_zoom asynchronously updates the scale tensor in place. +// +// Args: +// current_scale: A one-element zoom float tensor containing the scale value. +// growth_tracker: A one-element torch.zoom.IntTensor containing the number of recent consecutive unskipped steps. +// found_inf: A one-element zoom float tensor. If > 0, indicates that infs/nans were found by the relevant +// prior _amp_non_finite_check_and_unscale_zoom call, and 0 if no infs/nans were found. +// growth_factor: Multiplier if no infs/NaNs were found (typically slightly > 1). +// backoff_factor: Multiplier if infs/NaNs were found (typically 0.5). +// growth_interval: Number of consecutive unskipped steps that must occur for current_scale to be multiplied by +// growth_factor. +// +// Returns: +// current_scale +Tensor& _amp_update_scale_zoom_(Tensor& current_scale, + Tensor& growth_tracker, + const Tensor& found_inf, + double growth_factor, + double backoff_factor, + int64_t growth_interval) +{ + TORCH_CHECK(growth_tracker.is_privateuseone(), "growth_tracker must be a Zoom tensor."); + TORCH_CHECK(current_scale.is_privateuseone(), "current_scale must be a Zoom tensor."); + TORCH_CHECK(found_inf.is_privateuseone(), "found_inf must be a Zoom tensor."); + TORCH_CHECK(growth_tracker.numel() == 1, "growth_tracker must be a 1-element tensor."); + TORCH_CHECK(current_scale.numel() == 1, "current_scale must be a 1-element tensor."); + TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); + TORCH_CHECK(growth_tracker.scalar_type() == at::ScalarType::Int, "growth_tracker must be an int tensor."); + TORCH_CHECK(current_scale.scalar_type() == at::ScalarType::Float, "current_scale must be a float tensor."); + TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); + + amp_update_scale_zoom_kernel<<<1, 1, 0, c10::zoom::getCurrentZoomStream()>>>( + current_scale.mutable_data_ptr(), + growth_tracker.mutable_data_ptr(), + found_inf.const_data_ptr(), + growth_factor, + backoff_factor, + growth_interval); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + return current_scale; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/AveragePool2d.cu b/aten/src/ATen/native/zoom/AveragePool2d.cu new file mode 100644 index 0000000000000..309be1fbb62d4 --- /dev/null +++ b/aten/src/ATen/native/zoom/AveragePool2d.cu @@ -0,0 +1,463 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native { +namespace { + +__device__ inline int min(int a, int b) { + return a <= b ? a : b; +} + +__device__ inline int max(int a, int b) { + return a >= b ? a : b; +} + +template +__global__ void avg_pool2d_out_zoom_frame(const int nthreads, + const scalar_t* const bottom_data, const int64_t channels, + const int64_t height, const int64_t width, const int64_t pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + scalar_t* const top_data, const int divisor_override, + const bool count_include_pad, const bool use_divisor) { + HIP_KERNEL_LOOP(index, nthreads) { + const int pw = index % pooled_width; + const int ph = (index / pooled_width) % pooled_height; + const int c = (index / pooled_width / pooled_height) % channels; + const int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + const int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + + if (hstart >= hend || wstart >= wend) { + top_data[index] = scalar_t(0); + continue; + } + + accscalar_t aveval = accscalar_t(0); + const scalar_t* const bottom_slice = bottom_data + (n * channels + c) * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += bottom_slice[h * width + w]; + } + } + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if(count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + top_data[index] = static_cast(aveval / divide_factor); + } +} + +template +__global__ void avg_pool2d_out_zoom_frame_nhwc(const int nthreads, + const scalar_t* const bottom_data, const int64_t channels, + const int64_t height, const int64_t width, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + scalar_t* const top_data, const int divisor_override, + const bool count_include_pad, const bool use_divisor) { + HIP_KERNEL_LOOP(index, nthreads) { + const int c = index % channels; + const int pw = (index / channels) % pooled_width; + const int ph = (index / channels / pooled_width) % pooled_height; + const int n = index / channels / pooled_width / pooled_height; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + const int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + + if (hstart >= hend || wstart >= wend) { + top_data[index] = scalar_t(0); + continue; + } + + accscalar_t aveval = accscalar_t(0); + const scalar_t* const bottom_slice = bottom_data + n * channels * height * width + c; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + aveval += bottom_slice[(h * width + w) * channels]; + } + } + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if(count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + top_data[index] = static_cast(aveval / divide_factor); + } +} + +template +__global__ void avg_pool2d_backward_out_zoom_frame(const index_t nthreads, const scalar_t* const top_diff, + const int64_t channels, const int64_t height, + const int64_t width, const int64_t pooled_height, const int64_t pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + scalar_t* const bottom_diff, const int divisor_override, + bool count_include_pad, bool use_divisor) { + HIP_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + const int w = index % width + pad_w; + const int h = (index / width) % height + pad_h; + const int c = (index / width / height) % channels; + const int n = index / width / height / channels; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + accscalar_t gradient = accscalar_t(0); + const scalar_t* const top_diff_slice = + top_diff + (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + + if (hstart >= hend || wstart >= wend) { + continue; + } + + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if(count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + gradient += top_diff_slice[ph * pooled_width + pw] / divide_factor; + } + } + bottom_diff[index] = static_cast(gradient); + } +} + +template +__global__ void avg_pool2d_backward_out_zoom_frame_nhwc(const index_t nthreads, + const scalar_t* const top_diff, + const int64_t channels, const int64_t height, + const int64_t width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + scalar_t* const bottom_diff, const int divisor_override, + bool count_include_pad, bool use_divisor) { + HIP_KERNEL_LOOP(index, nthreads) { + const int c = index % channels; + const int w = (index / channels) % width; + const int h = (index / channels / width) % height; + const int n = index / channels / width / height; + + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + accscalar_t gradient = accscalar_t(0); + const scalar_t* const top_diff_slice = top_diff + n * channels * pooled_height * pooled_width + c; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + kernel_h, height + pad_h); + int wend = min(wstart + kernel_w, width + pad_w); + int pool_size = (hend - hstart) * (wend - wstart); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + + if (hstart >= hend || wstart >= wend) { + continue; + } + + int divide_factor; + if (use_divisor) { + divide_factor = divisor_override; + } else { + if(count_include_pad) { + divide_factor = pool_size; + } else { + divide_factor = (hend - hstart) * (wend - wstart); + } + } + gradient += top_diff_slice[(ph * pooled_width + pw) * channels] / divide_factor; + } + } + bottom_diff[index] = static_cast(gradient); + } +} + +} // anonymous namespace + +TORCH_IMPL_FUNC(avg_pool2d_out_zoom) +(const Tensor& input_, + int64_t kH_, + int64_t kW_, + int64_t dH_, + int64_t dW_, + int64_t padH_, + int64_t padW_, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override, + const Tensor& output) { + TensorArg output_arg{ output, "output", 1 }; + TensorArg input_arg{ input_, "input_", 2 }; + + checkAllSameGPU("avg_pool2d_out_zoom", {output_arg, input_arg}); + + const int kH = safe_downcast(kH_); + const int kW = safe_downcast(kW_); + + const int dH = safe_downcast(dH_); + const int dW = safe_downcast(dW_); + + const int padH = safe_downcast(padH_); + const int padW = safe_downcast(padW_); + + /* sizes */ + const int64_t nInputPlane = input_.size(-3); + const int64_t inputHeight = input_.size(-2); + const int64_t inputWidth = input_.size(-1); + + int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + const auto memory_format = input_.suggest_memory_format(); + + Tensor input = input_.contiguous(memory_format); + + const auto count = safe_downcast(output.numel()); + const uint32_t num_threads = std::min(at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); + const uint32_t num_blocks = ceil_div(count, num_threads); + + bool use_divisor = divisor_override.has_value(); + const auto divisor_override_value = use_divisor ? divisor_override.value() : 0; + + if (count != 0) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "avg_pool2d_out_zoom_frame", + [&] { + using accscalar_t = acc_type; + + scalar_t *output_data = output.mutable_data_ptr(); + const scalar_t *input_data = input.const_data_ptr(); + + switch (memory_format){ + case MemoryFormat::ChannelsLast: { + output.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast); + avg_pool2d_out_zoom_frame_nhwc + <<>>( + count, + input_data, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + output_data, + divisor_override_value, + count_include_pad, + use_divisor); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + case MemoryFormat::Contiguous: { + avg_pool2d_out_zoom_frame + <<>>( + count, + input_data, + nInputPlane, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + kH, + kW, + dH, + dW, + padH, + padW, + output_data, + divisor_override_value, + count_include_pad, + use_divisor); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } + } + ); + } +} + +TORCH_IMPL_FUNC(avg_pool2d_backward_out_zoom) ( + const Tensor& gradOutput_, + const Tensor& input_, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override, + const Tensor& gradInput +) { + TensorArg gradInput_arg{ gradInput, "gradInput", 1 }; + TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 }; + TensorArg input_arg{ input_, "input_", 3 }; + + checkAllSameGPU("avg_pool2d_backward_out_zoom", + {gradInput_arg, gradOutput_arg, input_arg}); + + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); + + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dH : safe_downcast(stride[1]); + + const int padH = safe_downcast(padding[0]); + const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const auto memory_format = input_.suggest_memory_format(); + const Tensor input = input_.contiguous(memory_format); + const Tensor gradOutput = gradOutput_.contiguous(memory_format); + + const int64_t nInputPlane = input.size(-3); + const int64_t inputHeight = input.size(-2); + const int64_t inputWidth = input.size(-1); + + const int64_t outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode); + const int64_t outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode); + + + const auto count = input.numel(); + if (count == 0) { + return; + } + + const uint32_t num_threads = std::min(at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); + const uint32_t num_blocks = ceil_div(count, num_threads); + + bool use_divisor = divisor_override.has_value(); + const auto divisor_override_value = use_divisor ? divisor_override.value() : 0; + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "avg_pool2d_backward_out_zoom_frame", + [&] { + using accscalar_t = acc_type; + + const scalar_t *gradOutput_data = gradOutput.const_data_ptr(); + scalar_t *gradInput_data = gradInput.mutable_data_ptr(); + + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(input, INT_MAX) ? ScalarType::Int : ScalarType::Long, + "avg_pool2d_backward_out_zoom_frame_launcher", + [&] { + switch (memory_format) { + + case MemoryFormat::ChannelsLast: { + gradInput.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::ChannelsLast); + avg_pool2d_backward_out_zoom_frame_nhwc + <<>>( + count, + gradOutput_data, + nInputPlane, + inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, + dH, dW, + padH, padW, + gradInput_data, + divisor_override_value, + count_include_pad, use_divisor); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + case MemoryFormat::Contiguous: { + avg_pool2d_backward_out_zoom_frame + <<>>( + count, + gradOutput_data, + nInputPlane, + inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, + dH, dW, + padH, padW, + gradInput_data, + divisor_override_value, + count_include_pad, use_divisor); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } + }); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/AveragePool3d.cu b/aten/src/ATen/native/zoom/AveragePool3d.cu new file mode 100644 index 0000000000000..d470809373c8b --- /dev/null +++ b/aten/src/ATen/native/zoom/AveragePool3d.cu @@ -0,0 +1,606 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + + +namespace at::native { +namespace { + +__device__ inline int min(int a, int b) { + return a <= b ? a : b; +} + +__device__ inline int max(int a, int b) { + return a >= b ? a : b; +} + +template +__global__ void avg_pool3d_zoom_update_output( + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + int kT, int kH, int kW, + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, + int offsetZ, int divisor_override) +{ + int oCol = blockIdx.x * blockDim.x + threadIdx.x; + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + int oFrame = (blockIdx.z + offsetZ) % output.size(1); // output frame/time + int slice = (blockIdx.z + offsetZ) / output.size(1); // output slice/feature + + if (oRow < output.size(2) && oCol < output.size(3)) + { + accscalar_t sum = 0.0; + + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, input.size(1) + padT); + int hend = min(hstart + kH, input.size(2) + padH); + int wend = min(wstart + kW, input.size(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, input.size(1)); + hend = min(hend, input.size(2)); + wend = min(wend, input.size(3)); + + if (tstart >= tend || hstart >= hend || wstart >= wend) { + output[slice][oFrame][oRow][oCol] = scalar_t(0); + return; + } + + accscalar_t divide_factor; + if (divisor_override) { + divide_factor = static_cast(divisor_override); + } else { + if(count_include_pad) { + divide_factor = static_cast(pool_size); + } else { + divide_factor = static_cast((tend - tstart) * (hend - hstart) * (wend - wstart)); + } + } + + int ti, hi, wi; + for (ti = tstart; ti < tend; ++ti) + { + for (hi = hstart; hi < hend; ++hi) + { + for (wi = wstart; wi < wend; ++wi) + { + const scalar_t val = input[slice][ti][hi][wi]; + sum += val; + } + } + } + + output[slice][oFrame][oRow][oCol] = static_cast(sum / divide_factor); + } +} + +// Inner-most loop size (kW) passed as template parameter for +// performance reasons. +// +template +__global__ void avg_pool3d_zoom_update_output( + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + int kT, int kH, + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, + int offsetZ, int divisor_override) +{ + int oCol = blockIdx.x * blockDim.x + threadIdx.x; + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + int oFrame = (blockIdx.z + offsetZ) % output.size(1); // output frame/time + int slice = (blockIdx.z + offsetZ) / output.size(1); // output slice/feature + + if (oRow < output.size(2) && oCol < output.size(3)) + { + accscalar_t sum = 0.0; + + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, input.size(1) + padT); + int hend = min(hstart + kH, input.size(2) + padH); + int wend = min(wstart + KERNEL_WIDTH, input.size(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, input.size(1)); + hend = min(hend, input.size(2)); + wend = min(wend, input.size(3)); + + if (tstart >= tend || hstart >= hend || wstart >= wend) { + output[slice][oFrame][oRow][oCol] = scalar_t(0); + return; + } + + accscalar_t divide_factor; + if (divisor_override) { + divide_factor = static_cast(divisor_override); + } else { + if(count_include_pad) { + divide_factor = static_cast(pool_size); + } else { + divide_factor = static_cast((tend - tstart) * (hend - hstart) * (wend - wstart)); + } + } + + int ti, hi, wi; + for (ti = tstart; ti < tend; ++ti) + { + for (hi = hstart; hi < hend; ++hi) + { + for (wi = wstart; wi < wend; ++wi) + { + const scalar_t val = input[slice][ti][hi][wi]; + sum += val; + } + } + } + + output[slice][oFrame][oRow][oCol] = static_cast(sum / divide_factor); + } +} + +template +__global__ void avg_pool3d_single_backward_out_frame_stride1( + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, + int kT, int kH, int kW, + accscalar_t normFactor, + int offsetZ) +{ + int iCol = blockIdx.x * blockDim.x + threadIdx.x; + int iRow = blockIdx.y * blockDim.y + threadIdx.y; + int iFrame = (blockIdx.z + offsetZ) % gradInput.size(1); // input frame/time + int slice = (blockIdx.z + offsetZ) / gradInput.size(1); // input slice/feature + + // guard against over-tiled threads + if (iRow < gradInput.size(2) && iCol < gradInput.size(3)) + { + accscalar_t sum = 0.0; + const scalar_t *gOut = &gradOutput[slice][max(0, iFrame - kT + 1)] + [max(0, iRow - kH + 1)][max(0, iCol - kW + 1)]; + int frameOffset = 0; + for (int oFrame = max(0, iFrame - kT + 1); + oFrame < min(iFrame + 1, gradOutput.size(1)); + ++oFrame) + { + int rowOffset = frameOffset; + for (int oRow = max(0, iRow - kH + 1); + oRow < min(iRow + 1, gradOutput.size(2)); + ++oRow) + { + int colOffset = rowOffset; + for (int oCol = max(0, iCol - kW + 1); + oCol < min(iCol + 1, gradOutput.size(3)); + ++oCol) + { + sum += gOut[colOffset]; + ++colOffset; + } + rowOffset += gradOutput.size(3); + } + frameOffset += gradOutput.size(2) * gradOutput.size(3); + } + gradInput[slice][iFrame][iRow][iCol] = static_cast(sum * normFactor); + } +} + +template +__global__ void avg_pool3d_zoom_update_grad_input_atomic( + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, + int kT, int kH, int kW, + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, + int offsetZ, int divisor_override, const int gradInput_numel) +{ + int oCol = blockIdx.x * blockDim.x + threadIdx.x; + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + int oFrame = (blockIdx.z + offsetZ) % gradOutput.size(1); // gradOutput frame/time + int slice = (blockIdx.z + offsetZ) / gradOutput.size(1); // gradOutput slice/feature + + // guard against over-tiled threads + if (oRow < gradOutput.size(2) && oCol < gradOutput.size(3)) + { + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, gradInput.size(1) + padT); + int hend = min(hstart + kH, gradInput.size(2) + padH); + int wend = min(wstart + kW, gradInput.size(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, gradInput.size(1)); + hend = min(hend, gradInput.size(2)); + wend = min(wend, gradInput.size(3)); + + accscalar_t divide_factor; + if (divisor_override) { + divide_factor = static_cast(divisor_override); + } else { + if(count_include_pad) { + divide_factor = static_cast(pool_size); + } else { + divide_factor = static_cast((tend - tstart) * (hend - hstart) * (wend - wstart)); + } + } + + scalar_t val = static_cast( + static_cast(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor); + for (int iFrame = tstart; iFrame < tend; ++iFrame) + { + for (int iRow = hstart; iRow < hend; ++iRow) + { + for (int iCol = wstart; iCol < wend; ++iCol) + { + const int index = slice * gradInput.stride(0) + iFrame * gradInput.stride(1) + iRow * gradInput.stride(2) + iCol * gradInput.stride(3); + fastAtomicAdd(gradInput.data(), index, gradInput_numel, val, true); + } + } + } + } +} + +template +__global__ void avg_pool3d_zoom_update_grad_input( + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, + int kT, int kH, int kW, + int dT, int dH, int dW, + int padT, int padH, int padW, + bool count_include_pad, int offsetZ, int divisor_override) +{ + int oCol = blockIdx.x * blockDim.x + threadIdx.x; + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + int oFrame = (blockIdx.z + offsetZ) % gradOutput.size(1); // gradOutput frame/time + int slice = (blockIdx.z + offsetZ) / gradOutput.size(1); // gradOutput slice/feature + + // guard against over-tiled threads + if (oRow < gradOutput.size(2) && oCol < gradOutput.size(3)) + { + int tstart = oFrame * dT - padT; + int hstart = oRow * dH - padH; + int wstart = oCol * dW - padW; + int tend = min(tstart + kT, gradInput.size(1) + padT); + int hend = min(hstart + kH, gradInput.size(2) + padH); + int wend = min(wstart + kW, gradInput.size(3) + padW); + int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); + tstart = max(tstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + tend = min(tend, gradInput.size(1)); + hend = min(hend, gradInput.size(2)); + wend = min(wend, gradInput.size(3)); + + accscalar_t divide_factor; + if (divisor_override) { + divide_factor = static_cast(divisor_override); + } else { + if(count_include_pad) { + divide_factor = static_cast(pool_size); + } else { + divide_factor = static_cast((tend - tstart) * (hend - hstart) * (wend - wstart)); + } + } + + scalar_t val = static_cast( + static_cast(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor); + for (int iFrame = tstart; iFrame < tend; ++iFrame) + { + for (int iRow = hstart; iRow < hend; ++iRow) + { + for (int iCol = wstart; iCol < wend; ++iCol) + { + gradInput[slice][iFrame][iRow][iCol] = val; + } + } + } + } +} + +} // anonymous namespace + +#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ + avg_pool3d_zoom_update_output \ + <<>>( \ + work_input.packed_accessor64(), \ + work_output.packed_accessor64(), \ + kT, kH, \ + dT, dH, dW, \ + padT, padH, padW, \ + count_include_pad, \ + offsetZ, divisor); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + break + + +TORCH_IMPL_FUNC(avg_pool3d_out_zoom) ( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override, + const Tensor& output +) { + TensorArg output_arg{ output, "output", 1 }; + TensorArg input_arg{ input, "input", 2 }; + + checkAllSameGPU(__func__, {output_arg, input_arg}); + + const int kT = safe_downcast(kernel_size[0]); + const int kH = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[1]); + const int kW = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[2]); + + const int dT = stride.empty() ? kT : safe_downcast(stride[0]); + const int dH = stride.empty() ? kH : + stride.size() == 1 ? dT : safe_downcast(stride[1]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dT : safe_downcast(stride[2]); + + const int padT = safe_downcast(padding[0]); + const int padH = padding.size() == 1 ? padT : safe_downcast(padding[1]); + const int padW = padding.size() == 1 ? padT : safe_downcast(padding[2]); + + // if divisor==0 then we will ignore it + int64_t divisor = 0; + if (divisor_override.has_value()) { + divisor = divisor_override.value(); + } + + const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1; + const int64_t nslices = input.size(-4); + const int64_t itime = input.size(-3); + const int64_t iheight = input.size(-2); + const int64_t iwidth = input.size(-1); + + const int64_t otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode); + const int64_t oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode); + const int64_t owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode); + + Tensor work_input = input.contiguous(); + Tensor work_output = output; + if (input.ndimension() == 5) { + // Collapse batch and feature dimensions. + work_input = work_input.reshape({nbatch * nslices, itime, iheight, iwidth}); + work_output = work_output.reshape({nbatch * nslices, otime, oheight, owidth}); + } + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), + "avg_pool3d_out_zoom", + [&] { + using accscalar_t = acc_type; + int64_t totalZ = otime * nslices * nbatch; + int64_t offsetZ = 0; + dim3 block(32, 8); + + while (totalZ > 0) { + dim3 grid(ceil_div(owidth, static_cast(block.x)), + ceil_div(oheight, static_cast(block.y)), + totalZ > 65535 ? 65535 : totalZ); + + switch (kW) { + LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1); + LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2); + LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3); + LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4); + LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5); + LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6); + LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7); + default: + avg_pool3d_zoom_update_output + <<>>( + work_input.packed_accessor64(), + work_output.packed_accessor64(), + kT, kH, kW, + dT, dH, dW, + padT, padH, padW, + count_include_pad, + offsetZ, divisor); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + + totalZ -= 65535; + offsetZ += 65535; + } + } + ); +} + +#undef LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH + + +TORCH_IMPL_FUNC(avg_pool3d_backward_out_zoom) ( + const Tensor& gradOutput, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + bool ceil_mode, + bool count_include_pad, + std::optional divisor_override, + const Tensor& gradInput +) { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("avg_pool3d_backward_zoom"); + + TensorArg gradInput_arg{ gradInput, "gradInput", 1 }; + TensorArg gradOutput_arg{ gradOutput, "gradOutput", 2 }; + TensorArg input_arg{ input, "input", 3 }; + + checkAllSameGPU(__func__, + {gradInput_arg, gradOutput_arg, input_arg}); + + const int kT = safe_downcast(kernel_size[0]); + const int kH = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[1]); + const int kW = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[2]); + + const int dT = stride.empty() ? kT : safe_downcast(stride[0]); + const int dH = stride.empty() ? kH : + stride.size() == 1 ? dT : safe_downcast(stride[1]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dT : safe_downcast(stride[2]); + + const int padT = safe_downcast(padding[0]); + const int padH = padding.size() == 1 ? padT : safe_downcast(padding[1]); + const int padW = padding.size() == 1 ? padT : safe_downcast(padding[2]); + + TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5), + "non-empty 4D or 5D (batch mode) tensor expected for gradOutput"); + + // if divisor==0 then we will ignore it + int64_t divisor = 0; + if (divisor_override.has_value()) { + divisor = divisor_override.value(); + } + + gradInput.zero_(); + + const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1; + const int64_t nslices = input.size(-4); + const int64_t itime = input.size(-3); + const int64_t iheight = input.size(-2); + const int64_t iwidth = input.size(-1); + + const int64_t otime = gradOutput.size(-3); + const int64_t oheight = gradOutput.size(-2); + const int64_t owidth = gradOutput.size(-1); + + /* XXX shape check behavior from TH */ + const int64_t otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode); + const int64_t oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode); + const int64_t owidth_for_chape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode); + + const bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW); + + Tensor work_grad_input = gradInput; + Tensor work_grad_output = gradOutput.contiguous(); + + if (input.ndimension() == 5) { + // Collapse batch and feature dimensions. + work_grad_input = work_grad_input.reshape({nbatch * nslices, itime, iheight, iwidth}); + work_grad_output = work_grad_output.reshape({nbatch * nslices, otime, oheight, owidth}); + } + + + // Optimizing for stride 1 is probably only of limited value, but this + // specialization yields 3x speedup over the gpuAtomicAdd implementation. + // Padding must be 0, otherwise, pool size may change. + if (dT == 1 && dH == 1 && dW == 1 && padT == 0 && padH == 0 && padW == 0) { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "avg_pool3d_backward_out_frame_stride1", + [&] { + using accscalar_t = acc_type; + int64_t totalZ = itime * nslices * nbatch; + int64_t offsetZ = 0; + dim3 block(32, 8); + + accscalar_t divide_factor; + if (divisor) { + divide_factor = static_cast(divisor); + } else { + divide_factor = static_cast(kT * kH * kW); + } + + while (totalZ > 0) { + dim3 grid(ceil_div(iwidth, static_cast(block.x)), + ceil_div(iheight, static_cast(block.y)), + totalZ > 65535 ? 65535 : totalZ); + + avg_pool3d_single_backward_out_frame_stride1 + <<>>( + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), + kT, kH, kW, + 1.0f/divide_factor, + offsetZ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + totalZ -= 65535; + offsetZ += 65535; + } + } + ); + } + else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "avg_pool3d_backward_out_frame", + [&] { + using accscalar_t = acc_type; + int64_t totalZ = otime * nslices * nbatch; + int64_t offsetZ = 0; + dim3 block(32, 8); + + while (totalZ > 0) { + dim3 grid(ceil_div(owidth, static_cast(block.x)), + ceil_div(oheight, static_cast(block.y)), + totalZ > 65535 ? 65535 : totalZ); + + if (kernelsOverlap) { + avg_pool3d_zoom_update_grad_input_atomic + <<>>( + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), + kT, kH, kW, + dT, dH, dW, + padT, padH, padW, + count_include_pad, + offsetZ, divisor, work_grad_input.numel()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + else { + avg_pool3d_zoom_update_grad_input + <<>>( + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), + kT, kH, kW, + dT, dH, dW, + padT, padH, padW, + count_include_pad, + offsetZ, divisor); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + + totalZ -= 65535; + offsetZ += 65535; + } + } + ); + } +} + +} // at::native diff --git a/aten/src/ATen/native/zoom/BinaryBitwiseOpsKernels.cu b/aten/src/ATen/native/zoom/BinaryBitwiseOpsKernels.cu new file mode 100644 index 0000000000000..79f4bc80615f5 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryBitwiseOpsKernels.cu @@ -0,0 +1,81 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +template +struct BitwiseAndFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a & b; + } +}; + +template<> +struct BitwiseAndFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a && b; + } +}; + +void bitwise_and_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_zoom", [&]() { + BitwiseAndFunctor f; + opmath_symmetric_gpu_kernel_with_scalars(iter, f); + }); +} + +template +struct BitwiseOrFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a | b; + } +}; + +template<> +struct BitwiseOrFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a || b; + } +}; + +void bitwise_or_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_or_zoom", [&]() { + BitwiseOrFunctor f; + opmath_symmetric_gpu_kernel_with_scalars(iter, f); + }); +} + +template +struct BitwiseXorFunctor { + __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a ^ b; + } +}; + +template<> +struct BitwiseXorFunctor { + __device__ __forceinline__ bool operator()(bool a, bool b) const { + return a != b; + } +}; + +void bitwise_xor_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_xor_zoom", [&]() { + BitwiseXorFunctor f; + opmath_symmetric_gpu_kernel_with_scalars(iter, f); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(bitwise_or_stub, &bitwise_or_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_zoom); + + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryDivFloorKernel.cu b/aten/src/ATen/native/zoom/BinaryDivFloorKernel.cu new file mode 100644 index 0000000000000..7ad48ce8c7cb1 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryDivFloorKernel.cu @@ -0,0 +1,83 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { +namespace binary_internal { + +void div_floor_kernel_zoom(TensorIteratorBase& iter) { + // See NOTE: [Floor Division in Python] + const auto dtype = iter.common_dtype(); + if (dtype == kByte) { + // In the special case of unsigned integer division, floor division is + // equivalent to truncation division (since the signs of the divisor and + // dividend are always the same) + return div_trunc_kernel_zoom(iter); + } else if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_zoom", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return c10::div_floor_integer(a, b); + }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_zoom", [&]() { + using accscalar_t = at::acc_type; + auto b = iter.scalar_value(2); + if (C10_UNLIKELY(b == 0)) { + return div_true_kernel_zoom(iter); + } + + auto inv_b = accscalar_t(1.0) / b; + iter.remove_operand(2); + gpu_kernel(iter, [b, inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + auto mod = std::fmod(a, b); + auto div = (a - mod) * inv_b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = c10::hip::compat::copysign(scalar_t(0), a * inv_b); + } + return floordiv; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_zoom", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return c10::div_floor_floating(a, b); + }); + }); + } +} +} // namespace binary_internal + +REGISTER_PRIVATEUSE1_DISPATCH(div_floor_stub, &binary_internal::div_floor_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryDivTrueKernel.cu b/aten/src/ATen/native/zoom/BinaryDivTrueKernel.cu new file mode 100644 index 0000000000000..09b92154633f6 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryDivTrueKernel.cu @@ -0,0 +1,61 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { +namespace binary_internal { + +CONSTEXPR_EXCEPT_WIN_CUDA char div_name[] = "div_kernel"; +void div_true_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (iter.common_dtype() == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto div_string = jiterator_stringify( + template T div_kernel(T a, T b) { return a / b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, div_string); +#else + using opmath_t = at::opmath_type; + opmath_gpu_kernel_with_scalars(iter, DivFunctor()); +#endif + return; + } + if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_zoom", [&]() { + using opmath_t = at::opmath_type; + auto inv_b = opmath_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel( + iter, + BUnaryFunctor>( + MulFunctor(), inv_b)); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_zoom", [&]() { + DivFunctor f; + gpu_kernel_with_scalars(iter, f); + }); + } +} +} // namespace binary_internal + +REGISTER_PRIVATEUSE1_DISPATCH(div_true_stub, &binary_internal::div_true_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryDivTruncKernel.cu b/aten/src/ATen/native/zoom/BinaryDivTruncKernel.cu new file mode 100644 index 0000000000000..bc1f9a851ae32 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryDivTruncKernel.cu @@ -0,0 +1,53 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native { +namespace binary_internal { + +void div_trunc_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_zoom", [&]() { + gpu_kernel_with_scalars( + iter, + [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a / b; }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_zoom", [&]() { + using accscalar_t = at::acc_type; + auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel(iter, [inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + return std::trunc(a * inv_b); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_zoom", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return std::trunc(a / b); + }); + }); + } +} +} // namespace binary_internal + +REGISTER_PRIVATEUSE1_DISPATCH(div_trunc_stub, &binary_internal::div_trunc_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryGeometricKernels.cu b/aten/src/ATen/native/zoom/BinaryGeometricKernels.cu new file mode 100644 index 0000000000000..ad16a2c2a6681 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryGeometricKernels.cu @@ -0,0 +1,39 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +void atan2_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + iter.common_dtype(), "atan2_zoom", + [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return ::atan2(a, b); + }); + }); +} + +void hypot_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + iter.common_dtype(), "hypot_zoom", + [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return ::hypot(a, b); + }); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(atan2_stub, &atan2_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(hypot_stub, &hypot_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryInternal.h b/aten/src/ATen/native/zoom/BinaryInternal.h new file mode 100644 index 0000000000000..a42408c5207fa --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryInternal.h @@ -0,0 +1,48 @@ +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_zoom(TensorIteratorBase& iter); +void div_trunc_kernel_zoom(TensorIteratorBase& iter); +} // namespace binary_internal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/zoom/BinaryLogicalOpsKernels.cu b/aten/src/ATen/native/zoom/BinaryLogicalOpsKernels.cu new file mode 100644 index 0000000000000..5eb61fc112e8b --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryLogicalOpsKernels.cu @@ -0,0 +1,128 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +CONSTEXPR_EXCEPT_WIN_CUDA char logical_and_name[] = "logical_and_kernel"; +void logical_and_kernel_zoom(TensorIterator& iter) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto logical_and_string = jiterator_stringify( + template + bool logical_and_kernel(T a, T b) { + return a && b; + } + ); // logical_and_string + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ logical_and_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 2>(iter, logical_and_string); + }); // logical_and_string +#else + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_zoom", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return a && b; + }); + }); +#endif + } else { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16, + dtype, "logical_and_zoom", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return a && b; + }); + }); + } +} + +CONSTEXPR_EXCEPT_WIN_CUDA char logical_or_name[] = "logical_or_kernel"; +void logical_or_kernel_zoom(TensorIterator& iter) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto logical_or_string = jiterator_stringify( + template + bool logical_or_kernel(T a, T b) { + return a || b; + } + ); // logical_or_string + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ logical_or_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 2>(iter, logical_or_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_zoom", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return a || b; + }); + }); +#endif + } else { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16, + dtype, "logical_or_zoom", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return a || b; + }); + }); + } +} + +CONSTEXPR_EXCEPT_WIN_CUDA char logical_xor_name[] = "logical_xor_kernel"; +void logical_xor_kernel_zoom(TensorIterator& iter) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto logical_xor_string = jiterator_stringify( + template + bool logical_xor_kernel(T a, T b) { + return bool(a) != bool(b); + } + ); + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ logical_xor_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 2>(iter, logical_xor_string); + }); // logical_xor_string +#else + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_zoom", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return bool(a) != bool(b); + }); + }); +#endif + } else { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16, + dtype, "logical_xor_zoom", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { + return bool(a) != bool(b); + }); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(logical_and_stub, &logical_and_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(logical_or_stub, &logical_or_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(logical_xor_stub, &logical_xor_kernel_zoom); + + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryMiscBackwardOpsKernels.cu b/aten/src/ATen/native/zoom/BinaryMiscBackwardOpsKernels.cu new file mode 100644 index 0000000000000..e7e2bea410b0b --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryMiscBackwardOpsKernels.cu @@ -0,0 +1,131 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include + +#include + +#include +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +CONSTEXPR_EXCEPT_WIN_CUDA char sigmoid_backward_name[] = "sigmoid_backward"; +void sigmoid_backward_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.dtype(); + if(isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto sigmoid_backward_string = jiterator_stringify( + template + T sigmoid_backward(T a, T b) { + return a * std::conj((T{1.} - b) * b); + } + ); // sigmoid_backward_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "sigmoid_backward_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ sigmoid_backward_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 2>(iter, sigmoid_backward_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "sigmoid_backward_zoom", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + using comp_t = at::opmath_type; + const auto one = comp_t{1.}; + const auto comp_b = static_cast(b); + const auto comp_a = static_cast(a); + return static_cast(comp_a * std::conj((one - comp_b) * comp_b)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "sigmoid_backward_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t(1.) - b) * b; + }); + }); + } +} + +void logit_backward_kernel_zoom(TensorIteratorBase& iter, const Scalar& eps_scalar) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "logit_zoom", + [&]() { + using T_ACC = acc_type; + const T_ACC eps = eps_scalar.to(); + if (eps < T_ACC(0)) { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < T_ACC(0) || x_acc > T_ACC(1)) + ? std::numeric_limits::quiet_NaN() + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } else { + const T_ACC lo = eps; + const T_ACC hi = T_ACC(1) - eps; + gpu_kernel( + iter, [lo, hi] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + const T_ACC dy_acc = static_cast(dy); + const T_ACC x_acc = static_cast(x); + return (x_acc < lo || x_acc > hi) + ? T_ACC(0) + : dy_acc / (x_acc * (T_ACC(1) - x_acc)); + }); + } + }); +} + +CONSTEXPR_EXCEPT_WIN_CUDA char tanh_backward_name[] = "tanh_backward"; +void tanh_backward_kernel_zoom(TensorIteratorBase& iter) { + auto dtype = iter.dtype(); + if(isComplexType(dtype)) { +#if AT_USE_JITERATOR() + static const auto tanh_backward_string = jiterator_stringify( + template + T tanh_backward(T a, T b) { + return a * std::conj(T{1.} - b * b); + } + ); // tanh_backward_string + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_zoom", [&]() { + jitted_gpu_kernel< + /*name=*/ tanh_backward_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 2>(iter, tanh_backward_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "tanh_backward_complex_zoom", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + using comp_t = at::opmath_type; + const auto one = comp_t{1.}; + const auto comp_b = static_cast(b); + const auto comp_a = static_cast(a); + return static_cast(comp_a * std::conj(one - comp_b * comp_b)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "tanh_backward_zoom", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a * (scalar_t{1.} - b * b); + }); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(logit_backward_stub, &logit_backward_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(tanh_backward_stub, &tanh_backward_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryMiscOpsKernels.cu b/aten/src/ATen/native/zoom/BinaryMiscOpsKernels.cu new file mode 100644 index 0000000000000..fac17cfb29e47 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryMiscOpsKernels.cu @@ -0,0 +1,81 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +void smooth_l1_kernel_zoom(TensorIteratorBase& iter, double beta) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "smooth_l1_zoom", [&iter, beta]() { + scalar_t beta_val(beta); + gpu_kernel(iter, [beta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { + auto z = ::abs(a - b); + return z < beta_val ? scalar_t(0.5) * z * z / beta_val : z - scalar_t(0.5) * beta_val; + }); + }); +} + +void huber_kernel_zoom(TensorIterator& iter, double delta) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "huber_zoom", [&iter, delta] { + scalar_t delta_val(delta); + gpu_kernel(iter, [delta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { + auto z = ::abs(a - b); + return z < delta_val ? scalar_t(0.5) * z * z : delta_val * (z - scalar_t(0.5) * delta_val); + }); + }); +} + +void mse_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "mse_zoom", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + auto diff = a - b; + return diff * diff; + }); + }); +} + +void xlogy_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_zoom", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t { + if (at::_isnan(y)){ + return NAN; + } + if (x == 0){ + return 0; + } + return x * std::log(y); + }); + }); +} + +void xlog1py_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlog1py_zoom", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t { + if (at::_isnan(y)){ + return NAN; + } + if (x == 0){ + return 0; + } + return x * std::log1p(y); + }); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(smooth_l1_stub, &smooth_l1_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(huber_stub, &huber_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(mse_stub, &mse_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(xlogy_stub, &xlogy_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(xlog1py_stub, &xlog1py_kernel_zoom); + +// DO NOT ADD ANY NEW KERNELS HERE +// CUDA compilation times grow quickly. It's perfectly acceptable to have a file per kernel. + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryMulKernel.cu b/aten/src/ATen/native/zoom/BinaryMulKernel.cu new file mode 100644 index 0000000000000..dd42ba4d24880 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryMulKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +CONSTEXPR_EXCEPT_WIN_CUDA char mul_name[] = "mul_kernel"; +void mul_kernel_zoom(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (common_dtype == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto mul_string = jiterator_stringify( + template T mul_kernel(T a, T b) { return a * b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, mul_string); +#else + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); +#endif + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_zoom", [&]() { + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(mul_stub, &mul_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryRemainderKernel.cu b/aten/src/ATen/native/zoom/BinaryRemainderKernel.cu new file mode 100644 index 0000000000000..e290015f62502 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryRemainderKernel.cu @@ -0,0 +1,61 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +void remainder_kernel_zoom(TensorIteratorBase& iter) { + if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_zoom", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + scalar_t r = a % b; + if (r != 0 && c10::signs_differ(r, b)) { + r += b; + } + return r; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "remainder_zoom", [&]() { + gpu_kernel_with_scalars(iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { + auto mod = ::fmod(a, b); + if (mod != 0 && c10::signs_differ(b, mod)) { + mod += b; + } + return mod; + }); + }); + } +} + +void fmod_kernel_zoom(TensorIteratorBase& iter) { + if (isIntegralType(iter.common_dtype(), /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_zoom", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return a % b; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "fmod_zoom", [&]() { + gpu_kernel_with_scalars(iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { + return ::fmod(a, b); + }); + }); + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(remainder_stub, &remainder_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(fmod_stub, &fmod_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/BinaryShiftOpsKernels.cu b/aten/src/ATen/native/zoom/BinaryShiftOpsKernels.cu new file mode 100644 index 0000000000000..2b9edb9cfda72 --- /dev/null +++ b/aten/src/ATen/native/zoom/BinaryShiftOpsKernels.cu @@ -0,0 +1,44 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + + +void lshift_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_zoom", [&]() { + gpu_kernel_with_scalars(iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT; + if ((static_cast>(b) < 0) || (b >= max_shift)) { + return 0; + } + return static_cast>(a) << b; + }); + }); +} + +void rshift_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_zoom", [&]() { + gpu_kernel_with_scalars(iter, + []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + // right shift value to retain sign bit for signed and no bits for unsigned + constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT - std::is_signed_v; + if ((static_cast>(b) < 0) || (b >= max_shift)) { + return a >> max_shift; + } + return a >> b; + }); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(lshift_stub, &lshift_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(rshift_stub, &rshift_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Blas.cpp b/aten/src/ATen/native/zoom/Blas.cpp new file mode 100644 index 0000000000000..107640de6bf4b --- /dev/null +++ b/aten/src/ATen/native/zoom/Blas.cpp @@ -0,0 +1,1209 @@ +#ifdef ENABLE_ZOOM_BLAS + +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + + +constexpr int64_t c_i64_grid_X_chunk = 1ULL << 28; +constexpr int64_t c_i64_grid_YZ_chunk + = int64_t((std::numeric_limits::max() & ~0xf)); // % 16 == 0 + +namespace at::native { + +namespace { + +// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 +c10::MaybeOwned inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { + if (resolve_conj && tensor.is_conj()) { + return c10::MaybeOwned::owned(tensor.resolve_conj()); + } else { + return c10::MaybeOwned::borrowed(tensor); + } +} + +c10::MaybeOwned inline prepare_matrix_for_hipblas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); + } + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, !transpose_result); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, transpose_result); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +c10::MaybeOwned inline prepare_matrix_for_hipblas(const Tensor& tensor, bool& transpose_tensor) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, true); + } + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, true); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, true); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +struct hipblasCommonArgs { + hipblasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) { + bool transpose_result, transpose_mat1, transpose_mat2; + result = prepare_matrix_for_hipblas(c, transpose_result); + mata = prepare_matrix_for_hipblas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result); + matb = prepare_matrix_for_hipblas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result); + auto mat1_sizes = mat1.sizes(); + auto mat2_sizes = mat2.sizes(); + if (transpose_result) { + transpose_mat1 = !transpose_mat1; + transpose_mat2 = !transpose_mat2; + mat1_sizes = mata->sizes(); + mat2_sizes = matb->sizes(); + } + + m = mat1_sizes[transpose_result ? 1 : 0]; + k = mat1_sizes[transpose_result ? 0 : 1]; + n = mat2_sizes[transpose_result ? 0 : 1]; + lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0); + ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0); + result_ld = result->stride(transpose_result ? 0 : 1); + transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n'; + transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n'; + } + char transa, transb; + int64_t m, n, k; + int64_t lda, ldb, result_ld; + c10::MaybeOwned mata, matb, result; +}; +} // namespace + +c10::MaybeOwned prepare_batch_matrix_for_hipblas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) { + IntArrayRef tensor_strides = tensor.strides(); + c10::MaybeOwned tensor_; + int fast_dim = transpose_result ? 2 : 1; + int leading_dim = transpose_result ? 1 : 2; + + if (tensor_strides[fast_dim] == 1 && + (tensor_strides[leading_dim] >= std::max(1, m))) { + transpose_tensor = false; + tensor_ = resolve_conj_if_indicated(tensor, true); + ld_tensor = tensor_->strides()[leading_dim]; + } else if ((tensor_strides[leading_dim] == 1) && + (tensor_strides[fast_dim] >= std::max(1, n))) { + transpose_tensor = true; + tensor_ = resolve_conj_if_indicated(tensor, false); + ld_tensor = tensor_->strides()[fast_dim]; + } else { + transpose_tensor = !transpose_result; + // gemm call requires leading dimension and stride parameters to be non-zero + bool is_stride_non_zero = tensor.strides()[1] != 0 && tensor.strides()[2] != 0; + if (tensor.is_contiguous() && is_stride_non_zero) { + tensor_ = resolve_conj_if_indicated(tensor, transpose_result); + } else { + tensor_ = c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } + ld_tensor = tensor_->strides()[1]; + } + + return tensor_; +} + +namespace { + +enum class Activation { + None, + RELU, + GELU, +}; + +zoom::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) { + switch (a) { + case Activation::None: + return zoom::blas::GEMMAndBiasActivationEpilogue::None; + case Activation::RELU: + return zoom::blas::GEMMAndBiasActivationEpilogue::RELU; + case Activation::GELU: + return zoom::blas::GEMMAndBiasActivationEpilogue::GELU; + default: + TORCH_CHECK(false); + return zoom::blas::GEMMAndBiasActivationEpilogue::None; + } +} + +static bool getDisableAddmmHIPLt() { + #ifdef DISABLE_HIPBLASLT + return true; + #else + static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); + // if we enable tunable op, it'll take priority over just hipblaslt (heuristics) + // note the current tunable op is not the hipblaslt path (gemm_and_bias) + auto tuning_ctx = at::zoom::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { + return true; + } + // allow both CUDA and HIP env var names for ROCm builds + // also, current default for ROCm builds is disable by default + if (env_value == nullptr) { + env_value = std::getenv("DISABLE_ADDMM_HIP_LT"); + } + if (env_value != nullptr && strcmp(env_value, "0") == 0) { + return false; + } + return true; + #endif +} + + +static bool isSupportedHipLtROCmArch(int index) { + #ifdef DISABLE_HIPBLASLT + return false; + #else + hipDeviceProp_t* prop = at::zoom::getDeviceProperties(index); + std::string device_arch = prop->gcnArchName; + static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; + for (std::string arch : archs) { + size_t substring = device_arch.find(arch); + if (substring != std::string::npos) { + return true; + } + } + TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!"); + return false; + #endif +} + + +Tensor& addmm_out_hip_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { + // Make sure to keep addmm_hip below in sync with this code; it + // preflights a check to try to avoid actually needing to call + // expand(). + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ) + + TensorArg targs[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}}; + checkAllSameGPU(__func__, targs); + + IntArrayRef mat1_sizes = mat1.sizes(); + IntArrayRef mat2_sizes = mat2.sizes(); + IntArrayRef self__sizes; + bool useLtInterface = false; + static bool disable_addmm_hip_lt = getDisableAddmmHIPLt(); + at::ScalarType scalar_type = self.scalar_type(); + c10::MaybeOwned self_; + if (&result != &self) { + // Strangely, if mat2 has only 1 row or column, we get + // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. + // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] + // is to use lt interface only when self is bias. + // for cuda 11.4, cublasLtMatmul is activated + // the last two conditions is to skip 16b transA and non-trans-B having + // leading dim >> rows when they are sliced from a large tensor + // see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul + // if (!disable_addmm_hip_lt) { + // useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 && + // result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] && + // self.is_contiguous() && result.is_contiguous() && + // isSupportedHipLtROCmArch(self.device().index()) && + // (scalar_type == at::ScalarType::Float || + // scalar_type == at::ScalarType::Half || + // scalar_type == at::ScalarType::BFloat16) && + + // mat2_sizes[0] > 1 && mat2_sizes[1] > 1 && + // mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 && + // mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && + // // avoid leading dim >> rows bugs + // ((mat1.strides()[0] == 1 && mat1.strides()[1] == mat1_sizes[0]) || + // (mat1.strides()[1] == 1 && mat1.strides()[0] == mat1_sizes[1]) || + // (scalar_type != at::ScalarType::Half && + // scalar_type != at::ScalarType::BFloat16)) && + // ((mat2.strides()[0] == 1 && mat2.strides()[1] == mat2_sizes[0]) || + // (mat2.strides()[1] == 1 && mat2.strides()[0] == mat2_sizes[1]) || + // (scalar_type != at::ScalarType::Half && + // scalar_type != at::ScalarType::BFloat16)); + // } + // if (!useLtInterface) { + self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm"); + // } + self__sizes = self_->sizes(); + } else { + // useLtInterface = !disable_addmm_hip_lt && + // result.dim() == 2 && result.is_contiguous() && + // isSupportedHipLtROCmArch(self.device().index()) && + // (scalar_type == at::ScalarType::Float || + // scalar_type == at::ScalarType::Half || + // scalar_type == at::ScalarType::BFloat16); + + self_ = c10::MaybeOwned::borrowed(self); + self__sizes = self_->sizes(); + TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0"); + TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1"); + } + + if (&result != &self) { + at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]}); + if (beta.toComplexDouble() != 0.0 && !useLtInterface) { + at::native::copy_(result, *self_); + } + } + + + IntArrayRef result_sizes = result.sizes(); + if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) { + return result; + } + + hipblasCommonArgs args(mat1, mat2, result); + + if (mat1.numel() == 0) { + // By definition, when beta==0, values in self should be ignored. nans and infs + // should not propagate + if (beta.toComplexDouble() == 0.) { + return result.zero_(); + } + // TODO: We could squeeze some perf by calling at::zoom::mul_out here instead, to bypass the dispatcher. + // That requires some fixing some internal build dependencies though. + return at::mul_out( + result, + self.expand(result.sizes()), + at::native::scalar_tensor( + beta, + self.scalar_type(), + c10::nullopt /* layout */, + at::kCPU, + c10::nullopt /* pin_memory */)); + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); + + // if (useLtInterface) { + // AT_DISPATCH_FLOATING_TYPES_AND2( + // at::ScalarType::Half, + // at::ScalarType::BFloat16, + // scalar_type, + // "addmm_hip_lt", + // [&] { + // at::zoom::blas::gemm_and_bias( + // args.transa == 't', + // args.transb == 't', + // args.m, + // args.n, + // args.k, + // alpha.to>(), + // args.mata->const_data_ptr(), + // args.lda, + // args.matb->const_data_ptr(), + // args.ldb, + // // This condition is needed for mm case on ROCm for hipblasLt path. + // // Passing the bias ptr as null to avoid accuracy issues for mm case. + // (&result != &self) ? self.const_data_ptr() : nullptr, + // args.result->data_ptr(), + // args.result_ld, + // activation_to_gemm_and_blas_arg(activation) + // ); + // }); + // } else + // { + + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + scalar_type, + "addmm_hip", + [&] { + using opmath_t = at::opmath_type; + opmath_t alpha_val = alpha.to(); + opmath_t beta_val = beta.to(); + const scalar_t* mat1_ptr = args.mata->const_data_ptr(); + const scalar_t* mat2_ptr = args.matb->const_data_ptr(); + scalar_t* result_ptr = args.result->mutable_data_ptr(); + at::zoom::blas::gemm( + args.transa, + args.transb, + args.m, + args.n, + args.k, + alpha_val, + mat1_ptr, + args.lda, + mat2_ptr, + args.ldb, + beta_val, + result_ptr, + args.result_ld); + }); + + // AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + // at::ScalarType::Half, + // at::ScalarType::BFloat16, + // scalar_type, + // "addmm_hip", + // [&] { + // using opmath_t = at::opmath_type; + // opmath_t alpha_val = alpha.to(); + // opmath_t beta_val = beta.to(); + // const scalar_t* mat1_ptr = args.mata->const_data_ptr(); + // const scalar_t* mat2_ptr = args.matb->const_data_ptr(); + // scalar_t* result_ptr = args.result->mutable_data_ptr(); + + // static constexpr int GEMM_DIM_X = 32; + // static constexpr int GEMM_DIM_Y = 32; + + // // JIT kernel + // auto desc = at::zoom::jit::make_kernel_descriptor("gemm", gemm_code, /*nInputs=*/5, /*nOutputs=*/1); + // auto gemm_kernel = at::zoom::jit::zoom_generate_code(desc); + // at::zoom::jit::hiprtcFunction gemm_f = at::zoom::jit::jit_pwise_function(gemm_kernel, desc.name); + + // // chunked launch + // for(int64_t n_base = 0; n_base < args.n; n_base += c_i64_grid_YZ_chunk) + // { + // // don't need to block through M as it's 32 bit and can use full 32-bits in X-dim of grid + // int32_t nblock = int32_t(std::min(args.n - n_base, c_i64_grid_YZ_chunk)); + + // void* gemm_args[] = { + // args.transa, + // args.transb, + // args.m, + // args.n, + // args.k, + // alpha_val, + // mat1_ptr, + // args.lda, + // mat2_ptr, + // args.ldb, + // beta_val, + // result_ptr, + // args.result_ldN + // }; + + // at::zoom::jit::launch_jitted_pwise_function(gemm_f, gemm_args, gemm_grid, gemm_thread, smem); + + // } + // }); + switch (activation) { + case Activation::RELU: + at::relu_(const_cast(*args.result)); + break; + case Activation::GELU: + at::gelu_(const_cast(*args.result), "tanh"); + break; + default: break; + } + // } + + if (!result.is_same(*args.result)) { + result.copy_(*args.result); + } + return result; +} + +const Tensor& baddbmm_out_hip_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { + // handle pathological cases that blas may not like + if (result.numel() == 0) { + return result; + } else if (batch1.size(2) == 0) { + if (beta.to>() == 0.0) { + return result.zero_(); + } else { + return result.mul_(beta); + } + } + + bool transpose_result = false; + c10::MaybeOwned result_; + IntArrayRef result_strides = result.strides(); + IntArrayRef result_sizes = result.sizes(); + + if ((result_strides[1] == 1) && + ((result_sizes[2] == 1) || (result_strides[2] >= std::max(1, result_sizes[1])))) { + result_ = resolve_conj_if_indicated(result, true); + } else if ((result_strides[2] == 1) && + (result_sizes[1] == 1 || (result_strides[1] >= std::max(1, result_sizes[2])))) { + transpose_result = true; + result_ = resolve_conj_if_indicated(result, true); + } else { + result_ = c10::MaybeOwned::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2)); + } + + int leading_dim = transpose_result ? 1 : 2; + + int64_t m = result_sizes[transpose_result ? 2 : 1]; + int64_t n = result_sizes[leading_dim]; + int64_t k = (transpose_result ? batch2 : batch1).sizes()[leading_dim]; + + int64_t lda, ldb, ldc; + bool transpose_batch1, transpose_batch2; + auto batch1_ = prepare_batch_matrix_for_hipblas(transpose_result ? batch2 : batch1, transpose_batch1, lda, transpose_result, m, k); + auto batch2_ = prepare_batch_matrix_for_hipblas(transpose_result ? batch1 : batch2, transpose_batch2, ldb, transpose_result, k, n); + + ldc = result_->strides()[leading_dim]; + int64_t num_batches = result_->sizes()[0]; + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj()); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_hip", [&] { + using opmath_t = at::opmath_type; + opmath_t alpha_val = alpha.to(); + opmath_t beta_val = beta.to(); + const scalar_t* batch1_ptr = batch1_->const_data_ptr(); + const scalar_t* batch2_ptr = batch2_->const_data_ptr(); + scalar_t* result_ptr = result_->mutable_data_ptr(); + const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n'; + const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n'; + // If batch is 1 call gemm rather than bgemm + if (num_batches == 1) { + at::zoom::blas::gemm( + transa, transb, + m, n, k, + alpha_val, + batch1_ptr, lda, + batch2_ptr, ldb, + beta_val, + result_ptr, ldc); + } else { + at::zoom::blas::bgemm( + transa, transb, + m, n, k, + alpha_val, + batch1_ptr, lda, batch1_->strides()[0], + batch2_ptr, ldb, batch2_->strides()[0], + beta_val, + result_ptr, ldc, result_->strides()[0], + num_batches + ); + } + }); + if (!result.is_same(*result_)) { + result.copy_(*result_); + } + return result; +} + +} // anonymous namespace + +TORCH_IMPL_FUNC(addmm_out_hip)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) { + addmm_out_hip_impl(const_cast(result), self, mat1, mat2, beta, alpha); +} + +TORCH_IMPL_FUNC(addmm_activation_out_hip)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) { + addmm_out_hip_impl(const_cast(result), self, mat1, mat2, beta, alpha, use_gelu ? Activation::GELU : Activation::RELU); +} + +TORCH_IMPL_FUNC(mm_out_hip)(const Tensor& self, const Tensor& mat2, const Tensor& result) { + addmm_out_hip_impl(const_cast(result), result, self, mat2, 0, 1); +} + +TORCH_IMPL_FUNC(baddbmm_out_hip)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) { + { + at::NoNamesGuard guard; + baddbmm_out_hip_impl(result, self, batch1, batch2, beta, alpha); + } +} + +TORCH_IMPL_FUNC(bmm_out_hip)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) { + Scalar beta(0.0); + Scalar alpha(1.0); + { + NoNamesGuard guard; + baddbmm_out_hip_impl(result, result, batch1, batch2, beta, alpha); + } +} + + +namespace { + +inline void dot_check(const Tensor& self, const Tensor& other) { + TORCH_CHECK( + self.dim() == 1 && other.dim() == 1, + "1D tensors expected, but got ", + self.dim(), + "D and ", + other.dim(), + "D tensors"); + TORCH_CHECK( + self.scalar_type() == other.scalar_type(), + "dot : expected both vectors to have same dtype, but found ", + self.scalar_type(), + " and ", + other.scalar_type()); + TORCH_CHECK( + self.numel() == other.numel(), + "inconsistent tensor size, expected tensor [", + self.numel(), + "] and src [", + other.numel(), + "] to have the same number of elements, but got ", + self.numel(), + " and ", + other.numel(), + " elements respectively"); + TORCH_CHECK( + (self.numel() <= INT_MAX) && (self.stride(0) <= INT_MAX) && + (other.stride(0) <= INT_MAX), + "dot only supports n, incx, incy with the bound [val] <= %d", + INT_MAX); +} + +} // anonymous namespace + +// global sum reduce partial dot kernel results +std::string sum_reduce_blocks_code = R"( +#define HIP_ENABLE_PRINTF_DEBUG +extern "C" __global__ void reduce_blocks_kernel(scalar_t* block_results, scalar_t* out, int num_blocks) +{ + __shared__ scalar_t sdata[256]; + int tid = threadIdx.x; + + scalar_t sum = zero_init(); + for (int i = tid; i < num_blocks; i += blockDim.x) { + sum += block_results[i]; + } + + sdata[tid] = sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + out[0] = sdata[0]; + } +} +)"; + +// compute partial results for dot kernel in each warp +std::string dot_partial_code = R"( +#define HIP_ENABLE_PRINTF_DEBUG +template +__device__ T dot_mul(T x, T y) { + return x * y; +} + +// complex mul +template<> +__device__ hipFloatComplex dot_mul(hipFloatComplex x, hipFloatComplex y) { + return hipCmulf(x, y); +} + +template<> +__device__ hipDoubleComplex dot_mul(hipDoubleComplex x, hipDoubleComplex y) { + return hipCmul(x, y); +} + +extern "C" __global__ void dot_partial_kernel(scalar_t* a, scalar_t* b, scalar_t* block_results, int incx, int incy, int N) +{ + __shared__ scalar_t sdata[256]; + int tid = threadIdx.x; + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + scalar_t sum = zero_init(); + + // Grid stride loop + for (int i = gid; i < N; i += stride) { + sum += dot_mul(a[i * incx], b[i * incy]); + } + + // Store in shared memory + sdata[tid] = sum; + __syncthreads(); + + // Perform reduction in shared memory + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + + // Write result for this block to global memory + if (tid == 0) { + block_results[blockIdx.x] = sdata[0]; + } +} +)"; + +Tensor dot_hip(const Tensor& self, const Tensor& other) { + if (self.is_complex()) { + if (self.is_conj()) { + if (other.is_conj()) { + return (dot_hip(self.conj(), other.conj())).conj(); + } else { + return vdot_hip(self.conj(), other); + } + } else if (other.is_conj()) { + return vdot_hip(other.conj(), self); + } + } + + at::NoNamesGuard guard; + dot_check(self, other); + + int N = static_cast(self.numel()); + int incx = static_cast(self.stride(0)); + int incy = static_cast(other.stride(0)); + if (N == 1) { + incx = 1; + incy = 1; + } + + if (self._is_zerotensor() || other._is_zerotensor() || N == 0) { + return at::_efficientzerotensor({}, self.options()); + } + + return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + self.scalar_type(), "dot", + [&] { + Tensor result = at::empty({}, self.options()); + int THREADS_PER_BLOCK = num_threads(); + int BLOCKS = (N + block_work_size() - 1) / block_work_size(); + const int smem = sizeof(scalar_t)*256; + + auto self_ptr = self.data_ptr(); + auto other_ptr = other.data_ptr(); + auto res_ptr = result.data_ptr(); + auto* allocator = at::zoom::getZoomDeviceAllocator(); + scalar_t* partial_res_ptr = (scalar_t*) allocator->raw_allocate(sizeof(scalar_t)*BLOCKS); + void* dot_partial_args[] = {&self_ptr, &other_ptr, &partial_res_ptr, &incx, &incy, &N}; + + auto desc = at::zoom::jit::make_kernel_descriptor("dot_partial", dot_partial_code, /*nInputs=*/5, /*nOutputs=*/1); + auto dot_partial_kernel = at::zoom::jit::zoom_generate_code(desc); + at::zoom::jit::hiprtcFunction dot_partial_f = at::zoom::jit::jit_pwise_function(dot_partial_kernel, desc.name); + at::zoom::jit::launch_jitted_pwise_function(dot_partial_f, dot_partial_args, {BLOCKS, 1u, 1u}, {THREADS_PER_BLOCK, 1u, 1u}, smem); + + void* sum_reduce_blocks_args[] = {&partial_res_ptr, &res_ptr, &BLOCKS}; + auto reduce_desc = at::zoom::jit::make_kernel_descriptor("reduce_blocks", sum_reduce_blocks_code, /*nInputs=*/2, /*nOutputs=*/1); + auto sum_reduce_blocks_kernel = at::zoom::jit::zoom_generate_code(reduce_desc); + at::zoom::jit::hiprtcFunction sum_reduce_blocks_f = at::zoom::jit::jit_pwise_function(sum_reduce_blocks_kernel, reduce_desc.name); + at::zoom::jit::launch_jitted_pwise_function(sum_reduce_blocks_f, sum_reduce_blocks_args, {1u, 1u, 1u}, {THREADS_PER_BLOCK, 1u, 1u}, smem); + + allocator->raw_deallocate(partial_res_ptr); + + return result; + }); +} + +// compute partial results for dot kernel in each warp +std::string vdot_partial_code = R"( +#define HIP_ENABLE_PRINTF_DEBUG +template +__device__ T dot_mul(T x, T y) { + return x * y; +} + +// conjugate for complex dot product +template<> +__device__ hipFloatComplex dot_mul(hipFloatComplex x, hipFloatComplex y) { + return hipCmulf(hipConjf(x), y); +} + +template<> +__device__ hipDoubleComplex dot_mul(hipDoubleComplex x, hipDoubleComplex y) { + return hipCmul(hipConj(x), y); +} + +extern "C" __global__ void vdot_partial_kernel(scalar_t* a, scalar_t* b, scalar_t* block_results, int incx, int incy, int N) +{ + __shared__ scalar_t sdata[256]; + int tid = threadIdx.x; + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + scalar_t sum = zero_init(); + + // Grid stride loop + for (int i = gid; i < N; i += stride) { + sum += dot_mul(a[i * incx], b[i * incy]); + } + + // Store in shared memory + sdata[tid] = sum; + __syncthreads(); + + // Perform reduction in shared memory + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + sdata[tid] += sdata[tid + s]; + } + __syncthreads(); + } + + // Write result for this block to global memory + if (tid == 0) { + block_results[blockIdx.x] = sdata[0]; + } +} +)"; + + +Tensor vdot_hip(const Tensor& self, const Tensor& other) { + if (!self.is_complex()) { + return dot_hip(self, other); + } + + if (self.is_conj()) { + if (other.is_conj()) { + return vdot_hip(other.conj(), self.conj()); + } else { + return dot_hip(self.conj(), other); + } + } else if (other.is_conj()) { + return (dot_hip(self, other.conj())).conj(); + } + + at::NoNamesGuard guard; + dot_check(self, other); + + int N = static_cast(self.numel()); + int incx = static_cast(self.stride(0)); + int incy = static_cast(other.stride(0)); + if (N == 1) { + incx = 1; + incy = 1; + } + + + if (self._is_zerotensor() || other._is_zerotensor() || N == 0) { + return at::_efficientzerotensor({}, self.options()); + } + + return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] { + Tensor result = at::empty({}, self.options()); + int THREADS_PER_BLOCK = num_threads(); + int BLOCKS = (N + block_work_size() - 1) / block_work_size(); + const int smem = sizeof(scalar_t)*256; + + auto self_ptr = self.data_ptr(); + auto other_ptr = other.data_ptr(); + auto res_ptr = result.data_ptr(); + auto* allocator = at::zoom::getZoomDeviceAllocator(); + scalar_t* partial_res_ptr = (scalar_t*) allocator->raw_allocate(sizeof(scalar_t)*BLOCKS); + void* vdot_partial_args[] = {&self_ptr, &other_ptr, &partial_res_ptr, &incx, &incy, &N}; + + auto desc = at::zoom::jit::make_kernel_descriptor("vdot_partial", vdot_partial_code, /*nInputs=*/5, /*nOutputs=*/1); + auto vdot_partial_kernel = at::zoom::jit::zoom_generate_code(desc); + at::zoom::jit::hiprtcFunction vdot_partial_f = at::zoom::jit::jit_pwise_function(vdot_partial_kernel, desc.name); + at::zoom::jit::launch_jitted_pwise_function(vdot_partial_f, vdot_partial_args, {BLOCKS, 1u, 1u}, {THREADS_PER_BLOCK, 1u, 1u}, smem); + + void* sum_reduce_blocks_args[] = {&partial_res_ptr, &res_ptr, &BLOCKS}; + auto reduce_desc = at::zoom::jit::make_kernel_descriptor("reduce_blocks", sum_reduce_blocks_code, /*nInputs=*/2, /*nOutputs=*/1); + auto sum_reduce_blocks_kernel = at::zoom::jit::zoom_generate_code(reduce_desc); + at::zoom::jit::hiprtcFunction sum_reduce_blocks_f = at::zoom::jit::jit_pwise_function(sum_reduce_blocks_kernel, reduce_desc.name); + at::zoom::jit::launch_jitted_pwise_function(sum_reduce_blocks_f, sum_reduce_blocks_args, {1u, 1u, 1u}, {THREADS_PER_BLOCK, 1u, 1u}, smem); + + allocator->raw_deallocate(partial_res_ptr); + + return result; + }); +} + +TORCH_IMPL_FUNC(addmv_out_hip)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) { + c10::MaybeOwned self_ = expand_size(self, {mat.size(0)}); + auto betaval = beta_.toComplexDouble(); + if (mat.numel() == 0) { + // shortcut for an empty matrix + // By definition, when beta==0, values in self should be ignored. nans and infs + // should not propagate + if (betaval == 0.0) { + result.zero_(); + } else { + at::mul_out( + const_cast(result), + self, + at::native::scalar_tensor( + beta_, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */)); + } + } else { + if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents will be zeroed later + at::native::copy_(const_cast(result), *self_); + } + if (result.numel() != 0) { + auto r_stride = result.stride(0); + auto vec_stride = vec.stride(0); + + // Check for contiguity of `vec` and update `vec_stride` accordingly + const auto vec_contiguous = vec_stride == 0 ? vec.contiguous() : vec; + // A vector can be contiguous and have a stride of zero if it has it is of length 1 + vec_stride = std::max(vec_contiguous.stride(0), 1LL); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "addmv_impl_hip", [&] { + auto beta = beta_.to(); + auto alpha = alpha_.to(); + if (mat.stride(0) == 1 && mat.stride(1) >= std::max(1, mat.size(0))) { + at::zoom::blas::gemv('n', + mat.size(0), mat.size(1), alpha, mat.const_data_ptr(), mat.stride(1), vec_contiguous.const_data_ptr(), + vec_stride, beta, result.mutable_data_ptr(), r_stride); + } + else if (mat.stride(1) == 1 && mat.stride(0) >= std::max(1, mat.size(1))) { + at::zoom::blas::gemv('t', + mat.size(1), mat.size(0), alpha, mat.const_data_ptr(), mat.stride(0), + vec_contiguous.const_data_ptr(), vec_stride, beta, result.mutable_data_ptr(), r_stride); + } + else { + Tensor cmat = mat.contiguous(); + at::zoom::blas::gemv('t', + mat.size(1), mat.size(0), alpha, cmat.const_data_ptr(), cmat.stride(0), + vec_contiguous.const_data_ptr(), vec_stride, beta, result.mutable_data_ptr(), r_stride); + } + }); + } + } +} + + +Tensor& _int_mm_out_hip(const Tensor& self, const Tensor& mat2, Tensor& result) { + // NOTE: cuBLAS is currently broken for some combination of transposed inputs. + TORCH_CHECK(self.dim() == 2, "Expected self to be of dimension 2 but got ", self.dim()); + TORCH_CHECK(mat2.dim() == 2, "Expected mat2 to be of dimension 2 but got ", mat2.dim()); + TORCH_CHECK(self.size(0) > 16, "self.size(0) needs to be greater than 16, but got ", self.size(0)); + TORCH_CHECK(self.size(1) > 0 && self.size(1) % 8 == 0, "self.size(1) needs to be greater than 0 and a multiple of 8, but got ", self.size(1)); + TORCH_CHECK(self.size(1) == mat2.size(0), "self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0)); + TORCH_CHECK(mat2.size(1) > 0 && mat2.size(1) % 8 == 0, "mat2.size(1) needs to be greater than 0 and a multiple of 8, but got ", mat2.size(1)); + + TORCH_CHECK(result.dtype() == at::kInt, "Expected result dtype to be of type kInt but got ", result.dtype()); + TORCH_CHECK(result.size(0) == self.size(0), "Expected result.size(0) to be ", self.size(0), " but got ", result.size(0)); + TORCH_CHECK(result.size(1) == mat2.size(1), "Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1)); + + TORCH_CHECK(result.dim() == 2, "Expected result to be of dimension 2 but got ", result.dim()); + + TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous."); + + hipblasCommonArgs args(self, mat2, result); + + at::zoom::blas::int8_gemm( + args.transa == 't', + args.transb == 't', + args.m, + args.n, + args.k, + args.mata->data_ptr(), + args.lda, + args.matb->data_ptr(), + args.ldb, + args.result->data_ptr(), + args.result_ld); + + if (!result.is_same(*args.result)) { + result.copy_(*args.result); + } + + + // holdover from cuda/hip backend + TORCH_CHECK(false, "_int_mm_out_hip not compiled for this platform."); + + return result; +} + +Tensor _int_mm_hip(const Tensor& self, const Tensor& mat2) { + Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt)); + return _int_mm_out_hip(self, mat2, result); +} + +static bool _scaled_mm_allowed_device() { + auto dprops = at::zoom::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + static const std::vector archs = {"gfx940", "gfx941", "gfx942"}; + for (std::string arch : archs) { + size_t substring = device_arch.find(arch); + if (substring != std::string::npos) { + return true; + } + } + return false; + +} + +// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax +// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default. +// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed. +// Known limitations: +// - Only works if mat1 is row-major and mat2 is column-major +// - Only works if matrices sizes are divisible by 32 +// +// Arguments: +// - `mat1`: the first operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` +// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2` +// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16` +// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type +// - `scale_a`: a scalar tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type +// - `scale_b`: a scalar tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type +// - `scale_result`: a scalar tensor with the scale of the output, only set if the output is a float8 type +// - `use_fast_accum`: if true, enables fast float8 accumulation +// - `out`: a reference to the output tensor +// - `amax`: a reference to the amax tensor of the output, only needed if the output is a float8 type and will be updated inplace + +std::tuple +_scaled_mm_out_hip(const Tensor& mat1, const Tensor& mat2, + const std::optional& bias, + std::optional out_dtype, + const std::optional& scale_a, + const std::optional& scale_b, + const std::optional& scale_result, + bool use_fast_accum, + Tensor& out, Tensor& amax) { + // Check sizes + bool allowed_device = _scaled_mm_allowed_device(); + TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+"); + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + TORCH_CHECK( + mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", + mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); + TORCH_CHECK(!scale_a || (scale_a->numel() == 1 && scale_a->scalar_type() == kFloat), + "scale_a must be float scalar"); + TORCH_CHECK(!scale_b || (scale_b->numel() == 1 && scale_b->scalar_type() == kFloat), + "scale_b must be a float scalar"); + TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), + "scale_result must be a float scalar"); + TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], + " but got ", bias->numel()); + TORCH_CHECK( + mat1.sizes()[1] % 16 == 0, + "Expected trailing dimension of mat1 to be divisible by 16 ", + "but got mat1 shape: (", + mat1.sizes()[0], + "x", + mat1.sizes()[1], + "."); + TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x", + mat2.sizes()[1], " must be divisible by 16"); + // Check types + TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); + TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar"); + TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); + TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); + // Type restrictions imposed by CuBLASLt as of CUDA-12.1 + TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, + "Multiplication of two Float8_e5m2 matrices is not supported"); + if (bias) { + TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32"); + TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half, + "Bias must be either Half or BFloat16, but got ", bias->scalar_type()); + TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) || + bias->scalar_type() == ScalarType::BFloat16, + "Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); + TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half, + "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); + } + { + auto bias_ = bias.value_or(Tensor()); + auto scale_a_ = scale_a.value_or(Tensor()); + auto scale_b_ = scale_b.value_or(Tensor()); + auto scale_result_ = scale_result.value_or(Tensor()); + TensorArg targs[]{{out, "out", 0}, {amax, "amax", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}, + {bias_, "bias", 4}, {scale_a_, "scale_a", 5}, {scale_b_, "scale_b", 6}, + {scale_result_, "scale_result", 7}}; + checkAllSameGPU(__func__, targs); + } + + IntArrayRef mat1_sizes = mat1.sizes(); + IntArrayRef mat2_sizes = mat2.sizes(); + at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + at::native::resize_output(amax, {}); + + hipblasCommonArgs args(mat1, mat2, out); + const auto out_dtype_ = args.result->scalar_type(); + TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by hipBLASLt"); + auto tuning_ctx = at::zoom::tunable::getTuningContext(); + if (tuning_ctx->IsTunableOpEnabled()) { +#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \ + if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ + if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ + static at::zoom::tunable::ScaledGemmTunableOp< \ + at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ + static at::zoom::tunable::ScaledGemmTunableOp< \ + at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + } \ + else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ + if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ + static at::zoom::tunable::ScaledGemmTunableOp< \ + at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ + static at::zoom::tunable::ScaledGemmTunableOp< \ + at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + } + AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] { + bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); + bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); + at::zoom::tunable::ScaledGemmParams params; + params.transa = args.transa; + params.transb = args.transb; + params.m = args.m; + params.n = args.n; + params.k = args.k; + params.a = args.mata->data_ptr(); + params.a_scale_ptr = scale_a ? scale_a->data_ptr() : nullptr; + params.lda = args.lda; + params.a_dtype = args.mata->scalar_type(); + params.b = args.matb->data_ptr(); + params.b_scale_ptr = scale_b ? scale_b->data_ptr() : nullptr; + params.ldb = args.ldb; + params.b_dtype = args.matb->scalar_type(); + params.bias_ptr = bias ? bias->data_ptr(): nullptr; + params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; + params.c = args.result->data_ptr(); + params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr; + params.ldc = args.result_ld; + params.c_dtype = out_dtype_; + params.amax_ptr = amax.data_ptr(); + params.use_fast_accum = use_fast_accum; + if (transa_ && transb_) { + TUNABLE_DISPATCH(at::zoom::tunable::BlasOp::T, at::zoom::tunable::BlasOp::T) + } + else if (transa_ && !transb_) { + TUNABLE_DISPATCH(at::zoom::tunable::BlasOp::T, at::zoom::tunable::BlasOp::N) + } + else if (!transa_ && transb_) { + TUNABLE_DISPATCH(at::zoom::tunable::BlasOp::N, at::zoom::tunable::BlasOp::T) + } + else if (!transa_ && !transb_) { + TUNABLE_DISPATCH(at::zoom::tunable::BlasOp::N, at::zoom::tunable::BlasOp::N) + } + else { + TORCH_CHECK(false, "unreachable"); + } + }), + kHalf, kBFloat16, kFloat8_e4m3fnuz, kFloat8_e5m2fnuz, AT_EXPAND(AT_FLOATING_TYPES)); +#undef TUNABLE_DISPATCH + } + else + { +#if ROCM_VERSION >= 60200 + // hipBlasLT requires scaleD to be set to something in order to use AMAX + auto dummy_options = TensorOptions().dtype(kFloat).device(kPrivateUse1); + auto dummy_scale = at::ones(1, dummy_options); +#endif + at::zoom::blas::scaled_gemm( + args.transa, + args.transb, + args.m, + args.n, + args.k, + args.mata->data_ptr(), + scale_a ? scale_a->data_ptr() : nullptr, + args.lda, + args.mata->scalar_type(), + args.matb->data_ptr(), + scale_b ? scale_b->data_ptr() : nullptr, + args.ldb, + args.matb->scalar_type(), + bias ? bias->data_ptr(): nullptr, + bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, + args.result->data_ptr(), +#if ROCM_VERSION >= 60200 + scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(), +#else + scale_result ? scale_result->data_ptr() : nullptr, +#endif + args.result_ld, + out_dtype_, + amax.data_ptr(), + use_fast_accum); + } + +#if ROCM_VERSION >= 60000 && ROCM_VERSION < 60200 + // ROCm's hipBLASLt does not support amax before 6.2, so calculate separately + amax = at::max(at::abs(out.to(kFloat))); +#endif + + return {out, amax}; +} + +std::tuple +_scaled_mm_hip(const Tensor& mat_a, const Tensor& mat_b, + const std::optional& bias, + std::optional out_dtype, + const std::optional& scale_a, + const std::optional& scale_b, + const std::optional& scale_result, + bool use_fast_accum) { + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); + Tensor amax = at::empty({0}, mat_a.options().dtype(ScalarType::Float)); + return _scaled_mm_out_hip(mat_a, mat_b, bias, out_dtype, scale_a, scale_b, scale_result, use_fast_accum, out, amax); +} + +} // namespace at::native + +#endif \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Bucketization.cu b/aten/src/ATen/native/zoom/Bucketization.cu new file mode 100644 index 0000000000000..95232a5f4d414 --- /dev/null +++ b/aten/src/ATen/native/zoom/Bucketization.cu @@ -0,0 +1,233 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace at::native { + +// Implement a numpy like searchsorted and a TF like bucketize function running on cuda +// See details in ATen/native/Bucketization.cpp + +namespace { + +template +__device__ int64_t lower_bound(const input_t *data_ss, int64_t start, int64_t end, const input_t val, const int64_t *data_sort) { + // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset + // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2 + const int64_t orig_start = start; + while (start < end) { + const int64_t mid = start + ((end - start) >> 1); + const input_t mid_val = data_sort ? data_ss[orig_start + data_sort[mid]] : data_ss[mid]; + if (!(mid_val >= val)) { + start = mid + 1; + } + else { + end = mid; + } + } + return start; +} + +template +__device__ int64_t upper_bound(const input_t *data_ss, int64_t start, int64_t end, const input_t val, const int64_t *data_sort) { + // sorter gives relative ordering for ND tensors, so we need to save and add the non-updated start as an offset + // i.e. the second row of a 3x3 tensors starts at element 3 but sorter's second row only contains 0, 1, or 2 + const int64_t orig_start = start; + while (start < end) { + const int64_t mid = start + ((end - start) >> 1); + const input_t mid_val = data_sort ? data_ss[orig_start + data_sort[mid]] : data_ss[mid]; + if (!(mid_val > val)) { + start = mid + 1; + } + else { + end = mid; + } + } + return start; +} + +template +__global__ void searchsorted_zoom_kernel( + output_t *data_out, + const input_t *data_in, + const input_t *data_bd, + const int64_t *data_sort, + int64_t idim_in, + int64_t idim_bd, + int64_t numel_in, + bool right, + bool is_1d_boundaries) { + + for (int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < numel_in; tid += blockDim.x * gridDim.x) { + // If boundaries tensor is 1d, we always search the entire boundary tensor + int64_t start_bd = is_1d_boundaries ? 0 : tid / idim_in * idim_bd; + int64_t end_bd = start_bd + idim_bd; + + int64_t pos = !right ? + lower_bound(data_bd, start_bd, end_bd, data_in[tid], data_sort) - start_bd : + upper_bound(data_bd, start_bd, end_bd, data_in[tid], data_sort) - start_bd; + + // type conversion might happen here + data_out[tid] = pos; + } +} + +template +void searchsorted_zoom_contiguous(Tensor& result, const Tensor& input, const Tensor& boundaries, const bool& right, const Tensor& sorter) { + int64_t numel_in = input.numel(); + bool is_scalar_input = input.dim() == 0 && numel_in == 1; + // inner most dim size of input and boundaries + int64_t idim_in = is_scalar_input ? 1 : input.sizes().back(); + int64_t idim_bd = boundaries.sizes().back(); + + const input_t *data_in = input.const_data_ptr(); + const input_t *data_bd = boundaries.const_data_ptr(); + const int64_t *data_sort = sorter.defined() ? sorter.const_data_ptr() : nullptr; + output_t *data_out = result.mutable_data_ptr(); + + int64_t maxThread = at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock; + int64_t maxGrid = 1024; + dim3 block = dim3(std::min(maxThread, numel_in)); + dim3 grid = dim3(std::min(maxGrid, ceil_div(numel_in, block.x))); + c10::zoom::ZoomStream stream = c10::zoom::getCurrentZoomStream(); + + searchsorted_zoom_kernel<<>>( + data_out, data_in, data_bd, data_sort, idim_in, idim_bd, numel_in, right, boundaries.dim() == 1); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); +} + +void dispatch( + Tensor& result, + const Tensor& input, + const Tensor& boundaries, + bool out_int32, + bool right, + const Tensor& sorter) { + if (!out_int32) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_zoom", [&] { + searchsorted_zoom_contiguous(result, input, boundaries, right, sorter); + }); + } + else { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_zoom", [&] { + searchsorted_zoom_contiguous(result, input, boundaries, right, sorter); + }); + } +} + +} + +Tensor& searchsorted_out_zoom( + const Tensor& sorted_sequence, + const Tensor& self, + bool out_int32, + bool right, + const std::optional side_opt, + const std::optional& sorter_opt, + Tensor& result) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned sorter_maybe_owned = at::borrow_from_optional_tensor(sorter_opt); + const Tensor& sorter = *sorter_maybe_owned; + searchsorted_pre_check(sorted_sequence, self, result, out_int32, right, side_opt, sorter); + resize_output(result, self.sizes()); + + // we have two inputs to set right, pre_check checks that they aren't set to opposites + bool is_right = (side_opt && *side_opt == "right") || right; + if (self.numel() == 0) { + return result; + } + + // for non-contiguous result tensors, we write the output to a contiguous copy so we can later copy back, maintaining the original result tensor + Tensor out = result; + if (!result.is_contiguous()) { + out = result.contiguous(); + } + if (sorted_sequence.is_contiguous() && self.is_contiguous() && sorted_sequence.dtype() == self.dtype() && sorter.is_contiguous()) { + dispatch(out, self, sorted_sequence, out_int32, is_right, sorter); + } + else { + Tensor trimmed_input; + Tensor trimmed_boundaries; + Tensor trimmed_sorter; + searchsorted_maybe_trim_input_tensors(trimmed_input, trimmed_boundaries, trimmed_sorter, self, sorted_sequence, sorter); + const Tensor& final_input = trimmed_input.defined() ? trimmed_input : self; + const Tensor& final_boundaries = trimmed_boundaries.defined() ? trimmed_boundaries : sorted_sequence; + const Tensor& final_sorter = trimmed_sorter.defined() ? trimmed_sorter : sorter; + dispatch(out, final_input, final_boundaries, out_int32, is_right, final_sorter); + } + + // if result is non-contiguous, we wrote the answer to a copied version, so we copy back to the original result tensor + if (!result.is_contiguous()) { + result.copy_(out); + } + return result; +} + +Tensor& searchsorted_out_zoom( + const Tensor& sorted_sequence, + const Scalar& self, + bool out_int32, + bool right, + const std::optional side_opt, + const std::optional& sorter_opt, + Tensor& result) { + const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device()); + return searchsorted_out_zoom(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter_opt, result); +} + +Tensor searchsorted_zoom( + const Tensor& sorted_sequence, + const Tensor& self, + bool out_int32, + bool right, + const std::optional side_opt, + const std::optional& sorter) { + ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long; + c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type); + Tensor result = at::empty({0}, options, MemoryFormat::Contiguous); + at::native::searchsorted_out_zoom(sorted_sequence, self, out_int32, right, side_opt, sorter, result); + return result; +} + +Tensor searchsorted_zoom( + const Tensor& sorted_sequence, + const Scalar& self, + bool out_int32, + bool right, + const std::optional side_opt, + const std::optional& sorter) { + const Tensor& scalar_tensor = searchsorted_scalar_tensor(self, sorted_sequence.device()); + return searchsorted_zoom(sorted_sequence, scalar_tensor, out_int32, right, side_opt, sorter); +} + +Tensor& bucketize_out_zoom(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right, Tensor& result) { + TORCH_CHECK(boundaries.dim() == 1, "boundaries tensor must be 1 dimension, but got dim(", boundaries.dim(), ")"); + at::native::searchsorted_out_zoom(boundaries, self, out_int32, right, nullopt, nullopt, result); + return result; +} + +Tensor bucketize_zoom(const Tensor& self, const Tensor& boundaries, bool out_int32, bool right) { + ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long; + c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type); + Tensor result = at::empty({0}, options, MemoryFormat::Contiguous); + at::native::bucketize_out_zoom(self, boundaries, out_int32, right, result); + return result; +} + +Tensor bucketize_zoom(const Scalar& self, const Tensor& boundaries, bool out_int32, bool right) { + return bucketize_zoom(searchsorted_scalar_tensor(self, boundaries.device()), boundaries, out_int32, right); +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Col2Im.cu b/aten/src/ATen/native/zoom/Col2Im.cu new file mode 100644 index 0000000000000..2b9ed73079c87 --- /dev/null +++ b/aten/src/ATen/native/zoom/Col2Im.cu @@ -0,0 +1,171 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace at::native { +namespace { + +void col2im_out_zoom_template( + Tensor& output, + const Tensor& input_, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + TensorArg input_arg{input_, "input", 1}; + TensorArg output_arg{output, "output", 2}; + checkAllSameGPU(__func__, {input_arg, output_arg}); + + TORCH_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + TORCH_CHECK( + kernel_size.size() == 2, + "It is expected kernel_size equals to 2, but got size ", + kernel_size.size()); + + TORCH_CHECK( + dilation.size() == 2, + "It is expected dilation equals to 2, but got size ", + dilation.size()); + + TORCH_CHECK( + padding.size() == 2, + "It is expected padding equals to 2, but got size ", + padding.size()); + + TORCH_CHECK( + stride.size() == 2, + "It is expected stride equals to 2, but got size ", + stride.size()); + + int64_t output_height = output_size[0]; + int64_t output_width = output_size[1]; + int64_t kernel_height = kernel_size[0]; + int64_t kernel_width = kernel_size[1]; + int64_t dilation_height = dilation[0]; + int64_t dilation_width = dilation[1]; + int64_t pad_height = padding[0]; + int64_t pad_width = padding[1]; + int64_t stride_height = stride[0]; + int64_t stride_width = stride[1]; + + col2im_shape_check( + input_, + Tensor(), + output_height, + output_width, + kernel_height, + kernel_width, + dilation_height, + dilation_width, + pad_height, + pad_width, + stride_height, + stride_width); + + Tensor input = input_.contiguous(); + + bool batched_input = true; + if (input.dim() == 2) { + // Force batch + batched_input = false; + input = input.view({1, input.size(0), input.size(1)}); + } + + int64_t batch_size = input.size(0); + int64_t n_input_plane = input.size(1); + int64_t n_output_plane = n_input_plane / (kernel_width * kernel_height); + int64_t input_batch_stride = input.stride(0); + + output.resize_({batch_size, n_output_plane, output_height, output_width}); + int64_t output_batch_stride = output.stride(0); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), "col2im_out_zoom", [&] { + int64_t height_col = (output_height + 2 * pad_height - + (dilation_height * (kernel_height - 1) + 1)) / + stride_height + + 1; + int64_t width_col = (output_width + 2 * pad_width - + (dilation_width * (kernel_width - 1) + 1)) / + stride_width + + 1; + + col2im_batched( + c10::zoom::getCurrentZoomStream(), + input.const_data_ptr(), + input_batch_stride, + batch_size, + n_output_plane, + output_height, + output_width, + height_col, + width_col, + kernel_height, + kernel_width, + pad_height, + pad_width, + stride_height, + stride_width, + dilation_height, + dilation_width, + output.mutable_data_ptr(), + output_batch_stride); + + if (!batched_input) { + output.resize_({n_output_plane, output_height, output_width}); + } + }); +} + +} // namespace + +Tensor& col2im_out_zoom(const Tensor& input, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + Tensor& output) { + col2im_out_zoom_template( + output, input, output_size, kernel_size, dilation, padding, stride); + return output; +} + +Tensor col2im_zoom( + const Tensor& input, + IntArrayRef output_size, + IntArrayRef kernel_size, + IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride) { + Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + col2im_out_zoom_template( + output, input, output_size, kernel_size, dilation, padding, stride); + return output; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CompareEQKernel.cu b/aten/src/ATen/native/zoom/CompareEQKernel.cu new file mode 100644 index 0000000000000..b8869c0dc86b3 --- /dev/null +++ b/aten/src/ATen/native/zoom/CompareEQKernel.cu @@ -0,0 +1,50 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { namespace { + +enum class EqOpType {EQ, NE}; + +template +struct CompareEqFunctor{ + CompareEqFunctor(EqOpType op): op_(op) {} + const EqOpType op_; + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + if (op_ == EqOpType::EQ) { + return a == b; + } else { //NE + return a != b; + } + + } + }; +} + +C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) { + AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_zoom", AT_WRAP([&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, CompareEqFunctor(op)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +void eq_kernel_zoom(TensorIteratorBase& iter) { + compare_eq_ne_kernel(iter, EqOpType::EQ); +} + +void ne_kernel_zoom(TensorIteratorBase& iter) { + compare_eq_ne_kernel(iter, EqOpType::NE); +} + +REGISTER_PRIVATEUSE1_DISPATCH(eq_stub, &eq_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(ne_stub, &ne_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/CompareKernels.cu b/aten/src/ATen/native/zoom/CompareKernels.cu new file mode 100644 index 0000000000000..7975d449d1959 --- /dev/null +++ b/aten/src/ATen/native/zoom/CompareKernels.cu @@ -0,0 +1,103 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { namespace { + +enum class OpType {GE, GT, LE, LT}; + +template +struct CompareFunctor{ + constexpr CompareFunctor(OpType op): op_(op) {}; + OpType op_; + __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const { + if (op_ == OpType::GE) { + return a >= b; + } else if (op_ == OpType::GT) { + return a > b; + } else if (op_ == OpType::LE) { + return a <= b; + } else { //LT + return a < b; + } + } +}; + +// Reflects the comparison operator, so reflect(op)(a, b) == op(b, a) +OpType reflect(OpType x) { + switch (x) { + case OpType::GE: return OpType::LE; + case OpType::GT: return OpType::LT; + case OpType::LE: return OpType::GE; + case OpType::LT: return OpType::GT; + } + TORCH_INTERNAL_ASSERT(false, "Invalid OpType"); +} + +} // namespace (anonymous) + +template +void compare_scalar_kernel(TensorIteratorBase &iter, OpType op, scalar_t rhs) { + CompareFunctor f(op); + gpu_kernel(iter, [=] GPU_LAMBDA (scalar_t lhs) -> bool { + return f(lhs, rhs); + }); +} + +template +void compare_kernel_impl(TensorIteratorBase &iter, OpType op) { + // If either input is a cpu scalar, perform the equivalent comparison + // where the scalar is on the right hand side. This saves us from + // generating two otherwise identical kernels with mirrored + // arguments. + if (iter.is_cpu_scalar(1)) { + const scalar_t lhs = iter.scalar_value(1); + iter.remove_operand(1); + const DeviceGuard device_guard(iter.device(1)); + compare_scalar_kernel(iter, reflect(op), lhs); + } else if (iter.is_cpu_scalar(2)) { + const scalar_t rhs = iter.scalar_value(2); + iter.remove_operand(2); + compare_scalar_kernel(iter, op, rhs); + } else { + CompareFunctor f(op); + gpu_kernel(iter, f); + } +} + +C10_NOINLINE void compare_kernel_with_scalars(TensorIteratorBase &iter, OpType op) { + AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "compare_zoom", [&]() { + compare_kernel_impl(iter, op); + }); +} + + +void ge_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::GE); +} + +void gt_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::GT); +} + +void le_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::LE); +} + +void lt_kernel_zoom(TensorIteratorBase& iter) { + compare_kernel_with_scalars(iter, OpType::LT); +} + +REGISTER_PRIVATEUSE1_DISPATCH(ge_stub, &ge_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(gt_stub, >_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(le_stub, &le_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(lt_stub, <_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/ComplexKernel.cu b/aten/src/ATen/native/zoom/ComplexKernel.cu new file mode 100644 index 0000000000000..c00c15b49a03d --- /dev/null +++ b/aten/src/ATen/native/zoom/ComplexKernel.cu @@ -0,0 +1,36 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { +namespace { + +void complex_kernel_zoom(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.input_dtype(0), "complex_zoom", [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> c10::complex { + return c10::complex(a, b); + }); + }); +} + +void polar_kernel_zoom(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(0), "polar_zoom", [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> c10::complex { + return c10::complex(a * std::cos(b), a * std::sin(b)); + }); + }); +} + +} // anonymous namespace + +REGISTER_PRIVATEUSE1_DISPATCH(complex_stub, &complex_kernel_zoom); +REGISTER_PRIVATEUSE1_DISPATCH(polar_stub, &polar_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CompositeRandomAccessor.h b/aten/src/ATen/native/zoom/CompositeRandomAccessor.h new file mode 100644 index 0000000000000..d47a7fa776f1b --- /dev/null +++ b/aten/src/ATen/native/zoom/CompositeRandomAccessor.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +struct TupleInfoCPU { + template + using tuple = thrust::tuple; + + template + static constexpr auto tie(Types&... args) noexcept { + return thrust::tie(args...); + } +}; + +template +using CompositeRandomAccessorCPU = + CompositeRandomAccessor; + +template +void swap( + references_holder rh1, + references_holder rh2 +) { + return thrust::swap(rh1.data(), rh2.data()); +} + +template +auto get(references_holder rh) -> decltype(thrust::get(rh.data())) { + return thrust::get(rh.data()); +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ConvolutionMM2d.cu b/aten/src/ATen/native/zoom/ConvolutionMM2d.cu new file mode 100644 index 0000000000000..c2e165d4ac9dc --- /dev/null +++ b/aten/src/ATen/native/zoom/ConvolutionMM2d.cu @@ -0,0 +1,502 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +namespace at::native { +namespace { + +void slow_conv2d_shape_check( + const Tensor& input, const Tensor& grad_output, + const Tensor& weight, const Tensor& bias, + int64_t kH, int64_t kW, + int64_t dH, int64_t dW, + int64_t padH, int64_t padW, + bool weight_nullable) { + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: ", kH, " kW: ", kW); + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: ", dH, " dW: ", dW); + + TORCH_CHECK(weight_nullable || weight.defined(), + "weight tensor is expected to be non-nullable"); + TORCH_CHECK(!weight.defined() || + ((weight.numel() > 0) && (weight.dim() == 2)), + "non-empty 2D weight tensor expected, but got: ", weight.sizes()); + TORCH_CHECK(!bias.defined() || (bias.dim() == 1 && bias.sizes()[0] == weight.sizes()[0]), + "Expected bias to have shape [", weight.sizes()[0], "] but got ", bias.sizes()); + + const auto in_sizes = input.sizes(); + constexpr int ndim = 4; + constexpr int dimf = 1; + constexpr int dimh = 2; + constexpr int dimw = 3; + TORCH_CHECK(in_sizes.size() == ndim, "Expected 4D input tensor, but got ", in_sizes); + + // Allow for empty batch size but not other dimensions + const bool valid_empty = c10::multiply_integers(in_sizes.slice(1)) != 0; + TORCH_CHECK(valid_empty, "non-empty input tensor expected but got: ", in_sizes); + + int64_t inputHeight = in_sizes[dimh]; + int64_t inputWidth = in_sizes[dimw]; + + int64_t exactInputHeight = inputHeight + 2 * padH; + int64_t exactInputWidth = inputWidth + 2 * padW; + + TORCH_CHECK(exactInputHeight >= kH && exactInputWidth >= kW, + "Calculated padded input size per channel: ", + IntArrayRef{exactInputHeight, exactInputWidth}, + ". Kernel size: ", IntArrayRef{kH, kW}, + ". Kernel size can't be greater than actual input size"); + + // NOTE: can't use conv_output_size if the weight isn't defined + auto outputHeight = div_rtn(exactInputHeight - kH, dH) + 1; + auto outputWidth = div_rtn(exactInputWidth - kW, dW) + 1; + + TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1, + "Given input size per channel: ", + IntArrayRef{inputHeight, inputWidth}, + ". Calculated output size per channel: ", + IntArrayRef{outputHeight, outputWidth}, + ". Output size is too small"); + + if (weight.defined()) { + const auto w_sizes = weight.sizes(); + int64_t nInputPlane = w_sizes[1]; + if (w_sizes.size() == 2) { + nInputPlane /= (kH * kW); + } + TORCH_CHECK(in_sizes[dimf] == nInputPlane, + "Expected input dim ", dimf, " to have size ", nInputPlane, + " but got ", in_sizes[dimf]); + } + + if (grad_output.defined()) { + const auto gO_sizes = grad_output.sizes(); + TORCH_CHECK(gO_sizes.size() == ndim, + "Expected grad_output to have ", ndim, + " dimensions but got shape", gO_sizes); + + if (weight.defined()) { + const auto w_sizes = weight.sizes(); + TORCH_CHECK(gO_sizes[dimf] == w_sizes[0], + "Expected dim ", dimf, " to have size ", w_sizes[0], + " but got ", gO_sizes[dimf]); + } else if (bias.defined()) { + const auto b_sizes = bias.sizes(); + int64_t nOutputPlane = b_sizes.size() == 0 ? 1 : b_sizes[0]; + TORCH_CHECK(gO_sizes[dimf] == nOutputPlane, + "Expected grad_output dim ", dimf, " to have size ", + nOutputPlane, " but got ", gO_sizes[dimf]); + } + TORCH_CHECK(gO_sizes[dimh] == outputHeight, + "Expected grad_output dim ", dimh, " to have size ", + outputHeight, " but got ", gO_sizes[dimh]); + TORCH_CHECK(gO_sizes[dimw] == outputWidth, + "Expected grad_output dim ", dimw, " to have size ", + outputWidth, " but got ", gO_sizes[dimw]); + } +} + +Tensor new_view_weight_MM2d(const Tensor& weight_) { + auto weight = weight_.expect_contiguous(); + const auto w_sizes = weight->sizes(); + TORCH_CHECK(w_sizes.size() == 4); + int64_t s1 = w_sizes[0]; + int64_t s2 = c10::multiply_integers(w_sizes.slice(1)); + return weight->view({s1, s2}); +} + +void slow_conv2d_forward( + const Tensor &input, + const Tensor &output, + const Tensor &weight_, + const Tensor &bias, + int64_t kH, int64_t kW, + int64_t dH, int64_t dW, + int64_t padH, int64_t padW) { + auto weight = new_view_weight_MM2d(weight_); + slow_conv2d_shape_check( + input, {}, weight, bias, kH, kW, dH, dW, padH, padW, /*weight_nullable*/false); + + constexpr int dimf = 1; + constexpr int dimh = 2; + constexpr int dimw = 3; + + auto in_sizes = input.sizes(); + int64_t batchSize = in_sizes[0]; + int64_t nInputPlane = in_sizes[dimf]; + int64_t inputHeight = in_sizes[dimh]; + int64_t inputWidth = in_sizes[dimw]; + int64_t nOutputPlane = weight.sizes()[0]; + int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; + int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; + + // Resize output + resize_output(output, {batchSize, nOutputPlane, outputHeight, outputWidth}); + + // Create temporary columns + at::Tensor columns; + + const bool requires_columns = ( + kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0); + + if (requires_columns) { + columns = at::empty({nInputPlane * kW * kH, outputHeight * outputWidth}, input.options()); + } + + if (bias.defined()) { + TORCH_CHECK(bias.scalar_type() == input.scalar_type(), + "Expected bias to have type ", input.scalar_type(), + " but got ", bias.scalar_type()); + output.copy_(bias.view({-1, 1, 1})); + } else { + output.zero_(); + } + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "slow_conv2d_zoom", [&] { + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; elt ++) { + // Matrix multiply per output: + auto input_n = input.select(0, elt); + auto output_n = output.select(0, elt); + + if (requires_columns) { + // Extract columns: + at::native::im2col( + c10::zoom::getCurrentZoomStream(), + input_n.const_data_ptr(), + nInputPlane, inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, padH, padW, dH, dW, + 1, 1, + columns.mutable_data_ptr() + ); + } + + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m = nOutputPlane; + int64_t n = outputHeight * outputWidth; + int64_t k = nInputPlane*kH*kW; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + auto gemm_in_ptr = requires_columns ? + columns.const_data_ptr() : + input_n.const_data_ptr(); + at::zoom::blas::gemm( + 'n', 'n', + n, m, k, + scalar_t(1), + gemm_in_ptr, n, + weight.const_data_ptr(), k, + scalar_t(1), + output_n.mutable_data_ptr(), n + ); + } + }); +} + +void slow_conv2d_backward( + const Tensor &input, + const Tensor &grad_output, + const Tensor &grad_input, + const Tensor &weight_, + const Tensor &grad_columns, + int kH, int kW, + int dH, int dW, + int padH, int padW) { + Tensor weight = new_view_weight_MM2d(weight_); + slow_conv2d_shape_check(input, grad_output, weight, {}, + kH, kW, dH, dW, padH, padW, /*weight_nullable=*/false); + + // Params + auto weight_sizes = weight.sizes(); + int nInputPlane = weight_sizes[1]/(kW*kH); + int nOutputPlane = weight_sizes[0]; + + TORCH_INTERNAL_ASSERT(grad_output.is_contiguous()); + + auto input_sizes = input.sizes(); + int64_t inputWidth = input_sizes[3]; + int64_t inputHeight = input_sizes[2]; + auto output_sizes = grad_output.sizes(); + int64_t outputWidth = output_sizes[3]; + int64_t outputHeight = output_sizes[2]; + + // Batch size + input planes + int64_t batchSize = input_sizes[0]; + + // Resize output + resize_output(grad_input, input_sizes); + TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); + + // Resize temporary columns + resize_output(grad_columns, {nInputPlane*kW*kH, outputHeight*outputWidth}); + TORCH_CHECK(grad_columns.is_contiguous(), "grad_columns must be contiguous"); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "slow_conv2d_backward_zoom", [&] { + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; elt ++) { + // Matrix multiply per sample: + auto grad_input_n = grad_input.select(0, elt); + auto grad_output_n = grad_output.select(0, elt); + + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m = nInputPlane*kW*kH; + int64_t n = grad_columns.sizes()[1]; + int64_t k = nOutputPlane; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + at::zoom::blas::gemm( + 'n', 't', + n, m, k, + scalar_t(1), + grad_output_n.const_data_ptr(), n, + weight.const_data_ptr(), m, + scalar_t(0), + grad_columns.mutable_data_ptr(), n + ); + + // Unpack columns back into input: + using acc_t = at::acc_type; + at::native::col2im( + c10::zoom::getCurrentZoomStream(), + grad_columns.const_data_ptr(), + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, kH, kW, padH, padW, dH, dW, + 1, 1, grad_input_n.mutable_data_ptr() + ); + } + }); +} + +void slow_conv2d_grad_weight( + const Tensor &input, + const Tensor &grad_output, + const Tensor &grad_weight_, + const Tensor &columns, + int64_t kH, int64_t kW, + int64_t dH, int64_t dW, + int64_t padH, int64_t padW) { + TORCH_CHECK(grad_weight_.is_contiguous(), "grad_weight needs to be contiguous"); + auto grad_weight = new_view_weight_MM2d(grad_weight_); + slow_conv2d_shape_check(input, grad_output, grad_weight, {}, + kH, kW, dH, dW, padH, padW, /*weight_nullable=*/true); + + // Params + TORCH_INTERNAL_ASSERT(input.is_contiguous()); + TORCH_INTERNAL_ASSERT(grad_output.is_contiguous()); + + auto input_sizes = input.sizes(); + int64_t nInputPlane = input_sizes[1]; + int64_t nOutputPlane = grad_output.sizes()[1]; + + int64_t inputWidth = input_sizes[3]; + int64_t inputHeight = input_sizes[2]; + int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; + int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; + + // Batch size + input planes + int64_t batchSize = input_sizes[0]; + + // Resize temporary columns + resize_output(columns, {nInputPlane * kH * kW, outputHeight * outputWidth}); + + const bool requires_columns = ( + kW != 1 || kH != 1 || dW != 1 || dH != 1 || padH != 0 || padW != 0); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "slow_conv2d_grad_weight_zoom", [&] { + // For each elt in batch, do: + for (int elt = 0; elt < batchSize; elt ++) { + // Matrix multiply per output: + auto grad_output_n = grad_output.select(0, elt); + + // Matrix multiply per output: + auto input_n = input.select(0, elt); + + if (requires_columns) { + // Extract columns: + at::native::im2col( + c10::zoom::getCurrentZoomStream(), + input_n.const_data_ptr(), + nInputPlane, inputHeight, inputWidth, + outputHeight, outputWidth, + kH, kW, padH, padW, dH, dW, + 1, 1, + columns.mutable_data_ptr() + ); + } + + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + int64_t m = nOutputPlane; + int64_t n = nInputPlane*kW*kH; + int64_t k = columns.sizes()[1]; + + // Do GEMM (note: this is a bit confusing because gemm assumes column-major matrices) + auto gemm_in_ptr = requires_columns ? + columns.const_data_ptr() : + input_n.const_data_ptr(); + at::zoom::blas::gemm( + 't', 'n', + n, m, k, + scalar_t(1), + gemm_in_ptr, k, + grad_output_n.const_data_ptr(), k, + scalar_t(1), + grad_weight.mutable_data_ptr(), n + ); + } + }); +} + +} // namespace (anonymous) + + +Tensor& slow_conv2d_forward_out_zoom( + const Tensor &self_, + const Tensor &weight_, + IntArrayRef kernel_size, + const std::optional &bias_, + IntArrayRef stride, + IntArrayRef padding, + Tensor &output) { + TORCH_CHECK(kernel_size.size() == 2); + TORCH_CHECK(stride.size() == 2); + TORCH_CHECK(padding.size() == 2); + + auto self = self_.expect_contiguous(); + auto weight = weight_.expect_contiguous(); + auto bias = [&] { + if (bias_.has_value() && bias_->defined()) { + return bias_->expect_contiguous(); + } + return MaybeOwned::owned(std::in_place); + }(); + + slow_conv2d_forward( + *self, + output, + *weight, + *bias, + kernel_size[0], kernel_size[1], + stride[0], stride[1], + padding[0], padding[1] + ); + return output; +} + +Tensor slow_conv2d_forward_zoom( + const Tensor &self, + const Tensor &weight, + IntArrayRef kernel_size, + const std::optional &bias, + IntArrayRef stride, + IntArrayRef padding) { + auto output = at::empty({0}, self.options()); + return slow_conv2d_forward_out_zoom( + self, weight, kernel_size, bias, stride, padding, output); +} + +std::tuple slow_conv2d_backward_out_zoom( + const Tensor& grad_output_, + const Tensor& self_, + const Tensor& weight_, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + Tensor& grad_input, + Tensor& grad_weight, + Tensor& grad_bias) { + auto grad_output = grad_output_.expect_contiguous(); + + Tensor columns = at::empty({0}, self_.options()); + if (grad_input.defined()) { + resize_output(grad_input, self_.sizes()); + auto weight = weight_.expect_contiguous(); + + slow_conv2d_backward( + self_, *grad_output, + grad_input, *weight, + columns, + kernel_size[0], kernel_size[1], + stride[0], stride[1], + padding[0], padding[1]); + } + if (grad_bias.defined()) { + at::sum_out(grad_bias, *grad_output, IntArrayRef{0, 2, 3}); + } + if (grad_weight.defined()) { + resize_output(grad_weight, weight_.sizes()); + grad_weight.zero_(); + auto self = self_.expect_contiguous(); + slow_conv2d_grad_weight( + *self, + *grad_output, + grad_weight, + columns, + kernel_size[0], kernel_size[1], + stride[0], stride[1], + padding[0], padding[1] + ); + } + return std::tuple{ + grad_input, grad_weight, grad_bias}; +} + +std::tuple slow_conv2d_backward_zoom( + const Tensor& grad_output, + const Tensor& self, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + std::array output_mask) { + Tensor grad_input; + Tensor grad_weight; + Tensor grad_bias; + + if (output_mask[0]) { + grad_input = at::empty({0}, grad_output.options()); + } + + if (output_mask[1]) { + grad_weight = at::empty({0}, grad_output.options()); + } + + if (output_mask[2]) { + grad_bias = at::empty({0}, grad_output.options()); + } + + return native::slow_conv2d_backward_out_zoom( + grad_output, + self, + weight, + kernel_size, + stride, + padding, + grad_input, + grad_weight, + grad_bias); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Copy.cu b/aten/src/ATen/native/zoom/Copy.cu new file mode 100644 index 0000000000000..50b04fc10f92c --- /dev/null +++ b/aten/src/ATen/native/zoom/Copy.cu @@ -0,0 +1,400 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include + +namespace at::native { + +// forward decl, defined below +void direct_copy_kernel_zoom(TensorIteratorBase &iter); + +// forward decl, defined in UnarySignKernels.cu +void neg_kernel_zoom(TensorIteratorBase& iter); + +// forward decl, defined in UnaryComplexKernels.cu +void conj_kernel_zoom(TensorIteratorBase& iter); + +void float8_copy_kernel_zoom(TensorIteratorBase &iter) { + ScalarType dtype = iter.dtype(0); + ScalarType other_dtype = iter.dtype(1); + if (dtype == kFloat8_e4m3fn) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e4m3fn(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e4m3fn(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e4m3fn(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; }); + break; + } + } else if (dtype == kFloat8_e5m2) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { +#ifdef AT_USE_NV_CVT_INTRINSICS + const auto x = __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { +#ifdef AT_USE_NV_CVT_INTRINSICS + const auto x = __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { +#ifdef AT_USE_NV_CVT_INTRINSICS + const auto x = __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; }); + break; + } + } else if (dtype == kFloat8_e4m3fnuz) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e4m3fnuz(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e4m3fnuz(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e4m3fnuz(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; }); + break; + } + } else if (dtype == kFloat8_e5m2fnuz) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e5m2fnuz(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e5m2fnuz(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e5m2fnuz(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; }); + break; + } + } else { + TORCH_CHECK(false, "This supposed ot be called only for Float8 types"); + } +} + +// TODO: We probably can use the opaque type trick to avoid creating duplicate +// kernels for equivalent bit lengths +void direct_copy_kernel_zoom(TensorIteratorBase &iter) { + ScalarType dtype = iter.dtype(0); + if (isQIntType(dtype)) { + AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); + }); + } else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) { + float8_copy_kernel_zoom(iter); + } else if (isBitsType(dtype)) { + TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting " + "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype); + AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { + gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); + }); + } else { + AT_DISPATCH_V2( + dtype, "copy_", AT_WRAP([&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + } +} + +void neg_conj_kernel_zoom(TensorIteratorBase &iter) { + AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_zoom", [&] { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -std::conj(x); }); + }); +} + +using namespace at::zoom; + +// device-to-device copy, does type conversion +void copy_device_to_device(TensorIterator& iter, + bool non_blocking, + bool p2p_enabled) { + int64_t numel = iter.numel(); + + // We can memcpy the memory if both tensors have the same type AND both + // tensors are contiguous after dimension coalescing and reordering. + bool same_type = iter.dtype(0) == iter.dtype(1); + bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj(); + bool same_neg = iter.tensor(0).is_neg() == iter.tensor(1).is_neg(); + bool memcpy_eligible = same_type && same_conj && same_neg && iter.is_contiguous(); + + Device dst_device = iter.device(0); + Device src_device = iter.device(1); + + c10::zoom::ZoomGuard device_guard(src_device); + + // We always perform the copy on the source device, using the current stream + // on the source device, and we fully synchronize on both src and dst's + // current streams for completion of the copy. We have to explicitly do this + // for non-contig copies. This mimics the behavior of cross-device + // hipMemcpyAsync on the default stream. + c10::zoom::ZoomStream copy_stream = c10::zoom::getCurrentZoomStream(src_device.index()); + if (src_device != dst_device) { + // This is a cross-device copy on the src current stream and dst current + // stream. We perform a two-way barrier between both devices' streams + // before the copy. This ensures that any write-after-write and + // write-after-read dependencies on the destination side are handled, so + // that no one is operating on the dst memory when we perform the copy. + // src waits on dst barrier (src already waits on src) + ZoomEvent dst_ready; + device_guard.set_device(dst_device); + dst_ready.record(c10::zoom::getCurrentZoomStream(dst_device.index())); + + device_guard.set_device(src_device); + dst_ready.block(copy_stream); + } + + if (memcpy_eligible) { + void *dst = iter.data_ptr(0); + void *src = iter.data_ptr(1); + size_t size = numel * iter.element_size(0); + if (src != dst || src_device != dst_device) { + // Due to bizarre cuda driver intricacies, copies of + // hipMallocAsynced memory between devices that aren't + // peer-to-peer-capable need "hipMemcpyPeerAsync". + // So we let the allocator implement the correct call + // (either hipMemcpyAsync or hipMemcpyPeerAsync) + C10_ZOOM_CHECK(c10::zoom::ZoomCachingAllocator::memcpyAsync( + dst, dst_device.index(), + src, src_device.index(), + size, copy_stream, p2p_enabled)); + } + } else { + if (same_neg) { + if (!same_conj) { + conj_kernel_zoom(iter); + } else { + direct_copy_kernel_zoom(iter); + } + } else { + if (!same_conj) { + neg_conj_kernel_zoom(iter); + } else { + neg_kernel_zoom(iter); + } + } + } + + if (src_device != dst_device) { + // dst waits on src barrier (dst already waits on dst). We cannot + // operate on dst's copy until the copy is complete. + + // Still on src_device, record stream event + ZoomEvent src_ready; + src_ready.record(copy_stream); + + device_guard.set_device(dst_device); + src_ready.block(c10::zoom::getCurrentZoomStream(dst_device.index())); + } + + C10_ZOOM_CHECK(hipGetLastError()); +} + +static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) { + Device dst_device = iter.device(0); + Device src_device = iter.device(1); + + if (dst_device == src_device) { + // We never require temporaries for copies on the same GPU. + TORCH_INTERNAL_ASSERT(dst_device.is_privateuseone() && src_device.is_privateuseone()); + return false; + } + + bool same_dtype = iter.dtype(0) == iter.dtype(1); + if (same_dtype && iter.is_contiguous()) { + // Contiguous same-dtype copies can always use hipMemcpyAsync + return false; + } else if (dst_device.is_privateuseone() && src_device.is_privateuseone()) { + // Copies between GPUs can use the copy kernel if P2P is supported + return !p2p_enabled; + } else { + // The remaining cases require temporaries. For example, this includes + // non-contiguous copies between CPU and GPU. + return true; + } +} + +static bool maybe_enable_p2p_access(Device dst_device, Device src_device) { + if (dst_device.is_cpu() || src_device.is_cpu()) { + return false; + } + return at::zoom::get_p2p_access(src_device.index(), dst_device.index()); +} + +static void copy_kernel_zoom(TensorIterator& iter, bool non_blocking) { + TORCH_CHECK(iter.ntensors() == 2); + + Device dst_device = iter.device(0); + Device src_device = iter.device(1); + + // Enable p2p access between devices. (No-op if it involves the CPU) + bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device); + + if (copy_requires_temporaries(iter, p2p_enabled)) { + // NB: this involves recursive calls to copy. Be careful that those copies + // don't require temporaries or you will cause an infinite recursion! + auto& dst = iter.tensor(0); + Tensor dst_contig; + Tensor src_contig; + + // If non_blocking is true - type conversions are performed on the GPU + // For blocking transfers conversions are performed on CPU to avoid allocating + // extra GPU memory + // for GPU-GPU transfers conversions are performed on the source device + auto conversion_device = non_blocking ? DeviceType::PrivateUse1 : kCPU; + if (iter.device_type(1) == conversion_device) { + dst_contig = dst.is_contiguous() ? dst : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous(); + } else { + bool same_type = iter.dtype(0) == iter.dtype(1); + dst_contig = (dst.is_contiguous() && same_type) ? dst : at::empty_like(dst, iter.dtype(1), LEGACY_CONTIGUOUS_MEMORY_FORMAT); + src_contig = iter.tensor(1).expand_as(dst).contiguous(); + } + + // propagate the correct conjugate bit + dst_contig._set_conj(dst.is_conj()); + src_contig._set_conj(iter.tensor(1).is_conj()); + + dst_contig._set_neg(dst.is_neg()); + src_contig._set_neg(iter.tensor(1).is_neg()); + + // perform a same-dtype copy on contiguous tensors + TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes())); + TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type()); + dst_contig.copy_(src_contig, non_blocking); + + // if necessary, copy back into dst + if (!dst_contig.is_same(dst)) { + TORCH_INTERNAL_ASSERT(dst_contig.device() == dst.device()); + dst.copy_(dst_contig, non_blocking); + } + return; + } + + // Copy on GPU (or between GPUs) + if (dst_device.is_privateuseone() && src_device.is_privateuseone()) { + copy_device_to_device(iter, non_blocking, p2p_enabled); + return; + } + + // Copy between CPU and GPU + c10::zoom::OptionalZoomGuard device_guard; + hipMemcpyKind kind; + if (dst_device.is_privateuseone() && src_device.is_cpu()) { + device_guard.set_device(dst_device); + kind = hipMemcpyHostToDevice; + } else if (dst_device.is_cpu() && src_device.is_privateuseone()) { + device_guard.set_device(src_device); + kind = hipMemcpyDeviceToHost; + } else { + TORCH_INTERNAL_ASSERT(false, "unsupported devices in GPU copy_()"); + } + + void* dst = iter.data_ptr(0); + void* src = iter.data_ptr(1); + int64_t nbytes = iter.numel() * iter.element_size(0); + c10::zoom::ZoomStream stream = c10::zoom::getCurrentZoomStream(); + + if (non_blocking) { + C10_ZOOM_CHECK(hipMemcpyAsync(dst, src, nbytes, kind, stream)); + // we use both the storage context and the tensor data pointer as the key + // for the caching host allocator. This allows us to better attribute the + // events to the original tensor allocation correctly. The cases we seek to + // handle are: + + // 1: a user can pass a pinned memory tensor with an alternative + // context, for example if allocating memory directly from the pinned memory + // allocator and constructing a tensor with torch::from_blob. + + // 2: a user can pass a tensor with a different base pointer to the original + // allocation (via slicing). + const auto& dst_tensor = iter.tensor(0); + const auto& src_tensor = iter.tensor(1); + const auto& host_tensor = (dst_device == kCPU ? dst_tensor : src_tensor); + auto* ptr = (dst_device == kCPU ? dst : src); + auto* ctx = host_tensor.storage().data_ptr().get_context(); + // TODO: warn on the return value. + CachingHostAllocator_recordEvent(ptr, ctx, stream); + + } else { + c10::zoom::memcpy_and_sync(dst, src, nbytes, kind, stream); + } + + if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) { + iter.tensor(0).conj_physical_(); + } + if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) { + iter.tensor(0).neg_(); + } +} + + REGISTER_PRIVATEUSE1_DISPATCH(copy_stub, ©_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Copy.h b/aten/src/ATen/native/zoom/Copy.h new file mode 100644 index 0000000000000..d7a7243b36dfd --- /dev/null +++ b/aten/src/ATen/native/zoom/Copy.h @@ -0,0 +1,11 @@ +#pragma once + +namespace at { +struct TensorIteratorBase; + + namespace native { + + void direct_copy_kernel_zoom(TensorIteratorBase &iter); + + } +} \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/CopysignKernel.cu b/aten/src/ATen/native/zoom/CopysignKernel.cu new file mode 100644 index 0000000000000..d34dbc1ee9487 --- /dev/null +++ b/aten/src/ATen/native/zoom/CopysignKernel.cu @@ -0,0 +1,27 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include + +#include +#include +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at::native { + +void copysign_kernel_zoom(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_zoom", [&]() { + gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return c10::hip::compat::copysign(a, b); + }); + }); +} + +REGISTER_PRIVATEUSE1_DISPATCH(copysign_stub, ©sign_kernel_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CrossKernel.cu b/aten/src/ATen/native/zoom/CrossKernel.cu new file mode 100644 index 0000000000000..459766ccace6e --- /dev/null +++ b/aten/src/ATen/native/zoom/CrossKernel.cu @@ -0,0 +1,92 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include + +namespace at::native { + +template +__global__ void cross_kernel( + int numel, T* out, const T* x1, const T* x2, OffsetCalc offset_calculator, + StrideType ostride, StrideType x1stride, StrideType x2stride) { + HIP_KERNEL_LOOP(i, numel) { + const auto offsets = offset_calculator.get(i); + auto* out_row = out + offsets[0]; + const auto* x1_row = x1 + offsets[1]; + const auto* x2_row = x2 + offsets[2]; + + const T val0 = (x1_row[1 * x1stride] * x2_row[2 * x2stride] - + x1_row[2 * x1stride] * x2_row[1 * x2stride]); + + const T val1 = (x1_row[2 * x1stride] * x2_row[0 * x2stride] - + x1_row[0 * x1stride] * x2_row[2 * x2stride]); + + const T val2 = (x1_row[0 * x1stride] * x2_row[1 * x2stride] - + x1_row[1 * x1stride] * x2_row[0 * x2stride]); + + + out_row[0 * ostride] = val0; + out_row[1 * ostride] = val1; + out_row[2 * ostride] = val2; + } +} + +void launch_cross_kernel(const TensorIteratorBase& iter, int64_t ostride, + int64_t x1stride, int64_t x2stride) { + const auto N = iter.numel(); + auto offset_calculator = make_element_offset_calculator<3>(iter); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(N > 0 && N <= std::numeric_limits::max()); + int64_t grid = (N + num_threads() - 1) / num_threads(); + auto stream = c10::zoom::getCurrentZoomStream(); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.common_dtype(), "cross_zoom", [&] { + auto out = static_cast(iter.data_ptr(0)); + auto x1 = static_cast(iter.data_ptr(1)); + auto x2 = static_cast(iter.data_ptr(2)); + constexpr int64_t int_max = std::numeric_limits::max(); + if (ostride * 2 > int_max || x1stride * 2 > int_max || x2stride * 2 > int_max) { + cross_kernel<<>>( + N, out, x1, x2, offset_calculator, ostride, x1stride, x2stride); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + cross_kernel<<>>( + N, out, x1, x2, offset_calculator, + static_cast(ostride), + static_cast(x1stride), + static_cast(x2stride)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); +} + +void cross_impl(const Tensor& result, const Tensor& x1, const Tensor& x2, int64_t dim) { + const int64_t ostride = result.stride(dim); + const int64_t x1stride = x1.stride(dim); + const int64_t x2stride = x2.stride(dim); + + auto iter = TensorIteratorConfig() + .add_output(result) + .add_const_input(x1) + .add_const_input(x2) + .resize_outputs(false) + .declare_static_shape(result.sizes(), /*squash_dims=*/dim) + .build(); + + if (iter.numel() == 0) { + return; + } + + if (iter.can_use_32bit_indexing()) { + launch_cross_kernel(iter, ostride, x1stride, x2stride); + } else { + for (auto&& sub_iter: iter.with_32bit_indexing()) { + launch_cross_kernel(sub_iter, ostride, x1stride, x2stride); + } + } +} + +REGISTER_PRIVATEUSE1_DISPATCH(cross_stub, &cross_impl); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CumminmaxKernel.cu b/aten/src/ATen/native/zoom/CumminmaxKernel.cu new file mode 100644 index 0000000000000..5c3e3a6aa211f --- /dev/null +++ b/aten/src/ATen/native/zoom/CumminmaxKernel.cu @@ -0,0 +1,29 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +#include +#include + +#include +#include + +namespace at::native { + +void launch_cummax_zoom_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, + self.scalar_type(), "cummax_zoom", [&]() { + scalar_t init = self.is_floating_point() ? (-1*std::numeric_limits::infinity()) : std::numeric_limits::lowest(); + scan_dim_with_indices(self, values, indices, dim, init, std::greater_equal()); + }); +} + +void launch_cummin_zoom_kernel(const TensorBase& self, const TensorBase& values, const TensorBase& indices, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, + self.scalar_type(), "cummin_zoom", [&]() { + scalar_t init = self.is_floating_point() ? std::numeric_limits::infinity() : std::numeric_limits::max(); + scan_dim_with_indices(self, values, indices, dim, init, std::less_equal()); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CumprodKernel.cu b/aten/src/ATen/native/zoom/CumprodKernel.cu new file mode 100644 index 0000000000000..eaa48e306d479 --- /dev/null +++ b/aten/src/ATen/native/zoom/CumprodKernel.cu @@ -0,0 +1,23 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +#include +#include + +namespace at::native { + +void launch_cumprod_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, ScalarType::BFloat16, self.scalar_type(), "cumprod_zoom", [&]() { + scalar_t init = 1; + scan_dim( + self, + result, + dim, + init, + std::multiplies()); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/CumsumKernel.cu b/aten/src/ATen/native/zoom/CumsumKernel.cu new file mode 100644 index 0000000000000..41808fb8fae8a --- /dev/null +++ b/aten/src/ATen/native/zoom/CumsumKernel.cu @@ -0,0 +1,25 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +#include +#include + +namespace at::native { + +void launch_cumsum_zoom_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, ScalarType::BFloat16, + self.scalar_type(), "cumsum_zoom", + [&]() { + scalar_t init = 0; + scan_dim( + self, + result, + dim, + init, + std::plus()); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DepthwiseConv2d.cu b/aten/src/ATen/native/zoom/DepthwiseConv2d.cu new file mode 100644 index 0000000000000..1999c0f346017 --- /dev/null +++ b/aten/src/ATen/native/zoom/DepthwiseConv2d.cu @@ -0,0 +1,732 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native { +namespace { +using at::zoom::detail::HIP_NUM_THREADS; +using at::zoom::detail::GET_BLOCKS; + +template class PtrTraits = DefaultPtrTraits> +PackedTensorAccessor32 dummy_packed_accessor32() { + std::array zeros{}; + return {nullptr, zeros.data(), zeros.data()}; +} + +template +__global__ void +conv_depthwise2d_forward_kernel_generic( + const PackedTensorAccessor32 input, + PackedTensorAccessor32 output, + const PackedTensorAccessor32 weight, + const PackedTensorAccessor32 bias, + bool biasEnabled, + index_t totalElements, + const int outputChannels, + const int depthwiseMultiplier, + const int inputWidth, const int inputHeight, + const int outputWidth, const int outputHeight, + const int kernelWidth, const int kernelHeight, + const int strideWidth, const int strideHeight, + const int padWidth, const int padHeight, + const int dilationWidth, const int dilationHeight) { + using acc_t = at::acc_type; + + HIP_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) { + //calculate n,c,h,w indices, replacing modulos by divide and multiply add, + //result is same as would be in the code below + //const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth + //const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth + //const int h = (linearIndex / outputWidth) % outputHeight; + //const int w = linearIndex % outputWidth; + + int indtmp1 = linearIndex/outputWidth; + const int w = linearIndex - indtmp1 * outputWidth; + int indtmp2 = indtmp1/outputHeight; + const int h = indtmp1 - indtmp2 * outputHeight; + indtmp1 = indtmp2; + indtmp2 = indtmp1/outputChannels; + const int c = indtmp1 - indtmp2 * outputChannels; + const int n = indtmp2; + + int inputChannel = c; + int inputChannels = outputChannels; + if (depthwiseMultiplier !=1) { + inputChannel /= depthwiseMultiplier; + inputChannels /= depthwiseMultiplier; + } + + int weightOffset = c * kernelHeight * kernelWidth; + + // By precisely computing the filtering boundaries, we avoid repeating several + // expensive edge condition checks for every fetched item. If the input element is + // resident in L1, then the extra branches and comparisons would have been + // comparable in terms of cycles with the actual data fetch. Therefore computing + // boundaries ahead of the loop showed significant performance boost. + + int kHmin = 0, kHmax = kernelHeight, kWmin = 0, kWmax = kernelWidth; + + // Top + int h_in_min = -padHeight + h * strideHeight; + if (h_in_min < 0) { + kHmin = -h_in_min / dilationHeight; + if ((-h_in_min) % dilationHeight > 0) { + kHmin++; + } + } + + // Bottom + int h_in_max = h_in_min + (kernelHeight - 1) * dilationHeight - inputHeight + 1; + if (h_in_max >= 0) { + kHmax = kernelHeight - h_in_max / dilationHeight; + if (h_in_max % dilationHeight > 0) { + kHmax--; + } + } + + // Left + int w_in_min = -padWidth + w * strideWidth; + if (w_in_min < 0) { + kWmin = -w_in_min / dilationWidth; + if ((-w_in_min) % dilationWidth > 0) { + kWmin++; + } + } + + // Right + int w_in_max = w_in_min + (kernelWidth - 1) * dilationWidth - inputWidth + 1; + if (w_in_max >= 0) { + kWmax = kernelWidth - w_in_max / dilationWidth; + if (w_in_max % dilationWidth > 0) { + kWmax--; + } + } + + acc_t value = biasEnabled ? static_cast(bias.data()[c]) : acc_t(0); + const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth; + + for (int kH = kHmin; kH < kHmax; ++kH) { + const int h_in = -padHeight + h * strideHeight + kH * dilationHeight; + for (int kW = kWmin; kW < kWmax; ++kW) { + const int w_in = -padWidth + w * strideWidth + kW * dilationWidth; + const index_t offset = offset0 + h_in * inputWidth + w_in; + value += (static_cast(weight.data()[weightOffset + kH * kernelWidth + kW]) * + static_cast(input.data()[offset])); + } + } + output.data()[linearIndex] = static_cast(value); + } +} + +template +__global__ void +conv_depthwise2d_forward_kernel( + const PackedTensorAccessor32 input, + PackedTensorAccessor32 output, + const PackedTensorAccessor32 weight, + const PackedTensorAccessor32 bias, + bool biasEnabled, + index_t totalElements, + const int outputChannels, + const int depthwiseMultiplier, + const int inputWidth, const int inputHeight, + const int outputWidth, const int outputHeight, + const int kernelWidth, const int kernelHeight, + const int strideWidth, const int strideHeight, + const int padWidth, const int padHeight, + const int dilationWidth, const int dilationHeight) { + using acc_t = at::acc_type; + const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth; + const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight; + + HIP_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) { + //calculate n,c,h,w indices, replacing modulos by divide and multiply add, + //result is same as would be in the code below + //const int n = linearIndex / batchStride; //batchStride = outputChannels * outputHeight * outputWidth + //const int c = (linearIndex / channelStride) % outputChannels; //channelStride = outputHeight * outputWidth + //const int h = (linearIndex / outputWidth) % outputHeight; + //const int w = linearIndex % outputWidth; + + int indtmp1 = linearIndex/outputWidth; + const int w = linearIndex - indtmp1 * outputWidth; + int indtmp2 = indtmp1/outputHeight; + const int h = indtmp1 - indtmp2 * outputHeight; + indtmp1 = indtmp2; + indtmp2 = indtmp1/outputChannels; + const int c = indtmp1 - indtmp2 * outputChannels; + const int n = indtmp2; + + int inputChannel = c; + int inputChannels = outputChannels; + if (depthwiseMultiplier !=1) { + inputChannel /= depthwiseMultiplier; + inputChannels /= depthwiseMultiplier; + } + + int weightOffset = c * kernelHeight * kernelWidth; + + acc_t value = biasEnabled ? static_cast(bias.data()[c]) : acc_t(0); + const index_t offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth; + for (int kH = 0; kH < KH_LIMIT; ++kH) { + for (int kW = 0; kW < KW_LIMIT; ++kW) { + const int h_in = -padHeight + h * strideHeight + kH * dilationHeight; + const int w_in = -padWidth + w * strideWidth + kW * dilationWidth; + + if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) { + const index_t offset = offset0 + h_in * inputWidth + w_in; + value += (static_cast(weight.data()[weightOffset]) * + static_cast(input.data()[offset])); + } + ++weightOffset; + } + } + output.data()[linearIndex] = static_cast(value); + } +} + +template +__global__ void conv_depthwise2d_backward_kernel( + const PackedTensorAccessor32 grad_output, + PackedTensorAccessor32 grad_input, + const PackedTensorAccessor32 weight, + index_t totalElements, + const int inputChannels, + const int depthwiseMultiplier, + const int outputChannels, + const int inputWidth, const int inputHeight, + const int outputWidth, const int outputHeight, + const int kernelWidth, const int kernelHeight, + const int strideWidth, const int strideHeight, + const int padWidth, const int padHeight, + const int dilationWidth, const int dilationHeight) { + using acc_t = at::acc_type; + const int KW_LIMIT = (kSize != 0) ? kSize : kernelWidth; + const int KH_LIMIT = (kSize != 0) ? kSize : kernelHeight; + const int strideW = (stride != 0) ? stride : strideWidth; + const int strideH = (stride != 0) ? stride : strideHeight; + + HIP_KERNEL_LOOP_TYPE(linearIndex, totalElements, index_t) { + int indtmp1 = linearIndex/inputWidth; + const int w = linearIndex - indtmp1 * inputWidth; + int indtmp2 = indtmp1/inputHeight; + const int h = indtmp1 - indtmp2 * inputHeight; + indtmp1 = indtmp2; + indtmp2 = indtmp1/inputChannels; + const int c = indtmp1 - indtmp2 * inputChannels; + const int n = indtmp2; + + acc_t value(0); + + for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) { + int och = (c * depthwiseMultiplier) + multiplier; + int weightOffset = och * kernelHeight * kernelWidth; + for (int kh = 0; kh < KH_LIMIT; ++kh) { + #pragma unroll + for (int kw = 0; kw < KW_LIMIT; ++kw) { + int h_out = h + padHeight - kh * dilationHeight; + int w_out = w + padWidth - kw * dilationWidth; + if ((h_out % strideH == 0) && (w_out % strideW == 0)) { + h_out = h_out / strideH; + w_out = w_out / strideW; + + if ((h_out >= 0) && (h_out < outputHeight) + && (w_out >= 0) && (w_out < outputWidth)) { + + const int offset = ((n * outputChannels + och) * outputHeight + h_out) + * outputWidth + w_out; + value += (static_cast(weight.data()[weightOffset]) * + static_cast(grad_output.data()[offset])); + } + } + ++weightOffset; + } + } + } + grad_input.data()[linearIndex] = static_cast(value); + } +} + + +template +__global__ void conv_depthwise2d_grad_weight_kernel( + const PackedTensorAccessor32 grad_output, + const PackedTensorAccessor32 input, + PackedTensorAccessor32 grad_weight, + const int batchSize, + const int inputChannels, + const int kernelChannels, + const int depthwiseMultiplier, + const int inputWidth, const int inputHeight, + const int outputWidth, const int outputHeight, + const int kernelWidth, const int kernelHeight, + const int strideWidth, const int strideHeight, + const int padWidth, const int padHeight, + const int dilationWidth, const int dilationHeight) { + using acc_t = at::acc_type; + const int channelStride = kernelWidth * kernelHeight; + + // Each Block is responsible for accumulating over a permutation of + // (channels x kH x kW), use blockIdx to determine which one + int bidx = blockIdx.x; + int kW = bidx % kernelWidth; + int kH = (bidx / kernelWidth) % kernelHeight; + int ch = (bidx / channelStride); + + // Need to calculate which input channel is associated with this filter + // channel + int inputCh = ch / depthwiseMultiplier; + + acc_t grad(0); + + const int laneId = threadIdx.x % C10_WARP_SIZE; + const int batch = threadIdx.x / C10_WARP_SIZE; + const int nwarps = blockDim.x / C10_WARP_SIZE; + const int imageElements = outputWidth * outputHeight; + // Use warp per item. In the original kernel, a threadblock was used to sum over NHW. + // Here, we use a warp to sum values over HW dimension, and if batchSize is larger than the + // number of warps, a warp would loop over remaining batch items (e.g. if there are 8 warps, + // warp 0 would go over 0-8-16 etc image, warp 1 over 1-9-17 etc). Later in blockReduce, + // all the warps will be reduced anyway, thus the full reduction will be over NHW, like it + // should be. That allows to get rid of one modulo operation inside the loop (because n/batchIdx + // now does not have to be computed through modulo, you are just looping over it), and + // bring a nice speed-up. + for (int batchIdx = batch; batchIdx < batchSize; batchIdx += nwarps){ + // Warp-stride loop over elements in a batch item + for (index_t idx = laneId; idx < imageElements; idx += C10_WARP_SIZE) { + // Need to calculate the following: batch position, and offset into the grad_output + // in height, and width. We can intuit the corresponding position in the input from + // the other parameters we have + int go_w_offset = idx % outputWidth; + int go_h_offset = (idx / outputWidth); + + int i_w_offset = (go_w_offset * strideWidth) + (kW * dilationWidth) - padWidth; + int i_h_offset = (go_h_offset * strideHeight) + (kH * dilationHeight) - padHeight; + + if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < inputWidth && i_h_offset < inputHeight) { + int inputOffset = ((batchIdx * inputChannels + inputCh) * inputHeight + i_h_offset) * inputWidth + i_w_offset; + int outputOffset = ((batchIdx * kernelChannels + ch) * outputHeight ) * outputWidth + idx; + grad += (static_cast(input.data()[inputOffset]) * + static_cast(grad_output.data()[outputOffset])); + } + } + } + + // At this point each thread in the block has a local gradient, which we need to + // accumulate prior to writing the global value + extern __shared__ char smem[]; + acc_t* buf = reinterpret_cast(smem); + acc_t tval = zoom_utils::BlockReduceSum(grad, buf); + + // After reduction, first thread in the block has the gradient, so its responsible + // for writing it to grad_weight + if (threadIdx.x == 0) { + int weightOffset = kW + (kernelWidth * kH) + (kernelWidth * kernelHeight * ch); + grad_weight.data()[weightOffset] = static_cast(tval); + } +} + +void conv_depthwise2d_forward_out( + const Tensor &input, + const Tensor &output, + const Tensor &weight, + const Tensor &bias, + const int kW, const int kH, + const int dW, const int dH, + const int padW, const int padH, + const int dilationW, const int dilationH) { + // Only handle 4D Input Tensors for now + TORCH_CHECK(input.numel() > 0 && input.dim() == 4); + TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4); + TORCH_CHECK(output.is_contiguous()); + + auto in_sizes = input.sizes(); + auto w_sizes = weight.sizes(); + + // We assume that the input and weight Tensors are shaped properly by + // the caller, so we verify that here to some extent + + // Weight Tensor is shape (output_channels, 1, kH, kW) + TORCH_CHECK(w_sizes[1] == 1); + + // Input Tensor is shape (N, input_channels, H, W) + // We verify that the # of output_channels is a multiple of input_channels + TORCH_CHECK(w_sizes[0] % in_sizes[1] == 0); + + // Bias has same # of channels as output + const bool has_bias = bias.defined(); + TORCH_CHECK(!has_bias || (bias.dim() <= 1 && bias.numel() == w_sizes[0])); + + // Following the behavior of other THCUNN functions, we shape the output + // Tensor ourselves + int64_t height = in_sizes[2]; + int64_t width = in_sizes[3]; + int64_t outputChannels = w_sizes[0]; + auto out_sizes = conv_output_size(in_sizes, weight.sizes(), {padH, padW}, {dH, dW}, + {dilationH, dilationW}); + const auto outputWidth = out_sizes[3]; + const auto outputHeight = out_sizes[2]; + + resize_output(output, out_sizes); + + int64_t inputChannels = in_sizes[1]; + int64_t depthwiseMultiplier = outputChannels / inputChannels; + + // One thread per output value + TORCH_CHECK(canUse32BitIndexMath(input) && canUse32BitIndexMath(output)); + int32_t n = output.numel(); + int blocks = GET_BLOCKS(n); + dim3 grid(blocks); + dim3 block(HIP_NUM_THREADS); + const auto stream = c10::zoom::getCurrentZoomStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "conv_depthwise2d_forward_zoom", [&] { + // Create PackedTensorAccessor + // Kernel currently relies upon all the Tensors to be contiguous, but we made + // them contiguous above + const auto input_a = input.packed_accessor32(); + const auto weight_a = weight.packed_accessor32(); + const auto output_a = output.packed_accessor32(); + const auto bias_a = has_bias ? + bias.packed_accessor32() : + dummy_packed_accessor32(); + if (kW == 5 && kH == 5) { + conv_depthwise2d_forward_kernel<5> <<>>( + input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier, + width, height, outputWidth, outputHeight, + kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else if (kW == 3 && kH == 3) { + conv_depthwise2d_forward_kernel<3> <<>>( + input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier, + width, height, outputWidth, outputHeight, + kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else if (kW == 1 && kH == 1) { + conv_depthwise2d_forward_kernel<1> <<>>( + input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier, + width, height, outputWidth, outputHeight, + kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + conv_depthwise2d_forward_kernel_generic<<>>( + input_a, output_a, weight_a, bias_a, has_bias, n, outputChannels, depthwiseMultiplier, + width, height, outputWidth, outputHeight, + kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); +} + +void conv_depthwise2d_backward_out( + const Tensor &input, + const Tensor &grad_output, + const Tensor &grad_input, + const Tensor &weight, + const int kW, const int kH, + const int dW, const int dH, + const int padW, const int padH, + const int dilationW, const int dilationH) { + // Only handle 4D Input Tensors for now + TORCH_CHECK(input.numel() > 0 && input.dim() == 4); + TORCH_CHECK(weight.numel() > 0 && weight.dim() == 4); + TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4); + + // Minimal shape checking, as above + // Same # of elements in batch + TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]); + // Same # of filters as outputChannels + TORCH_CHECK(weight.sizes()[0] == grad_output.sizes()[1]); + + // Resize Grainput_a + auto in_sizes = input.sizes(); + resize_output(grad_input, in_sizes); + + int inputChannels = in_sizes[1]; + int height = in_sizes[2]; + int width = in_sizes[3]; + + auto gO_sizes = grad_output.sizes(); + int outputChannels = gO_sizes[1]; + int outputHeight = gO_sizes[2]; + int outputWidth = gO_sizes[3]; + + int depthwiseMultiplier = outputChannels / inputChannels; + + // Kernel currently relies upon all the Tensors to be contiguous + TORCH_CHECK(grad_output.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); + TORCH_CHECK(grad_input.is_contiguous()); + + // One thread per grainput_a value + TORCH_CHECK(canUse32BitIndexMath(grad_input) && + canUse32BitIndexMath(grad_output)); + int32_t n = grad_input.numel(); + int blocks = GET_BLOCKS(n); + dim3 grid(blocks); + dim3 block(HIP_NUM_THREADS); + const auto stream = c10::zoom::getCurrentZoomStream(); + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(), + "conv_depthwise2d_backward_zoom", [&] { + auto grad_output_a = grad_output.packed_accessor32(); + auto grad_input_a = grad_input.packed_accessor32(); + auto weight_a = weight.packed_accessor32(); + + if (kW == 3 && kH == 3) { + if (dW == 1 && dH == 1){ + conv_depthwise2d_backward_kernel<3, 1><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else if (dW == 2 && dH == 2) { + conv_depthwise2d_backward_kernel<3, 2><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + conv_depthwise2d_backward_kernel<3, 0><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } else if (kW == 1 && kH == 1) { + if (dW == 1 && dH == 1){ + conv_depthwise2d_backward_kernel<1, 1><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else if (dW == 2 && dH == 2) { + conv_depthwise2d_backward_kernel<1, 2><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + conv_depthwise2d_backward_kernel<1, 0><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } else if (dW == 1 && dH == 1) { + conv_depthwise2d_backward_kernel<0, 1><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else if (dW == 2 && dH == 2) { + conv_depthwise2d_backward_kernel<0, 2><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + conv_depthwise2d_backward_kernel<0, 0><<>>( + grad_output_a, grad_input_a, weight_a, n, inputChannels, depthwiseMultiplier, outputChannels, width, + height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); +} + +// Crude benchmarks suggest 256 is better than 512 and 1024 +// TODO: Autotune/use better heuristics, improve speed more. +int getGradParamsNumThreads(int batchSize) { + //warp per item in a batch, up to a maximum + constexpr int MAX_BLOCK_SIZE = 256; + return std::min(batchSize * at::zoom::warp_size(), MAX_BLOCK_SIZE); +} + +void conv_depthwise2d_grad_weight_out( + const Tensor &input, + const Tensor &grad_output, + const Tensor &grad_weight, + const int kW, const int kH, + const int dW, const int dH, + const int padW, const int padH, + const int dilationW, const int dilationH) { + // Only handle 4D Input Tensors for now + TORCH_CHECK(input.numel() > 0 && input.dim() == 4); + TORCH_CHECK(grad_output.numel() > 0 && grad_output.dim() == 4); + + // Minimal shape checking as above + // Same # of elements in batch + TORCH_CHECK(input.sizes()[0] == grad_output.sizes()[0]); + + auto in_sizes = input.sizes(); + int batchSize = in_sizes[0]; + int inputChannels = in_sizes[1]; + int height = in_sizes[2]; + int width = in_sizes[3]; + + auto gO_sizes = grad_output.sizes(); + int outputChannels = gO_sizes[1]; + int outputHeight = gO_sizes[2]; + int outputWidth = gO_sizes[3]; + + int depthwiseMultiplier = outputChannels / inputChannels; + + resize_output(grad_weight, {outputChannels, 1, kH, kW}); + + // Kernel currently relies upon all the Tensors to be contiguous + TORCH_CHECK(grad_output.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(grad_weight.is_contiguous()); + + // We parallelize so that each block computes a single value in grad_weight + TORCH_CHECK(canUse32BitIndexMath(input) && + canUse32BitIndexMath(grad_output)); + int blocks = outputChannels * kH * kW; + + // Make sure we have enough threads to perform the reduction, and use this number + // to create the shared memory size for the reduction + dim3 grid(blocks); + dim3 block(getGradParamsNumThreads(batchSize)); + const auto stream = c10::zoom::getCurrentZoomStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, grad_output.scalar_type(), + "conv_depthwise2d_grad_weight_zoom", [&] { + const auto grad_output_a = grad_output.packed_accessor32(); + const auto input_a = input.packed_accessor32(); + const auto grad_weight_a = grad_weight.packed_accessor32(); + using acc_t = at::acc_type; + int warp_size = at::zoom::warp_size(); + TORCH_INTERNAL_ASSERT(block.x % warp_size == 0); + int smem = (block.x / warp_size) * sizeof(acc_t); + conv_depthwise2d_grad_weight_kernel<<>>( + grad_output_a, input_a, grad_weight_a, batchSize, inputChannels, outputChannels, depthwiseMultiplier, + width, height, outputWidth, outputHeight, kW, kH, dW, dH, padW, padH, dilationW, dilationH); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +} // namespace (anonymous) + +const Tensor& conv_depthwise2d_zoom_out( + const Tensor &input_, + const Tensor &weight_, + IntArrayRef kernel_size, + const std::optional &bias_opt, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + const Tensor &out) { + TORCH_CHECK(kernel_size.size() == 2); + TORCH_CHECK(stride.size() == 2); + TORCH_CHECK(padding.size() == 2); + TORCH_CHECK(dilation.size() == 2); + + auto input = input_.expect_contiguous(); + auto weight = weight_.expect_contiguous(); + auto bias = [&] { + if (bias_opt.has_value() && bias_opt->defined()) { + return bias_opt->expect_contiguous(); + } + return c10::MaybeOwned::owned(std::in_place); + }(); + + conv_depthwise2d_forward_out( + *input, + out, + *weight, + *bias, + kernel_size[1], kernel_size[0], + stride[1], stride[0], + padding[1], padding[0], + dilation[1], dilation[0]); + return out; +} + +Tensor conv_depthwise2d_zoom( + const Tensor &input, + const Tensor &weight, + IntArrayRef kernel_size, + const std::optional &bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation) { + auto out = at::empty({0}, input.options()); + return conv_depthwise2d_zoom_out(input, weight, kernel_size, bias, + stride, padding, dilation, out); +} + +std::tuple conv_depthwise2d_backward_zoom_out( + const Tensor & grad_output_, + const Tensor & self_, + const Tensor & weight_, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + Tensor & grad_input, + Tensor & grad_weight) { + auto grad_output = grad_output_.expect_contiguous(); + + if (grad_weight.defined()) { + auto self = self_.expect_contiguous(); + conv_depthwise2d_grad_weight_out( + *self, *grad_output, grad_weight, + kernel_size[1], kernel_size[0], + stride[1], stride[0], + padding[1], padding[0], + dilation[1], dilation[0]); + } + + if (grad_input.defined()) { + auto weight = weight_.expect_contiguous(); + conv_depthwise2d_backward_out( + self_, *grad_output, grad_input, *weight, + kernel_size[1], kernel_size[0], + stride[1], stride[0], + padding[1], padding[0], + dilation[1], dilation[0]); + } + return std::forward_as_tuple(grad_input, grad_weight); +} + +std::tuple conv_depthwise2d_backward_zoom( + const Tensor& grad_output, + const Tensor& self, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + std::array output_mask) { + Tensor grad_input; + Tensor grad_weight; + + if (output_mask[0]) { + grad_input = at::empty({0}, grad_output.options()); + } + + if (output_mask[1]) { + grad_weight = at::empty({0}, grad_output.options()); + } + return conv_depthwise2d_backward_zoom_out( + grad_output, + self, + weight, + kernel_size, + stride, + padding, + dilation, + grad_input, + grad_weight); +} + +REGISTER_PRIVATEUSE1_DISPATCH(conv_depthwise2d_backward_stub, &conv_depthwise2d_backward_zoom); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DepthwiseConv3d.cu b/aten/src/ATen/native/zoom/DepthwiseConv3d.cu new file mode 100644 index 0000000000000..3d2cf1bb6cfd8 --- /dev/null +++ b/aten/src/ATen/native/zoom/DepthwiseConv3d.cu @@ -0,0 +1,706 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +#include +#include +#include + +namespace at::native { +namespace { + +template +__global__ void conv_depthwise3d_zoom_kernel( + const PackedTensorAccessor32 input, + PackedTensorAccessor32 output, + const PackedTensorAccessor32 kernel, + const scalar_t* bias, + int strideT, int strideH, int strideW, + int paddingT, int paddingH, int paddingW, + int dilationT_, int dilationH_, int dilationW_) +{ + const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2); + const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3); + const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4); + const int oC = output.size(1); + const int oT = output.size(2); + const int oH = output.size(3); + const int oW = output.size(4); + const int iC = input.size(1); + const int iT = input.size(2); + const int iH = input.size(3); + const int iW = input.size(4); + const int channel_multiplier = oC / iC; + const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_; + const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_; + const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_; + const int num_output = output.size(0) * output.stride(0); + + HIP_KERNEL_LOOP(index, num_output) { + const int out_col = index % oW; + const int out_row = (index / oW) % oH; + const int out_frame = (index / oW / oH) % oT; + const int out_channel = (index / oW / oH / oT) % oC; + const int batch = index / oW / oH / oT / oC; + + const int in_channel = out_channel / channel_multiplier; + + const int in_col_start = out_col * strideW - paddingW; + const int in_row_start = out_row * strideH - paddingH; + const int in_frame_start = out_frame * strideT - paddingT; + + accscalar_t sum = 0; + const scalar_t *kernel_ptr = kernel[out_channel].data(); + const scalar_t *input_ptr = + &input[batch][in_channel][in_frame_start][in_row_start][in_col_start]; + for (int k_frame = 0; k_frame < kT; ++k_frame) { + const int in_frame = in_frame_start + k_frame * dilationT; + for (int k_row = 0; k_row < kH; ++k_row) { + const int in_row = in_row_start + k_row * dilationH; + for (int k_col = 0; k_col < kW; ++k_col) { + const accscalar_t op1 = *(kernel_ptr++); + const int in_col = in_col_start + k_col * dilationW; + if (in_frame >= 0 && in_row >= 0 && in_col >= 0 && + in_frame < iT && in_row < iH && in_col < iW) { + sum += op1 * *(input_ptr); + } + input_ptr += dilationW; + } + input_ptr += iW * dilationH - kW * dilationW; + } + input_ptr += iW * (iH * dilationT - kH * dilationH); + } + if (bias != NULL) { + sum += bias[out_channel]; + } + + output[batch][out_channel][out_frame][out_row][out_col] = sum; + } +} + +template +__global__ void +conv_depthwise3d_zoom_backward_input_kernel( + const PackedTensorAccessor32 grad_output, + PackedTensorAccessor32 grad_input, + const PackedTensorAccessor32 kernel, + int strideT_, int strideH_, int strideW_, + int paddingT, int paddingH, int paddingW, + int dilationT_, int dilationH_, int dilationW_) { + const int kT = kKnownKernelT > 0 ? kKnownKernelT : kernel.size(2); + const int kH = kKnownKernelH > 0 ? kKnownKernelH : kernel.size(3); + const int kW = kKnownKernelW > 0 ? kKnownKernelW : kernel.size(4); + const int oC = grad_output.size(1); + const int oT = grad_output.size(2); + const int oH = grad_output.size(3); + const int oW = grad_output.size(4); + const int iC = grad_input.size(1); + const int iT = grad_input.size(2); + const int iH = grad_input.size(3); + const int iW = grad_input.size(4); + const int channel_multiplier = oC / iC; + const int dilationT = kKnownDilationT > 0 ? kKnownDilationT : dilationT_; + const int dilationH = kKnownDilationH > 0 ? kKnownDilationH : dilationH_; + const int dilationW = kKnownDilationW > 0 ? kKnownDilationW : dilationW_; + const int strideT = kKnownStrideT > 0 ? kKnownStrideT : strideT_; + const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_; + const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_; + const int num_input = grad_input.size(0) * grad_input.stride(0); + + HIP_KERNEL_LOOP(index, num_input) { + const int in_col = index % iW; + const int in_row = (index / iW) % iH; + const int in_frame = (index / iW / iH) % iT; + const int in_channel = (index / iW / iH / iT) % iC; + const int batch = index / iW / iH / iT / iC; + + const int out_col_end = in_col + paddingW; + const int out_row_end = in_row + paddingH; + const int out_frame_end = in_frame + paddingT; + + const scalar_t* kernel_ptr = kernel[in_channel * channel_multiplier].data(); + accscalar_t sum = 0; + + for (int k_chn = in_channel * channel_multiplier; + k_chn < (in_channel + 1) * channel_multiplier; + ++k_chn) { + const scalar_t* gout_ptr = grad_output[batch][k_chn].data(); + + for (int k_frame = 0; k_frame < kT; ++k_frame) { + const int out_frame_raw = out_frame_end - k_frame * dilationT; + const int out_frame = out_frame_raw / strideT; + for (int k_row = 0; k_row < kH; ++k_row) { + const int out_row_raw = out_row_end - k_row * dilationH; + const int out_row = out_row_raw / strideH; + for (int k_col = 0; k_col < kW; ++k_col) { + const accscalar_t op1 = *(kernel_ptr++); + const int out_col_raw = out_col_end - k_col * dilationW; + const int out_col = out_col_raw / strideW; + + const int out_offs = (out_frame * oH + out_row) * oW + out_col; + + accscalar_t op2 = (accscalar_t)0; + if (out_col >= 0 && out_row >= 0 && out_frame >= 0 && + out_col < oW && out_row < oH && out_frame < oT) { + op2 = *(gout_ptr + out_offs); + } + if (out_frame * strideT == out_frame_raw && + out_row * strideH == out_row_raw && + out_col * strideW == out_col_raw) { + sum += op1 * op2; + } + } + } + } + } + + grad_input[batch][in_channel][in_frame][in_row][in_col] = sum; + } +} + +template +__global__ void +conv_depthwise3d_zoom_backward_weight_kernel( + const PackedTensorAccessor32 grad_output, + const PackedTensorAccessor32 input, + PackedTensorAccessor32 grad_kernel, + int strideT, int strideH_, int strideW_, + int paddingT, int paddingH, int paddingW, + int dilationT, int dilationH, int dilationW) { + const int kC = grad_kernel.size(0); + const int kT = grad_kernel.size(2); + const int kH = grad_kernel.size(3); + const int kW = grad_kernel.size(4); + + const int strideH = kKnownStrideH > 0 ? kKnownStrideH : strideH_; + const int strideW = kKnownStrideW > 0 ? kKnownStrideW : strideW_; + + const int k_col = blockIdx.x % kW; + const int k_row = (blockIdx.x / kW) % kH; + const int k_frame = (blockIdx.x / kW / kH) % kT; + const int k_channel = blockIdx.x / kW / kH / kT; + scalar_t *result = &grad_kernel[k_channel][0][k_frame][k_row][k_col]; + + const int oT = grad_output.size(2); + const int oH = grad_output.size(3); + const int oW = grad_output.size(4); + const int iT = input.size(2); + const int iH = input.size(3); + const int iW = input.size(4); + const int channel_multiplier = grad_output.size(1) / input.size(1); + const int in_channel = k_channel / channel_multiplier; + + extern __shared__ int sdata_raw[]; + scalar_t* sdata = reinterpret_cast(sdata_raw); + + if (k_channel >= kC) { + return; + } + + const int laneid = threadIdx.x % C10_WARP_SIZE; + const int warpid = threadIdx.x / C10_WARP_SIZE; + const int nwarps = blockDim.x / C10_WARP_SIZE; + + accscalar_t grad = 0; + int batch = warpid / oT; + int gout_frame = warpid - batch * oT; + for (int outer_pos = warpid; outer_pos < input.size(0) * oT; + outer_pos += nwarps, gout_frame += nwarps) { + while (gout_frame >= oT) { gout_frame -= oT; batch ++; } + + const int in_frame = (gout_frame * strideT) + (k_frame * dilationT) - paddingT; + + if (in_frame < 0 || in_frame >= iT) { + continue; + } + + const scalar_t* gout_ptr = grad_output[batch][k_channel][gout_frame].data() + laneid; + const scalar_t* input_ptr = input[batch][in_channel][in_frame].data(); + + int gout_row = laneid / oW; + int gout_col = laneid - gout_row * oW; + + for (; gout_row < oH; ) { + const accscalar_t op1 = *(gout_ptr); + gout_ptr += C10_WARP_SIZE; + + const int in_col = (gout_col * strideW) + (k_col * dilationW) - paddingW; + const int in_row = (gout_row * strideH) + (k_row * dilationH) - paddingH; + const int in_pos = in_row * iW + in_col; + + accscalar_t op2 = (accscalar_t)0; + if (in_col >= 0 && in_col < iW && in_row >= 0 && in_row < iH) { + op2 = *(input_ptr + in_pos); + } + + gout_col += C10_WARP_SIZE; + while (gout_col >= oW) { + gout_col -= oW; gout_row ++; + } + + grad += op1 * op2; + } + } + + sdata[threadIdx.x] = grad; + __syncthreads(); + + ZOOM_KERNEL_ASSERT(__popc(blockDim.x) == 1); +#pragma unroll + for (int i = blockDim.x / 2; i >= 1; i >>= 1) { + if (threadIdx.x < i) { + sdata[threadIdx.x] += sdata[threadIdx.x + i]; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + *result = sdata[0]; + } +} + +template +void conv_depthwise_shape_check( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + const Tensor& grad_output, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation) { + TORCH_CHECK(kernel_size.size() == dim, + "kernel size length should be ", dim, ", but got ", kernel_size.size()); + TORCH_CHECK(stride.size() == dim, + "stride length should be ", dim, ", but got ", stride.size()); + TORCH_CHECK(padding.size() == dim, + "padding length should be ", dim, ", but got ", padding.size()); + TORCH_CHECK(dilation.size() == dim, + "dilation length should be ", dim, ", but got ", dilation.size()); + + TORCH_CHECK(weight.defined(), + "Weight must be defined."); + TORCH_CHECK(input.dim() == dim + 1 || input.dim() == dim + 2, + "Input dimension should be ", + dim + 1, "D or ", dim + 2, "D, got ", + input.dim(), "D"); + TORCH_CHECK(weight.dim() == dim + 2, + "Weight dimension should be ", dim + 2, "D, got ", weight.dim(), "D"); + TORCH_CHECK(weight.size(1) == 1, + "Depthwise weight should have in_channels=1, got ", weight.size(1)); + TORCH_CHECK(weight.size(0) % input.size(-dim - 1) == 0, + "Depthwise out channels should be a multiple of in channels, got ", + weight.size(0), " and ", input.size(-dim - 1)); + for (int i = 0; i < dim; ++i) { + TORCH_CHECK(weight.size(i + 2) == kernel_size[i], + "kernel size and weight size mismatch, got ", + kernel_size, " and ", weight.sizes()); + TORCH_CHECK(stride[i] >= 1, + "stride should be at least 1, got ", stride); + TORCH_CHECK(padding[i] >= 0, + "padding should be non-negative, got ", padding); + TORCH_CHECK(dilation[i] >= 1, + "dilation should be at least 1, got ", dilation); + } + + if (bias.defined()) { + TORCH_CHECK(bias.dim() == 1, + "Bias should be 1D tensor, got ", bias.dim(), "D"); + TORCH_CHECK(bias.size(0) == weight.size(0), + "Bias length should be equal to out_channels, got ", + bias.size(0), " and ", weight.size(0)); + } + + if (grad_output.defined()) { + auto expected_output_size = conv_output_size(input.sizes(), weight.sizes(), + padding, stride, dilation); + TORCH_CHECK(static_cast(grad_output.dim()) == expected_output_size.size(), + "Expect grad_output to be ", + expected_output_size.size(), "D, got ", + grad_output.dim(), "D."); + for (int i = 0; i < grad_output.dim(); ++i) { + TORCH_CHECK(grad_output.size(i) == expected_output_size[i], + "Expect grad_output to be of same shape as output, got ", + grad_output.size(i), " and ", expected_output_size[i], + " at dimension ", i); + } + } +} + +} + +#define NODEF_OR_EQUAL(x, y) ((y) < 0 || (x) == (y)) +#define NODEF_OR_EQUAL_3(x, y1, y2, y3) \ + (NODEF_OR_EQUAL(x[0], y1) && \ + NODEF_OR_EQUAL(x[1], y2) && \ + NODEF_OR_EQUAL(x[2], y3)) + +#define DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(kt, kh, kw, dilt, dilh, dilw) \ + if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) && \ + NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw))) { \ + using accscalar_t = acc_type; \ + conv_depthwise3d_zoom_kernel \ + \ + <<>>( \ + input_.packed_accessor32(), \ + output_.packed_accessor32(), \ + weight_.packed_accessor32(), \ + bias_ptr, \ + stride[0], stride[1], stride[2], \ + padding[0], padding[1], padding[2], \ + dilation[0], dilation[1], dilation[2]); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + } else + +#define DWCONV3D_FORWARD_DISPATCH_OTHERS \ + { \ + using accscalar_t = acc_type; \ + conv_depthwise3d_zoom_kernel \ + \ + <<>>( \ + input_.packed_accessor32(), \ + output_.packed_accessor32(), \ + weight_.packed_accessor32(), \ + bias_ptr, \ + stride[0], stride[1], stride[2], \ + padding[0], padding[1], padding[2], \ + dilation[0], dilation[1], dilation[2]); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + } + +Tensor conv_depthwise3d_zoom( + const Tensor& input, + const Tensor& weight, + IntArrayRef kernel_size, const std::optional& bias_opt, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + TORCH_CHECK(input.device() == weight.device(), "expects input and weight tensors to be on the same device."); + if (bias.defined()) { + TORCH_CHECK(input.device() == bias.device(), "expects input and bias tensors to be on the same device."); + } + + conv_depthwise_shape_check<3>(input, weight, bias, Tensor() /* undefined */, + kernel_size, stride, padding, dilation); + + Tensor input_ = input.contiguous(); + + if (input.dim() == 4 /* no batch */) { + input_ = input.unsqueeze(0); + } + + auto output_size = conv_output_size(input_.sizes(), weight.sizes(), + padding, stride, dilation); + for (size_t i = 0; i < output_size.size(); ++i) { + TORCH_CHECK(output_size[i] > 0, + "Output size should be positive, got ", output_size[i], " at dim ", i); + } + Tensor output = at::empty(output_size, input.options()); + Tensor output_ = output; + Tensor weight_ = weight.contiguous(); + Tensor bias_ = bias.defined() ? bias.contiguous() : bias; + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "conv_depthwise3d", + [&]{ + int64_t num_outputs = output_.numel(); + int64_t block = 256; + int64_t grid = std::min((num_outputs - 1) / block + 1, (int64_t)65536); + int64_t smem = 0; + + const scalar_t* bias_ptr = + bias_.defined() ? bias_.const_data_ptr() : NULL; + + // Range check to avoid overflow in zoom kernels. + TORCH_CHECK(input_.numel() <= std::numeric_limits::max(), + "Input tensor is too large."); + TORCH_CHECK(output_.numel() <= std::numeric_limits::max(), + "Output tensor is too large."); + TORCH_CHECK(weight_.numel() <= std::numeric_limits::max(), + "Weight tensor is too large."); + for (int i = 0; i < 3; ++i) { + TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits::max(), + "Padded input tensor is too large."); + } + + DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(3, 3, 3, 1, 1, 1) + DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION(-1, -1, -1, 1, 1, 1) + DWCONV3D_FORWARD_DISPATCH_OTHERS + } + ); + + return output; +} + +#undef DWCONV3D_FORWARD_DISPATCH_SPECIALIZATION +#undef DWCONV3D_FORWARD_DISPATCH_OTHERS + +#define DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( \ + kt, kh, kw, dilt, dilh, dilw, dt, dh, dw) \ + if (NODEF_OR_EQUAL_3(kernel_size, (kt), (kh), (kw)) && \ + NODEF_OR_EQUAL_3(dilation, (dilt), (dilh), (dilw)) && \ + NODEF_OR_EQUAL_3(stride, (dt), (dh), (dw))) { \ + using accscalar_t = acc_type; \ + conv_depthwise3d_zoom_backward_input_kernel \ + \ + <<>>( \ + grad_output_.packed_accessor32(), \ + grad_input_.packed_accessor32(), \ + weight_.packed_accessor32(), \ + stride[0], stride[1], stride[2], \ + padding[0], padding[1], padding[2], \ + dilation[0], dilation[1], dilation[2]); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + } else + +#define DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS \ + { \ + using accscalar_t = acc_type; \ + conv_depthwise3d_zoom_backward_input_kernel \ + \ + <<>>( \ + grad_output_.packed_accessor32(), \ + grad_input_.packed_accessor32(), \ + weight_.packed_accessor32(), \ + stride[0], stride[1], stride[2], \ + padding[0], padding[1], padding[2], \ + dilation[0], dilation[1], dilation[2]); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + } + +#define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(dh, dw) \ + if (NODEF_OR_EQUAL_3(stride, -1, (dh), (dw))) { \ + using accscalar_t = acc_type; \ + conv_depthwise3d_zoom_backward_weight_kernel \ + \ + <<>>( \ + grad_output_.packed_accessor32(), \ + input_.packed_accessor32(), \ + grad_weight.packed_accessor32(), \ + stride[0], stride[1], stride[2], \ + padding[0], padding[1], padding[2], \ + dilation[0], dilation[1], dilation[2]); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + } else + +#define DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS \ + { \ + using accscalar_t = acc_type; \ + conv_depthwise3d_zoom_backward_weight_kernel \ + \ + <<>>( \ + grad_output_.packed_accessor32(), \ + input_.packed_accessor32(), \ + grad_weight.packed_accessor32(), \ + stride[0], stride[1], stride[2], \ + padding[0], padding[1], padding[2], \ + dilation[0], dilation[1], dilation[2]); \ + C10_ZOOM_KERNEL_LAUNCH_CHECK(); \ + } + +std::tuple _depthwise_3d_backward_zoom_out( + Tensor& grad_input, + Tensor& grad_weight, + Tensor& grad_bias, + const Tensor& grad_output, + const Tensor& input, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + const std::array output_mask) +{ + + TORCH_CHECK(grad_output.device() == input.device() && + input.device() == weight.device(), + "expects input, weight and grad_output to be on the same device."); + conv_depthwise_shape_check<3>( + input, weight, Tensor() /* undefined */, grad_output, + kernel_size, stride, padding, dilation); + + const Tensor grad_output_ = grad_output.contiguous(); + + Tensor grad_input_ = + (output_mask[0] ? grad_input + : Tensor()); + + if (output_mask[0]) { + const Tensor weight_ = weight.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + grad_output.scalar_type(), + "conv_depthwise3d", + [&] { + int64_t num_inputs = grad_input_.numel(); + int64_t block = 256; + int64_t grid = std::min((num_inputs - 1) / block + 1, (int64_t)65536); + + // Range check to avoid overflow in zoom kernels. + TORCH_CHECK(grad_input_.numel() <= std::numeric_limits::max(), + "Input tensor is too large."); + TORCH_CHECK(grad_output_.numel() <= std::numeric_limits::max(), + "Output tensor is too large."); + TORCH_CHECK(weight_.numel() <= std::numeric_limits::max(), + "Weight tensor is too large."); + for (int i = 0; i < 3; ++i) { + TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= std::numeric_limits::max(), + "Padded input tensor is too large."); + } + + DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( + 3, 3, 3, 1, 1, 1, 1, 1, 1) + DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( + 3, 3, 3, 1, 1, 1, -1, -1, -1) + DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( + 3, 3, 3, -1, -1, -1, 1, 1, 1) + DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION( + 3, 3, 3, -1, -1, -1, -1, -1, -1) + DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS + } + ); + } + + if (output_mask[1]) { + const Tensor input_ = input.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + grad_output.scalar_type(), + "conv_depthwise3d", + [&] { + int64_t grid = grad_weight.numel(); + int64_t block = 256; + int64_t smem = sizeof(scalar_t) * block; + + const int64_t int_max = std::numeric_limits::max(); + TORCH_CHECK(grad_input_.numel() <= int_max, + "Input tensor is too large."); + TORCH_CHECK(grad_output_.numel() <= int_max, + "Output tensor is too large."); + TORCH_CHECK(weight.numel() <= int_max, + "Weight tensor is too large."); + for (int i = 0; i < 3; ++i) { + TORCH_CHECK(padding[i] * 2 + input.size(i + 2) <= int_max, + "Padded input tensor is too large."); + } + int64_t warp_size = at::zoom::warp_size(); + TORCH_CHECK(grad_output_.size(0) * grad_output_.size(2) < int_max - block / warp_size && + grad_output_.size(3) <= int_max - warp_size && + grad_output_.size(4) <= int_max - warp_size, + "Output size is too large."); + + DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(1, 1) + DWCONV3D_BACKWARD_WEIGHT_DISPATCH_SPECIALIZATION(2, 2) + DWCONV3D_BACKWARD_WEIGHT_DISPATCH_OTHERS + } + ); + } + + if (output_mask[2]) { + grad_bias = grad_output.sum({0, 2, 3, 4}); + } + + return std::tie(grad_input, grad_weight, grad_bias); + +} + + +std::tuple conv_depthwise3d_backward_zoom_out(const Tensor& grad_output, + const Tensor& input, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + Tensor& grad_input, + Tensor& grad_weight, + Tensor& grad_bias) { + if (grad_weight.defined()) { + grad_weight.resize_(weight.sizes()); + grad_weight.zero_(); + } + + return _depthwise_3d_backward_zoom_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + kernel_size, + stride, + padding, + dilation, + {true,true,true}); +} + +std::tuple conv_depthwise3d_backward_zoom( + const Tensor& grad_output, + const Tensor& input, + const Tensor& weight, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + const std::array output_mask) { + + auto options = grad_output.options(); + Tensor grad_input = + (output_mask[0] ? at::empty(input.sizes(), options) : Tensor()); + Tensor grad_weight = + (output_mask[1] ? at::empty(weight.sizes(), options) : Tensor()); + Tensor grad_bias; /* undefined temporarily */ + + return _depthwise_3d_backward_zoom_out( + grad_input, + grad_weight, + grad_bias, + grad_output, + input, + weight, + kernel_size, + stride, + padding, + dilation, + output_mask + ); + +} + +REGISTER_PRIVATEUSE1_DISPATCH(conv_depthwise3d_backward_stub, &conv_depthwise3d_backward_zoom); + +#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_SPECIALIZATION +#undef DWCONV3D_BACKWARD_INPUT_DISPATCH_OTHERS + +#undef NODEF_OR_EQUAL_3 +#undef NODEF_OR_EQUAL + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DeviceSqrt.cuh b/aten/src/ATen/native/zoom/DeviceSqrt.cuh new file mode 100644 index 0000000000000..d5833a9882fd8 --- /dev/null +++ b/aten/src/ATen/native/zoom/DeviceSqrt.cuh @@ -0,0 +1,18 @@ +#pragma once + +namespace at { namespace native { +// take these out when ROCm implements std:: math functions +#include +template +static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); + +template <> +__forceinline__ __device__ float device_sqrt(float val) { + return ::sqrtf(val); +} + +template <> +__forceinline__ __device__ double device_sqrt(double val) { + return ::sqrt(val); +} +}} \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/DilatedMaxPool2d.cu b/aten/src/ATen/native/zoom/DilatedMaxPool2d.cu new file mode 100644 index 0000000000000..5484d357e7ba3 --- /dev/null +++ b/aten/src/ATen/native/zoom/DilatedMaxPool2d.cu @@ -0,0 +1,563 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::native { +namespace { + +__device__ inline int min(int a, int b) { + return a <= b ? a : b; +} + +#define HIP_MAX_THREADS 1024 // this is safe, in reality 256 is our limit + +#define BLOCK_STRIDE 2 // increasing block_stride to lower # of blocks launched + +static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) { + return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1; +} + +static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) { + return min((size + pad) / stride + 1, pooled_size); +} + +// kernels borrowed from Caffe +template +__global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom_data, + const int64_t channels, const int64_t height, + const int64_t width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, scalar_t* top_data, + int64_t* top_mask) { + HIP_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_h; + int wstart = pw * stride_w - pad_w; + int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height); + int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width); + while(hstart < 0) + hstart += dilation_h; + while(wstart < 0) + wstart += dilation_w; + accscalar_t maxval = at::numeric_limits::lower_bound(); // -Infinity + int maxidx = hstart * width + wstart; + const scalar_t* btm_data = bottom_data + (n * channels + c) * height * width; + for (int h = hstart; h < hend; h += dilation_h) { + for (int w = wstart; w < wend; w += dilation_w) { + scalar_t val = btm_data[h * width + w]; + if ((static_cast(val) > maxval) || at::_isnan(val)) { + maxidx = h * width + w; + maxval = static_cast(val); + } + } + } + top_data[index] = static_cast(maxval); + top_mask[index] = maxidx; + } +} + +template +C10_LAUNCH_BOUNDS_1(HIP_MAX_THREADS) +__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch, + const int64_t channels, const int64_t height, + const int64_t width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int in_stride_n, const int in_stride_c, + const int in_stride_h, const int in_stride_w, + const int kernel_stride_C, const int kernel_size_C, + scalar_t* top_data, int64_t* top_mask) { + extern __shared__ int smem[]; + int *out_mask_cached = smem; + scalar_t *out_cached = reinterpret_cast(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]); + + // flattening cta for pre-computation & smem initialization; + int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); + int block_size = blockDim.x * blockDim.y * blockDim.z; + + // use shared memory to store temporary output value. This is simply to + // reduce register usage. + for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) { + out_cached[i] = at::numeric_limits::lower_bound(); + out_mask_cached[i] = 0; + } + + __syncthreads(); + + int batch_id = blockIdx.x % nbatch; + int channel_id = blockIdx.x / nbatch; + int channel_offset = threadIdx.x + channel_id * blockDim.x; + + top_data = top_data + batch_id * pooled_height * pooled_width * channels; + top_mask = top_mask + batch_id * pooled_height * pooled_width * channels; + bottom_data = bottom_data + batch_id * in_stride_n; + + out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x]; + out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x]; + + int oH = (pooled_height + gridDim.z-1) / gridDim.z; + int oW = (pooled_width + gridDim.y-1) / gridDim.y; + int ostartH = threadIdx.z + blockIdx.z*oH; + int oendH = ::min(ostartH+oH, pooled_height); + int ostartW = threadIdx.y + blockIdx.y*oW; + int oendW = ::min(ostartW+oW, pooled_width); + + for (int oh = ostartH; oh < oendH; oh+=blockDim.z) { + int hstart = oh * stride_h - pad_h; + int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height); + for (int ow = ostartW; ow < oendW; ow+=blockDim.y) { + int wstart = ow * stride_w - pad_w; + int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width); + while(hstart < 0) + hstart += dilation_h; + while(wstart < 0) + wstart += dilation_w; + for (int ih = hstart; ih < hend; ih += dilation_h) { + for (int iw = wstart; iw < wend; iw += dilation_w) { + int cached_index = threadIdx.x; + const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w; + for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { + scalar_t val = ptr_input[c*in_stride_c]; + if ((static_cast(val) > out_cached[cached_index]) || at::_isnan(val)) { + out_cached[cached_index] = static_cast(val); + out_mask_cached[cached_index] = ih * width + iw; + } + cached_index += blockDim.x; + } + } + } + scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels; + int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels; + + int cached_index = threadIdx.x; + for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) { + ptr_output_data[c] = out_cached[cached_index]; + ptr_output_mask[c] = out_mask_cached[cached_index]; + out_cached[cached_index] = at::numeric_limits::lower_bound(); + out_mask_cached[cached_index] = 0; + cached_index += blockDim.x; + } + } + } +} + + +static const int BLOCK_THREADS = 256; + +template +C10_LAUNCH_BOUNDS_2(BLOCK_THREADS, 4) +__global__ void max_pool_backward_nchw(const scalar_t* top_diff, + const int64_t* top_mask, const int num, const int64_t channels, + const int64_t height, const int64_t width, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + scalar_t* bottom_diff) { + HIP_KERNEL_LOOP(index, height*width) { + int h = index / width; + int w = index - h * width; + int phstart = p_start(h, pad_h, kernel_h, dilation_h, stride_h); + int phend = p_end(h, pad_h, pooled_height, stride_h); + int pwstart = p_start(w, pad_w, kernel_w, dilation_w, stride_w); + int pwend = p_end(w, pad_w, pooled_width, stride_w); + for (int n = blockIdx.y; n < num; n += gridDim.y) { + for (int c = blockIdx.z; c < channels; c+= gridDim.z) { + accscalar_t gradient = accscalar_t(0); + int offset = (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (top_mask[ph * pooled_width + pw + offset] == h * width + w) { + gradient += static_cast(top_diff[ph * pooled_width + pw + offset]); + } + } + } + bottom_diff[(n*channels+c)*height*width+index] = static_cast(gradient); + } + } + } +} + +template +C10_LAUNCH_BOUNDS_1(HIP_MAX_THREADS) +__global__ void max_pool_backward_nhwc(const scalar_t* top_diff, + const int64_t* top_mask, const int nbatch, const int64_t channels, + const int64_t height, const int64_t width, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int out_stride_c, const int out_stride_h, const int out_stride_w, + const int kernel_stride_C, const int kernel_size_C, + scalar_t* bottom_diff) { + extern __shared__ int smem[]; + accscalar_t *out_cached = reinterpret_cast(smem); + + int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); + int block_size = blockDim.x * blockDim.y * blockDim.z; + + int batch_id = blockIdx.x % nbatch; + int channel_id = blockIdx.x / nbatch; + int channel_offset = threadIdx.x + channel_id * blockDim.x; + + for (int i = thread_id; i < kernel_size_C*blockDim.x*blockDim.y*blockDim.z; i+= block_size) { + out_cached[i] = accscalar_t(0.0); + } + + __syncthreads(); + + out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x]; + + bottom_diff = bottom_diff + batch_id * height * width * channels; + top_mask = top_mask + batch_id * pooled_height * pooled_width * channels; + top_diff = top_diff + batch_id * pooled_height * pooled_width * channels; + + int iH = (height + gridDim.z-1) / gridDim.z; + int iW = (width + gridDim.y-1) / gridDim.y; + int istartH = threadIdx.z + blockIdx.z*iH; + int iendH = ::min(static_cast(istartH)+iH, height); + int istartW = threadIdx.y + blockIdx.y*iW; + int iendW = ::min(static_cast(istartW)+iW, width); + + for (int ih = istartH; ih < iendH; ih+=blockDim.z) { + int phstart = p_start(ih, pad_h, kernel_h, dilation_h, stride_h); + int phend = p_end(ih, pad_h, pooled_height, stride_h); + for (int iw = istartW; iw < iendW; iw+=blockDim.y) { + int pwstart = p_start(iw, pad_w, kernel_w, dilation_w, stride_w); + int pwend = p_end(iw, pad_w, pooled_width, stride_w); + int index_shift = ih * width + iw; + if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) { + for(int oh = phstart; oh < phend; ++oh) { + for(int ow = pwstart; ow < pwend; ++ow) { + int cached_index = threadIdx.x; + const int64_t* ptr_top_mask = top_mask + oh * out_stride_h + ow * out_stride_w; + for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { + if (ptr_top_mask[c*out_stride_c] == index_shift) { + out_cached[cached_index] += + static_cast(top_diff[oh * out_stride_h + ow * out_stride_w + c*out_stride_c]); + } + cached_index += blockDim.x; + } + } + } + scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels; + int cached_index = threadIdx.x; + for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { + ptr_bottom_diff[c] = static_cast(out_cached[cached_index]); + out_cached[cached_index] = accscalar_t(0.0); + cached_index += blockDim.x; + } + } else { + const int64_t* ptr_top_mask = top_mask + phstart * out_stride_h + pwstart * out_stride_w; + scalar_t *ptr_bottom_diff = bottom_diff + index_shift * channels; + int cached_index = threadIdx.x; + for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) { + if (ptr_top_mask[c*out_stride_c] == index_shift) { + ptr_bottom_diff[c] = + static_cast(top_diff[phstart * out_stride_h + pwstart * out_stride_w + c*out_stride_c]); + } + cached_index += blockDim.x; + } + } + } + } +} + +} // namespace + +TORCH_IMPL_FUNC(max_pool2d_with_indices_out_zoom) +(const Tensor& input_, +IntArrayRef kernel_size, +IntArrayRef stride, +IntArrayRef padding, +IntArrayRef dilation, +bool ceil_mode, +const Tensor& output, +const Tensor& indices) { + NoNamesGuard guard; + + TensorArg output_arg{ output, "output", 1 }; + TensorArg indices_arg{ indices, "indices", 2 }; + TensorArg input_arg{ input_, "input_", 3 }; + + checkAllSameGPU(__func__, {output_arg, indices_arg, input_arg}); + if (output.numel() == 0) { + return; + } + + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); + + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dH : safe_downcast(stride[1]); + + const int padH = safe_downcast(padding[0]); + const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const int dilationH = safe_downcast(dilation[0]); + const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast(dilation[1]); + + const auto memory_format = input_.suggest_memory_format(); + + const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1; + const int64_t nInputPlane = input_.size(-3); + const int64_t inputHeight = input_.size(-2); + const int64_t inputWidth = input_.size(-1); + + const int64_t outputHeight = output.size(-2); + const int64_t outputWidth = output.size(-1); + + Tensor input = input_.contiguous(memory_format); + + const int64_t in_stride_n = input_.ndimension() == 4 ? input.stride(-4) : 0; + const int64_t in_stride_c = input.stride(-3); + const int64_t in_stride_h = input.stride(-2); + const int64_t in_stride_w = input.stride(-1); + + const int count = safe_downcast(output.numel()); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "max_pool2d_with_indices_out_zoom_frame", + [&] { + using accscalar_t = acc_type; + + scalar_t *output_data = output.mutable_data_ptr(); + const scalar_t *input_data = input.const_data_ptr(); + int64_t *indices_data = indices.mutable_data_ptr(); + + switch (memory_format) { + case MemoryFormat::ChannelsLast: { + const int max_threads = std::min( + at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock, HIP_MAX_THREADS); + int* maxThreadsDim = at::zoom::getCurrentDeviceProperties()->maxThreadsDim; + int block_x = std::min( + maxThreadsDim[0], std::min(lastPow2(nInputPlane), at::zoom::warp_size())); + int block_y = std::min( + maxThreadsDim[1], std::min(lastPow2(outputWidth), max_threads / block_x)); + int block_z = std::min( + maxThreadsDim[2], std::min(lastPow2(outputHeight), max_threads / block_x / block_y)); + block_x = std::min( + maxThreadsDim[0], std::min(lastPow2(nInputPlane), max_threads / block_y / block_z)); + const dim3 block(block_x, block_y, block_z); + + int kernel_stride_C = ceil_div( + safe_downcast(nInputPlane), block_x * 4); + int kernel_size_C = ceil_div( + safe_downcast(nInputPlane), block_x * kernel_stride_C); + + int grid_x = nbatch*kernel_stride_C; + int grid_y = std::min( + at::zoom::getCurrentDeviceProperties()->maxGridSize[1], + ceil_div(safe_downcast(outputWidth), block_y*BLOCK_STRIDE)); + int grid_z = std::min( + at::zoom::getCurrentDeviceProperties()->maxGridSize[2], + ceil_div(safe_downcast(outputHeight), block_z*BLOCK_STRIDE)); + const dim3 grid(grid_x, grid_y, grid_z); + + size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t)); + AT_ASSERT(shmem_size <= at::zoom::getCurrentDeviceProperties()->sharedMemPerBlock); + + max_pool_forward_nhwc + <<>>( + input_data, nbatch, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + kH, kW, dH, dW, padH, padW, dilationH, dilationW, + in_stride_n, in_stride_c, + in_stride_h, in_stride_w, + kernel_stride_C, kernel_size_C, + output_data, indices_data); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + case MemoryFormat::Contiguous: { + const int num_threads = std::min(at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock, + BLOCK_THREADS); + max_pool_forward_nchw + <<>>( + count, input_data, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + kH, kW, dH, dW, padH, padW, dilationH, dilationW, + output_data, indices_data); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } + } + ); +} + +TORCH_IMPL_FUNC(max_pool2d_with_indices_backward_out_zoom) +(const Tensor& gradOutput_, +const Tensor& input_, +IntArrayRef kernel_size, +IntArrayRef stride, +IntArrayRef padding, +IntArrayRef dilation, +bool ceil_mode, +const Tensor& indices_, +const Tensor& gradInput) { + NoNamesGuard guard; + + TensorArg gradInput_arg{ gradInput, "gradInput", 1 }; + TensorArg gradOutput_arg{ gradOutput_, "gradOutput_", 2 }; + TensorArg input_arg{ input_, "input_", 3 }; + TensorArg indices_arg{ indices_, "indices", 4 }; + + checkAllSameGPU(__func__, + {gradInput_arg, gradOutput_arg, input_arg, indices_arg}); + if (gradOutput_.numel() == 0) { + return; + } + + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); + + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dH : safe_downcast(stride[1]); + + const int padH = safe_downcast(padding[0]); + const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); + + const int dilationH = safe_downcast(dilation[0]); + const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast(dilation[1]); + + const auto memory_format = input_.suggest_memory_format(); + + const Tensor input = input_.contiguous(memory_format); + + const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1; + const int64_t nInputPlane = input.size(-3); + const int64_t inputHeight = input.size(-2); + const int64_t inputWidth = input.size(-1); + + const int64_t in_stride_n = input.ndimension() == 4 ? input.stride(-4) : 0; + const int64_t in_stride_c = input.stride(-3); + const int64_t in_stride_h = input.stride(-2); + const int64_t in_stride_w = input.stride(-1); + + const Tensor gradOutput = gradOutput_.contiguous(memory_format); + + const int64_t outputHeight = gradOutput.size(-2); + const int64_t outputWidth = gradOutput.size(-1); + + const int64_t out_stride_c = gradOutput.stride(-3); + const int64_t out_stride_h = gradOutput.stride(-2); + const int64_t out_stride_w = gradOutput.stride(-1); + + const Tensor indices = indices_.contiguous(memory_format); + + gradInput.zero_(); + + int64_t count = input.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "max_pool2d_with_indices_out_zoom_frame", + [&] { + using accscalar_t = acc_type; + + const scalar_t *gradOutput_data = gradOutput.const_data_ptr(); + scalar_t *gradInput_data = gradInput.mutable_data_ptr(); + const int64_t *indices_data = indices.const_data_ptr(); + + switch (memory_format) { + case MemoryFormat::ChannelsLast: { + const int max_threads = std::min(at::zoom::getCurrentDeviceProperties()->maxThreadsPerBlock, HIP_MAX_THREADS); + int* maxThreadsDim = at::zoom::getCurrentDeviceProperties()->maxThreadsDim; + int block_x = std::min( + maxThreadsDim[0], std::min(lastPow2(nInputPlane), at::zoom::warp_size())); + int block_y = std::min( + maxThreadsDim[1], std::min(lastPow2(inputWidth), max_threads / block_x)); + int block_z = std::min( + maxThreadsDim[2], std::min(lastPow2(inputHeight), max_threads / block_x / block_y)); + block_x = std::min( + maxThreadsDim[0], std::min(lastPow2(nInputPlane), max_threads / block_y / block_z)); + const dim3 block(block_x, block_y, block_z); + + int kernel_stride_C = ceil_div( + safe_downcast(nInputPlane), block_x * 4); + int kernel_size_C = ceil_div( + safe_downcast(nInputPlane), block_x * kernel_stride_C); + + int grid_x = nbatch*kernel_stride_C; + int grid_y = std::min( + at::zoom::getCurrentDeviceProperties()->maxGridSize[1], + ceil_div(safe_downcast(inputWidth), block_y*BLOCK_STRIDE)); + int grid_z = std::min( + at::zoom::getCurrentDeviceProperties()->maxGridSize[2], + ceil_div(safe_downcast(inputHeight), block_z*BLOCK_STRIDE)); + const dim3 grid(grid_x, grid_y, grid_z); + + size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * sizeof(accscalar_t); + AT_ASSERT(shmem_size <= at::zoom::getCurrentDeviceProperties()->sharedMemPerBlock); + + // The backward kernel is launched on input instead output. + // If it is launched on output layer, atomic_add would not provide much benefit on FP16. + // Please check comments at https://github.com/pytorch/pytorch/pull/34519. + max_pool_backward_nhwc + <<>>( + gradOutput_data, + indices_data, + nbatch, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + kH, kW, dH, dW, padH, padW, dilationH, dilationW, + out_stride_c, out_stride_h, out_stride_w, + kernel_stride_C, kernel_size_C, + gradInput_data); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + case MemoryFormat::Contiguous: { + int imgcount = inputWidth * inputHeight; + dim3 grid; + const int blocks = (imgcount + BLOCK_THREADS - 1) / BLOCK_THREADS; + grid.x = blocks; + grid.y = nbatch; + uint64_t maxGridY = at::zoom::getCurrentDeviceProperties()->maxGridSize[1]; + if (maxGridY < grid.y) grid.y = maxGridY; + grid.z = nInputPlane; + uint64_t maxGridZ = at::zoom::getCurrentDeviceProperties()->maxGridSize[2]; + if (maxGridZ < grid.z) grid.z = maxGridZ; + + max_pool_backward_nchw + <<>>( + gradOutput_data, + indices_data, + nbatch, + nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, + kH, kW, dH, dW, padH, padW, dilationH, dilationW, + gradInput_data); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + default: TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); + } + } + ); +} + +} // at::native diff --git a/aten/src/ATen/native/zoom/DilatedMaxPool3d.cu b/aten/src/ATen/native/zoom/DilatedMaxPool3d.cu new file mode 100644 index 0000000000000..615d86cc1c643 --- /dev/null +++ b/aten/src/ATen/native/zoom/DilatedMaxPool3d.cu @@ -0,0 +1,652 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#endif + +namespace at::native { +namespace { + +__device__ inline int min(int a, int b) { + return a <= b ? a : b; +} + +template +__global__ static void max_pool3d_with_indices_single_out_frame( + const scalar_t* inputData, + scalar_t* outputData, + int64_t* indicesData, + int features, + int itime, int iheight, int iwidth, + int obatch, int otime, int oheight, int owidth, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int dilationT, int dilationH, int dilationW, + int offsetZ, + bool channels_last) +{ + int oColumn = blockIdx.x * blockDim.x + threadIdx.x; + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + int oFrame = 0; + // used only for channels-first indexing + int64_t slice = 0; + // used only for channels-last indexing + int batch = 0; + int channel = 0; + if (!channels_last) { + // indexing order: batch, channel, time + oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature + } else { + // indexing order: batch, time, channel + channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel) + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time) + batch = slice / otime; + oFrame = slice % otime; + } + + // For int64_t data type, see https://github.com/pytorch/pytorch/issues/52822 + if (oRow < oheight && oColumn < owidth && oFrame < otime && channel < features && batch < obatch) + { + int tStart = oFrame * dT - pT; + int hStart = oRow * dH - pH; + int wStart = oColumn * dW - pW; + int tEnd = min(tStart + (kT - 1) * dilationT + 1, itime); + int hEnd = min(hStart + (kH - 1) * dilationH + 1, iheight); + int wEnd = min(wStart + (kW - 1) * dilationW + 1, iwidth); + + while(tStart < 0) + tStart += dilationT; + while(hStart < 0) + hStart += dilationH; + while(wStart < 0) + wStart += dilationW; + + // maxIndex remains in "channels-first"/contiguous + int64_t maxIndex = tStart * iheight * iwidth + hStart * iwidth + wStart; + + if (!channels_last) { + inputData += (int64_t) slice * itime * iheight * iwidth; + } else { + inputData += ((int64_t) batch * itime * iheight * iwidth * features) + channel; + } + + scalar_t max = at::numeric_limits::lower_bound(); // -Infinity + + for (int t = tStart; t < tEnd; t += dilationT) + { + for (int h = hStart; h < hEnd; h += dilationH) + { + for (int w = wStart; w < wEnd; w += dilationW) + { + scalar_t val; + int index = t * iheight * iwidth + h * iwidth + w; + if (!channels_last) { + val = inputData[index]; + } else { + int64_t index_channels_last = index*features; + val = inputData[index_channels_last]; + } + + if ((max < val) || at::_isnan(val)) + { + max = val; + maxIndex = index; + } + } + } + } + + int64_t out_index; + if (!channels_last) { + out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn; + } else { + out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel; + } + outputData[out_index] = max; + indicesData[out_index] = maxIndex; + } +} + +template +void max_pool3d_with_indices_out_frame( + const scalar_t* input_data, + const Tensor& output, + const Tensor& indices, + int features, + int64_t totalZ, + int itime, int iheight, int iwidth, + int obatch, int otime, int oheight, int owidth, + int kT, int kH, int kW, + int dT, int dH, int dW, + int pT, int pH, int pW, + int dilationT, int dilationH, int dilationW, + bool channels_last) +{ + int offsetZ = 0; + int threadX = 32; + int threadY = 8; + int threadZ = 1; + int stepZ = 65535; + if (channels_last) { + threadX = 2; + threadY = 4; + threadZ = 64; + } + dim3 block(threadX, threadY, threadZ); + + while (totalZ > 0) { + dim3 grid(ceil_div(owidth, static_cast(block.x)), + ceil_div(oheight, static_cast(block.y)), + totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast(threadZ))); + + max_pool3d_with_indices_single_out_frame + <<>>( + input_data, + output.mutable_data_ptr(), + indices.mutable_data_ptr(), + features, + itime, iheight, iwidth, + obatch, otime, oheight, owidth, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + dilationT, dilationH, dilationW, + offsetZ, channels_last); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + totalZ -= threadZ*stepZ; + offsetZ += threadZ*stepZ; + } +} + +#undef UPDATE_OUTPUT_KERNEL_WIDTH + +template +__global__ static void max_pool3d_with_indices_backward_single_out_frame( + scalar_t *gradInputData, + const scalar_t *gradOutputData, + const int64_t *indicesData, + int features, + int itime, int iheight, int iwidth, + int obatch, int otime, int oheight, int owidth, + int offsetZ, + bool channels_last) +{ + int oColumn = blockIdx.x * blockDim.x + threadIdx.x; + int oRow = blockIdx.y * blockDim.y + threadIdx.y; + + int oFrame = 0; + // used only for channels-first indexing + int64_t slice = 0; + // used only for channels-last indexing + int batch = 0; + int channel = 0; + if (!channels_last) { + // indexing order: batch, channel, time + oFrame = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % otime; // output frame/time + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / otime; // output slice/feature + } else { + // indexing order: batch, time, channel + channel = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) % features; // output feature (channel) + slice = (blockIdx.z * blockDim.z + threadIdx.z + offsetZ) / features; // output slice (batch + time) + batch = slice / otime; + oFrame = slice % otime; + } + + if (oRow < oheight && oColumn < owidth && oFrame < otime && batch < obatch && channel < features) + { + int64_t out_index; + if (!channels_last) { + out_index = (int64_t) slice*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn; + } else { + out_index = ((int64_t) batch*otime*oheight*owidth + oFrame*oheight*owidth + oRow*owidth + oColumn)*features + channel; + } + int64_t maxIndex = indicesData[out_index]; + if (maxIndex != -1) { + if (!channels_last) { + gpuAtomicAddNoReturn(&gradInputData[(int64_t) slice * itime * iheight * iwidth + maxIndex], + gradOutputData[out_index]); + } else { + gpuAtomicAddNoReturn(&gradInputData[((int64_t) batch * itime * iheight * iwidth + maxIndex) * features + channel], + gradOutputData[out_index]); + } + } + } +} + +template +void max_pool3d_with_indices_backward_out_frame( + scalar_t *gradInputData, + const Tensor& gradOutput, + const Tensor& indices, + int features, + int64_t totalZ, + int itime, int iheight, int iwidth, + int obatch, int otime, int oheight, int owidth, + bool channels_last) +{ + int offsetZ = 0; + int threadX = 32; + int threadY = 8; + int threadZ = 1; + int stepZ = 65535; + if (channels_last) { + threadX = 2; + threadY = 4; + threadZ = 64; + } + dim3 block(threadX, threadY, threadZ); + + while (totalZ > 0) { + dim3 grid(ceil_div(owidth, static_cast(block.x)), + ceil_div(oheight, static_cast(block.y)), + totalZ > stepZ*threadZ ? stepZ : ceil_div(totalZ, static_cast(block.z))); + + max_pool3d_with_indices_backward_single_out_frame + <<>>( + gradInputData, + gradOutput.const_data_ptr(), + indices.const_data_ptr(), + features, + itime, iheight, iwidth, + obatch, otime, oheight, owidth, + offsetZ, + channels_last); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + totalZ -= threadZ*stepZ; + offsetZ += threadZ*stepZ; + } +} + +void max_pool3d_with_indices_out_zoom_template( + Tensor& output, + Tensor& indices, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) +{ + TensorArg output_arg{ output, "output", 1 }; + TensorArg indices_arg{ indices, "indices", 2 }; + TensorArg input_arg{ input, "input", 3 }; + + checkAllSameGPU(__func__, + {output_arg, indices_arg, input_arg}); + + // #20866, #22032: Guarantee this for the official C++ API? + TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3, + "max_pool3d: kernel_size must either be a single int, or a tuple of three ints") + const int kT = safe_downcast(kernel_size[0]); + const int kH = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[1]); + const int kW = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[2]); + + TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 3, + "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints") + const int dT = stride.empty() ? kT : safe_downcast(stride[0]); + const int dH = stride.empty() ? kH : + stride.size() == 1 ? dT : safe_downcast(stride[1]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dT : safe_downcast(stride[2]); + + TORCH_CHECK(padding.size() == 1 || padding.size() == 3, + "max_pool3d: padding must either be a single int, or a tuple of three ints"); + const int pT = safe_downcast(padding[0]); + const int pH = padding.size() == 1 ? pT : safe_downcast(padding[1]); + const int pW = padding.size() == 1 ? pT : safe_downcast(padding[2]); + + TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3, + "max_pool3d: dilation must be either a single int, or a tuple of three ints"); + const int dilationT = safe_downcast(dilation[0]); + const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast(dilation[1]); + const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast(dilation[2]); + + const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1; + const int64_t nslices = input.size(-4); + const int64_t itime = input.size(-3); + const int64_t iheight = input.size(-2); + const int64_t iwidth = input.size(-1); + + const int64_t otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode); + const int64_t oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode); + const int64_t owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode); + + pool3d_shape_check( + input, + nslices, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + dilationT, dilationH, dilationW, + itime, iheight, iwidth, + otime, oheight, owidth, + "max_pool3d_with_indices_out_zoom_template()"); + + bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d; + Tensor _input = input; + if (input.ndimension() == 4) { + Tensor input_channels_last_check = input.unsqueeze(0); + // work around buggy behavior of suggest_memory_format here where + // suggested format of unsqueezed tensor is contiguous while it is + // really only contiguous in ChannelsLast3d + channels_last = (!input_channels_last_check.is_contiguous()) && + input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d); + if (!channels_last) { + output.resize_({ nslices, otime, oheight, owidth}); + indices.resize_({nslices, otime, oheight, owidth}); + } else { + _input = input_channels_last_check; + output.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + indices.resize_({1, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + output = output.squeeze(0); + indices = indices.squeeze(0); + } + } else { + if (!channels_last) { + output.resize_({nbatch, nslices, otime, oheight, owidth}); + indices.resize_({nbatch, nslices, otime, oheight, owidth}); + } else { + output.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + indices.resize_({nbatch, nslices, otime, oheight, owidth}, at::MemoryFormat::ChannelsLast3d); + } + } + + if (input.numel() == 0) { + return; + } + + Tensor work_input; + Tensor work_output = output; + if (!channels_last) { + work_input = input.contiguous(); + } else { + work_input = _input.contiguous(at::MemoryFormat::ChannelsLast3d); + } + Tensor work_indices = indices; + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, + input.scalar_type(), + "max_pool3d_with_indices_out_frame", + [&]{ + const scalar_t *input_data = work_input.const_data_ptr(); + const int64_t totalZ = otime * nslices * nbatch; + + max_pool3d_with_indices_out_frame( + input_data, work_output, work_indices, + nslices, // features + totalZ, + itime, iheight, iwidth, + nbatch, otime, oheight, owidth, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + dilationT, dilationH, dilationW, channels_last); + } + ); +} + +void max_pool3d_with_indices_backward_out_zoom_template( + Tensor& gradInput, + const Tensor& gradOutput, + const Tensor& input, + const Tensor& indices, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) +{ + TensorArg gradInput_arg{ gradInput, "gradInput", 1 }; + TensorArg gradOutput_arg{ gradOutput, "gradOutput", 2 }; + TensorArg input_arg{ input, "input", 3 }; + TensorArg indices_arg{ indices, "indices", 4 }; + + checkAllSameGPU(__func__, + {gradInput_arg, gradOutput_arg, input_arg, indices_arg}); + + // #20866, #22032: Guarantee this for the official C++ API? + TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 3, + "max_pool3d: kernel_size must either be a single int, or a tuple of three ints") + const int kT = safe_downcast(kernel_size[0]); + const int kH = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[1]); + const int kW = kernel_size.size() == 1 ? kT : safe_downcast(kernel_size[2]); + + TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 3, + "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints") + const int dT = stride.empty() ? kT : safe_downcast(stride[0]); + const int dH = stride.empty() ? kH : + stride.size() == 1 ? dT : safe_downcast(stride[1]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dT : safe_downcast(stride[2]); + + TORCH_CHECK(padding.size() == 1 || padding.size() == 3, + "max_pool3d: padding must either be a single int, or a tuple of three ints"); + const int pT = safe_downcast(padding[0]); + const int pH = padding.size() == 1 ? pT : safe_downcast(padding[1]); + const int pW = padding.size() == 1 ? pT : safe_downcast(padding[2]); + + TORCH_CHECK(dilation.size() == 1 || dilation.size() == 3, + "max_pool3d: dilation must be either a single int, or a tuple of three ints"); + const int dilationT = safe_downcast(dilation[0]); + const int dilationH = dilation.size() == 1 ? dilationT : safe_downcast(dilation[1]); + const int dilationW = dilation.size() == 1 ? dilationT : safe_downcast(dilation[2]); + + TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), + "max_pool2d_with_indices_backward_out_zoom_template(): ", + "Expected 4D or 5D input tensor, but got ", input.sizes()); + + TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5), + "max_pool2d_with_indices_backward_out_zoom_template(): ", + "Expected 4D or 5D gradOutput tensor, but got ", gradOutput.sizes()); + + // Resize and initialize result tensor. + bool channels_last = input.ndimension() == 5 && input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d; + Tensor _input = input; + if (input.ndimension() == 4) { + Tensor input_channels_last_check = input.unsqueeze(0); + // work around buggy behavior of suggest_memory_format here where + // suggested format of unsqueezed tensor is contiguous while it is + // really only contiguous in ChannelsLast3d + channels_last = (!input_channels_last_check.is_contiguous()) && + input_channels_last_check.is_contiguous(at::MemoryFormat::ChannelsLast3d); + if (channels_last) { + _input = input_channels_last_check; + } + } + if (!channels_last) { + gradInput.resize_as_(input); + } else { + gradInput.resize_as_(_input, at::MemoryFormat::ChannelsLast3d); + } + gradInput.zero_(); + + const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1; + const int64_t nslices = input.size(-4); + + const int64_t otime = gradOutput.size(-3); + const int64_t oheight = gradOutput.size(-2); + const int64_t owidth = gradOutput.size(-1); + + const int64_t itime = gradInput.size(-3); + const int64_t iheight = gradInput.size(-2); + const int64_t iwidth = gradInput.size(-1); + + max_pool3d_backward_shape_check( + input, + gradOutput, + indices, + nslices, + kT, kH, kW, + dT, dH, dW, + pT, pH, pW, + dilationT, dilationH, dilationW, + itime, iheight, iwidth, + otime, oheight, owidth, + "max_pool3d_with_indices_backward_out_zoom_template()"); + + if (gradOutput.numel() == 0) { + return; + } + + Tensor work_grad_input = gradInput; + Tensor work_grad_output; + Tensor work_indices; + if (!channels_last) { + work_grad_output = gradOutput.contiguous(); + work_indices = indices.contiguous(); + } else { + if (input.ndimension() == 4) { + work_grad_output = gradOutput.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d); + work_indices = indices.unsqueeze(0).contiguous(at::MemoryFormat::ChannelsLast3d); + } else { + work_grad_output = gradOutput.contiguous(at::MemoryFormat::ChannelsLast3d); + work_indices = indices.contiguous(at::MemoryFormat::ChannelsLast3d); + } + } + + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), + "max_pool3d_with_indices_backward_out_frame", + [&] { + const int64_t totalZ = otime * nslices * nbatch; + scalar_t *grad_input_data = work_grad_input.mutable_data_ptr(); + + max_pool3d_with_indices_backward_out_frame( + grad_input_data, work_grad_output, work_indices, + nslices, + totalZ, + itime, iheight, iwidth, + nbatch, otime, oheight, owidth, + channels_last); + } + ); +} + +} // namespace + +std::tuple max_pool3d_with_indices_out_zoom(const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + Tensor& output, + Tensor& indices) +{ + max_pool3d_with_indices_out_zoom_template( + output, + indices, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode); + return std::tuple(output, indices); +} + +std::tuple max_pool3d_with_indices_zoom( + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode) +{ + NoNamesGuard guard; + + Tensor output = at::empty({0}, input.options()); + Tensor indices = at::empty({0}, input.options().dtype(kLong)); + max_pool3d_with_indices_out_zoom_template( + output, + indices, + input, + kernel_size, + stride, + padding, + dilation, + ceil_mode); + + guard.reset(); + namedinference::propagate_names(output, input); + namedinference::propagate_names(indices, input); + + return std::tuple(output, indices); +} + +Tensor& max_pool3d_with_indices_backward_out_zoom(const Tensor& gradOutput, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + const Tensor& indices, + Tensor& gradInput) +{ + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("max_pool3d_with_indices_backward_out_zoom"); + max_pool3d_with_indices_backward_out_zoom_template( + gradInput, + gradOutput, + input, + indices, + kernel_size, + stride, + padding, + dilation, + ceil_mode); + return gradInput; +} + +Tensor max_pool3d_with_indices_backward_zoom( + const Tensor& gradOutput, + const Tensor& input, + IntArrayRef kernel_size, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + bool ceil_mode, + const Tensor& indices) +{ + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("max_pool3d_with_indices_backward_zoom"); + auto gradInput = at::empty(input.sizes(), input.options()); + max_pool3d_with_indices_backward_out_zoom_template( + gradInput, + gradOutput, + input, + indices, + kernel_size, + stride, + padding, + dilation, + ceil_mode); + return gradInput; +} + +} // at::native diff --git a/aten/src/ATen/native/zoom/DistanceKernel.cu b/aten/src/ATen/native/zoom/DistanceKernel.cu new file mode 100644 index 0000000000000..248b8a431f7cb --- /dev/null +++ b/aten/src/ATen/native/zoom/DistanceKernel.cu @@ -0,0 +1,365 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +#include + +namespace at::native { + +namespace { + +constexpr int kHIPNumThreads = 256; + +template +struct dists { + + static __forceinline__ __device__ scalar_t sign(scalar_t val) { + return (0 < val) - (val < 0); + } + + // Zero norm + struct zero { + static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff != 0.0; } + static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; } + static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } + }; + + // One norm + struct one { + static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff; } + static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; } + static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } + static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t /*dist*/, const scalar_t /*p*/) { return grad * sign(diff); } + }; + + // Special case backward when p is less than two + struct lt_two { + static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { + return (dist == 0.0 || (diff == 0.0 && p < 1)) ? 0 : (sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1)); + } + }; + + // Two norm + struct two { + static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff * diff; } + static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return device_sqrt(agg); } + static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } + static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t /*p*/) { return dist == 0.0 ? 0 : grad * diff / dist; } + }; + + // General p norm + struct p { + static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { agg += std::pow(diff, p); } + static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, static_cast(1) / p); } + static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; } + static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : diff * std::pow(std::abs(diff), p - 2) * grad / std::pow(dist, p - 1); } + }; + + // Inf norm + struct inf { + static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { if (diff > agg) { agg = diff; } } + static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; } + static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { if (other > update) { update = other; } } + static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t /*p*/) { return grad * sign(diff) * (std::abs(diff) == dist); } + }; + +}; + +template +struct DistReduceOp { + __forceinline__ __device__ scalar_t combine(scalar_t a, scalar_t b) const { + F::agg(a, b); + return a; + } + + __forceinline__ __device__ scalar_t warp_shfl_down(scalar_t data, int offset) const { + return WARP_SHFL_DOWN(data, offset); + } +}; + +template +__global__ static void pdist_kernel_zoom_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p, + const double n2, const double n2_squared_minus_1) { + const int64_t k = blockIdx.x; + const int stride = blockDim.x; + + // The -1 accounts for floating point truncation issues + int64_t i = static_cast((n2 - device_sqrt(n2_squared_minus_1 - 2 * k))); + int64_t j = k - n * i + i * (i + 1) / 2 + i + 1; + + const scalar_t * const start = self + i * m; + const scalar_t * const end = start + m; + const scalar_t * a = start + threadIdx.x; + const scalar_t * b = self + j * m + threadIdx.x; + scalar_t agg = 0.0; + for (; a < end; a += stride, b += stride) { + F::inc(agg, std::abs(*a - *b), p); + } + + __shared__ scalar_t agg_smem[kHIPNumThreads]; + scalar_t agg_init{0.0}; + agg = zoom_utils::BlockReduce(agg, DistReduceOp{}, agg_init, agg_smem); + if (threadIdx.x == 0) { + result[k] = F::finish(agg, p); + } +} + +template +__global__ static void cdist_backward_kernel_zoom_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * x1, const scalar_t * x2, const scalar_t * dist, + const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t count, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) { + const int y = (blockIdx.y * gridDim.z + blockIdx.z) * blockDim.y + threadIdx.y; + const int init = blockIdx.x * blockDim.x + threadIdx.x; + if (y >= count || init >= m) { + return; + } + const int l = y / r_size; + const int k = y % r_size; + const int stride = blockDim.x * gridDim.x; + const int l_size = r_size * m; + + int64_t i = k / r2; + int64_t j = k % r2; + + const scalar_t grad_k = grad[y]; + const scalar_t dist_k = dist[y]; + + const scalar_t * const start = x1 + l * l1_size + i * m; + const scalar_t * const end = start + m; + const scalar_t * self_i = start + init; + const scalar_t * self_j = x2 + l * l2_size + j * m + init; + + scalar_t * buff_i = buffer + l * l_size + (r1 * j + i) * m + init; + + for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride) { + const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p); + *buff_i = res; + } +} + +template +__global__ static void pdist_backward_kernel_zoom_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p, + const double n2, const double n2_squared_minus_1) { + const int64_t k = blockIdx.x * blockDim.x + threadIdx.x; + const int init = blockIdx.y * blockDim.y + threadIdx.y; + const int stride = blockDim.y * gridDim.y; + + if (k >= combs) { + return; + } + + // The -1 accounts for floating point truncation issues + int64_t i = static_cast((n2 - device_sqrt(n2_squared_minus_1 - 2 * k))); + int64_t j = k - n * i + i * (i + 1) / 2 + i + 1; + int64_t ib = j - i - 1; + int64_t jb = n - 2 - i; + + const scalar_t grad_k = grad[k * gs]; + const scalar_t dist_k = dist[k]; + + const scalar_t * const start = self + i * m; + const scalar_t * const end = start + m; + const scalar_t * self_i = start + init; + const scalar_t * self_j = self + j * m + init; + scalar_t * buff_i = buffer + (ib * n + i) * m + init; + scalar_t * buff_j = buffer + (jb * n + j) * m + init; + for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) { + const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p); + *buff_i = res; + *buff_j = -res; + } +} + +template +__global__ static void cdist_kernel_zoom_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2, + const scalar_t p, const int64_t r2, const int64_t m, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) { + const int64_t l = blockIdx.x / r_size; + const int64_t k = blockIdx.x % r_size; + const int64_t i = k / r2; + const int64_t j = k % r2; + const int stride = blockDim.x; + + const scalar_t * const start = x1 + l * l1_size + i * m; + const scalar_t * const end = start + m; + const scalar_t * a = start + threadIdx.x; + const scalar_t * b = x2 + l * l2_size + j * m + threadIdx.x; + + scalar_t agg = 0.0; + for (; a < end; a += stride, b += stride) { + F::inc(agg, std::abs(*a - *b), p); + } + __shared__ scalar_t agg_smem[kHIPNumThreads]; + scalar_t agg_init{0.0}; + agg = zoom_utils::BlockReduce(agg, DistReduceOp{}, agg_init, agg_smem); + if (threadIdx.x == 0) { + result[blockIdx.x] = F::finish(agg, p); + } +} + +void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) { + const int64_t r1 = x1.size(-2); + const int64_t r2 = x2.size(-2); + const int64_t m = x1.size(-1); + const int64_t r_size = r1 * r2; + const int64_t l1_size = r1 * m; + const int64_t l2_size = r2 * m; + const dim3 grid(result.numel()); + const dim3 block(kHIPNumThreads); + + AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_zoom", [&] { + auto impl_fptr = cdist_kernel_zoom_impl::p>; + if (p == 0.0) { + impl_fptr = cdist_kernel_zoom_impl::zero>; + } else if (p == 1.0) { + impl_fptr = cdist_kernel_zoom_impl::one>; + } else if (p == 2.0) { + impl_fptr = cdist_kernel_zoom_impl::two>; + } else if (std::isinf(p)) { + impl_fptr = cdist_kernel_zoom_impl::inf>; + } + impl_fptr<<>>(result.mutable_data_ptr(), x1.const_data_ptr(), x2.const_data_ptr(), p, r2, m, r_size, l1_size, l2_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { + const dim3 grid(result.numel()); + const dim3 block(kHIPNumThreads); + int64_t n = self.size(0); + int64_t m = self.size(1); + // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do + // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device. + const double n2 = n - .5; + const double n2_squared_minus_1 = n2 * n2 - 1; + + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_zoom", [&] { + auto impl_fptr = pdist_kernel_zoom_impl::p>; + if (p == 0.0) { + impl_fptr = pdist_kernel_zoom_impl::zero>; + } else if (p == 1.0) { + impl_fptr = pdist_kernel_zoom_impl::one>; + } else if (p == 2.0) { + impl_fptr = pdist_kernel_zoom_impl::two>; + } else if (std::isinf(p)) { + impl_fptr = pdist_kernel_zoom_impl::inf>; + } + impl_fptr<<>>(result.mutable_data_ptr(), self.const_data_ptr(), n, m, p, n2, n2_squared_minus_1); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) { + if (p == 0.0 || grad.numel() == 0 || self.numel() == 0) { + result.fill_(0); + return; + } + + const int64_t n = result.size(0); + int64_t m = self.size(1); + const int block_x = 16; + // NB: be careful with changing block_y; as it's currently written, grid_y is limited to be 2^16. + // block_y of 64 gives us max pdist dim1 of 2**24 + const int block_y = 64; + const int grid_x = (dist.numel() + block_x - 1) / block_x; + const int grid_y = (m + block_y * 8 - 1) / (block_y * 8); + const dim3 grid(grid_x, grid_y); + const dim3 block(block_x, block_y); + // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do + // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device. + const double n2 = n - .5; + const double n2_squared_minus_1 = n2 * n2 - 1; + + Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options()); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_zoom_backward", [&] { + auto impl_fptr = pdist_backward_kernel_zoom_impl::p>; + if (p == 1.0) { + impl_fptr = pdist_backward_kernel_zoom_impl::one>; + } else if (p < 2.0) { + impl_fptr = pdist_backward_kernel_zoom_impl::lt_two>; + } else if (p == 2.0) { + impl_fptr = pdist_backward_kernel_zoom_impl::two>; + } else if (std::isinf(p)) { + impl_fptr = pdist_backward_kernel_zoom_impl::inf>; + } + impl_fptr<<>>(buffer.mutable_data_ptr(), grad.const_data_ptr(), self.const_data_ptr(), dist.const_data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + + at::sum_out(result, buffer, 0); +} + +void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) { + if (p == 0.0 || grad.numel() == 0 || x1.numel() == 0 || x2.numel() == 0) { + result.fill_(0); + return; + } + + const int64_t r1 = x1.size(-2); + const int64_t r2 = x2.size(-2); + const int64_t m = x1.size(-1); + // Just like we do in the CPU code, assume that result is always batched + int64_t batch = result.size(0); + const int block_x = 64; + const int block_y = 16; + const int grid_x = (m + block_x * 8 - 1) / (block_x * 8); + + const int64_t count = dist.numel(); + const int64_t grid_temp = (count + block_y - 1) / block_y; + + const int grid_y = (grid_temp - 1) / 65535 + 1; + const int grid_z = (grid_temp - 1) / grid_y + 1; + + const dim3 grid(grid_x, grid_y, grid_z); + const dim3 block(block_x, block_y); + + const int64_t r_size = r1 * r2; + const int64_t l1_size = r1 * m; + const int64_t l2_size = r2 * m; + //current implementation supports only gradient that can be collapsed to 1D. However, to avoid checking this assumption, + //we call grad.contiguous() before backward, so stride is guaranteed to be 1 + + Tensor buffer = at::empty({batch, r2, r1, m}, result.options()); + AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_zoom_backward", [&] { + auto impl_fptr = cdist_backward_kernel_zoom_impl::p>; + if (p == 1.0) { + impl_fptr = cdist_backward_kernel_zoom_impl::one>; + } else if (p < 2.0) { + impl_fptr = cdist_backward_kernel_zoom_impl::lt_two>; + } else if (p == 2.0) { + impl_fptr = cdist_backward_kernel_zoom_impl::two>; + } else if (std::isinf(p)) { + impl_fptr = cdist_backward_kernel_zoom_impl::inf>; + } + impl_fptr<<>>(buffer.mutable_data_ptr(), + grad.const_data_ptr(), x1.const_data_ptr(), x2.const_data_ptr(), dist.const_data_ptr(), + p, r1, r2, m, count, r_size, l1_size, l2_size); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + + at::sum_out(result, buffer, 1); + +} + + +} // anonymous namespace + +REGISTER_PRIVATEUSE1_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(cdist_stub, &cdist_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl); + +} // at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/DistributionBernoulli.cu b/aten/src/ATen/native/zoom/DistributionBernoulli.cu new file mode 100644 index 0000000000000..1e2dc0ada7939 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionBernoulli.cu @@ -0,0 +1,40 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace at::native { + +void bernoulli_tensor_kernel(const TensorBase &self, const TensorBase &p_, std::optional gen_) { + auto generator = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::bernoulli_kernel(self, p_, generator); +} + +void bernoulli_scalar_kernel(const TensorBase &self, double p, std::optional gen) { + auto iter = TensorIterator::borrowing_nullary_op(self); + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::bernoulli_kernel(iter, p, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(bernoulli_tensor_stub, &bernoulli_tensor_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(bernoulli_scalar_stub, &bernoulli_scalar_kernel); + + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionCauchyKernel.cu b/aten/src/ATen/native/zoom/DistributionCauchyKernel.cu new file mode 100644 index 0000000000000..729878c244cf7 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionCauchyKernel.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::cauchy_kernel(iter, median, sigma, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(cauchy_stub, &cauchy_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionExponentialKernel.cu b/aten/src/ATen/native/zoom/DistributionExponentialKernel.cu new file mode 100644 index 0000000000000..2dd9cece28699 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionExponentialKernel.cu @@ -0,0 +1,16 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + + +namespace at::native { + +void exponential_kernel(TensorIteratorBase& iter, double lambda, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::exponential_kernel(iter, lambda, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(exponential_stub, &exponential_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionGeometricKernel.cu b/aten/src/ATen/native/zoom/DistributionGeometricKernel.cu new file mode 100644 index 0000000000000..cd8a883cf0a38 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionGeometricKernel.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void geometric_kernel(TensorIteratorBase& iter, double p_, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::geometric_kernel(iter, p_, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(geometric_stub, &geometric_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionLogNormalKernel.cu b/aten/src/ATen/native/zoom/DistributionLogNormalKernel.cu new file mode 100644 index 0000000000000..dd57bc450f5cd --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionLogNormalKernel.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::log_normal_kernel(iter, mean, std, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(log_normal_stub, &log_normal_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionNormal.cu b/aten/src/ATen/native/zoom/DistributionNormal.cu new file mode 100644 index 0000000000000..1eee03731df11 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionNormal.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void normal_kernel(const TensorBase &self, double mean, double std, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::normal_kernel(self, mean, std, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(normal_stub, &normal_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionRandomKernel.cu b/aten/src/ATen/native/zoom/DistributionRandomKernel.cu new file mode 100644 index 0000000000000..7e8aa20d652ba --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionRandomKernel.cu @@ -0,0 +1,27 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_from_to_kernel(iter, range, base, gen); +} + +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_full_64_bits_range_kernel(iter, gen); +} + +void random_kernel(TensorIteratorBase& iter, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_kernel(iter, gen); +} + +REGISTER_PRIVATEUSE1_DISPATCH(random_from_to_stub, &random_from_to_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(random_stub, &random_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(random_full_64_bits_range_stub, &random_full_64_bits_range_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionTemplates.h b/aten/src/ATen/native/zoom/DistributionTemplates.h new file mode 100644 index 0000000000000..584d90f924776 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionTemplates.h @@ -0,0 +1,672 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +//#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +// launch bounds used for kernels utilizing TensorIterator +const uint32_t block_size_bound = 256; +const uint32_t grid_size_bound = 4; +// number of randoms given by distributions like hiprand_uniform4, hiprand_uniform2_double +// used in calculating philox offset. +const uint32_t hiprand4_engine_calls = 4; + +// utility function that calculates proper philox_offset +// for distributions utilizing TensorIterator. For distributions using +// TensorIterator, we are using a grid-stride loop with each +// thread yielding one element per thread. For the edge of the grid-stride +// loop, if the tensor size is large, the unroll loop will kick in and the float4 +// from hiprand4 will start getting utilized (for common tensor sizes, we end up +// using rand.x from each thread). Hence, the philox_offset is +// (number of elements per thread * number of engine calls), which makes +// sure that philox offset increment is not less than the number of randoms used +// in each thread. +std::tuple calc_execution_policy(int64_t total_elements) { + const uint64_t numel = static_cast(total_elements); + const uint32_t block_size = block_size_bound; + const uint32_t unroll = hiprand4_engine_calls; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + uint32_t blocks_per_sm = at::zoom::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + grid.x = std::min( + static_cast(at::zoom::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm, + grid.x); + //number of times random will be generated per thread, to offset philox counter in thc random state + uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1) + * hiprand4_engine_calls; + return std::make_tuple(counter_offset, grid, dim_block); +} + +// grid stride loop kernel for distributions +template +C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) +__global__ void distribution_elementwise_grid_stride_kernel(int numel, + PhiloxHIPState philox_args, + const dist_t dist_func, + const transform_t transform_func) { + auto seeds = at::zoom::philox::unpack(philox_args); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + + int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * + blockDim.x * gridDim.x * unroll_factor; + for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { + auto rand = dist_func(&state); + #pragma unroll + for (int ii = 0; ii < unroll_factor; ii++) { + int li = linear_index + blockDim.x * gridDim.x * ii; + if (li < numel) { + transform_func(li, static_cast((&rand.x)[ii])); + } + } + __syncthreads(); + } +} + +/** + * distribution_nullary_kernel is analogous to gpu_kernel in + * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses + * TensorIterator to launch a kernel. However, the differences are + * - it launches a grid-stride loop based kernel. The kernel is not + * generic like elementwise_kernel in Loops.cuh and is specialized + * for the distribution kernels here. + * - For big size tensors, we can launch multiple kernels recursively + * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox + * offset calculation is done in this function. + * + * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh + * to have grid-stride loop kernel and then use that to launch our distribution + * kernels? Note that we need a grid-stride loop kernel because, we found by testing + * that it achieves peak effective bandwidth. + */ +template +void distribution_nullary_kernel(at::TensorIteratorBase& iter, + RNG gen, + const dist_t& dist_func, + const transform_t transform_func) { + static_assert(unroll_factor >= 1, "unroll_factor must be >= 1."); + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + auto execution_policy = calc_execution_policy(numel); + auto counter_offset = std::get<0>(execution_policy); + auto grid = std::get<1>(execution_policy); + auto block = std::get<2>(execution_policy); + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(counter_offset); + } + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_nullary_kernel(sub_iter, + gen, dist_func, transform_func); + } + return; + } + + char* out_data = (char*)iter.data_ptr(0); + + auto stream = c10::zoom::getCurrentZoomStream(); + if (iter.is_trivial_1d()) { + auto strides = iter.get_inner_strides(); + int stride0 = strides[0]; + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + scalar_t* out = (scalar_t*)&out_data[stride0 * idx]; + *out = transform_func(rand); + } + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + auto offset_calc = make_offset_calculator<1>(iter); + distribution_elementwise_grid_stride_kernel<<>>( + numel, + rng_engine_inputs, + dist_func, + [=]__device__(int idx, accscalar_t rand) { + auto offsets = offset_calc.get(idx); + scalar_t* out = (scalar_t*)&out_data[offsets[0]]; + *out = transform_func(rand); + } + ); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +// Binary kernel +template +__global__ void distribution_binary_elementwise_kernel( + int numel, + func_t f, + PhiloxHIPState philox_args, + typename function_traits::result_type *output_data, + const typename function_traits::template arg<1>::type *input_data_1, + const typename function_traits::template arg<2>::type *input_data_2, + inp_offset_calc_t inp_calc, + out_offset_calc_t out_calc) { + auto seeds = at::zoom::philox::unpack(philox_args); + + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + + input_t_1 inputs_1[thread_work_size()]; + input_t_2 inputs_2[thread_work_size()]; + + int base_index = block_work_size() * blockIdx.x; + int remaining = std::min(numel - base_index, block_work_size()); + + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // load data into registers + int thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = inp_calc.get(input_idx); + inputs_1[i] = input_data_1[offsets[0]]; + inputs_2[i] = input_data_2[offsets[1]]; + + thread_idx += num_threads(); + } + + // compute and store + thread_idx = threadIdx.x; + #pragma unroll + for (int i = 0; i < thread_work_size(); i++) { + if (thread_idx >= remaining) { + break; + } + int input_idx = thread_idx + base_index; + auto offsets = out_calc.get(input_idx); + output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]); + thread_idx += num_threads(); + } +} + +template +void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxHIPState philox_args, const func_t &f) { + static_assert(std::is_same::template arg<0>::type, hiprandStatePhilox4_32_10_t&>::value, "the first argument of functor must be hiprandStatePhilox4_32_10_t"); + using input_t_1 = typename function_traits::template arg<1>::type; + using input_t_2 = typename function_traits::template arg<2>::type; + using output_t = typename function_traits::result_type; + + if (!iter.can_use_32bit_indexing()) { + for (auto& sub_iter : iter.with_32bit_indexing()) { + distribution_binary_kernel(sub_iter, philox_args, f); + } + return; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing()); + + int64_t numel = iter.numel(); + if (numel == 0) { + return; + } + + output_t *output_data = static_cast(iter.data_ptr(0)); + const input_t_1 *input_data_1 = static_cast(iter.data_ptr(1)); + const input_t_2 *input_data_2 = static_cast(iter.data_ptr(2)); + + int64_t grid = (numel + block_work_size() - 1) / block_work_size(); + auto stream = c10::zoom::getCurrentZoomStream(); + + if (iter.is_contiguous()) { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + distribution_binary_elementwise_kernel<<>>( + numel, f, philox_args, output_data, input_data_1, input_data_2, + make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } +} + +} // namespace +}} // namespace at::native + + +namespace at { +namespace native { +namespace templates { +namespace zoom { + +// ==================================================== Random ======================================================== + +template +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) { + AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_zoom", AT_WRAP([&] { + if (( + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && range >= 1ULL << 32) + { + // define lambda to mod with range and add base + auto random_func = [range, base] __device__ (uint64_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = hiprand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [range, base] __device__ (uint32_t rand) { + return transformation::uniform_int_from_to(rand, range, base); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { + return hiprand4(state); + }, + random_func); + } + }), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +// This is the special kernel to handle single specific case: +// from(inclusive) = std::numeric_limits::lowest() +// to(exclusive) = None (= std::numeric_limits::max() + 1) +template +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_zoom", [&] { + if (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int_full_range(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = hiprand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + TORCH_CHECK(false, "random_full_64_bits_range_kernel_zoom handles only int64, double, float and bfloat16"); + } + }); +} + +template +struct RandomFromToKernel { + void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen) { + random_from_to_kernel(iter, range, base, check_generator(gen)); + } + void operator()(TensorIteratorBase& iter, std::optional gen) { + random_full_64_bits_range_kernel(iter, check_generator(gen)); + } +}; + +template +void random_kernel(TensorIteratorBase& iter, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_zoom", [&] { + if (std::is_same::value || std::is_same::value) { + auto random_func = [] __device__ (uint64_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) -> ulonglong2 { + ulonglong2 ret; + uint4 rand_val = hiprand4(state); + ret.x = (static_cast(rand_val.x) << 32) | rand_val.y; + ret.y = (static_cast(rand_val.z) << 32) | rand_val.w; + return ret; + }, + random_func); + } else { + auto random_func = [] __device__ (uint32_t rand) { + return transformation::uniform_int(rand); + }; + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { + return hiprand4(state); + }, + random_func); + } + }); +} + +template +struct RandomKernel { + void operator()(TensorIteratorBase& iter, RNG gen) { + random_kernel(iter, gen); + } +}; + +// ==================================================================================================================== + +template +void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same::value) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_uniform2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_uniform4(state); }, + transform); + } +} + +template +void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) { + if (std::is_same::value) { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_normal2_double(state); }, + transform); + } else { + distribution_nullary_kernel(iter, + gen, + [] __device__ (hiprandStatePhilox4_32_10_t* state) { return hiprand_normal4(state); }, + transform); + } +} + +// ==================================================== Normal ======================================================== + +template +void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) { + auto iter = TensorIterator::borrowing_nullary_op(self); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_zoom", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda to multiply std and add mean + auto normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::normal(rand, mean, std)); + }; + normal_and_transform(iter, gen, normal_func); + }); +} + +template +struct NormalKernel { + void operator()(const TensorBase &self, double mean, double std, std::optional gen) { + normal_kernel(self, mean, std, check_generator(gen)); + } +}; + +// ==================================================== Uniform ======================================================== + +template +void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_zoom", [&] { + auto from = static_cast(from_); + auto to = static_cast(to_); + using opmath_t = at::opmath_type; + auto range = static_cast(to-from); + // define lambda to reverse bounds, multiply 'range' and add 'from_' + auto uniform_func = [range, from, to] __device__ (opmath_t rand) { + // Compute output value before reversing the bounds + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/96947 + auto value = static_cast(rand * range + from); + // reverse the bounds of hiprand4 from (0, 1] to [0, 1) + // Note that this method is from legacy THCTensorRandom and is likely to give + // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and + // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s. + // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706 + auto reverse_bound_value = value == to ? from : value; + return reverse_bound_value; + }; + uniform_and_transform(iter, gen, uniform_func); + }); +} + +template +struct UniformKernel { + void operator()(TensorIteratorBase& iter, double from, double to, std::optional gen) { + uniform_kernel(iter, from, to, check_generator(gen)); + } +}; + +// ================================================== LogNormal ======================================================= + +template +void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_zoom", [&] { + using accscalar_t = at::acc_type; + auto mean = static_cast(mean_); + auto std = static_cast(std_); + // define lambda for log_normal transformation + auto log_normal_func = [mean, std] __device__ (accscalar_t rand) { + return static_cast(transformation::log_normal(transformation::normal(rand, mean, std))); + }; + normal_and_transform(iter, gen, log_normal_func); + }); +} + +template +struct LogNormalKernel { + void operator()(TensorIteratorBase& iter, double mean, double std, std::optional gen) { + log_normal_kernel(iter, mean, std, check_generator(gen)); + } +}; + +// =================================================== Geometric ====================================================== + +template +void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_zoom", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for geometric transformation + auto geometric_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::geometric(rand, p)); + }; + uniform_and_transform(iter, gen, geometric_func); + }); +} + +template +struct GeometricKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + geometric_kernel(iter, p, check_generator(gen)); + } +}; + +// ================================================== Exponential ===================================================== + +template +void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) { + TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_zoom", [&] { + using accscalar_t = at::acc_type; + auto lambda = static_cast(lambda_); + // define lambda for exponential transformation + auto exponential_func = [lambda] __device__ (accscalar_t rand) { + return static_cast(transformation::exponential(rand, lambda)); + }; + uniform_and_transform(iter, gen, exponential_func); + }); +} + +template +struct ExponentialKernel { + void operator()(TensorIteratorBase& iter, double lambda, std::optional gen) { + exponential_kernel(iter, lambda, check_generator(gen)); + } +}; + +// ==================================================== Cauchy ======================================================== + +template +void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_zoom", [&] { + using accscalar_t = at::acc_type; + auto median = static_cast(median_); + auto sigma = static_cast(sigma_); + // define lambda for cauchy transformation + auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) { + return static_cast(transformation::cauchy(rand, median, sigma)); + }; + uniform_and_transform(iter, gen, cauchy_func); + }); +} + +template +struct CauchyKernel { + void operator()(TensorIteratorBase& iter, double median, double sigma, std::optional gen) { + cauchy_kernel(iter, median, sigma, check_generator(gen)); + } +}; + +// ==================================================== Bernoulli ===================================================== + +template +void bernoulli_tensor_zoom_kernel( + const TensorBase &ret, const at::TensorBase &p, + PhiloxHIPState philox_args) { + auto functor = [philox_args] __device__( + int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4, + const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) { + auto seeds = at::zoom::philox::unpack(philox_args); + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + // See Note [Register spilling in curand call for CUDA < 10] + float4 rand = hiprand_uniform4(&state); + switch (n) { + case 4: { + ZOOM_KERNEL_ASSERT(0 <= p4 && p4 <= 1); + v4 = static_cast(rand.w <= p4); + // fallthrough + } + case 3: { + ZOOM_KERNEL_ASSERT(0 <= p3 && p3 <= 1); + v3 = static_cast(rand.z <= p3); + // fallthrough + } + case 2: { + ZOOM_KERNEL_ASSERT(0 <= p2 && p2 <= 1); + v2 = static_cast(rand.y <= p2); + // fallthrough + } + case 1: { + ZOOM_KERNEL_ASSERT(0 <= p1 && p1 <= 1); + v1 = static_cast(rand.x <= p1); + } + } + }; + // The template argument `4` below indicates that we want to operate on four + // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details. + at::zoom::Zoom_tensor_apply2(ret, p, functor); +} + +template +void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) { + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(10); + } + TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type()); + // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else + const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat; + auto p_zoom = p_.to(TensorOptions().device(self.device()).dtype(p_type)); + auto p = expand_inplace(self, p_zoom); + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_zoom_self_", [&] { + if (std::is_same::value) { + return bernoulli_tensor_zoom_kernel(self, *p, rng_engine_inputs); + } else { + return bernoulli_tensor_zoom_kernel(self, *p, rng_engine_inputs); + } + }); +} + +template +void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) { + AT_DISPATCH_ALL_TYPES_AND3( + at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_zoom_", [&] { + using accscalar_t = at::DiscreteDistributionType::type; + // define lambda for bernoulli transformation + auto bernoulli_func = [p] __device__ (accscalar_t rand) { + return static_cast(transformation::bernoulli(rand, p)); + }; + uniform_and_transform(iter, gen, bernoulli_func); + }); +} + +template +struct BernoulliKernel { + void operator()(TensorIteratorBase& iter, double p, std::optional gen) { + bernoulli_kernel(iter, p, check_generator(gen)); + } + void operator()(const TensorBase &self, const TensorBase &p_, std::optional gen) { + bernoulli_kernel(self, p_, check_generator(gen)); + } +}; + +}}}} \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/DistributionUniform.cu b/aten/src/ATen/native/zoom/DistributionUniform.cu new file mode 100644 index 0000000000000..25ed5e7b8b114 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionUniform.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void uniform_kernel(TensorIteratorBase& iter, double from, double to, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + templates::zoom::uniform_kernel(iter, from, to, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(uniform_stub, &uniform_kernel); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Distributions.cpp b/aten/src/ATen/native/zoom/Distributions.cpp new file mode 100644 index 0000000000000..077d4d41b6afa --- /dev/null +++ b/aten/src/ATen/native/zoom/Distributions.cpp @@ -0,0 +1,84 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +Tensor _s_poisson_zoom(const Tensor& lambda, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + Tensor ret = at::empty(lambda.sizes(), lambda.options()); + launch_poisson_zoom_kernel(ret, lambda, gen); + return ret; +} + +Tensor _s_binomial_zoom(const Tensor& count, const Tensor& prob, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + Tensor ret = at::empty(count.sizes(), count.options()); + at::TensorIterator iter = at::TensorIteratorConfig() + .add_output(ret) + .add_input(count) + .add_input(prob) + .build(); + launch_binomial_zoom_kernel(iter, gen); + return ret; +} + +Tensor _s_gamma_zoom(const Tensor& alpha, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + Tensor ret = at::empty(alpha.sizes(), alpha.options()); + launch_gamma_kernel(ret, alpha, gen); + return ret; +} + +Tensor _s_dirichlet_zoom(const Tensor& alpha, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + Tensor ret = at::empty(alpha.sizes(), alpha.options()); + launch_gamma_kernel(ret, alpha, gen); + auto gamma_sum = ret.sum(/*dim=*/-1, /*keepdim=*/true); + at::TensorIterator iter = at::TensorIteratorConfig() + .add_output(ret) + .add_input(ret) + .add_input(gamma_sum) + .build(); + launch_dirichlet_kernel(iter); + return ret; +} + +Tensor _standard_gamma_grad_zoom(const Tensor& self, const Tensor& output) { + Tensor ret = at::empty(self.sizes(), self.options()); + TensorIterator iter = at::TensorIteratorConfig() + .add_output(ret) + .add_input(self) + .add_input(output) + .build(); + launch_standard_gamma_grad_kernel(iter); + return ret; +} + +Tensor _dirichlet_grad_zoom(const Tensor& x, const Tensor& alpha, const Tensor& total) { + Tensor ret = at::empty(x.sizes(), x.options()); + TensorIterator iter = at::TensorIteratorConfig() + .add_output(ret) + .add_input(x) + .add_input(alpha) + .add_input(total) + .build(); + launch_dirichlet_grad_kernel(iter); + return ret; +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Distributions.h b/aten/src/ATen/native/zoom/Distributions.h new file mode 100644 index 0000000000000..c395fd26a795a --- /dev/null +++ b/aten/src/ATen/native/zoom/Distributions.h @@ -0,0 +1,25 @@ +#pragma once + +namespace at { +struct ZoomGeneratorImpl; +struct TensorIteratorBase; +class TensorBase; + +namespace native { + +void launch_poisson_zoom_kernel( + const TensorBase &ret, const TensorBase &lambda, ZoomGeneratorImpl *gen); + +void launch_gamma_kernel( + const TensorBase &ret, const TensorBase &alpha, ZoomGeneratorImpl *gen); + +void launch_binomial_zoom_kernel( + TensorIteratorBase &iter, ZoomGeneratorImpl *gen); + +void launch_dirichlet_kernel(TensorIteratorBase &iter); + +void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter); + +void launch_dirichlet_grad_kernel(TensorIteratorBase &iter); + +}} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/DistributionsKernels.cu b/aten/src/ATen/native/zoom/DistributionsKernels.cu new file mode 100644 index 0000000000000..14a1ac3f82660 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionsKernels.cu @@ -0,0 +1,204 @@ +// #define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +#include +#include +//#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +/** + * Note [Register spilling in curand call for CUDA < 10] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * For CUDA < 10, hiprandStatePhilox4_32_10_t engine achieves poor performance (60% SOL bandwidth) + * when called to generate one random number at a time. This is because the line + * unsigned ret = (&state->output.x)[state->STATE++]; + * in + * QUALIFIERS unsigned int curand(hiprandStatePhilox4_32_10_t *state) + * in curand_kernel.h dynamically indexes into state.output, preventing the compiler from ever + * storing state.output in registers. + * + * CUDA 10 fixed this problem. However, for backwards compatibility, in the following kernels + * we are using curand distributions that utilize hiprand4 call. hiprand4 call doesn't have the + * register spilling problem. + */ + +namespace { + +template +void poisson_zoom_kernel( + const at::TensorBase &ret, + const at::TensorBase &lambda, + at::PhiloxHIPState philox_args) { + auto functor = [philox_args] __device__( + scalar_t & ret_val, const scalar_t& lambda) { + ZOOM_KERNEL_ASSERT(lambda >= 0 && "invalid Poisson rate, expected rate to be non-negative"); + auto seeds = at::zoom::philox::unpack(philox_args); + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + ret_val = static_cast(hiprand_poisson(&state, lambda)); + }; + at::zoom::Zoom_tensor_apply2(ret, lambda, functor); +} + +struct hiprand_uniform_wrapper { + hiprandStatePhilox4_32_10_t &state; + __device__ hiprand_uniform_wrapper(hiprandStatePhilox4_32_10_t &state): state(state) {} + __device__ float operator()() { + + uint32_t val = hiprand(&state); //need just bits + constexpr auto MASK = static_cast((static_cast(1) << std::numeric_limits::digits) - 1); + constexpr auto DIVISOR = static_cast(1) / (static_cast(1) << std::numeric_limits::digits); + return (val & MASK) * DIVISOR; + } +}; + +template +void binomial_zoom_kernel( + at::TensorIteratorBase &iter, + at::PhiloxHIPState philox_args) { + using accscalar_t = at::acc_type; + + at::native::distribution_binary_kernel(iter, philox_args, + [] GPU_LAMBDA (hiprandStatePhilox4_32_10_t& state, scalar_t count, scalar_t prob) { + auto uniform_lambda = hiprand_uniform_wrapper(state); + BaseSampler standard_uniform(uniform_lambda); + auto sample = sample_binomial(count, prob, standard_uniform); + return static_cast(sample); + } + ); +} + +template +void gamma_zoom_kernel( + const at::TensorBase &ret, + const at::TensorBase &alpha, + at::PhiloxHIPState philox_args) { + using accscalar_t = at::acc_type; + auto functor = [philox_args] __device__( + scalar_t & ret_val, const scalar_t& alpha) { + auto seeds = at::zoom::philox::unpack(philox_args); + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + blockIdx.x * blockDim.x + threadIdx.x, + std::get<1>(seeds), + &state); + + auto uniform_lambda = [&state] __device__ () { + return hiprand_uniform(&state); + }; + BaseSampler standard_uniform(uniform_lambda); + + auto normal_lambda = [&state] __device__ () { + return hiprand_normal(&state); + }; + BaseSampler standard_normal(normal_lambda); + auto sample = sample_gamma(alpha, standard_uniform, standard_normal); + auto min_value = std::numeric_limits::min(); + ret_val = (min_value > sample) ? min_value : sample; + }; + at::zoom::Zoom_tensor_apply2(ret, alpha, functor); +} + +} // namespace + +namespace at::native { + +void launch_dirichlet_kernel(at::TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + iter.input_dtype(), "dirichlet_zoom", [&] { + at::native::gpu_kernel( + iter, + [] GPU_LAMBDA (scalar_t gamma, scalar_t gamma_sum) { + auto ret_val = gamma / gamma_sum; + auto min_value = std::numeric_limits::min(); + auto max_value = 1 - std::numeric_limits::epsilon(); + ret_val = (min_value > ret_val) ? min_value : ret_val; + ret_val = (max_value < ret_val) ? max_value : ret_val; + return ret_val; + }); + }); +} + +void launch_poisson_zoom_kernel( + const TensorBase &ret, const TensorBase &lambda, ZoomGeneratorImpl *gen) { + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(20); + } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "poisson_zoom", [&] { + poisson_zoom_kernel(ret, lambda, rng_engine_inputs); + }); +} + +void launch_binomial_zoom_kernel( + TensorIteratorBase &iter, ZoomGeneratorImpl *gen) { + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(42); + } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "binomial_zoom", [&] { + binomial_zoom_kernel(iter, rng_engine_inputs); + }); +} + +void launch_gamma_kernel( + const TensorBase &ret, const TensorBase &alpha, ZoomGeneratorImpl *gen) { + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(10); + } + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "gamma_zoom", [&] { + gamma_zoom_kernel(ret, alpha, rng_engine_inputs); + }); +} + +void launch_standard_gamma_grad_kernel(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "_standard_gamma_grad_zoom", [&] { + using accscalar_t = at::acc_type; + gpu_kernel(iter, + [] GPU_LAMBDA (scalar_t self_val, scalar_t output_val) { + return standard_gamma_grad_one(self_val, output_val); + }); + }); +} + +void launch_dirichlet_grad_kernel(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(), "_dirichlet_grad_zoom", [&] { + using accscalar_t = at::acc_type; + at::native::gpu_kernel(iter, + [] GPU_LAMBDA (scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t { + return dirichlet_grad_one(x_val, alpha_val, total_val); + }); + }); +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/Dropout.cu b/aten/src/ATen/native/zoom/Dropout.cu new file mode 100644 index 0000000000000..a4dc5cf223d4f --- /dev/null +++ b/aten/src/ATen/native/zoom/Dropout.cu @@ -0,0 +1,412 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +namespace { + +// philox generates 128 bits of randomness at a time. Kernel uses this explicitly by putting suitably transformed result into float4 +// for all members of float4 to be consumed UNROLL has to be 4. Don't change! +// Note: VEC <= 4 (and in most real-world cases will be 4), so same logic applies. +const int UNROLL = 4; + +template < + typename scalar_t, + typename accscalar_t, + typename IndexType, + int ADims, + int VEC, + typename mask_t> +C10_LAUNCH_BOUNDS_2(256, 4) +__global__ void +fused_dropout_kernel_vec(at::zoom::detail::TensorInfo a, + at::zoom::detail::TensorInfo b, + at::zoom::detail::TensorInfo c, + IndexType totalElements, accscalar_t p, + PhiloxHIPState philox_args) { + // make sure we don't break assumption that we can't have > 4 elements / thread + static_assert(VEC <= 4, "Value of VEC must be in [2, 4]"); + + using LoadT = memory::aligned_vector; + using MaskLoadT = memory::aligned_vector; + + auto seeds = at::zoom::philox::unpack(philox_args); + IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + + // Helps align the total number of times hiprand_uniform4 is called by each thread for the same totalElements + // in the vec=2 and vec=4 cases. + bool gridxvec_loop_state = 0; + accscalar_t scale = 1.0 / p; + + float4 rand; + + // Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time + for (IndexType linearIndex = idx * VEC; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x * VEC) { + // local storage + scalar_t src[VEC]; + // We'll use this to actually cause vectorized loads later + LoadT *value = reinterpret_cast(&src); + + //curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything + // Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4) + // sets of rand. + if ((VEC == 4) || (gridxvec_loop_state == 0)) { + rand = hiprand_uniform4(&state); + } else { + // sets up the last two values we generated last iteration to be used this iteration. + rand.x = rand.z; + rand.y = rand.w; + gridxvec_loop_state ^= 1; + } + + rand.x = rand.x < p; + rand.y = rand.y < p; + if (VEC == 4) { + rand.z = rand.z < p; + rand.w = rand.w < p; + } + + // Note: We explicitly check for is_contiguous() before launching the vectorized kernel + // and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other) + // ordering. + // Single vectorized load + *value = *reinterpret_cast(&a.data[linearIndex]); + + scalar_t r[VEC]; + mask_t mask[VEC]; + + // Perform the actual computation + #pragma unroll + for (int ii = 0; ii < VEC; ii++) { + r[ii] = src[ii]*(&rand.x)[ii]*scale; + mask[ii] = (mask_t)(&rand.x)[ii]; + } + // Vectorized writes for both mask & result + *(reinterpret_cast(&b.data[linearIndex])) = *reinterpret_cast(&r[0]); + *(reinterpret_cast(&c.data[linearIndex])) = *reinterpret_cast(&mask[0]); + + __syncthreads(); + } +} + +template < + typename scalar_t, + typename accscalar_t, + typename IndexType, + int ADims, + int BDims = ADims, + typename mask_t> +C10_LAUNCH_BOUNDS_2(256, 4) +__global__ void +fused_dropout_kernel(zoom::detail::TensorInfo a, + zoom::detail::TensorInfo b, + zoom::detail::TensorInfo c, + IndexType totalElements, accscalar_t p, + PhiloxHIPState philox_args) { + auto seeds = at::zoom::philox::unpack(philox_args); + IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + hiprandStatePhilox4_32_10_t state; + hiprand_init(std::get<0>(seeds), + idx, + std::get<1>(seeds), + &state); + accscalar_t scale = 1.0 / p; + + IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * + blockDim.x * gridDim.x * UNROLL; + for (IndexType linearIndex = idx; + linearIndex < rounded_size; + linearIndex += gridDim.x * blockDim.x*UNROLL) { +//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything + float4 rand = hiprand_uniform4(&state); + scalar_t src[UNROLL]; + rand.x = rand.x < p; + rand.y = rand.y < p; + rand.z = rand.z < p; + rand.w = rand.w < p; + for (int ii = 0; ii < UNROLL; ii++) { + IndexType li = linearIndex + blockDim.x * gridDim.x * ii; + if (li < totalElements) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = + zoom::detail::IndexToOffset::get(li, a); + src[ii] = a.data[aOffset]; + } + } + for (int ii = 0; ii < UNROLL; ii++) { + IndexType li = linearIndex + blockDim.x * gridDim.x * ii; + if (li < totalElements) { + // Convert `linearIndex` into an offset of `b` + const IndexType bOffset = + zoom::detail::IndexToOffset::get(li, b); + b.data[bOffset] = src[ii]*(&rand.x)[ii]*scale; + c.data[bOffset] = (mask_t)(&rand.x)[ii]; + } + } + __syncthreads(); + } +} + +template +void masked_scale_kernel(at::Tensor& ret, const at::Tensor& src, const at::Tensor& mask, accscalar_t scale){ + auto iter = at::TensorIteratorConfig() + .check_all_same_dtype(false) + .add_output(ret) + .add_const_input(src) + .add_const_input(mask) + .build(); + + at::native::gpu_kernel( + iter, + [=]GPU_LAMBDA(const scalar_t src_val, const mask_t mask_val) -> scalar_t { + return (float)mask_val * src_val * scale; + }); +} + +template +int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { + int vec_size = 4; + // get the vector size + if (!self.is_non_overlapping_and_dense() || !ret.is_non_overlapping_and_dense() || !mask.is_non_overlapping_and_dense()) { + vec_size = 1; + } else { + vec_size = memory::can_vectorize_up_to((const char*)self.const_data_ptr()); + } + + // check that we'd have no remainders - prefer a smaller vector size with no remainders over a larger vector and remainder. + bool can_vectorize = true; + do { + can_vectorize = self.numel() % vec_size == 0 && ret.numel() % vec_size == 0 && mask.numel() % vec_size == 0; + if (!can_vectorize) vec_size /= 2; + } while (vec_size > 1 && !can_vectorize); + return can_vectorize ? vec_size : 1; +} + +template +inline void launcher( + const Tensor& self, + Tensor& ret, + Tensor& mask, + double p, + const int64_t nelem, + const PhiloxHIPState rng_engine_inputs, + dim3 grid, + dim3 dim_block) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "fused_dropout", + [&] { + using accscalar_t = acc_type; + accscalar_t pa = (accscalar_t)(p); + auto self_info = + zoom::detail::getTensorInfo(self); + auto ret_info = + zoom::detail::getTensorInfo(ret); + auto mask_info = + zoom::detail::getTensorInfo(mask); + self_info.collapseDims(); + ret_info.collapseDims(); + mask_info.collapseDims(); // ret and mask are collapsed to 1d + // contiguous tensor + + int vec_size = get_vector_size(self, ret, mask); + + if (vec_size > 1) { + switch (vec_size) { + case 4: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 4> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + case 2: + fused_dropout_kernel_vec< + scalar_t, + accscalar_t, + index_type, + 1, + 2> + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + } + } else { + switch (self_info.dims) { + case 1: + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + break; + default: + if (!self.is_contiguous() && ret.is_contiguous() && + mask.is_contiguous()) { + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + fused_dropout_kernel + <<>>( + self_info, + ret_info, + mask_info, + nelem, + pa, + rng_engine_inputs); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + } + } + }); +} + +} //anonymous namespace + +template +std::tuple +dropout_zoom(ZoomGeneratorImpl* gen, const Tensor& self, double p){ + Tensor mask = at::empty_like(self, self.options().dtype(c10::CppTypeToScalarType::value)); + const int64_t nelem = self.numel(); + // empty tensors should not get here, but just in case, avoid FPE + // non-training shot-cut + if (nelem==0) return std::tuple(self.clone(), mask); + + Tensor ret = at::empty_like(self); + const int64_t block_size = 256; + unsigned int blocks_per_sm = at::zoom::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; + dim3 dim_block(block_size); + dim3 grid((nelem + block_size -1)/block_size); + grid.x = std::min((unsigned int)at::zoom::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); +//number of times random will be generated per thread, to offset philox counter in thc random state + int64_t counter_offset = ((nelem - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; + PhiloxHIPState rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_hip_state(counter_offset); + } + if (zoom::detail::canUse32BitIndexMath(self)){ + launcher( + self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); + } else { + launcher( + self, ret, mask, p, nelem, rng_engine_inputs, grid, dim_block); + } + return std::tuple(ret, mask); +} + +std::tuple +native_dropout_zoom(const Tensor& self, double p, std::optional train){ + // short-cut for train == false + if (train.has_value() && !train.value()) { + return std::make_tuple(self.clone(), at::ones_like(self, self.options().dtype(c10::CppTypeToScalarType::value))); + } + // short-cut + if (p == 1) { + // native_dropout_cuda is in derivatives.yaml, so we don't need to add data + // dependency from output to input for autograd + auto ret = at::zeros_like(self); + auto mask = at::zeros_like(self, self.options().dtype(c10::CppTypeToScalarType::value)); + return std::tuple(ret, mask); + } + + auto gen = get_generator_or_default(c10::nullopt, zoom::detail::getDefaultZoomGenerator()); + double p1m = 1. - p; + return dropout_zoom(gen, self, p1m); +} + +// TODO: _fused_dropout_cuda is to be removed, see PR #63937 +std::tuple +fused_dropout_zoom(const Tensor& self, double p, std::optional gen_){ + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + return dropout_zoom(gen, self, p); +} + +template +Tensor dropout_backward_zoom(const Tensor& grad, const Tensor& mask, double scale){ + Tensor ret = at::empty_like(grad, grad.suggest_memory_format()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "masked_scale", [&] { + using accscalar_t = acc_type; + masked_scale_kernel(ret, grad, mask, (accscalar_t)scale); + }); + return ret; +} + +Tensor native_dropout_backward_zoom(const Tensor& grad, const Tensor& mask, double scale){ + TORCH_CHECK(mask.scalar_type() == at::ScalarType::Bool, "Mask should be Bool Scalar Type", mask.scalar_type()); + return dropout_backward_zoom(grad, mask, scale); +} + +// TODO: masked_scale_cuda is to be removed, see PR #63937 +Tensor masked_scale_zoom(const Tensor& self, const Tensor& mask, double scale){ + TORCH_CHECK(mask.scalar_type() == at::ScalarType::Byte, "mask should be torch.uint8 dtype"); + return dropout_backward_zoom(self, mask, scale); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Embedding.cu b/aten/src/ATen/native/zoom/Embedding.cu new file mode 100644 index 0000000000000..413f1bdfe7035 --- /dev/null +++ b/aten/src/ATen/native/zoom/Embedding.cu @@ -0,0 +1,383 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#if CUB_SUPPORTS_SCAN_BY_KEY() +#include +#endif + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +namespace at::native { + +namespace { + +static const int BLOCKDIMY = 16; + +template + +__global__ void embedding_backward_feature_kernel + (const index_t* indices, + const scalar_t* __restrict__ grad, + scalar_t* __restrict__ grad_weight, + int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot + int64_t stride, + int padding_idx) +{ + extern __shared__ char buf[]; + accscalar_t* smem = (accscalar_t*)buf; + accscalar_t* my_s = smem + C10_WARP_SIZE*threadIdx.y; + int* indices_batch = (int*)(buf + sizeof(accscalar_t)*C10_WARP_SIZE*blockDim.y); + + const int s = (int)stride; // OK to make int, we don't expect 2 billion+ embedding row size + + const int f = threadIdx.x + blockIdx.x*blockDim.x; // feature_dim + + for(int batch_start = 0; batch_start < n; batch_start += blockDim.x*blockDim.y) + { + // Entire block cooperates to load a batch of 1024 indices to process + int tid = threadIdx.x + threadIdx.y*blockDim.x; + if(batch_start + tid < n) + indices_batch[tid] = (int)indices[batch_start + tid]; + + int batch_end = batch_start + blockDim.x*blockDim.y < n ? + batch_start + blockDim.x*blockDim.y : n; + + // Loop over the batch of <= 1024 loaded indices in chunks of blockDim.y = 32 + for(int chunk_start = batch_start; chunk_start < batch_end; chunk_start += blockDim.y) + { + // This does double duty: it makes sure indices_batch is ready, and it makes sure match-group + // leaders are done with their accumulates before other warps start loading again. + __syncthreads(); + + int n_this_chunk = (batch_end - chunk_start) < blockDim.y ? + (batch_end - chunk_start) : blockDim.y; + + int src_row = chunk_start + threadIdx.y; + int dst_row = indices_batch[src_row - batch_start]; // This warp's target row in grad_weight + + // All warps load their smem segments with incoming grad data + if(src_row < n && f < s && dst_row != padding_idx) + my_s[threadIdx.x] = static_cast(grad[src_row*stride + f]); + + __syncthreads(); + + // To ensure determinism, we can't just have each warp add its grad data to its dst_row. + // We need to check if any other warps pulled grad data targeting dst_row. + // If so, we elect the first warp in each matching group as the leader. + // Each leader warp serializes the accumulates targeting dst_row in shared memory, + // then finishes by adding the accumulated buffer to dst_row in grad_weight. + if(dst_row != padding_idx && src_row < n) // Per-warp exit condition, safe with ballot_sync + { + int match_found_this_thread = 0; + if(threadIdx.x < n_this_chunk) + match_found_this_thread = (dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]); + + unsigned long long int matchmask = WARP_BALLOT(match_found_this_thread); + int first_remaining_peer = __ffsll(matchmask) - 1; + + if(threadIdx.y == first_remaining_peer) // Nominate lowest-indexed warp as the leader + { + matchmask ^= (1 << first_remaining_peer); + while(matchmask) + { + first_remaining_peer = __ffsll(matchmask) - 1; + my_s[threadIdx.x] += smem[threadIdx.x + C10_WARP_SIZE*first_remaining_peer]; + matchmask ^= (1 << first_remaining_peer); + } + if(f < s) + grad_weight[dst_row*stride + f] += static_cast(my_s[threadIdx.x]); + } + } + } + } +} + + +template +__global__ void embedding_backward_kernel( + index_t* input, index_t* indices, scalar_t* grad_output, scalar_t* grad_weight, + index_t* count, int64_t numel, int64_t stride, int padding_idx) { + + using accscalar_t = acc_type; + int idx = blockIdx.x * 4 + threadIdx.y; + + // Each warp is responsible for an input into the LookupTable. + // If the preceding input has the same as this input, then the warp + // exits immediately. The warp also processes subsequent inputs with the + // same value. + // + // Input Warp + // 1 + // 1 ( exits without doing any work) + // 5 + // 8 + + // Number of values processed by each thread (grain size) + const int SZ = 4; + + if (idx < numel + && (idx == 0 || input[idx] != input[idx - 1]) + && input[idx] != padding_idx) { + do { + const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ; + const int weight_row = ((int) input[idx]) * stride; + const int grad_row = ((int) indices[idx]) * stride; + const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0; + + accscalar_t gradient[SZ]; + accscalar_t weight[SZ]; + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int feature_dim = start_feature + ii * C10_WARP_SIZE; + if (feature_dim < stride) { + gradient[ii] = static_cast(grad_output[grad_row + feature_dim]); + weight[ii] = static_cast(grad_weight[weight_row + feature_dim]); + } + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + weight[ii] += gradient[ii] * scale; + } + + #pragma unroll + for (int ii = 0; ii < SZ; ii++) { + int feature_dim = start_feature + ii * C10_WARP_SIZE; + if (feature_dim < stride) { + grad_weight[weight_row + feature_dim] = static_cast(weight[ii]); + } + } + + idx++; + } while (idx < numel && input[idx] == input[idx - 1]); + } +} + +/* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */ +template +__global__ void renorm_kernel( + scalar_t* weights, index_t* indices, accscalar_t max_norm, + accscalar_t norm_type, int64_t dim, + int64_t weights_stride0, int64_t weights_stride1, + const int64_t *num_unique_indices) { + if (blockIdx.x >= *num_unique_indices) { + return; + } + + // Some casting hacks since dynamic shared memory and templates don't work together: + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + + int tid = threadIdx.x; + int base_index = indices[blockIdx.x] * weights_stride0; + + accscalar_t v = 0; + for (int i = tid; i < dim; i += blockDim.x) { + auto x = static_cast(weights[base_index + i * weights_stride1]); + if (norm_type == 1) { + v += std::abs(x); + } else if (norm_type == 2) { + v += x * x; + } else { + v += std::pow(x, norm_type); + } + } + + v = zoom_utils::BlockReduceSum(v, sdata); + + if (tid == 0) { + sdata[0] = std::pow(v, static_cast(1.0 / norm_type)); + } + __syncthreads(); + + // now we renormalize the blocks that need it + if (sdata[0] > max_norm) { + auto factor = static_cast(max_norm / (sdata[0] + 1e-7)); + for (int i = tid; i < dim; i += blockDim.x) { + weights[base_index + i * weights_stride1] *= factor; + } + } +} + +} // anonymous namespace + +#if !CUB_SUPPORTS_SCAN_BY_KEY() +template +void embedding_dense_backward_zoom_scan(Tensor &sorted_indices, Tensor &count); +#endif + +Tensor embedding_dense_backward_zoom(const Tensor & grad_, const Tensor & indices_, + int64_t num_weights, int64_t padding_idx, + bool scale_grad_by_freq) { + auto grad_arg = TensorArg(grad_, "grad", 1); + auto indices_arg = TensorArg(indices_, "indices", 1); + checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt}); + checkSameGPU("embedding_backward", grad_arg, indices_arg); + + auto indices = indices_.contiguous(); + + auto num_indices = indices.numel(); + auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)}); + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + if (num_indices <= 3072 && !scale_grad_by_freq) { + auto indices_contig = indices.contiguous(); + auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options()); + int64_t stride = grad_weight.stride(0); + int warp_size = at::zoom::warp_size(); + dim3 grid(ceil_div(stride, (int64_t)warp_size)); + dim3 block(warp_size, BLOCKDIMY); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, + grad.scalar_type(), + "embedding_backward", + [&] + { + using accscalar_t = acc_type; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_zoom", [&] () { + embedding_backward_feature_kernel + <<>> + (indices_contig.const_data_ptr(), + grad.const_data_ptr(), + grad_weight.mutable_data_ptr(), + static_cast(num_indices), + static_cast(stride), + static_cast(padding_idx)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + }); + return grad_weight; + } + + auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor count; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_zoom", [&] () { + auto range = at::arange(num_indices, indices.options()); + int64_t nbits = zoom::hipcub::get_num_bits(num_weights); + zoom::hipcub::radix_sort_pairs( + indices.const_data_ptr(), sorted_indices.mutable_data_ptr(), + range.const_data_ptr(), orig_indices.mutable_data_ptr(), + num_indices, false/*, 0, nbits*/); + }); + + if (scale_grad_by_freq) { + count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); +#if CUB_SUPPORTS_SCAN_BY_KEY() + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_zoom", [&] () { + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = sorted_indices.const_data_ptr(); + auto count_data = count.mutable_data_ptr(); + zoom::hipcub::inclusive_sum_by_key( + sorted_data, + at_zoom_detail::hipcub::ConstantInputIterator(1), + count_data, + num_indices + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + zoom::hipcub::inclusive_scan_by_key( + thrust::make_reverse_iterator(sorted_data + num_indices), + thrust::make_reverse_iterator(static_cast(count_data) + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + at_zoom_detail::hipcub::Max(), + num_indices + ); + }); +#else + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_zoom", [&] () { + embedding_dense_backward_zoom_scan(sorted_indices, count); + }); +#endif + } + + return embedding_backward_zoom_kernel(grad, orig_indices, + sorted_indices, count, num_weights, padding_idx); +} + +Tensor & embedding_renorm_zoom_(Tensor & self, const Tensor & indices, + double max_norm, double norm_type) { + auto self_arg = TensorArg(self, "self", 1); + auto indices_arg = TensorArg(indices, "indices", 1); + checkDim("embedding_renorm_", self_arg, 2); + checkSameGPU("embedding_renorm", self_arg, indices_arg); + + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_zoom_", [&] () { + + auto num_indices = indices.numel(); + auto indices_contig = std::get<0>(indices.sort()).contiguous(); + auto unique_indices = at::empty(indices.numel(), indices.options()); + auto num_unique_indices = at::empty({}, indices.options().dtype(kLong)); + + zoom::hipcub::unique( + indices_contig.const_data_ptr(), + unique_indices.mutable_data_ptr(), + num_unique_indices.mutable_data_ptr(), + num_indices + ); + + int warp_size = at::zoom::warp_size(); + TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 && + num_threads() <= zoom_utils::kHIPBlockReduceMaxThreads, + "BlockReduceSum requires all warps be active"); + const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr(); + dim3 grid = unique_indices.numel(); + dim3 block = num_threads(); + int dim = self.stride(0); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_zoom_", [&] { + using accscalar_t = acc_type; + renorm_kernel<<>>( + self.mutable_data_ptr(), + unique_indices.const_data_ptr(), + static_cast(max_norm), + static_cast(norm_type), + dim, self.stride(0), self.stride(1), + num_unique_indices_ptr); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + }); + return self; +} + + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/zoom/EmbeddingBackwardKernel.cu new file mode 100644 index 0000000000000..8187318bd6156 --- /dev/null +++ b/aten/src/ATen/native/zoom/EmbeddingBackwardKernel.cu @@ -0,0 +1,365 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#if CUB_SUPPORTS_UNIQUE_BY_KEY() +#include +#endif + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::native { + +namespace { + +/* This code computes the sum of the weights in two-steps: + 1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces` + 2) Each partial-sum from 1) are summed and scatter into `grad_weight` + + Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the + kernel execution. If it is high, the size of the thread blocks will be + too small to achieve good occupancy. Similarly, a very low value will + make the size of the thread blocks in the final sum in step 2) too small. +*/ +constexpr int NROWS_PER_THREAD = 10; + +// Fast ceil division (no overflow checking) +__host__ __device__ __forceinline__ +int64_t ceil_div(int64_t x, int64_t y) { + return (x + y - 1) / y; +} + +template +__global__ +void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets, + const int64_t *num_of_segments_ptr, int64_t numel) { + int64_t num_of_segments = *num_of_segments_ptr; + const int id = blockIdx.x * blockDim.x + threadIdx.x; + if(id < num_of_segments) { + const int64_t idx_start = segment_offsets[id]; + const int64_t idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1]; + const int64_t size = idx_end - idx_start; + ret[id] = ceil_div(size, NROWS_PER_THREAD); + } +} + +template +__global__ +void krn_partial_segment_offset( + index_t *ret, + const index_t *partials_per_segment, + const index_t *partials_per_segment_offset, + const index_t *segment_offsets, + const int64_t *num_of_segments_ptr) { + int64_t num_of_segments = *num_of_segments_ptr; + const int id = blockIdx.x * blockDim.x + threadIdx.x; + if(id < num_of_segments) { + index_t idx = partials_per_segment_offset[id]; + const index_t num_partials = partials_per_segment[id]; + const index_t segment_offset = segment_offsets[id]; + for (int64_t i=0; i +__global__ void compute_grad_weight_bags( + const index_t *indices, const scalar_t *gradOutput, + const index_t *offset2bag, const index_t *count, ptrdiff_t numel, + int64_t stride, int mode_mean, const index_t *bag_size, + const scalar_t* per_sample_weights, int64_t per_sample_weights_stride, + const index_t* segment_offsets, const int64_t *num_of_segments_ptr, + acc_type *grad_weight_per_segment, + const int64_t stride_warped) { + + int64_t num_of_segments = *num_of_segments_ptr; + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + const int id = gid / stride_warped; + const int startFeature = gid % stride_warped; + if (startFeature >= stride) { + return; + } + if (id >= num_of_segments) { + return; + } + const int idx_begin = segment_offsets[id]; + const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1]; + + acc_type weight = 0; + for (int idx=idx_begin; idx < idx_end; ++idx) { + const int origRow = indices[idx]; + const int seq_number = offset2bag[origRow]; + const int gradOutputRow = seq_number * stride; + + acc_type scale = count ? 1.0 / count[idx] : 1.0; + if (per_sample_weights) { + scale *= per_sample_weights[origRow * per_sample_weights_stride]; + } + + acc_type gradient = gradOutput[gradOutputRow + startFeature]; + if (mode_mean) { + gradient /= bag_size[seq_number]; + } + weight += gradient * scale; + } + grad_weight_per_segment[id * stride + startFeature] = weight; +} + +template +__global__ void compute_grad_weight( + const index_t *indices, + const scalar_t *gradOutput, + const index_t *count, + ptrdiff_t numel, + int64_t stride, + const index_t* segment_offsets, + const int64_t *num_of_segments_ptr, + acc_type *grad_weight_per_segment, + const int64_t stride_warped) { + + int64_t num_of_segments = *num_of_segments_ptr; + using accscalar_t = acc_type; + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + const int id = gid / stride_warped; + const int startFeature = gid % stride_warped; + if (startFeature >= stride) { + return; + } + if (id >= num_of_segments) { + return; + } + const int idx_begin = segment_offsets[id]; + const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1]; + + accscalar_t weight = 0; + for (int idx=idx_begin; idx < idx_end; ++idx) { + const index_t target_row = indices[idx]; + const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0; + weight += gradOutput[target_row * stride + startFeature] * scale; + } + grad_weight_per_segment[id * stride + startFeature] = weight; +} + +// This kernel assumes that all input tensors are contiguous. +template +__global__ void sum_and_scatter( + const index_t *input, scalar_t *gradWeight, int64_t stride, + const index_t* segment_offsets, const int64_t *num_of_segments_ptr, + const acc_type *grad_weight_per_segment, + const index_t *segment_sizes_offsets, const int64_t *num_of_partial_segments_ptr, + const int64_t padding_idx, + const int64_t stride_warped) { + + int64_t num_of_segments = *num_of_segments_ptr; + int64_t num_of_partial_segments = *num_of_partial_segments_ptr; + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + const int id = gid / stride_warped; + const int startFeature = gid % stride_warped; + if (startFeature >= stride) { + return; + } + if (id >= num_of_segments) { + return; + } + + const int idx_begin = segment_sizes_offsets[id]; + const int idx_end = (id == num_of_segments-1)?num_of_partial_segments:segment_sizes_offsets[id+1]; + acc_type weight = 0; + for (int idx=idx_begin; idx < idx_end; ++idx) { + weight += grad_weight_per_segment[idx*stride + startFeature]; + } + int64_t target_row = input[segment_offsets[id]]; + if (target_row != padding_idx) { + gradWeight[target_row * stride + startFeature] = weight; + } +} + +template +__global__ void compute_num_of_partial_segments(const index_t *partials_per_segment, const index_t *partials_per_segment_offset, const int64_t *num_of_segments_ptr, int64_t *output) { + int64_t num_of_segments = *num_of_segments_ptr; + *output = partials_per_segment[num_of_segments-1] + + partials_per_segment_offset[num_of_segments-1]; +} + +#if !CUB_SUPPORTS_UNIQUE_BY_KEY() +__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) { + *num_of_segments_ptr = num_of_segments; +} +#endif + +} // anon namespace + +#if !CUB_SUPPORTS_UNIQUE_BY_KEY() +template +int64_t embedding_backward_zoom_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets); +#endif + +Tensor embedding_backward_zoom_kernel( + const Tensor &grad, + const Tensor &orig_indices, + const Tensor &sorted_indices, + const Tensor &count, + int64_t num_weights, + int padding_idx, + bool mode_mean, + const Tensor &offset2bag, + const Tensor &bag_size, + const Tensor &per_sample_weights) { + + auto stream = c10::zoom::getCurrentZoomStream(); + const ptrdiff_t numel = sorted_indices.numel(); + + auto grad_weight = at::zeros({num_weights, grad.size(-1)}, grad.options()); + const int64_t stride = grad_weight.stride(0); + + // Compute the number of segments and their start position so that we do not have to + // spawn a warp per index. In this context, a segment is a number of rows that should + // be summarized. + // Unit: index in `sorted_indices` and `orig_indices` + auto segment_offsets = at::empty({numel}, orig_indices.options()); + auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong)); + int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr(); +#if !CUB_SUPPORTS_UNIQUE_BY_KEY() + AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_zoom_kernel", [&] () { + int64_t num_of_segments = embedding_backward_zoom_kernel_unique_by_key(sorted_indices, segment_offsets); + write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::zoom::getCurrentZoomStream()>>>(num_of_segments_ptr, num_of_segments); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +#else + AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_zoom_kernel", [&] () { + zoom::hipcub::unique_by_key( + sorted_indices.const_data_ptr(), thrust::make_counting_iterator(0), + nullptr, segment_offsets.mutable_data_ptr(), + num_of_segments_ptr, sorted_indices.numel()); + }); +#endif + + int64_t max_segments = std::min(numel, num_weights); + + AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_zoom_kernel", [&] () { + // We split the segments up into sizes of `NROWS_PER_THREAD` + // Compute the number partial-segments per segment (some partial-segments + // may not be the full `NROWS_PER_THREAD` number of rows) + auto partials_per_segment = at::empty({max_segments}, orig_indices.options()); + { + krn_partials_per_segment<<>> ( + partials_per_segment.mutable_data_ptr(), + segment_offsets.const_data_ptr(), + num_of_segments_ptr, + numel); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + + // In order to compute `partial_segment_offset`, which is the start index + // of each partial-segment in `sorted_indices`, we need to compute the + // start position of each _segment_ in `partial_segment_offset`. + // Unit: index in `partial_segment_offset` + auto partials_per_segment_offset = at::empty({max_segments}, orig_indices.options()); + zoom::hipcub::exclusive_sum( + partials_per_segment.const_data_ptr(), + partials_per_segment_offset.mutable_data_ptr(), + max_segments); + + // The total number of partial-segments is the sum of `partials_per_segment_offset` + auto num_of_partial_segments_tensor = at::empty({}, grad.options().dtype(kLong)); + int64_t *num_of_partial_segments_ptr = num_of_partial_segments_tensor.mutable_data_ptr(); + compute_num_of_partial_segments<<<1, 1, 0, c10::zoom::getCurrentZoomStream()>>>( + partials_per_segment.const_data_ptr(), + partials_per_segment_offset.const_data_ptr(), + num_of_segments_ptr, num_of_partial_segments_ptr); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + + auto max_partial_segment = numel / NROWS_PER_THREAD + max_segments; + + // Now we can compute the start position of each partial-segment + // Unit: index in `sorted_indices` and `orig_indices` + auto partial_segment_offset = at::empty({max_partial_segment}, orig_indices.options()); + { + krn_partial_segment_offset<<>> ( + partial_segment_offset.mutable_data_ptr(), + partials_per_segment.const_data_ptr(), + partials_per_segment_offset.const_data_ptr(), + segment_offsets.const_data_ptr(), + num_of_segments_ptr); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + + const int warp_size = at::zoom::warp_size(); + const int stride_warped = ceil_div(stride, warp_size)*warp_size; + const int block = std::min(stride_warped, MAX_BLOCK_SIZE); + const int grid = ceil_div(max_partial_segment*stride_warped, block); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + grad.scalar_type(), "embedding_bag_backward_zoom_compute_grad_weight", [&] { + // For numerical stability, the dtype of `grad_weight_per_segment` + // should match `acc_type` + using partial_weight_t = acc_type; + TensorOptions op; + if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) { + op = grad.options().dtype(at::kFloat); + } else { + op = grad.options(); + } + auto grad_weight_per_segment = at::empty({max_partial_segment, stride}, op); + // Compute the sum of each partial-segment and handle bags + if (offset2bag.defined()) { + compute_grad_weight_bags<<>>( + orig_indices.const_data_ptr(), + grad.const_data_ptr(), + offset2bag.const_data_ptr(), + count.defined() ? count.const_data_ptr() : nullptr, numel, stride, + mode_mean, bag_size.const_data_ptr(), + per_sample_weights.defined() ? per_sample_weights.const_data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, + partial_segment_offset.const_data_ptr(), + num_of_partial_segments_ptr, grad_weight_per_segment.mutable_data_ptr(), + stride_warped); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + compute_grad_weight<<>>( + orig_indices.const_data_ptr(), + grad.const_data_ptr(), + count.defined() ? count.const_data_ptr() : nullptr, + numel, stride, + partial_segment_offset.const_data_ptr(), + num_of_partial_segments_ptr, + grad_weight_per_segment.mutable_data_ptr(), + stride_warped); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + + // Finally, we sum all the partial-sums and scatter them + // into `grad_weight`. + const int grid2 = ceil_div(max_segments*stride_warped, block); + sum_and_scatter<<>>( + sorted_indices.const_data_ptr(), + grad_weight.mutable_data_ptr(), + stride, + segment_offsets.const_data_ptr(), + num_of_segments_ptr, grad_weight_per_segment.const_data_ptr(), + partials_per_segment_offset.const_data_ptr(), + num_of_partial_segments_ptr, + padding_idx, + stride_warped); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + }); + return grad_weight; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/EmbeddingBackwardKernel.cuh b/aten/src/ATen/native/zoom/EmbeddingBackwardKernel.cuh new file mode 100644 index 0000000000000..35f14ef997eab --- /dev/null +++ b/aten/src/ATen/native/zoom/EmbeddingBackwardKernel.cuh @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include +#include + +namespace at { +namespace native { + +Tensor embedding_backward_zoom_kernel( + const Tensor &grad, + const Tensor &orig_indices, + const Tensor &sorted_indices, + const Tensor &count, + int64_t num_weights, + int padding_idx = -1, + bool mode_mean = false, + const Tensor &offset2bag = Tensor(), + const Tensor &bag_size = Tensor(), + const Tensor &per_sample_weights = Tensor()); + +}} diff --git a/aten/src/ATen/native/zoom/EmbeddingBag.cu b/aten/src/ATen/native/zoom/EmbeddingBag.cu new file mode 100644 index 0000000000000..3ed55a07c39fe --- /dev/null +++ b/aten/src/ATen/native/zoom/EmbeddingBag.cu @@ -0,0 +1,560 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include +#include +#include +#include +#include + +#include + +#if CUB_SUPPORTS_SCAN_BY_KEY() +#include +#endif + +namespace at::native { + +#if !CUB_SUPPORTS_SCAN_BY_KEY() +template +void embedding_dense_backward_zoom_scan(Tensor &sorted_indices, Tensor &count); +#endif + +namespace { + +constexpr int MODE_SUM = 0; +constexpr int MODE_MEAN = 1; +constexpr int MODE_MAX = 2; + +std::pair promoteIndicesAndOffsets( + const Tensor& indices, + const Tensor& offsets) { + const auto commonType = + promoteTypes(offsets.scalar_type(), indices.scalar_type()); + return { + indices.scalar_type() == commonType ? indices + : indices.toType(commonType), + offsets.scalar_type() == commonType ? offsets + : offsets.toType(commonType)}; +} + +// This kernel assumes that all input tensors except `weight` and +// per_sample_weights are contiguous. +template +__global__ void EmbeddingBag_updateOutputKernel_max( + const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output, + index_t *offset2bag, int64_t numIndices, int64_t numBags, + int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1, + index_t *bag_size, index_t *max_indices, + index_t padding_idx, int64_t numRows) { + + // the strategy here is that each bag x feature is handled by a single thread + + int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x); + int64_t numChunks = numBags * chunksPerBag; + int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; + int64_t chunkStride = gridDim.x * blockDim.y; + + for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { + int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; + if (featureDim < featureSize) { + int64_t bag = chunk / chunksPerBag; + const scalar_t *weightFeat = weight + featureDim * weight_stride1; + int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it + int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices; + ZOOM_KERNEL_ASSERT(end >= begin); + scalar_t weightFeatMax = 0; + int64_t bag_size_ = 0; + int64_t maxWord = -1; + for (int64_t emb = begin; emb < end; emb++) { + bool pad = (input[emb] == padding_idx); + ZOOM_KERNEL_ASSERT(input[emb] < numRows); + const int64_t weightRow = input[emb] * weight_stride0; + scalar_t weightValue = weightFeat[weightRow]; + if (bag_size_ == 0 || weightValue > weightFeatMax) { + weightFeatMax = pad ? weightFeatMax : weightValue; + maxWord = pad ? maxWord : input[emb]; + } + bag_size_ += pad ? 0 : 1; + + if (featureDim == 0) { + offset2bag[emb] = bag; + } + } + bag_size[bag] = bag_size_; + max_indices[bag * featureSize + featureDim] = maxWord; + output[bag * featureSize + featureDim] = weightFeatMax; + } + } +} + +// This kernel assumes that all input tensors except `weight` and +// per_sample_weights are contiguous. +template +__global__ void EmbeddingBag_updateOutputKernel_sum_mean( + const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output, + index_t *offset2bag, int64_t numIndices, int64_t numBags, + int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1, + int mode, index_t *bag_size, + const scalar_t* per_sample_weights, int64_t per_sample_weights_stride, + index_t padding_idx, int64_t numRows) { + + // the strategy here is that each bag x feature is handled by a single thread + + using accscalar_t = acc_type; + int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x); + int64_t numChunks = numBags * chunksPerBag; + int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; + int64_t chunkStride = gridDim.x * blockDim.y; + + for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { + int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; + if (featureDim < featureSize) { + int64_t bag = chunk / chunksPerBag; + const scalar_t *weightFeat = weight + featureDim * weight_stride1; + int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it + int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices; + ZOOM_KERNEL_ASSERT(end >= begin); + accscalar_t weightFeatSum = 0; + int64_t bag_size_ = 0; + for (int64_t emb = begin; emb < end; emb++) { + bool pad = (input[emb] == padding_idx); + ZOOM_KERNEL_ASSERT(input[emb] < numRows); + const int64_t weightRow = input[emb] * weight_stride0; + scalar_t weightValue = weightFeat[weightRow]; + weightValue = pad ? static_cast(0) : weightValue; + if (per_sample_weights) { + accscalar_t scaleWeightBy = static_cast( + per_sample_weights[emb * per_sample_weights_stride]); + weightFeatSum += scaleWeightBy * static_cast(weightValue); + } else { + weightFeatSum += static_cast(weightValue); + } + bag_size_ += pad ? 0 : 1; + + if (featureDim == 0) { + offset2bag[emb] = bag; + } + } + if (mode == MODE_MEAN) { + if (bag_size_ != 0) { + weightFeatSum = weightFeatSum / static_cast(bag_size_); + } + } + bag_size[bag] = bag_size_; + output[bag * featureSize + featureDim] = static_cast(weightFeatSum); + } + } +} + +Tensor embedding_bag_backward_zoom_sum_avg( + const Tensor &grad, + const Tensor &indices_, + const Tensor &offset2bag, + const Tensor &bag_size, + int64_t num_weights, + bool scale_grad_by_freq, int64_t mode, + const Tensor& per_sample_weights, + int64_t padding_idx) { + auto indices = indices_.contiguous(); + + ptrdiff_t num_indices = indices.numel(); + + if (num_indices == 0) { + // all empty bags + return at::zeros({num_weights, grad.size(1)}, grad.options()); + } + + auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor count; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_zoom_sum_avg", [&] () { + auto range = at::arange(num_indices, indices.options()); + // int64_t nbits = zoom::hipcub::get_num_bits(num_weights); + zoom::hipcub::radix_sort_pairs( + indices.const_data_ptr(), sorted_indices.mutable_data_ptr(), + range.const_data_ptr(), orig_indices.mutable_data_ptr(), + num_indices, false/*, 0, nbits*/); + }); + + if (scale_grad_by_freq) { + count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); +#if CUB_SUPPORTS_SCAN_BY_KEY() + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_zoom_sum_avg", [&] () { + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + // Compute an increasing sequence per unique item in sortedIndices: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 1 2 3 1 2 1 1 2 + auto sorted_data = sorted_indices.const_data_ptr(); + auto count_data = count.mutable_data_ptr(); + zoom::hipcub::inclusive_sum_by_key( + sorted_data, + at_zoom_detail::hipcub::ConstantInputIterator(1), + count_data, + num_indices + ); + + // Take the maximum of each count per unique key in reverse: + // sorted: 2 5 5 5 7 7 8 9 9 + // count: 1 3 3 3 2 2 1 2 2 + zoom::hipcub::inclusive_scan_by_key( + thrust::make_reverse_iterator(sorted_data + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + thrust::make_reverse_iterator(count_data + num_indices), + at_zoom_detail::hipcub::Max(), + num_indices + ); + }); +#else + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_zoom_sum_avg", [&] () { + embedding_dense_backward_zoom_scan(sorted_indices, count); + }); +#endif + } + return embedding_backward_zoom_kernel(grad, orig_indices, sorted_indices, + count, num_weights, padding_idx, mode == MODE_MEAN, offset2bag, + bag_size, per_sample_weights); +} + +template +__global__ void EmbeddingBag_accGradParametersKernel_max( + const index_t *max_indices, const scalar_t *gradOutput, + scalar_t *gradWeight, int64_t stride, int64_t numBags, + index_t padding_idx, const index_t numel) { + + using accscalar_t = acc_type; + + int64_t chunksPerBag = ceil_div(stride, (int64_t)blockDim.x); + int64_t numChunks = numBags * chunksPerBag; + int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; + int64_t chunkStride = gridDim.x * blockDim.y; + + for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { + int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; + if (featureDim < stride) { + int64_t bag = chunk / chunksPerBag; + + index_t word_idx = max_indices[bag * stride + featureDim]; + if (word_idx >= 0 && word_idx != padding_idx) { + // If bag is empty, we have max_indices[idx] set to -1 in forward. + fastAtomicAdd( + gradWeight, static_cast(word_idx * stride + featureDim), + numel, gradOutput[bag * stride + featureDim], true); + } + } + } +} + +Tensor embedding_bag_backward_zoom_max(const Tensor &grad, + const Tensor &max_indices, + int64_t num_weights, + int64_t padding_idx) { + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("embedding_bag_backward_zoom_max"); + + auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options()); + + int64_t stride = grad_weight.stride(0); + + int64_t numBags = grad.size(0); + + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + dim3 block = dim3(64, 4); + int grid = 1024; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "embedding_bag_backward_zoom_max", [&] { + AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_zoom_max", [&] () { + EmbeddingBag_accGradParametersKernel_max< + scalar_t, index_t><<>>( + max_indices.const_data_ptr(), grad.const_data_ptr(), + grad_weight.mutable_data_ptr(), stride, numBags, + padding_idx, grad_weight.numel()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + }); + + return grad_weight; +} +} + +// Assumes all input tensors are contiguous. +// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details +std::tuple +_embedding_bag_forward_only_zoom(const Tensor &weight, const Tensor &indices, + const Tensor &offsets, const bool scale_grad_by_freq, + const int64_t mode, bool sparse, const std::optional& per_sample_weights_opt, + bool include_last_offset, int64_t padding_idx) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); + const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; + + return _embedding_bag_zoom( + weight, + indices, + offsets, + scale_grad_by_freq, + mode, + sparse, + per_sample_weights, + include_last_offset, + padding_idx); +} + +// Assumes all input tensors are contiguous. +// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details +std::tuple +_embedding_bag_zoom(const Tensor &weight, const Tensor &indices_, + const Tensor &offsets_, const bool scale_grad_by_freq, + const int64_t mode, bool sparse, const std::optional& per_sample_weights_opt, + bool include_last_offset, int64_t padding_idx) { + TORCH_CHECK(indices_.dim() == 1 || indices_.dim() == 2, + "input has to be a 1D or 2D Tensor, but got Tensor of dimension ", + indices_.dim()); + if (indices_.dim() == 1) { + TORCH_CHECK(offsets_.dim() == 1, + "offsets has to be a 1D Tensor, but got Tensor of dimension ", + offsets_.dim()); + } + TORCH_CHECK(weight.dim() == 2, + "weight has to be a 2D Tensor, but got Tensor of dimension ", + weight.dim()); + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); + const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; + + Tensor indices, offsets; + std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); + auto indices_arg = TensorArg(indices, "indices", 1); + checkScalarTypes("embedding_bag_zoom", indices_arg, {kLong, kInt}); + auto offsets_arg = TensorArg(offsets, "offsets", 1); + checkScalarTypes("embedding_bag_zoom", offsets_arg, {kLong, kInt}); + checkSameType("embedding_bag_zoom", indices_arg, offsets_arg); + auto weight_arg = TensorArg(weight, "weight", 1); + checkSameGPU("embedding_bag_zoom", weight_arg, indices_arg); + checkSameGPU("embedding_bag_zoom", weight_arg, offsets_arg); + + int64_t numIndices = indices.size(0); + int64_t numBags = offsets.size(0); + if (include_last_offset) { + // Check https://github.com/pytorch/pytorch/issues/29019 + // We plan to add one more element in offsets, which is equal to the size of + // indices. Currently for cuda devices, we still use the legacy + // implementation even this flag is enabled. + TORCH_CHECK( + numBags >= 1, "include_last_offset: numBags should be at least 1"); + numBags -= 1; + } + int64_t featureSize = weight.size(1); + + auto bag_size = at::empty(offsets.sizes(), indices.options()); + auto offset2bag = + at::empty({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0] + + hipStream_t stream = c10::zoom::getCurrentZoomStream(); + + auto output = at::empty({numBags, featureSize}, weight.options()); + + Tensor max_indices; + + if (mode == MODE_MAX) { + max_indices = at::empty({numBags, featureSize}, indices.options()); + } else { + // No need to allocate if we aren't doing a backwards pass + max_indices = at::empty({0}, indices.options()); + } + + dim3 block = dim3(64, 4); + + int grid = 1024; + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_zoom", [&] { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_zoom", [&] () { + if (mode == MODE_MAX) { + EmbeddingBag_updateOutputKernel_max<<>>( + indices.const_data_ptr(), offsets.const_data_ptr(), + weight.const_data_ptr(), output.mutable_data_ptr(), + offset2bag.mutable_data_ptr(), numIndices, numBags, featureSize, + weight.stride(0), weight.stride(1), bag_size.mutable_data_ptr(), + max_indices.mutable_data_ptr(), + padding_idx, weight.size(0)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } else { + EmbeddingBag_updateOutputKernel_sum_mean<<>>( + indices.const_data_ptr(), offsets.const_data_ptr(), + weight.const_data_ptr(), output.mutable_data_ptr(), + offset2bag.mutable_data_ptr(), numIndices, numBags, featureSize, + weight.stride(0), weight.stride(1), mode, bag_size.mutable_data_ptr(), + per_sample_weights.defined() ? per_sample_weights.const_data_ptr() : NULL, + per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, + padding_idx, weight.size(0)); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + } + }); + }); + + return std::tuple(output, offset2bag, bag_size, max_indices); +} + +Tensor _embedding_bag_dense_backward_zoom(const Tensor &grad_, const Tensor &indices, + const Tensor &offset2bag, + const Tensor &bag_size_, + const Tensor &max_indices, + int64_t num_weights, + bool scale_grad_by_freq, int64_t mode, const std::optional& per_sample_weights_opt, + int64_t padding_idx) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); + const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; + + // indices, offsets and offset2bag are assumed having correct dtypes and + // contiguous here due to the checks in _embedding_bag_backward in + // EmbeddingBag.cpp. + // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml + // for more details. + + Tensor grad = grad_.contiguous(); + auto indices_arg = TensorArg(indices, "indices", 1); + auto grad_arg = TensorArg(grad, "grad", 1); + checkSameGPU("embedding_bag_zoom", grad_arg, indices_arg); + + + switch (mode) { + case MODE_SUM: + case MODE_MEAN: + if (mode == MODE_MEAN) + AT_ASSERT(!per_sample_weights.defined()); + return embedding_bag_backward_zoom_sum_avg(grad, indices, offset2bag, + bag_size_, num_weights, scale_grad_by_freq, mode, + per_sample_weights, padding_idx); + + case MODE_MAX: + AT_ASSERT(!per_sample_weights.defined()); + return embedding_bag_backward_zoom_max(grad, max_indices, num_weights, + padding_idx); + + default: + AT_ERROR( + "Unknown mode for embedding_bag_backward_zoom ", mode); + } +} + +template +__global__ static void _embedding_bag_per_sample_weights_backward_kernel( + const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1, + const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1, + const index_t* indices, // contiguous + const index_t* offset2bag, // contiguous + int64_t num_samples, + int64_t embedding_features, + scalar_t* output, + index_t padding_idx) { + using accscalar_t = acc_type; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + const int warp = idx / C10_WARP_SIZE; + const int thread_in_warp = idx % C10_WARP_SIZE; + const int num_warps = blockDim.x * gridDim.x / C10_WARP_SIZE; + + // Each warp is responsible for the accumulation of one sample. + // This involves doing one dot product between grad[bag_idx] and weight[embedding_idx]. + for (int sample_idx = warp; sample_idx < num_samples; sample_idx += num_warps) { + accscalar_t result = 0.; + const int bag_idx = (int)offset2bag[sample_idx]; + const int embedding_idx = (int)indices[sample_idx]; + if (embedding_idx != padding_idx) { + for (int feature_idx = thread_in_warp; feature_idx < embedding_features; + feature_idx += C10_WARP_SIZE) { + result += + grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] * + weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx]; + } + } + result = zoom_utils::WarpReduceSum(result); + if (thread_in_warp == 0) { + output[sample_idx] = result; + } + } +} + +Tensor _embedding_bag_per_sample_weights_backward_zoom( + const Tensor& grad, + const Tensor& weight, // NB: embedding table, not per_sample_weights + const Tensor& indices_, + const Tensor& offsets_, + const Tensor& offset2bag, + int64_t mode, + int64_t padding_idx) { + TORCH_CHECK( + mode == MODE_SUM, + "embedding_bag_backward: per_sample_weights only supported for mode='sum'"); + + AT_ASSERT(grad.dim() == 2); + auto embedding_features = grad.size(1); + + Tensor indices, offsets; + std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); + AT_ASSERT(indices.dim() == 1); + auto num_samples = indices.size(0); + + AT_ASSERT(weight.dim() == 2); + AT_ASSERT(weight.size(1) == embedding_features); + + const int threads_per_block = 512; + const int warps_per_block = threads_per_block / at::zoom::warp_size(); + + dim3 block(threads_per_block); + dim3 grid((num_samples + warps_per_block - 1) / warps_per_block); + + auto output = at::empty({num_samples}, grad.options()); + + // Early return when there is no samples in the batch. This saves unnecessary kernel + // launch, but also prevents hipGetLastError() to complain about invalid launch args + if (num_samples == 0) { + return output; + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_zoom", [&]() { + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_zoom", [&]() { + _embedding_bag_per_sample_weights_backward_kernel + <<>>( + grad.const_data_ptr(), grad.stride(0), grad.stride(1), + weight.const_data_ptr(), weight.stride(0), weight.stride(1), + indices.const_data_ptr(), + offset2bag.const_data_ptr(), + num_samples, + embedding_features, + output.mutable_data_ptr(), + padding_idx); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); + } + ); + return output; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/Equal.cpp b/aten/src/ATen/native/zoom/Equal.cpp new file mode 100644 index 0000000000000..00f6acf51d0b6 --- /dev/null +++ b/aten/src/ATen/native/zoom/Equal.cpp @@ -0,0 +1,49 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#else +#include +#include +#endif + +namespace at::native { + +bool zoom_equal(const Tensor& self, const Tensor &src) { + if (!at::namedinference::are_names_equal( + self.unsafeGetTensorImpl(), src.unsafeGetTensorImpl())) { + return false; + } + at::NoNamesGuard guard; + TORCH_CHECK(self.device() == src.device(), "Cannot compare two tensors on " + "different devices. Got: ", self.device(), " and ", src.device()); + if (self.sizes() != src.sizes()) { + return false; + } + if (self.numel() == 0) { + return true; + } + + // This is the same optimization done in the cpu_equal. Since the flags like neg/conj should be already handled outside the + // cuda_equal, it should be safe to have the following fast path by + // ensuring the storage and strides exactly the same. + if (self.is_alias_of(src) + && self.storage_offset() == src.storage_offset() + && self.dtype() == src.dtype() + && self.is_contiguous() == src.is_contiguous() + && self.strides().equals(src.strides()) + // Extra checks to ensure the safety in case cuda_equal is directly called in C++. + && self.layout() == src.layout() + && self.is_neg() == src.is_neg() + && self.is_conj() == src.is_conj()) { + return true; + } + + return at::eq(self, src).all().item().to(); +} + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/FillKernel.cu b/aten/src/ATen/native/zoom/FillKernel.cu new file mode 100644 index 0000000000000..24c0a00c54726 --- /dev/null +++ b/aten/src/ATen/native/zoom/FillKernel.cu @@ -0,0 +1,30 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +template +struct FillFunctor { + FillFunctor(scalar_t v): value(v) {} + __device__ __forceinline__ scalar_t operator() () const { + return value; + } + private: + scalar_t value; +}; + +void fill_kernel_zoom(TensorIterator& iter, const Scalar& value) { + AT_DISPATCH_V2(iter.dtype(), "fill_zoom", AT_WRAP([&]() { + gpu_kernel(iter, FillFunctor(value.to())); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); +} + +REGISTER_PRIVATEUSE1_DISPATCH(fill_stub, &fill_kernel_zoom); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/FlattenIndicesKernel.cu b/aten/src/ATen/native/zoom/FlattenIndicesKernel.cu new file mode 100644 index 0000000000000..65bcb764d538c --- /dev/null +++ b/aten/src/ATen/native/zoom/FlattenIndicesKernel.cu @@ -0,0 +1,28 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +template +struct HIPKernelLauncher { + static void launch(TensorIteratorBase& iter, const func_t& f) { + gpu_kernel(iter, f); + } +}; + +Tensor flatten_indices_zoom_kernel(const Tensor& indices, IntArrayRef size) { + return _flatten_indices(indices, size); +} + +} + +REGISTER_PRIVATEUSE1_DISPATCH(flatten_indices_stub, &flatten_indices_zoom_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/ForeachBinaryOpList.cu b/aten/src/ATen/native/zoom/ForeachBinaryOpList.cu new file mode 100644 index 0000000000000..02e2c4d4fe942 --- /dev/null +++ b/aten/src/ATen/native/zoom/ForeachBinaryOpList.cu @@ -0,0 +1,295 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#endif + +namespace at::native { + +template class Op> +std::vector foreach_tensor_list_op( + TensorList tensors1, + TensorList tensors2, + const Scalar& alpha = 1) { + std::vector> tensor_lists; + std::vector vec_res; + vec_res.reserve(tensors1.size()); + for (const auto& t : tensors1) { + vec_res.emplace_back(at::native::empty_like(t)); + } + + tensor_lists.emplace_back(tensors1.vec()); + tensor_lists.emplace_back(tensors2.vec()); + tensor_lists.emplace_back(std::move(vec_res)); + + using opmath_t = at::opmath_type; + multi_tensor_apply<3>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 3, + /* r_args_depth */ 2, + /* res_arg_index */ 2>(), + Op(), + alpha.to()); + + return tensor_lists[2]; +} + +template class Op> +void foreach_tensor_list_op_( + TensorList tensors1, + TensorList tensors2, + const Scalar& alpha = 1) { + std::vector> tensor_lists; + tensor_lists.emplace_back(tensors1.vec()); + tensor_lists.emplace_back(tensors2.vec()); + + using opmath_t = at::opmath_type; + multi_tensor_apply<2>( + tensor_lists, + BinaryOpListAlphaFunctor< + T, + /* depth */ 2, + /* r_args_depth */ 2, + /* res_arg_index */ 0>(), + Op(), + alpha.to()); + increment_version(tensors1); +} + +template