@@ -60,10 +60,9 @@ using op_call_result =
6060
6161template <
6262 typename CTYPE_COMMON,
63- const char * op_name,
6463 typename Op,
65- typename ... Args>
66- inline void apply_elementwise_fn (
64+ typename ... Args>
65+ inline bool validate_elementwise_fn_inputs (
6766 const Op& compute_fun,
6867 KernelRuntimeContext& ctx,
6968 const Tensor& out,
@@ -72,7 +71,6 @@ inline void apply_elementwise_fn(
7271 static_assert (
7372 (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
7473 ...));
75- constexpr auto kNumInputs = sizeof ...(inputs);
7674 constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
7775 const auto check_input_dtype = [](auto input, auto compute_type) {
7876 return internal::check_tensor_dtype (
@@ -82,7 +80,33 @@ inline void apply_elementwise_fn(
8280 ctx,
8381 (check_input_dtype (inputs, compute_type) && ...) &&
8482 internal::check_tensor_dtype (out, out_dtypes, compute_type),
85- InvalidArgument, );
83+ InvalidArgument, false );
84+
85+ return true ;
86+ }
87+
88+ template <
89+ typename CTYPE_COMMON,
90+ const char * op_name,
91+ typename Op,
92+ typename ... Args>
93+ inline void apply_elementwise_fn (
94+ const Op& compute_fun,
95+ KernelRuntimeContext& ctx,
96+ const Tensor& out,
97+ SupportedTensorDtypes out_dtypes,
98+ Args... inputs) {
99+ const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMMON>(
100+ compute_fun,
101+ ctx,
102+ out,
103+ out_dtypes,
104+ inputs...);
105+ if (!inputs_valid) {
106+ return ;
107+ }
108+
109+ constexpr auto kNumInputs = sizeof ...(inputs);
86110
87111 struct InputInfo {
88112 load_to_common_fn<CTYPE_COMMON> load_to_common;
@@ -135,6 +159,7 @@ inline void apply_elementwise_fn(
135159}
136160} // namespace internal
137161
162+ // / DEPRECATED: prefer the variant with out_dtypes in the template argument.
138163template <typename CTYPE_COMMON, const char * op_name, typename Op>
139164inline void apply_unitensor_elementwise_fn (
140165 const Op& compute_fun,
@@ -147,19 +172,75 @@ inline void apply_unitensor_elementwise_fn(
147172 compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
148173}
149174
175+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
176+ inline void apply_unitensor_elementwise_fn (
177+ const Op& compute_fun,
178+ KernelRuntimeContext& ctx,
179+ const Tensor& a,
180+ SupportedTensorDtypes a_dtypes,
181+ const Tensor& out) {
182+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
183+ compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
184+ }
185+
186+ /* *
187+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
188+ */
189+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
190+ inline void apply_bitensor_elementwise_fn (
191+ const Op& compute_fun,
192+ KernelRuntimeContext& ctx,
193+ const Tensor& a,
194+ SupportedTensorDtypes a_dtypes,
195+ const Tensor& b,
196+ SupportedTensorDtypes b_dtypes,
197+ const Tensor& out,
198+ SupportedTensorDtypes out_dtypes) {
199+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
200+ compute_fun,
201+ ctx,
202+ out,
203+ out_dtypes,
204+ std::make_pair (&a, a_dtypes),
205+ std::make_pair (&b, b_dtypes));
206+ }
207+
150208/* *
151209 * Useful for bi-tensor elementwise operators. For each element of the inputs,
152210 * perform a computation and write to the corresponding element of the output.
153211 * Tensor broadcasting is applied wherever it is required.
154212 */
155- template <typename CTYPE_COMMON, const char * op_name, typename Op>
213+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
156214inline void apply_bitensor_elementwise_fn (
157215 const Op& compute_fun,
158216 KernelRuntimeContext& ctx,
159217 const Tensor& a,
160218 SupportedTensorDtypes a_dtypes,
161219 const Tensor& b,
162220 SupportedTensorDtypes b_dtypes,
221+ const Tensor& out) {
222+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
223+ compute_fun,
224+ ctx,
225+ out,
226+ out_dtypes,
227+ std::make_pair (&a, a_dtypes),
228+ std::make_pair (&b, b_dtypes));
229+ }
230+
231+ /* *
232+ * DEPRECATED: prefer the variant with out_dtypes in the template argument list.
233+ */
234+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
235+ inline void apply_tritensor_elementwise_fn (
236+ const Op& compute_fun,
237+ KernelRuntimeContext& ctx,
238+ const Tensor& a,
239+ SupportedTensorDtypes a_dtypes,
240+ const Tensor& b,
241+ SupportedTensorDtypes b_dtypes,
242+ const Tensor& c,
243+ SupportedTensorDtypes c_dtypes,
163244 const Tensor& out,
164245 SupportedTensorDtypes out_dtypes) {
165246 internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
@@ -168,7 +249,8 @@ inline void apply_bitensor_elementwise_fn(
168249 out,
169250 out_dtypes,
170251 std::make_pair (&a, a_dtypes),
171- std::make_pair (&b, b_dtypes));
252+ std::make_pair (&b, b_dtypes),
253+ std::make_pair (&c, c_dtypes));
172254}
173255
174256/* *
@@ -191,7 +273,7 @@ inline void apply_bitensor_elementwise_fn(
191273 * static constexpr const char op_name[] = "my_op";
192274 * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
193275 */
194- template <typename CTYPE_COMMON, const char * op_name, typename Op>
276+ template <typename CTYPE_COMMON, const char * op_name, SupportedTensorDtypes out_dtypes, typename Op>
195277inline void apply_tritensor_elementwise_fn (
196278 const Op& compute_fun,
197279 KernelRuntimeContext& ctx,
@@ -201,8 +283,7 @@ inline void apply_tritensor_elementwise_fn(
201283 SupportedTensorDtypes b_dtypes,
202284 const Tensor& c,
203285 SupportedTensorDtypes c_dtypes,
204- const Tensor& out,
205- SupportedTensorDtypes out_dtypes) {
286+ const Tensor& out) {
206287 internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
207288 compute_fun,
208289 ctx,
0 commit comments