Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions backends/aoti/slim/core/SlimTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,33 @@ class SlimTensor {
set_sizes_and_strides(sizes, makeArrayRef(contig_strides));
}

/**
* Returns a copy of this tensor.
*
* @return A new SlimTensor with same content.
*/
SlimTensor clone() const {
return _clone_impl(
this->sizes(), this->strides(), this->dtype(), this->device());
}

/**
* Returns a contiguous copy of this tensor.
* If the tensor is already contiguous, returns a copy with independent
* storage.
*
* @return A new contiguous SlimTensor.
*/
SlimTensor clone_contiguous() const {
std::vector<int64_t> contig_strides =
compute_contiguous_strides(this->sizes());
return _clone_impl(
this->sizes(),
makeArrayRef(contig_strides),
this->dtype(),
this->device());
}

// =========================================================================
// View Operations
// =========================================================================
Expand Down Expand Up @@ -364,6 +391,39 @@ class SlimTensor {
makeArrayRef(sizes), makeArrayRef(strides), storage_offset);
}

/**
* Returns a new tensor with dimensions permuted according to dims.
* The returned tensor shares the same underlying storage.
*
* @param dims The permutation of dimensions.
* @return A new SlimTensor with permuted dimensions.
*/
inline SlimTensor permute(IntArrayRef dims) const;

/**
* Overload for initializer lists.
*/
inline SlimTensor permute(std::initializer_list<int64_t> dims) const {
return permute(makeArrayRef(dims));
}

/**
* Returns a tensor with the same data and number of elements as this tensor,
* but with the specified shape. If possible, returns a view; otherwise
* creates a contiguous copy.
*
* @param shape The target shape (may contain one -1 for inference).
* @return A new SlimTensor with the specified shape.
*/
inline SlimTensor reshape(IntArrayRef shape) const;

/**
* Overload for initializer lists.
*/
inline SlimTensor reshape(std::initializer_list<int64_t> shape) const {
return reshape(makeArrayRef(shape));
}

// =========================================================================
// Copy Operation
// =========================================================================
Expand Down Expand Up @@ -445,6 +505,18 @@ class SlimTensor {
}

private:
SlimTensor _clone_impl(
c10::IntArrayRef sizes,
c10::IntArrayRef strides,
c10::ScalarType dtype,
const c10::Device& device) const {
Storage storage = new_storage(sizes, strides, dtype, device);
SlimTensor result =
SlimTensor(std::move(storage), sizes, strides, dtype, 0);
result.copy_(*this);
return result;
}

void refresh_numel() {
numel_ = compute_numel(sizes_and_strides_.sizes_arrayref());
}
Expand Down
58 changes: 58 additions & 0 deletions backends/aoti/slim/core/SlimTensorView-incl.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,62 @@ inline SlimTensor& SlimTensor::as_strided_(
return *this;
}

inline SlimTensor SlimTensor::permute(IntArrayRef dims) const {
const size_t ndim = this->dim();
ET_CHECK_MSG(
ndim == dims.size(),
"permute: dims length (%zu) must equal tensor.dim() (%zu)",
dims.size(),
ndim);

IntArrayRef old_sizes = this->sizes();
IntArrayRef old_strides = this->strides();
std::vector<int64_t> new_sizes(ndim);
std::vector<int64_t> new_strides(ndim);
std::vector<bool> seen_dims(ndim, false);

for (size_t i = 0; i < ndim; i++) {
int64_t d = c10::maybe_wrap_dim(dims[i], ndim);
ET_CHECK_MSG(!seen_dims[d], "permute: duplicate dims are not allowed");
seen_dims[d] = true;
new_sizes[i] = old_sizes[d];
new_strides[i] = old_strides[d];
}

SlimTensor result = *this;
result.as_strided_(
makeArrayRef(new_sizes),
makeArrayRef(new_strides),
this->storage_offset());
return result;
}

inline SlimTensor SlimTensor::reshape(IntArrayRef proposed_shape) const {
std::vector<int64_t> final_shape_vec =
infer_size(proposed_shape, static_cast<int64_t>(this->numel()));

// compute_stride returns the proper strides to use if this
// reshape can be just a view.
std::optional<std::vector<int64_t>> new_strides_opt = compute_stride(
this->sizes(), this->strides(), makeArrayRef(final_shape_vec));

// Create a view if possible
if (new_strides_opt.has_value()) {
SlimTensor result = *this;
result.as_strided_(
makeArrayRef(final_shape_vec),
makeArrayRef(new_strides_opt.value()),
this->storage_offset());
return result;
}

// If a view is not possible, create a contiguous clone and reshape that
SlimTensor contiguous_clone = this->clone_contiguous();
// After cloning, the tensor is already contiguous. We just need to update
// its metadata to reflect the new shape. This is effectively a view of
// the new contiguous clone.
contiguous_clone.set_sizes_contiguous(makeArrayRef(final_shape_vec));
return contiguous_clone;
}

} // namespace executorch::backends::aoti::slim
23 changes: 13 additions & 10 deletions backends/aoti/slim/core/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@ def get_backend_mode():

def define_common_targets():
"""Define test targets for SlimTensor core module."""
runtime.cxx_test(
name = "test_slimtensor_dtypes",
srcs = [
"test_slimtensor_dtypes.cpp",
],
deps = [
"//executorch/backends/aoti/slim/factory:empty",
],
)

# Backend mode specific tests
for backend_mode in get_backend_mode():
backend_suffix = "_" + backend_mode if backend_mode == "cuda" else ""
Expand Down Expand Up @@ -77,3 +67,16 @@ def define_common_targets():
],
**backend_kwargs
)


runtime.cxx_test(
name = "test_permute_reshape" + backend_suffix,
srcs = [
"test_permute_reshape.cpp",
],
deps = [
"//executorch/backends/aoti/slim/core:slimtensor",
"//executorch/backends/aoti/slim/factory:empty",
],
**backend_kwargs
)
Loading
Loading