Skip to content

Commit 726dcea

Browse files
[slimtensor] Add storage and device property getters to common_shims_slim (#16991)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #16455 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/97/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/97/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/96/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/97/orig Differential Revision: [D90126251](https://our.internmc.facebook.com/intern/diff/D90126251/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia <[email protected]> Co-authored-by: Gasoonjia <[email protected]>
1 parent 944a436 commit 726dcea

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

backends/aoti/common_shims_slim.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,48 @@ int32_t aoti_torch_layout_strided() {
6464
return 0;
6565
}
6666

67+
// ============================================================
68+
// Storage & Device Property Getters - Implementations
69+
// ============================================================
70+
71+
AOTITorchError aoti_torch_get_storage_offset(
72+
Tensor* tensor,
73+
int64_t* ret_storage_offset) {
74+
if (tensor == nullptr || ret_storage_offset == nullptr) {
75+
return Error::InvalidArgument;
76+
}
77+
*ret_storage_offset = tensor->storage_offset();
78+
return Error::Ok;
79+
}
80+
81+
AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) {
82+
if (tensor == nullptr || ret_size == nullptr) {
83+
return Error::InvalidArgument;
84+
}
85+
*ret_size = static_cast<int64_t>(tensor->storage()->nbytes());
86+
return Error::Ok;
87+
}
88+
89+
AOTITorchError aoti_torch_get_device_type(
90+
Tensor* tensor,
91+
int32_t* ret_device_type) {
92+
if (tensor == nullptr || ret_device_type == nullptr) {
93+
return Error::InvalidArgument;
94+
}
95+
*ret_device_type = static_cast<int32_t>(tensor->device_type());
96+
return Error::Ok;
97+
}
98+
99+
AOTITorchError aoti_torch_get_device_index(
100+
Tensor* tensor,
101+
int32_t* ret_device_index) {
102+
if (tensor == nullptr || ret_device_index == nullptr) {
103+
return Error::InvalidArgument;
104+
}
105+
*ret_device_index = static_cast<int32_t>(tensor->device_index());
106+
return Error::Ok;
107+
}
108+
67109
} // extern "C"
68110

69111
} // namespace aoti

backends/aoti/common_shims_slim.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
5050

5151
AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided();
5252

53+
// ============================================================
54+
// Storage & Device Property Getters - Declarations
55+
// ============================================================
56+
57+
AOTI_SHIM_EXPORT AOTITorchError
58+
aoti_torch_get_storage_offset(Tensor* tensor, int64_t* ret_storage_offset);
59+
60+
AOTI_SHIM_EXPORT AOTITorchError
61+
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size);
62+
63+
AOTI_SHIM_EXPORT AOTITorchError
64+
aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type);
65+
66+
AOTI_SHIM_EXPORT AOTITorchError
67+
aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index);
68+
5369
} // extern "C"
5470

5571
} // namespace aoti

backends/aoti/tests/test_common_shims_slim.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,93 @@ void runGetDimTest(slim_c10::DeviceType device_type) {
289289
}
290290
}
291291

