Skip to content

Commit d946c3b

Browse files
committed
Update on "[slimtensor] Add common_shims_slim with basic property getters"
Add SlimTensor-based implementations of basic property getter AOTI shim functions: 1. `aoti_torch_get_data_ptr()` - Returns pointer to tensor data 2. `aoti_torch_get_sizes()` - Returns pointer to sizes array (SlimTensor stores int64_t directly) 3. `aoti_torch_get_strides()` - Returns pointer to strides array (SlimTensor stores int64_t directly) 4. `aoti_torch_get_dtype()` - Returns the scalar type as int32_t 5. `aoti_torch_get_dim()` - Returns the number of dimensions Key design: - Create a new common_shim_slim.h for working on new API while not impact the current pipeline. Will use common_shim_slim.{h/cpp} to replace current common_shim.{h/cpp} when everything has been set up. - Uses `#ifdef CUDA_AVAILABLE` conditional compilation to seperate the implementation between cuda backend and mps backend since SlimTensor hasn't have mps support yet. Will remove the branch once SlimTensor support mps. - Refactored to a header-only library so the caller's preprocessor flags determine which tensor type is used. This design supports both CUDA backend (SlimTensor) and MPS backend (ETensor) from a single library. Differential Revision: [D90126254](https://our.internmc.facebook.com/intern/diff/D90126254/) [ghstack-poisoned]
2 parents 895b86f + ecd3787 commit d946c3b

31 files changed

+630
-385
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/common_shims_slim.h>
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace aoti {
14+
15+
// ============================================================
16+
// Basic Property Getters - Implementations
17+
// ============================================================
18+
19+
AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
20+
if (tensor == nullptr || ret_data_ptr == nullptr) {
21+
return Error::InvalidArgument;
22+
}
23+
*ret_data_ptr = tensor->data_ptr();
24+
return Error::Ok;
25+
}
26+
27+
AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
28+
if (tensor == nullptr || ret_sizes == nullptr) {
29+
return Error::InvalidArgument;
30+
}
31+
*ret_sizes = const_cast<int64_t*>(tensor->sizes().data());
32+
return Error::Ok;
33+
}
34+
35+
AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
36+
if (tensor == nullptr || ret_strides == nullptr) {
37+
return Error::InvalidArgument;
38+
}
39+
*ret_strides = const_cast<int64_t*>(tensor->strides().data());
40+
return Error::Ok;
41+
}
42+
43+
AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
44+
if (tensor == nullptr || ret_dtype == nullptr) {
45+
return Error::InvalidArgument;
46+
}
47+
*ret_dtype = static_cast<int32_t>(tensor->dtype());
48+
return Error::Ok;
49+
}
50+
51+
AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
52+
if (tensor == nullptr || ret_dim == nullptr) {
53+
return Error::InvalidArgument;
54+
}
55+
*ret_dim = static_cast<int64_t>(tensor->dim());
56+
return Error::Ok;
57+
}
58+
59+
} // namespace aoti
60+
} // namespace backends
61+
} // namespace executorch

backends/aoti/common_shims_slim.h

Lines changed: 13 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,9 @@
99
#pragma once
1010

1111
#include <executorch/backends/aoti/export.h>
12+
#include <executorch/backends/aoti/slim/core/SlimTensor.h>
1213
#include <executorch/runtime/core/error.h>
1314
#include <cstdint>
14-
#include <unordered_map>
15-
#include <vector>
16-
17-
// Uses conditional compilation to separate the implementation between
18-
// CUDA backend (SlimTensor) and other backends like MPS (ETensor).
19-
// The caller determines which path is used by defining CUDA_AVAILABLE.
20-
#ifdef CUDA_AVAILABLE
21-
#include <executorch/backends/aoti/slim/core/SlimTensor.h>
22-
#else
23-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
24-
#endif
2515

