|
7 | 7 | */
|
8 | 8 |
|
9 | 9 | #include <executorch/kernels/portable/cpu/pattern/pattern.h>
|
| 10 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
10 | 11 | #include <executorch/runtime/kernel/kernel_includes.h>
|
11 | 12 | #include <cmath>
|
12 | 13 |
|
13 | 14 | namespace torch {
|
14 | 15 | namespace executor {
|
15 | 16 | namespace native {
|
16 | 17 |
|
| 18 | +// REVIEW: I'm not entirely sure what the best way to implement this |
| 19 | +// namespace is. Some options: |
| 20 | +// 1) All in one file, with or without an `IMPLEMENT_VECTORIZED_MATH_OP` macro. |
| 21 | +// 2) Include in each `unary_ufunc_*` op_foo.cpp, with or without an |
| 22 | +// `IMPLEMENT_VECTORIZED_MATH_OP` macro. |
| 23 | +// |
| 24 | +// I think my preferred option would be (2) with a macro, but I've |
| 25 | +// left the macro out for ease of reading this PoC PR. |
| 26 | +namespace math { |
| 27 | +using std::expm1; |
| 28 | +#ifdef ET_USE_PYTORCH_HEADERS |
| 29 | +template <typename T> |
| 30 | +auto expm1(at::vec::Vectorized<T> x) { |
| 31 | + // ATen knows to do this conversion because the TensorIterator for this op |
| 32 | + // (and lots of similar ones in aten/src/ATen/native/UnaryOps.cpp) is created |
| 33 | + // with build_borrowing_unary_float_op. |
| 34 | + if constexpr (!executorch::runtime::is_floating_point<T>::value) { |
| 35 | + return at::vec::convert<float>(x).expm1(); |
| 36 | + } else { |
| 37 | + return x.expm1(); |
| 38 | + } |
| 39 | +} |
| 40 | +#endif |
| 41 | +} // namespace math |
17 | 42 | Tensor& expm1_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
|
18 |
| - return internal::unary_ufunc_realhbbf16_to_floathbf16( |
19 |
| - std::expm1, ctx, in, out); |
| 43 | + ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); |
| 44 | + |
| 45 | + // Resize for dynamic shape |
| 46 | + ET_KERNEL_CHECK_MSG( |
| 47 | + ctx, |
| 48 | + resize_tensor(out, in.sizes()) == Error::Ok, |
| 49 | + InvalidArgument, |
| 50 | + out, |
| 51 | + "Failed to resize output tensor."); |
| 52 | + |
| 53 | + ET_KERNEL_CHECK( |
| 54 | + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); |
| 55 | + |
| 56 | + static constexpr const char op_name[] = "expm1.out"; |
| 57 | + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] { |
| 58 | + utils::apply_unitensor_elementwise_fn< |
| 59 | + CTYPE_IN, |
| 60 | + op_name, |
| 61 | + utils::SupportedTensorDtypes::FLOATHBF16>( |
| 62 | + [](auto x) { return math::expm1(x); }, |
| 63 | + ctx, |
| 64 | + in, |
| 65 | + utils::SupportedTensorDtypes::REALHBBF16, |
| 66 | + out); |
| 67 | + }); |
| 68 | + |
| 69 | + return out; |
20 | 70 | }
|
21 | 71 |
|
22 | 72 | } // namespace native
|
|
0 commit comments