@@ -225,17 +225,27 @@ template <>
225225struct TensorCheck <Int4x2> {
226226 void operator ()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params,
227227 const std::string& /* provider_type*/ ) const {
228- ORT_UNUSED_PARAMETER (params);
228+ const bool has_abs_err = params.absolute_error .has_value ();
229+ Tensor expected_sorted, actual_sorted;
229230 const Int4x2* cur_expected;
230231 const Int4x2* cur_actual;
231232 const auto size = narrow<size_t >(actual.Shape ().Size ());
232233 cur_expected = expected.Data <Int4x2>();
233234 cur_actual = actual.Data <Int4x2>();
235+ double threshold = 0 .0f ;
236+ if (has_abs_err) {
237+ threshold = *(params.absolute_error );
238+ }
234239
235240 for (size_t i = 0 ; i < size; ++i) {
236241 size_t r = i >> 1 ;
237242 size_t c = i & 0x1 ;
238- EXPECT_EQ (cur_expected[r].GetElem (c), cur_actual[r].GetElem (c)) << " i:" << i;
243+ // TODO: the relative error is not used for int4 yet.
244+ if (has_abs_err) {
245+ EXPECT_NEAR (cur_expected[r].GetElem (c), cur_actual[r].GetElem (c), threshold) << " i:" << i;
246+ } else {
247+ EXPECT_EQ (cur_expected[r].GetElem (c), cur_actual[r].GetElem (c)) << " i:" << i;
248+ }
239249 }
240250 }
241251};
@@ -244,17 +254,28 @@ template <>
244254struct TensorCheck <UInt4x2> {
245255 void operator ()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params,
246256 const std::string& /* provider_type*/ ) const {
247- ORT_UNUSED_PARAMETER (params);
257+ const bool has_abs_err = params.absolute_error .has_value ();
258+ Tensor expected_sorted, actual_sorted;
248259 const UInt4x2* cur_expected;
249260 const UInt4x2* cur_actual;
250261 const auto size = narrow<size_t >(actual.Shape ().Size ());
251262 cur_expected = expected.Data <UInt4x2>();
252263 cur_actual = actual.Data <UInt4x2>();
253264
254- for (size_t i = 0 ; i < size; ++i) {
265+ double threshold = 0 .0f ;
266+ if (has_abs_err) {
267+ threshold = *(params.absolute_error );
268+ }
269+
270+ for (size_t i = 0 ; i < static_cast <size_t >(size); ++i) {
255271 size_t r = i >> 1 ;
256272 size_t c = i & 0x1 ;
257- EXPECT_EQ (cur_expected[r].GetElem (c), cur_actual[r].GetElem (c)) << " i:" << i;
273+ // TODO: the relative error is not used for int4 yet.
274+ if (has_abs_err) {
275+ EXPECT_NEAR (cur_expected[r].GetElem (c), cur_actual[r].GetElem (c), threshold) << " i:" << i;
276+ } else {
277+ EXPECT_EQ (cur_expected[r].GetElem (c), cur_actual[r].GetElem (c)) << " i:" << i;
278+ }
258279 }
259280 }
260281};
@@ -292,7 +313,7 @@ struct TensorCheck<uint8_t> {
292313 // For any other EPs, we still expect an exact match for the results
293314 // TODO: Verify if DML can possibly have a ROUNDING_MODE parameter and conform to the other EPs #41968513
294315 if ((provider_type == kNnapiExecutionProvider || provider_type == kDmlExecutionProvider ||
295- provider_type == kXnnpackExecutionProvider ) &&
316+ provider_type == kXnnpackExecutionProvider || provider_type == kOpenVINOExecutionProvider ) &&
296317 (has_abs_err || has_rel_err)) {
297318 double threshold = has_abs_err ? *(params.absolute_error )
298319 : 0.0 ;
@@ -357,6 +378,49 @@ struct TensorCheck<int8_t> {
357378 }
358379};
359380
381+ template <>
382+ struct TensorCheck <uint16_t > {
383+ void operator ()(const Tensor& expected,
384+ const Tensor& actual,
385+ const ValidateOutputParams& params,
386+ const std::string& ) const {
387+ const bool has_abs_err = params.absolute_error .has_value ();
388+ const bool has_rel_err = params.relative_error .has_value ();
389+
390+ Tensor expected_sorted, actual_sorted;
391+ const uint16_t * cur_expected;
392+ const uint16_t * cur_actual;
393+ const auto size = actual.Shape ().Size ();
394+ if (params.sort_output ) {
395+ sort_expected_and_actual_buffers<uint16_t >(expected, expected_sorted, actual, actual_sorted);
396+ cur_expected = expected_sorted.Data <uint16_t >();
397+ cur_actual = actual_sorted.Data <uint16_t >();
398+ } else {
399+ cur_expected = expected.Data <uint16_t >();
400+ cur_actual = actual.Data <uint16_t >();
401+ }
402+
403+ if (has_abs_err || has_rel_err) {
404+ double threshold = has_abs_err ? *(params.absolute_error )
405+ : 0.0 ;
406+
407+ for (int64_t i = 0 ; i < size; ++i) {
408+ if (has_rel_err) {
409+ EXPECT_NEAR (cur_expected[i], cur_actual[i],
410+ *(params.relative_error ) * cur_expected[i]) // expected[i] is unsigned, can't be negative
411+ << " i:" << i;
412+ } else { // has_abs_err
413+ EXPECT_NEAR (cur_expected[i], cur_actual[i], threshold) << " i:" << i;
414+ }
415+ }
416+ } else {
417+ for (int64_t i = 0 ; i < size; ++i) {
418+ EXPECT_EQ (cur_expected[i], cur_actual[i]) << " i:" << i;
419+ }
420+ }
421+ }
422+ };
423+
360424template <>
361425struct TensorCheck <double > {
362426 void operator ()(const Tensor& expected,
0 commit comments