@@ -60,8 +60,43 @@ using op_call_result =
6060
6161template <
6262 typename CTYPE_COMMON,
63+ typename CTYPE_OUT,
6364 typename Op,
64- typename ... Args>
65+ typename ... Args>
66+ inline void dtype_specialized_elementwise_fn_impl (
67+ const Op& compute_fun,
68+ KernelRuntimeContext& ctx,
69+ const Tensor& out,
70+ Args... inputs) {
71+ constexpr auto kNumInputs = sizeof ...(inputs);
72+ ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
73+
74+ std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75+ inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76+
77+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78+
79+ ::executorch::extension::parallel_for (
80+ 0 ,
81+ out.numel(),
82+ ::executorch::extension::internal::GRAIN_SIZE,
83+ [&](const auto begin, const auto end) {
84+ const auto range =
85+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
86+ auto begin_it = range.begin ();
87+ begin_it += begin;
88+ for (; (*begin_it)[0 ] < end; ++begin_it) {
89+ const auto & indexes = *begin_it;
90+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
91+ for (const auto idx : c10::irange (kNumInputs )) {
92+ loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1 ]];
93+ }
94+ data_out[indexes[0 ]] = std::apply (compute_fun, loaded_inputs);
95+ }
96+ });
97+ }
98+
99+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
65100inline bool validate_elementwise_fn_inputs (
66101 const Op& compute_fun,
67102 KernelRuntimeContext& ctx,
@@ -80,7 +115,8 @@ inline bool validate_elementwise_fn_inputs(
80115 ctx,
81116 (check_input_dtype (inputs, compute_type) && ...) &&
82117 internal::check_tensor_dtype (out, out_dtypes, compute_type),
83- InvalidArgument, false );
118+ InvalidArgument,
119+ false );
84120
85121 return true ;
86122}
@@ -90,22 +126,12 @@ template <
90126 const char * op_name,
91127 typename Op,
92128 typename ... Args>
93- inline void apply_elementwise_fn (
129+ inline void apply_elementwise_fn_generic_impl (
94130 const Op& compute_fun,
95131 KernelRuntimeContext& ctx,
96132 const Tensor& out,
97133 SupportedTensorDtypes out_dtypes,
98134 Args... inputs) {
99- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100- compute_fun,
101- ctx,
102- out,
103- out_dtypes,
104- inputs...);
105- if (!inputs_valid) {
106- return ;
107- }
108-
109135 constexpr auto kNumInputs = sizeof ...(inputs);
110136
111137 struct InputInfo {
@@ -157,6 +183,63 @@ inline void apply_elementwise_fn(
157183 }
158184 });
159185}
186+
187+ template <
188+ typename CTYPE_COMMON,
189+ const char * op_name,
190+ typename Op,
191+ typename ... Args>
192+ inline void apply_elementwise_fn_runtime_out_dtypes (
193+ const Op& compute_fun,
194+ KernelRuntimeContext& ctx,
195+ const Tensor& out,
196+ SupportedTensorDtypes out_dtypes,
197+ Args... inputs) {
198+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
199+ compute_fun, ctx, out, out_dtypes, inputs...);
200+ if (!inputs_valid) {
201+ return ;
202+ }
203+
204+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
205+ compute_fun, ctx, out, out_dtypes, inputs...);
206+ }
207+
208+ template <
209+ typename CTYPE_COMMON,
210+ const char * op_name,
211+ SupportedTensorDtypes out_dtypes,
212+ typename Op,
213+ typename ... Args>
214+ inline void apply_elementwise_fn (
215+ const Op& compute_fun,
216+ KernelRuntimeContext& ctx,
217+ const Tensor& out,
218+ Args... inputs) {
219+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
220+ compute_fun, ctx, out, out_dtypes, inputs...);
221+ if (!inputs_valid) {
222+ return ;
223+ }
224+
225+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
226+ const bool all_inputs_compute_dtype =
227+ ((inputs.first ->scalar_type () == compute_type) && ...);
228+
229+ constexpr ScalarType out_specialized_scalar_type =
230+ specialized_output_scalar_type<CTYPE_COMMON>(out_dtypes);
231+ if (all_inputs_compute_dtype &&
232+ out.scalar_type () == out_specialized_scalar_type) {
233+ using CTYPE_OUT =
234+ typename ScalarTypeToCppType<out_specialized_scalar_type>::type;
235+ dtype_specialized_elementwise_fn_impl<CTYPE_COMMON, CTYPE_OUT>(
236+ compute_fun, ctx, out, inputs...);
237+ return ;
238+ }
239+
240+ apply_elementwise_fn_generic_impl<CTYPE_COMMON, op_name>(
241+ compute_fun, ctx, out, out_dtypes, inputs...);
242+ }
160243} // namespace internal
161244
162245// / DEPRECATED: prefer the variant with out_dtypes in the template argument.
@@ -168,19 +251,23 @@ inline void apply_unitensor_elementwise_fn(
168251 SupportedTensorDtypes a_dtypes,
169252 const Tensor& out,
170253 SupportedTensorDtypes out_dtypes) {
171- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
254+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
172255 compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
173256}
174257
175- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
258+ template <
259+ typename CTYPE_COMMON,
260+ const char * op_name,
261+ SupportedTensorDtypes out_dtypes,
262+ typename Op>
176263inline void apply_unitensor_elementwise_fn (
177264 const Op& compute_fun,
178265 KernelRuntimeContext& ctx,
179266 const Tensor& a,
180267 SupportedTensorDtypes a_dtypes,
181268 const Tensor& out) {
182- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
183- compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
269+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
270+ compute_fun, ctx, out, std::make_pair (&a, a_dtypes));
184271}
185272
186273/* *
@@ -196,7 +283,7 @@ inline void apply_bitensor_elementwise_fn(
196283 SupportedTensorDtypes b_dtypes,
197284 const Tensor& out,
198285 SupportedTensorDtypes out_dtypes) {
199- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
286+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
200287 compute_fun,
201288 ctx,
202289 out,
@@ -210,7 +297,11 @@ inline void apply_bitensor_elementwise_fn(
210297 * perform a computation and write to the corresponding element of the output.
211298 * Tensor broadcasting is applied wherever it is required.
212299 */
213- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
300+ template <
301+ typename CTYPE_COMMON,
302+ const char * op_name,
303+ SupportedTensorDtypes out_dtypes,
304+ typename Op>
214305inline void apply_bitensor_elementwise_fn (
215306 const Op& compute_fun,
216307 KernelRuntimeContext& ctx,
@@ -219,11 +310,10 @@ inline void apply_bitensor_elementwise_fn(
219310 const Tensor& b,
220311 SupportedTensorDtypes b_dtypes,
221312 const Tensor& out) {
222- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
313+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
223314 compute_fun,
224315 ctx,
225316 out,
226- out_dtypes,
227317 std::make_pair (&a, a_dtypes),
228318 std::make_pair (&b, b_dtypes));
229319}
@@ -243,7 +333,7 @@ inline void apply_tritensor_elementwise_fn(
243333 SupportedTensorDtypes c_dtypes,
244334 const Tensor& out,
245335 SupportedTensorDtypes out_dtypes) {
246- internal::apply_elementwise_fn <CTYPE_COMMON, op_name>(
336+ internal::apply_elementwise_fn_runtime_out_dtypes <CTYPE_COMMON, op_name>(
247337 compute_fun,
248338 ctx,
249339 out,
@@ -273,7 +363,11 @@ inline void apply_tritensor_elementwise_fn(
273363 * static constexpr const char op_name[] = "my_op";
274364 * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
275365 */
276- template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
366+ template <
367+ typename CTYPE_COMMON,
368+ const char * op_name,
369+ SupportedTensorDtypes out_dtypes,
370+ typename Op>
277371inline void apply_tritensor_elementwise_fn (
278372 const Op& compute_fun,
279373 KernelRuntimeContext& ctx,
@@ -284,11 +378,10 @@ inline void apply_tritensor_elementwise_fn(
284378 const Tensor& c,
285379 SupportedTensorDtypes c_dtypes,
286380 const Tensor& out) {
287- internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
381+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name, out_dtypes >(
288382 compute_fun,
289383 ctx,
290384 out,
291- out_dtypes,
292385 std::make_pair (&a, a_dtypes),
293386 std::make_pair (&b, b_dtypes),
294387 std::make_pair (&c, c_dtypes));
0 commit comments