292+
// ============================================================================
293+
// Storage & Device Property Tests
294+
// ============================================================================
295+
296+
void runGetStorageOffsetTest(slim_c10::DeviceType device_type) {
297+
std::vector<int64_t> sizes = {2, 3};
298+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
299+
slim_c10::Device device(device_type, 0);
300+
301+
Tensor* tensor = new Tensor(slim::empty_strided(
302+
slim::makeArrayRef(sizes),
303+
slim::makeArrayRef(strides),
304+
slim_c10::ScalarType::Float,
305+
device));
306+
307+
int64_t ret_storage_offset = -1;
308+
AOTITorchError error =
309+
aoti_torch_get_storage_offset(tensor, &ret_storage_offset);
310+
311+
EXPECT_EQ(error, Error::Ok);
312+
// Default storage offset for newly created tensor is 0
313+
EXPECT_EQ(ret_storage_offset, 0);
314+
315+
delete tensor;
316+
}
317+
318+
void runGetStorageSizeTest(slim_c10::DeviceType device_type) {
319+
std::vector<int64_t> sizes = {2, 3};
320+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
321+
slim_c10::Device device(device_type, 0);
322+
323+
Tensor* tensor = new Tensor(slim::empty_strided(
324+
slim::makeArrayRef(sizes),
325+
slim::makeArrayRef(strides),
326+
slim_c10::ScalarType::Float,
327+
device));
328+
329+
int64_t ret_size = -1;
330+
AOTITorchError error = aoti_torch_get_storage_size(tensor, &ret_size);
331+
332+
EXPECT_EQ(error, Error::Ok);
333+
// 2 * 3 * sizeof(float) = 6 * 4 = 24 bytes
334+
EXPECT_EQ(ret_size, 24);
335+
336+
delete tensor;
337+
}
338+
339+
void runGetDeviceTypeTest(slim_c10::DeviceType device_type) {
340+
std::vector<int64_t> sizes = {2, 3};
341+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
342+
slim_c10::Device device(device_type, 0);
343+
344+
Tensor* tensor = new Tensor(slim::empty_strided(
345+
slim::makeArrayRef(sizes),
346+
slim::makeArrayRef(strides),
347+
slim_c10::ScalarType::Float,
348+
device));
349+
350+
int32_t ret_device_type = -1;
351+
AOTITorchError error = aoti_torch_get_device_type(tensor, &ret_device_type);
352+
353+
EXPECT_EQ(error, Error::Ok);
354+
EXPECT_EQ(ret_device_type, static_cast<int32_t>(device_type));
355+
356+
delete tensor;
357+
}
358+
359+
void runGetDeviceIndexTest(slim_c10::DeviceType device_type) {
360+
std::vector<int64_t> sizes = {2, 3};
361+
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
362+
slim_c10::Device device(device_type, 0);
363+
364+
Tensor* tensor = new Tensor(slim::empty_strided(
365+
slim::makeArrayRef(sizes),
366+
slim::makeArrayRef(strides),
367+
slim_c10::ScalarType::Float,
368+
device));
369+
370+
int32_t ret_device_index = -1;
371+
AOTITorchError error = aoti_torch_get_device_index(tensor, &ret_device_index);
372+
373+
EXPECT_EQ(error, Error::Ok);
374+
EXPECT_EQ(ret_device_index, 0);
375+
376+
delete tensor;
377+
}
378+
292379
// ============================================================================
293380
// CPU Tests
294381
// ============================================================================
@@ -313,6 +400,22 @@ TEST_F(CommonShimsSlimTest, GetDim_CPU) {
313400
runGetDimTest(slim_c10::DeviceType::CPU);
314401
}
315402

403+
TEST_F(CommonShimsSlimTest, GetStorageOffset_CPU) {
404+
runGetStorageOffsetTest(slim_c10::DeviceType::CPU);
405+
}
406+
407+
TEST_F(CommonShimsSlimTest, GetStorageSize_CPU) {
408+
runGetStorageSizeTest(slim_c10::DeviceType::CPU);
409+
}
410+
411+
TEST_F(CommonShimsSlimTest, GetDeviceType_CPU) {
412+
runGetDeviceTypeTest(slim_c10::DeviceType::CPU);
413+
}
414+
415+
TEST_F(CommonShimsSlimTest, GetDeviceIndex_CPU) {
416+
runGetDeviceIndexTest(slim_c10::DeviceType::CPU);
417+
}
418+
316419
// ============================================================================
317420
// CUDA Tests
318421
// ============================================================================
@@ -352,6 +455,34 @@ TEST_F(CommonShimsSlimTest, GetDim_CUDA) {
352455
}
353456
runGetDimTest(slim_c10::DeviceType::CUDA);
354457
}
458+
459+
TEST_F(CommonShimsSlimTest, GetStorageOffset_CUDA) {
460+
if (!isCudaAvailable()) {
461+
GTEST_SKIP() << "CUDA not available";
462+
}
463+
runGetStorageOffsetTest(slim_c10::DeviceType::CUDA);
464+
}
465+
466+
TEST_F(CommonShimsSlimTest, GetStorageSize_CUDA) {
467+
if (!isCudaAvailable()) {
468+
GTEST_SKIP() << "CUDA not available";
469+
}
470+
runGetStorageSizeTest(slim_c10::DeviceType::CUDA);
471+
}
472+
473+
TEST_F(CommonShimsSlimTest, GetDeviceType_CUDA) {
474+
if (!isCudaAvailable()) {
475+
GTEST_SKIP() << "CUDA not available";
476+
}
477+
runGetDeviceTypeTest(slim_c10::DeviceType::CUDA);
478+
}
479+
480+
TEST_F(CommonShimsSlimTest, GetDeviceIndex_CUDA) {
481+
if (!isCudaAvailable()) {
482+
GTEST_SKIP() << "CUDA not available";
483+
}
484+
runGetDeviceIndexTest(slim_c10::DeviceType::CUDA);
485+
}
355486
#endif
356487

357488
// ============================================================================

0 commit comments

Comments
 (0)