Skip to content

Commit 0cfa50b

Browse files
authored
Merge pull request #458 from howjmay/vqdmlal_high_lane
feat: Add vqdmlal_high_lane_[s16|s32]
2 parents b5f1a82 + 7de1c40 commit 0cfa50b

File tree

3 files changed

+72
-10
lines changed

3 files changed

+72
-10
lines changed

neon2rvv.h

+18-6
Original file line numberDiff line numberDiff line change
@@ -8993,15 +8993,15 @@ FORCE_INLINE uint64x2_t vmlal_high_laneq_u32(uint64x2_t a, uint32x4_t b, uint32x
89938993
__riscv_vwmaccu_vv_u64m2(__riscv_vlmul_ext_v_u64m1_u64m2(a), b_high, c_dup, 2));
89948994
}
89958995

8996-
FORCE_INLINE int32x4_t vqdmlal_lane_s16(int32x4_t a, int16x4_t b, int16x4_t c, const int __d) {
8997-
vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, __d, 4);
8996+
FORCE_INLINE int32x4_t vqdmlal_lane_s16(int32x4_t a, int16x4_t b, int16x4_t c, const int lane) {
8997+
vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, lane, 4);
89988998
vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b, c_dup, 4));
89998999
vint32m1_t bc_mulx2 = __riscv_vmul_vx_i32m1(bc_mul, 2, 4);
90009000
return __riscv_vadd_vv_i32m1(a, bc_mulx2, 4);
90019001
}
90029002

9003-
FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, const int __d) {
9004-
vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, __d, 2);
9003+
FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, const int lane) {
9004+
vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, lane, 2);
90059005
vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b, c_dup, 2));
90069006
vint64m1_t bc_mulx2 = __riscv_vmul_vx_i64m1(bc_mul, 2, 2);
90079007
return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2);
@@ -9011,9 +9011,21 @@ FORCE_INLINE int64x2_t vqdmlal_lane_s32(int64x2_t a, int32x2_t b, int32x2_t c, c
90119011

90129012
// FORCE_INLINE int64_t vqdmlals_lane_s32(int64_t a, int32_t b, int32x2_t v, const int lane);
90139013

9014-
// FORCE_INLINE int32x4_t vqdmlal_high_lane_s16(int32x4_t a, int16x8_t b, int16x4_t v, const int lane);
9014+
FORCE_INLINE int32x4_t vqdmlal_high_lane_s16(int32x4_t a, int16x8_t b, int16x4_t c, const int lane) {
9015+
vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8);
9016+
vint16m1_t c_dup = __riscv_vrgather_vx_i16m1(c, lane, 4);
9017+
vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b_high, c_dup, 4));
9018+
vint32m1_t bc_mulx2 = __riscv_vmul_vx_i32m1(bc_mul, 2, 4);
9019+
return __riscv_vadd_vv_i32m1(a, bc_mulx2, 4);
9020+
}
90159021

9016-
// FORCE_INLINE int64x2_t vqdmlal_high_lane_s32(int64x2_t a, int32x4_t b, int32x2_t v, const int lane);
9022+
FORCE_INLINE int64x2_t vqdmlal_high_lane_s32(int64x2_t a, int32x4_t b, int32x2_t c, const int lane) {
9023+
vint32m1_t b_high = __riscv_vslidedown_vx_i32m1(b, 2, 4);
9024+
vint32m1_t c_dup = __riscv_vrgather_vx_i32m1(c, lane, 2);
9025+
vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b_high, c_dup, 2));
9026+
vint64m1_t bc_mulx2 = __riscv_vmul_vx_i64m1(bc_mul, 2, 2);
9027+
return __riscv_vadd_vv_i64m1(a, bc_mulx2, 2);
9028+
}
90179029

90189030
// FORCE_INLINE int32x4_t vqdmlal_laneq_s16(int32x4_t a, int16x4_t b, int16x8_t v, const int lane);
90199031

tests/impl.cpp

+52-2
Original file line numberDiff line numberDiff line change
@@ -31892,9 +31892,59 @@ result_t test_vqdmlalh_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
3189231892

3189331893
result_t test_vqdmlals_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
3189431894

31895-
result_t test_vqdmlal_high_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
31895+
result_t test_vqdmlal_high_lane_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
31896+
#ifdef ENABLE_TEST_ALL
31897+
const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1;
31898+
const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2;
31899+
const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3;
31900+
int32x4_t a = vld1q_s32(_a);
31901+
int16x8_t b = vld1q_s16(_b);
31902+
int16x4_t c = vld1_s16(_c);
31903+
;
31904+
int32x4_t d;
31905+
int32_t _d[4];
31906+
#define TEST_IMPL(IDX) \
31907+
for (int i = 0; i < 4; i++) { \
31908+
_d[i] = sat_add(_a[i], sat_dmull(_b[i + 4], _c[IDX])); \
31909+
} \
31910+
d = vqdmlal_high_lane_s16(a, b, c, IDX); \
31911+
CHECK_RESULT(validate_int32(d, _d[0], _d[1], _d[2], _d[3]))
31912+
31913+
IMM_4_ITER
31914+
#undef TEST_IMPL
3189631915

31897-
result_t test_vqdmlal_high_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
31916+
return TEST_SUCCESS;
31917+
#else
31918+
return TEST_UNIMPL;
31919+
#endif // ENABLE_TEST_ALL
31920+
}
31921+
31922+
result_t test_vqdmlal_high_lane_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
31923+
#ifdef ENABLE_TEST_ALL
31924+
const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1;
31925+
const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2;
31926+
const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3;
31927+
int64x2_t a = vld1q_s64(_a);
31928+
int32x4_t b = vld1q_s32(_b);
31929+
int32x2_t c = vld1_s32(_c);
31930+
int64x2_t d;
31931+
int64_t _d[2];
31932+
31933+
#define TEST_IMPL(IDX) \
31934+
for (int i = 0; i < 2; i++) { \
31935+
_d[i] = sat_add(_a[i], sat_dmull(_b[i + 2], _c[IDX])); \
31936+
} \
31937+
d = vqdmlal_high_lane_s32(a, b, c, IDX); \
31938+
CHECK_RESULT(validate_int64(d, _d[0], _d[1]))
31939+
31940+
IMM_2_ITER
31941+
#undef TEST_IMPL
31942+
31943+
return TEST_SUCCESS;
31944+
#else
31945+
return TEST_UNIMPL;
31946+
#endif // ENABLE_TEST_ALL
31947+
}
3189831948

3189931949
result_t test_vqdmlal_laneq_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
3190031950

tests/impl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1950,8 +1950,8 @@
19501950
_(vqdmlal_lane_s32) \
19511951
/*_(vqdmlalh_lane_s16) */ \
19521952
/*_(vqdmlals_lane_s32) */ \
1953-
/*_(vqdmlal_high_lane_s16) */ \
1954-
/*_(vqdmlal_high_lane_s32) */ \
1953+
_(vqdmlal_high_lane_s16) \
1954+
_(vqdmlal_high_lane_s32) \
19551955
/*_(vqdmlal_laneq_s16) */ \
19561956
/*_(vqdmlal_laneq_s32) */ \
19571957
/*_(vqdmlalh_laneq_s16) */ \

0 commit comments

Comments
 (0)