2616
namespace executorch {
2717
namespace backends {
@@ -30,185 +20,31 @@ namespace aoti {
3020
// Common using declarations for ExecuTorch types
3121
using executorch::runtime::Error;
3222

33-
// ============================================================
34-
// Tensor Type Definition - branched based on CUDA_AVAILABLE
35-
// ============================================================
36-
#ifdef CUDA_AVAILABLE
23+
// Tensor type definition using SlimTensor
3724
using Tensor = executorch::backends::aoti::slim::SlimTensor;
38-
#else
39-
using Tensor = executorch::runtime::etensor::Tensor;
40-
#endif
4125

4226
// Common AOTI type aliases
4327
using AOTIRuntimeError = Error;
4428
using AOTITorchError = Error;
4529

46-
#ifndef CUDA_AVAILABLE
47-
namespace internal {
48-
// Global storage for tensor metadata (ETensor path only)
49-
// SlimTensor stores sizes/strides directly in int64_t[] - no caching needed
50-
inline std::unordered_map<Tensor*, std::vector<int64_t>>& tensor_to_sizes() {
51-
static std::unordered_map<Tensor*, std::vector<int64_t>> instance;
52-
return instance;
53-
}
54-
inline std::unordered_map<Tensor*, std::vector<int64_t>>& tensor_to_strides() {
55-
static std::unordered_map<Tensor*, std::vector<int64_t>> instance;
56-
return instance;
57-
}
58-
} // namespace internal
59-
#endif
60-
6130
// ============================================================
62-
// Basic Property Getters - Inline implementations
31+
// Basic Property Getters - Declarations
6332
// ============================================================
6433

65-
inline AOTITorchError aoti_torch_get_data_ptr(
66-
Tensor* tensor,
67-
void** ret_data_ptr) {
68-
if (tensor == nullptr) {
69-
return Error::InvalidArgument;
70-
}
71-
if (ret_data_ptr == nullptr) {
72-
return Error::InvalidArgument;
73-
}
74-
75-
#ifdef CUDA_AVAILABLE
76-
*ret_data_ptr = tensor->data_ptr();
77-
#else
78-
*ret_data_ptr = tensor->mutable_data_ptr();
79-
#endif
80-
return Error::Ok;
81-
}
82-
83-
inline AOTITorchError aoti_torch_get_sizes(
84-
Tensor* tensor,
85-
int64_t** ret_sizes) {
86-
if (tensor == nullptr) {
87-
return Error::InvalidArgument;
88-
}
89-
if (ret_sizes == nullptr) {
90-
return Error::InvalidArgument;
91-
}
92-
93-
#ifdef CUDA_AVAILABLE
94-
// SlimTensor stores sizes directly in int64_t[] - no caching needed
95-
*ret_sizes = const_cast<int64_t*>(tensor->sizes().data());
96-
#else
97-
auto it = internal::tensor_to_sizes().find(tensor);
98-
bool needs_update = false;
99-
100-
if (it == internal::tensor_to_sizes().end()) {
101-
needs_update = true;
102-
} else {
103-
// Validate cached metadata matches current tensor state
104-
auto tensor_sizes = tensor->sizes();
105-
needs_update = !std::equal(
106-
it->second.begin(),
107-
it->second.end(),
108-
tensor_sizes.begin(),
109-
tensor_sizes.end());
110-
}
111-
112-
if (needs_update) {
113-
std::vector<int64_t> sizes(tensor->dim());
114-
auto tensor_sizes = tensor->sizes();
115-
for (int i = 0; i < tensor->dim(); i++) {
116-
sizes[i] = tensor_sizes[i];
117-
}
118-
it = internal::tensor_to_sizes()
119-
.insert_or_assign(tensor, std::move(sizes))
120-
.first;
121-
}
122-
123-
// For 0D tensors, data() returns nullptr on empty vectors
124-
if (it->second.empty()) {
125-
static int64_t empty_sizes_placeholder = 0;
126-
*ret_sizes = &empty_sizes_placeholder;
127-
} else {
128-
*ret_sizes = it->second.data();
129-
}
130-
#endif
131-
return Error::Ok;
132-
}
133-
134-
inline AOTITorchError aoti_torch_get_strides(
135-
Tensor* tensor,
136-
int64_t** ret_strides) {
137-
if (tensor == nullptr) {
138-
return Error::InvalidArgument;
139-
}
140-
if (ret_strides == nullptr) {
141-
return Error::InvalidArgument;
142-
}
143-
144-
#ifdef CUDA_AVAILABLE
145-
// SlimTensor stores strides directly in int64_t[] - no caching needed
146-
*ret_strides = const_cast<int64_t*>(tensor->strides().data());
147-
#else
148-
auto it = internal::tensor_to_strides().find(tensor);
149-
bool needs_update = false;
150-
151-
if (it == internal::tensor_to_strides().end()) {
152-
needs_update = true;
153-
} else {
154-
// Validate cached metadata matches current tensor state
155-
auto tensor_strides = tensor->strides();
156-
needs_update = !std::equal(
157-
it->second.begin(),
158-
it->second.end(),
159-
tensor_strides.begin(),
160-
tensor_strides.end());
161-
}
162-
163-
if (needs_update) {
164-
std::vector<int64_t> strides(tensor->dim());
165-
auto tensor_strides = tensor->strides();
166-
for (int i = 0; i < tensor->dim(); i++) {
167-
strides[i] = tensor_strides[i];
168-
}
169-
it = internal::tensor_to_strides()
170-
.insert_or_assign(tensor, std::move(strides))
171-
.first;
172-
}
173-
174-
// For 0D tensors, data() returns nullptr on empty vectors
175-
if (it->second.empty()) {
176-
static int64_t empty_strides_placeholder = 0;
177-
*ret_strides = &empty_strides_placeholder;
178-
} else {
179-
*ret_strides = it->second.data();
180-
}
181-
#endif
182-
return Error::Ok;
183-
}
34+
AOTI_SHIM_EXPORT AOTITorchError
35+
aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr);
18436

185-
inline AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
186-
if (tensor == nullptr) {
187-
return Error::InvalidArgument;
188-
}
189-
if (ret_dtype == nullptr) {
190-
return Error::InvalidArgument;
191-
}
37+
AOTI_SHIM_EXPORT AOTITorchError
38+
aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes);
19239

193-
#ifdef CUDA_AVAILABLE
194-
*ret_dtype = static_cast<int32_t>(tensor->dtype());
195-
#else
196-
*ret_dtype = static_cast<int32_t>(tensor->scalar_type());
197-
#endif
198-
return Error::Ok;
199-
}
40+
AOTI_SHIM_EXPORT AOTITorchError
41+
aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides);
20042

