@@ -2142,13 +2142,95 @@ static uint64_t reduce_max_uint64(const uint64_t* data, int nitems) {
21422142
21432143static double reduce_prod_float32_nan_safe (const float * data , int nitems ) {
21442144 if (nitems <= 0 ) return 1.0 ;
2145+ #if defined(__AVX__ ) || defined(__AVX2__ )
2146+ int i = 0 ;
2147+ __m256d vprod0 = _mm256_set1_pd (1.0 );
2148+ __m256d vprod1 = _mm256_set1_pd (1.0 );
2149+ int nan_mask = 0 ;
2150+ const int limit = nitems & ~7 ;
2151+ for (; i < limit ; i += 8 ) {
2152+ __m256 v = _mm256_loadu_ps (data + i );
2153+ nan_mask |= _mm256_movemask_ps (_mm256_cmp_ps (v , v , _CMP_UNORD_Q ));
2154+ __m128 vlow = _mm256_castps256_ps128 (v );
2155+ __m128 vhigh = _mm256_extractf128_ps (v , 1 );
2156+ __m256d vlo = _mm256_cvtps_pd (vlow );
2157+ __m256d vhi = _mm256_cvtps_pd (vhigh );
2158+ vprod0 = _mm256_mul_pd (vprod0 , vlo );
2159+ vprod1 = _mm256_mul_pd (vprod1 , vhi );
2160+ }
2161+ __m256d vprod = _mm256_mul_pd (vprod0 , vprod1 );
2162+ __m128d low = _mm256_castpd256_pd128 (vprod );
2163+ __m128d high = _mm256_extractf128_pd (vprod , 1 );
2164+ __m128d prod128 = _mm_mul_pd (low , high );
2165+ prod128 = _mm_mul_sd (prod128 , _mm_unpackhi_pd (prod128 , prod128 ));
2166+ double acc = _mm_cvtsd_f64 (prod128 );
2167+ if (nan_mask ) return NAN ;
2168+ for (; i < nitems ; i ++ ) {
2169+ double v = (double )data [i ];
2170+ acc *= v ;
2171+ if (v != v ) return v ;
2172+ }
2173+ return acc ;
2174+ #elif defined(__SSE2__ )
2175+ int i = 0 ;
2176+ __m128d vprod0 = _mm_set1_pd (1.0 );
2177+ __m128d vprod1 = _mm_set1_pd (1.0 );
2178+ int nan_mask = 0 ;
2179+ const int limit = nitems & ~3 ;
2180+ for (; i < limit ; i += 4 ) {
2181+ __m128 v = _mm_loadu_ps (data + i );
2182+ nan_mask |= _mm_movemask_ps (_mm_cmpunord_ps (v , v ));
2183+ __m128 vhigh = _mm_movehl_ps (v , v );
2184+ __m128d vlo = _mm_cvtps_pd (v );
2185+ __m128d vhi = _mm_cvtps_pd (vhigh );
2186+ vprod0 = _mm_mul_pd (vprod0 , vlo );
2187+ vprod1 = _mm_mul_pd (vprod1 , vhi );
2188+ }
2189+ __m128d prod128 = _mm_mul_pd (vprod0 , vprod1 );
2190+ prod128 = _mm_mul_sd (prod128 , _mm_unpackhi_pd (prod128 , prod128 ));
2191+ double acc = _mm_cvtsd_f64 (prod128 );
2192+ if (nan_mask ) return NAN ;
2193+ for (; i < nitems ; i ++ ) {
2194+ double v = (double )data [i ];
2195+ acc *= v ;
2196+ if (v != v ) return v ;
2197+ }
2198+ return acc ;
2199+ #elif (defined(__ARM_NEON ) || defined(__ARM_NEON__ )) && defined(__aarch64__ )
2200+ int i = 0 ;
2201+ float64x2_t vprod0 = vdupq_n_f64 (1.0 );
2202+ float64x2_t vprod1 = vdupq_n_f64 (1.0 );
2203+ uint32x4_t vnan = vdupq_n_u32 (0 );
2204+ const int limit = nitems & ~3 ;
2205+ for (; i < limit ; i += 4 ) {
2206+ float32x4_t v = vld1q_f32 (data + i );
2207+ uint32x4_t eq = vceqq_f32 (v , v );
2208+ vnan = vorrq_u32 (vnan , veorq_u32 (eq , vdupq_n_u32 (~0U )));
2209+ float64x2_t vlo = vcvt_f64_f32 (vget_low_f32 (v ));
2210+ float64x2_t vhi = vcvt_f64_f32 (vget_high_f32 (v ));
2211+ vprod0 = vmulq_f64 (vprod0 , vlo );
2212+ vprod1 = vmulq_f64 (vprod1 , vhi );
2213+ }
2214+ float64x2_t vprod = vmulq_f64 (vprod0 , vprod1 );
2215+ double acc = vgetq_lane_f64 (vprod , 0 ) * vgetq_lane_f64 (vprod , 1 );
2216+ uint32x4_t nan_or = vorrq_u32 (vnan , vextq_u32 (vnan , vnan , 2 ));
2217+ nan_or = vorrq_u32 (nan_or , vextq_u32 (nan_or , nan_or , 1 ));
2218+ if (vgetq_lane_u32 (nan_or , 0 )) return NAN ;
2219+ for (; i < nitems ; i ++ ) {
2220+ double v = (double )data [i ];
2221+ acc *= v ;
2222+ if (v != v ) return v ;
2223+ }
2224+ return acc ;
2225+ #else
21452226 double acc = 1.0 ;
21462227 for (int i = 0 ; i < nitems ; i ++ ) {
21472228 double v = (double )data [i ];
21482229 acc *= v ;
21492230 if (v != v ) return v ;
21502231 }
21512232 return acc ;
2233+ #endif
21522234}
21532235
21542236static double reduce_prod_float64_nan_safe (const double * data , int nitems ) {
@@ -2227,15 +2309,97 @@ static double reduce_prod_float64_nan_safe(const double* data, int nitems) {
22272309
22282310static double reduce_sum_float32_nan_safe (const float * data , int nitems ) {
22292311 if (nitems <= 0 ) return 0.0 ;
2312+ #if defined(__AVX__ ) || defined(__AVX2__ )
2313+ int i = 0 ;
2314+ __m256d vsum0 = _mm256_setzero_pd ();
2315+ __m256d vsum1 = _mm256_setzero_pd ();
2316+ int nan_mask = 0 ;
2317+ const int limit = nitems & ~7 ;
2318+ for (; i < limit ; i += 8 ) {
2319+ __m256 v = _mm256_loadu_ps (data + i );
2320+ nan_mask |= _mm256_movemask_ps (_mm256_cmp_ps (v , v , _CMP_UNORD_Q ));
2321+ __m128 vlow = _mm256_castps256_ps128 (v );
2322+ __m128 vhigh = _mm256_extractf128_ps (v , 1 );
2323+ __m256d vlo = _mm256_cvtps_pd (vlow );
2324+ __m256d vhi = _mm256_cvtps_pd (vhigh );
2325+ vsum0 = _mm256_add_pd (vsum0 , vlo );
2326+ vsum1 = _mm256_add_pd (vsum1 , vhi );
2327+ }
2328+ __m256d vsum = _mm256_add_pd (vsum0 , vsum1 );
2329+ __m128d low = _mm256_castpd256_pd128 (vsum );
2330+ __m128d high = _mm256_extractf128_pd (vsum , 1 );
2331+ __m128d sum128 = _mm_add_pd (low , high );
2332+ sum128 = _mm_add_sd (sum128 , _mm_unpackhi_pd (sum128 , sum128 ));
2333+ double acc = _mm_cvtsd_f64 (sum128 );
2334+ if (nan_mask ) return NAN ;
2335+ for (; i < nitems ; i ++ ) {
2336+ double v = (double )data [i ];
2337+ acc += v ;
2338+ if (v != v ) return v ;
2339+ }
2340+ return acc ;
2341+ #elif defined(__SSE2__ )
2342+ int i = 0 ;
2343+ __m128d vsum0 = _mm_setzero_pd ();
2344+ __m128d vsum1 = _mm_setzero_pd ();
2345+ int nan_mask = 0 ;
2346+ const int limit = nitems & ~3 ;
2347+ for (; i < limit ; i += 4 ) {
2348+ __m128 v = _mm_loadu_ps (data + i );
2349+ nan_mask |= _mm_movemask_ps (_mm_cmpunord_ps (v , v ));
2350+ __m128 vhigh = _mm_movehl_ps (v , v );
2351+ __m128d vlo = _mm_cvtps_pd (v );
2352+ __m128d vhi = _mm_cvtps_pd (vhigh );
2353+ vsum0 = _mm_add_pd (vsum0 , vlo );
2354+ vsum1 = _mm_add_pd (vsum1 , vhi );
2355+ }
2356+ __m128d sum128 = _mm_add_pd (vsum0 , vsum1 );
2357+ sum128 = _mm_add_sd (sum128 , _mm_unpackhi_pd (sum128 , sum128 ));
2358+ double acc = _mm_cvtsd_f64 (sum128 );
2359+ if (nan_mask ) return NAN ;
2360+ for (; i < nitems ; i ++ ) {
2361+ double v = (double )data [i ];
2362+ acc += v ;
2363+ if (v != v ) return v ;
2364+ }
2365+ return acc ;
2366+ #elif (defined(__ARM_NEON ) || defined(__ARM_NEON__ )) && defined(__aarch64__ )
2367+ int i = 0 ;
2368+ float64x2_t vsum0 = vdupq_n_f64 (0.0 );
2369+ float64x2_t vsum1 = vdupq_n_f64 (0.0 );
2370+ uint32x4_t vnan = vdupq_n_u32 (0 );
2371+ const int limit = nitems & ~3 ;
2372+ for (; i < limit ; i += 4 ) {
2373+ float32x4_t v = vld1q_f32 (data + i );
2374+ uint32x4_t eq = vceqq_f32 (v , v );
2375+ vnan = vorrq_u32 (vnan , veorq_u32 (eq , vdupq_n_u32 (~0U )));
2376+ float64x2_t vlo = vcvt_f64_f32 (vget_low_f32 (v ));
2377+ float64x2_t vhi = vcvt_f64_f32 (vget_high_f32 (v ));
2378+ vsum0 = vaddq_f64 (vsum0 , vlo );
2379+ vsum1 = vaddq_f64 (vsum1 , vhi );
2380+ }
2381+ double acc = vaddvq_f64 (vaddq_f64 (vsum0 , vsum1 ));
2382+ uint32x4_t nan_or = vorrq_u32 (vnan , vextq_u32 (vnan , vnan , 2 ));
2383+ nan_or = vorrq_u32 (nan_or , vextq_u32 (nan_or , nan_or , 1 ));
2384+ if (vgetq_lane_u32 (nan_or , 0 )) return NAN ;
2385+ for (; i < nitems ; i ++ ) {
2386+ double v = (double )data [i ];
2387+ acc += v ;
2388+ if (v != v ) return v ;
2389+ }
2390+ return acc ;
2391+ #else
22302392 double acc = 0.0 ;
22312393 for (int i = 0 ; i < nitems ; i ++ ) {
22322394 double v = (double )data [i ];
22332395 acc += v ;
22342396 if (v != v ) return v ;
22352397 }
22362398 return acc ;
2399+ #endif
22372400}
22382401
2402+
22392403static double reduce_sum_float64_nan_safe (const double * data , int nitems ) {
22402404 if (nitems <= 0 ) return 0.0 ;
22412405#if defined(__AVX__ ) || defined(__AVX2__ )
@@ -2310,6 +2474,90 @@ static double reduce_sum_float64_nan_safe(const double* data, int nitems) {
23102474#endif
23112475}
23122476
2477+ static int64_t reduce_sum_int32 (const int32_t * data , int nitems ) {
2478+ if (nitems <= 0 ) return 0 ;
2479+ #if defined(__AVX2__ )
2480+ int i = 0 ;
2481+ __m256i acc0 = _mm256_setzero_si256 ();
2482+ __m256i acc1 = _mm256_setzero_si256 ();
2483+ const int limit = nitems & ~7 ;
2484+ for (; i < limit ; i += 8 ) {
2485+ __m256i v = _mm256_loadu_si256 ((const __m256i * )(data + i ));
2486+ __m128i vlow = _mm256_castsi256_si128 (v );
2487+ __m128i vhigh = _mm256_extracti128_si256 (v , 1 );
2488+ __m256i vlow64 = _mm256_cvtepi32_epi64 (vlow );
2489+ __m256i vhigh64 = _mm256_cvtepi32_epi64 (vhigh );
2490+ acc0 = _mm256_add_epi64 (acc0 , vlow64 );
2491+ acc1 = _mm256_add_epi64 (acc1 , vhigh64 );
2492+ }
2493+ acc0 = _mm256_add_epi64 (acc0 , acc1 );
2494+ int64_t tmp [4 ];
2495+ _mm256_storeu_si256 ((__m256i * )tmp , acc0 );
2496+ int64_t acc = tmp [0 ] + tmp [1 ] + tmp [2 ] + tmp [3 ];
2497+ for (; i < nitems ; i ++ ) {
2498+ acc += data [i ];
2499+ }
2500+ return acc ;
2501+ #else
2502+ int64_t acc = 0 ;
2503+ for (int i = 0 ; i < nitems ; i ++ ) {
2504+ acc += data [i ];
2505+ }
2506+ return acc ;
2507+ #endif
2508+ }
2509+
2510+ static uint64_t reduce_sum_uint32 (const uint32_t * data , int nitems ) {
2511+ if (nitems <= 0 ) return 0 ;
2512+ #if defined(__AVX2__ )
2513+ int i = 0 ;
2514+ __m256i acc0 = _mm256_setzero_si256 ();
2515+ __m256i acc1 = _mm256_setzero_si256 ();
2516+ const int limit = nitems & ~7 ;
2517+ for (; i < limit ; i += 8 ) {
2518+ __m256i v = _mm256_loadu_si256 ((const __m256i * )(data + i ));
2519+ __m128i vlow = _mm256_castsi256_si128 (v );
2520+ __m128i vhigh = _mm256_extracti128_si256 (v , 1 );
2521+ __m256i vlow64 = _mm256_cvtepu32_epi64 (vlow );
2522+ __m256i vhigh64 = _mm256_cvtepu32_epi64 (vhigh );
2523+ acc0 = _mm256_add_epi64 (acc0 , vlow64 );
2524+ acc1 = _mm256_add_epi64 (acc1 , vhigh64 );
2525+ }
2526+ acc0 = _mm256_add_epi64 (acc0 , acc1 );
2527+ uint64_t tmp [4 ];
2528+ _mm256_storeu_si256 ((__m256i * )tmp , acc0 );
2529+ uint64_t acc = tmp [0 ] + tmp [1 ] + tmp [2 ] + tmp [3 ];
2530+ for (; i < nitems ; i ++ ) {
2531+ acc += data [i ];
2532+ }
2533+ return acc ;
2534+ #elif (defined(__ARM_NEON ) || defined(__ARM_NEON__ )) && defined(__aarch64__ )
2535+ int i = 0 ;
2536+ uint64x2_t acc0 = vdupq_n_u64 (0 );
2537+ uint64x2_t acc1 = vdupq_n_u64 (0 );
2538+ const int limit = nitems & ~3 ;
2539+ for (; i < limit ; i += 4 ) {
2540+ uint32x4_t v = vld1q_u32 (data + i );
2541+ uint64x2_t lo = vmovl_u32 (vget_low_u32 (v ));
2542+ uint64x2_t hi = vmovl_u32 (vget_high_u32 (v ));
2543+ acc0 = vaddq_u64 (acc0 , lo );
2544+ acc1 = vaddq_u64 (acc1 , hi );
2545+ }
2546+ uint64x2_t accv = vaddq_u64 (acc0 , acc1 );
2547+ uint64_t acc = vgetq_lane_u64 (accv , 0 ) + vgetq_lane_u64 (accv , 1 );
2548+ for (; i < nitems ; i ++ ) {
2549+ acc += data [i ];
2550+ }
2551+ return acc ;
2552+ #else
2553+ uint64_t acc = 0 ;
2554+ for (int i = 0 ; i < nitems ; i ++ ) {
2555+ acc += data [i ];
2556+ }
2557+ return acc ;
2558+ #endif
2559+ }
2560+
23132561static double comma (double a , double b ) {
23142562 (void )a ;
23152563 return b ;
@@ -4953,7 +5201,7 @@ static void eval_reduction(const me_expr* n, int output_nitems) {
49535201 for (int i = 0 ; i < nitems ; i ++ ) acc *= data [i ];
49545202 }
49555203 else {
4956- for ( int i = 0 ; i < nitems ; i ++ ) acc += data [ i ] ;
5204+ acc = reduce_sum_int32 ( data , nitems ) ;
49575205 }
49585206 ((int64_t * )write_ptr )[0 ] = acc ;
49595207 }
@@ -5105,7 +5353,7 @@ static void eval_reduction(const me_expr* n, int output_nitems) {
51055353 for (int i = 0 ; i < nitems ; i ++ ) acc *= data [i ];
51065354 }
51075355 else {
5108- for ( int i = 0 ; i < nitems ; i ++ ) acc += data [ i ] ;
5356+ acc = reduce_sum_uint32 ( data , nitems ) ;
51095357 }
51105358 ((uint64_t * )write_ptr )[0 ] = acc ;
51115359 }
0 commit comments