201-
inline AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
202-
if (tensor == nullptr) {
203-
return Error::InvalidArgument;
204-
}
205-
if (ret_dim == nullptr) {
206-
return Error::InvalidArgument;
207-
}
43+
AOTI_SHIM_EXPORT AOTITorchError
44+
aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype);
20845

209-
*ret_dim = static_cast<int64_t>(tensor->dim());
210-
return Error::Ok;
211-
}
46+
AOTI_SHIM_EXPORT AOTITorchError
47+
aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
21248

21349
} // namespace aoti
21450
} // namespace backends

backends/aoti/slim/c10/cuda/Exception.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919

2020
/// Checks a CUDA expression and aborts on error.
2121
/// @param EXPR The CUDA expression to check.
22+
#ifndef ET_CUDA_CHECK
2223
#define ET_CUDA_CHECK(EXPR) \
2324
do { \
2425
const cudaError_t __err = EXPR; \
2526
ET_CHECK_MSG( \
2627
__err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \
2728
} while (0)
29+
#endif
2830

2931
/// Checks a CUDA expression and logs a warning on error (non-fatal).
3032
/// @param EXPR The CUDA expression to check.

backends/aoti/targets.bzl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,20 @@ def define_common_targets():
8787
],
8888
)
8989

90-
# SlimTensor-based common shims (header-only library)
91-
# The caller determines which tensor type is used by defining CUDA_AVAILABLE.
92-
# - With CUDA_AVAILABLE=1: Uses SlimTensor
93-
# - Without CUDA_AVAILABLE: Uses ETensor
90+
# SlimTensor-based common shims library
91+
# Uses SlimTensor for all tensor operations
9492
runtime.cxx_library(
9593
name = "common_shims_slim",
94+
srcs = [
95+
"common_shims_slim.cpp",
96+
],
9697
headers = [
9798
"common_shims_slim.h",
9899
"export.h",
99100
],
100101
visibility = ["@EXECUTORCH_CLIENTS"],
101-
deps = [
102+
exported_deps = [
102103
"//executorch/runtime/core:core",
103-
"//executorch/runtime/core/exec_aten:lib",
104104
"//executorch/backends/aoti/slim/core:slimtensor",
105105
],
106106
)

backends/arm/_passes/decompose_add_sub_alpha_pass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -76,7 +76,11 @@ def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
7676
lhs, rhs = args
7777

7878
alpha_full = super().call_operator(
79-
full_op, ((1,), float(alpha)), {}, meta, updated=True
79+
full_op,
80+
((1,), float(alpha)),
81+
{"device": meta["val"].device},
82+
meta,
83+
updated=True,
8084
)
8185
scaled_rhs = super().call_operator(
8286
mul_op,

0 commit comments

Comments
 (0)