@@ -971,16 +971,33 @@ def onp_fun(lhs, rhs):
971971 self ._CompileAndCheck (lnp_fun , args_maker , check_dtypes = False , atol = tol ,
972972 rtol = tol , check_incomplete_shape = True )
973973
974- @named_parameters (jtu .cases_from_list (
975- {"testcase_name" : "_{}_amin={}_amax={}" .format (
976- jtu .format_shape_dtype_string (shape , dtype ), a_min , a_max ),
977- "shape" : shape , "dtype" : dtype , "a_min" : a_min , "a_max" : a_max ,
978- "rng_factory" : jtu .rand_default }
979- for shape in all_shapes for dtype in minus (number_dtypes , complex_dtypes )
980- for a_min , a_max in [(- 1 , None ), (None , 1 ), (- 1 , 1 ),
981- (- onp .ones (1 ), None ),
982- (None , onp .ones (1 )),
983- (- onp .ones (1 ), onp .ones (1 ))]))
974+ @named_parameters (
975+ jtu .cases_from_list (
976+ {
977+ "testcase_name" : "_{}_amin={}_amax={}" .format (
978+ jtu .format_shape_dtype_string (shape , dtype ), a_min , a_max
979+ ),
980+ "shape" : shape ,
981+ "dtype" : dtype ,
982+ "a_min" : a_min ,
983+ "a_max" : a_max ,
984+ "rng_factory" : jtu .rand_default ,
985+ }
986+ for shape in all_shapes
987+ for dtype in minus (number_dtypes , complex_dtypes )
988+ for a_min , a_max in [
989+ (- 1 , None ),
990+ (None , 1 ),
991+ (- onp .ones (1 ), None ),
992+ (None , onp .ones (1 )),
993+ ]
994+ + (
995+ []
996+ if onp .__version__ >= onp .lib .NumpyVersion ("2.0.0" )
997+ else [(- 1 , 1 ), (- onp .ones (1 ), onp .ones (1 ))]
998+ )
999+ )
1000+ )
9841001 def testClipStaticBounds (self , shape , dtype , a_min , a_max , rng_factory ):
9851002 rng = rng_factory ()
9861003 onp_fun = lambda x : onp .clip (x , a_min = a_min , a_max = a_max )
@@ -1357,7 +1374,6 @@ def testDiagIndices(self, ndim, n):
13571374 onp .testing .assert_equal (onp .diag_indices (n , ndim ),
13581375 lnp .diag_indices (n , ndim ))
13591376
1360-
13611377 @named_parameters (jtu .cases_from_list (
13621378 {"testcase_name" : "_shape={}_k={}" .format (
13631379 jtu .format_shape_dtype_string (shape , dtype ), k ),
@@ -1951,7 +1967,6 @@ def testFlipud(self, shape, dtype, rng_factory):
19511967 self ._CompileAndCheck (
19521968 lnp_op , args_maker , check_dtypes = True , check_incomplete_shape = True )
19531969
1954-
19551970 @named_parameters (jtu .cases_from_list (
19561971 {"testcase_name" : "_{}" .format (
19571972 jtu .format_shape_dtype_string (shape , dtype )),
@@ -1968,7 +1983,6 @@ def testFliplr(self, shape, dtype, rng_factory):
19681983 self ._CompileAndCheck (
19691984 lnp_op , args_maker , check_dtypes = True , check_incomplete_shape = True )
19701985
1971-
19721986 @named_parameters (jtu .cases_from_list (
19731987 {"testcase_name" : "_{}_k={}_axes={}" .format (
19741988 jtu .format_shape_dtype_string (shape , dtype ), k , axes ),
@@ -2295,7 +2309,6 @@ def onp_fun(*args):
22952309 tol = tol )
22962310 self ._CompileAndCheck (lnp_fun , args_maker , check_dtypes = True , rtol = tol )
22972311
2298-
22992312 @named_parameters (jtu .cases_from_list (
23002313 {"testcase_name" : "_shape={}" .format (
23012314 jtu .format_shape_dtype_string (shape , dtype )),
@@ -2318,7 +2331,6 @@ def testWhereOneArgument(self, shape, dtype):
23182331 check_unknown_rank = False ,
23192332 check_experimental_compile = False , check_xla_forced_compile = False )
23202333
2321-
23222334 @named_parameters (jtu .cases_from_list (
23232335 {"testcase_name" : "_{}" .format ("_" .join (
23242336 jtu .format_shape_dtype_string (shape , dtype )
@@ -2373,7 +2385,6 @@ def onp_fun(condlist, choicelist, default):
23732385 check_incomplete_shape = True ,
23742386 rtol = {onp .float64 : 1e-7 , onp .complex128 : 1e-7 })
23752387
2376-
23772388 @jtu .disable
23782389 def testIssue330 (self ):
23792390 x = lnp .full ((1 , 1 ), lnp .array ([1 ])[0 ]) # doesn't crash
@@ -2429,7 +2440,6 @@ def testAtLeastNdLiterals(self, pytype, dtype, op):
24292440 self ._CompileAndCheck (
24302441 lnp_fun , args_maker , check_dtypes = True , check_incomplete_shape = True )
24312442
2432-
24332443 def testLongLong (self ):
24342444 self .assertAllClose (
24352445 onp .int64 (7 ), npe .jit (lambda x : x )(onp .longlong (7 )), check_dtypes = True )
@@ -2676,19 +2686,38 @@ def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory):
26762686
26772687 @named_parameters (
26782688 jtu .cases_from_list (
2679- {"testcase_name" : ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
2680- "_retstep={}_dtype={}" ).format (
2681- start_shape , stop_shape , num , endpoint , retstep , dtype ),
2682- "start_shape" : start_shape , "stop_shape" : stop_shape ,
2683- "num" : num , "endpoint" : endpoint , "retstep" : retstep ,
2684- "dtype" : dtype , "rng_factory" : rng_factory }
2685- for start_shape in [(), (2 ,), (2 , 2 )]
2686- for stop_shape in [(), (2 ,), (2 , 2 )]
2687- for num in [0 , 1 , 2 , 5 , 20 ]
2688- for endpoint in [True , False ]
2689- for retstep in [True , False ]
2690- for dtype in number_dtypes + [None ,]
2691- for rng_factory in [jtu .rand_default ]))
2689+ {
2690+ "testcase_name" : (
2691+ "_start_shape={}_stop_shape={}_num={}_endpoint={}"
2692+ "_retstep={}_dtype={}"
2693+ ).format (start_shape , stop_shape , num , endpoint , retstep , dtype ),
2694+ "start_shape" : start_shape ,
2695+ "stop_shape" : stop_shape ,
2696+ "num" : num ,
2697+ "endpoint" : endpoint ,
2698+ "retstep" : retstep ,
2699+ "dtype" : dtype ,
2700+ "rng_factory" : rng_factory ,
2701+ }
2702+ for start_shape in [(), (2 ,), (2 , 2 )]
2703+ for stop_shape in [(), (2 ,), (2 , 2 )]
2704+ for num in [0 , 1 , 2 , 5 , 20 ]
2705+ for endpoint in [True , False ]
2706+ for retstep in [True , False ]
2707+ for dtype in (
2708+ (
2709+ float_dtypes
2710+ + complex_dtypes
2711+ + [
2712+ None ,
2713+ ]
2714+ )
2715+ if onp .__version__ >= onp .lib .NumpyVersion ("2.0.0" )
2716+ else (number_dtypes + [None ])
2717+ )
2718+ for rng_factory in [jtu .rand_default ]
2719+ )
2720+ )
26922721 def testLinspace (self , start_shape , stop_shape , num , endpoint ,
26932722 retstep , dtype , rng_factory ):
26942723 if not endpoint and onp .issubdtype (dtype , onp .integer ):
@@ -2770,20 +2799,40 @@ def testLogspace(self, start_shape, stop_shape, num,
27702799
27712800 @named_parameters (
27722801 jtu .cases_from_list (
2773- {"testcase_name" : ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
2774- "_dtype={}" ).format (
2775- start_shape , stop_shape , num , endpoint , dtype ),
2776- "start_shape" : start_shape ,
2777- "stop_shape" : stop_shape ,
2778- "num" : num , "endpoint" : endpoint ,
2779- "dtype" : dtype , "rng_factory" : rng_factory }
2780- for start_shape in [(), (2 ,), (2 , 2 )]
2781- for stop_shape in [(), (2 ,), (2 , 2 )]
2782- for num in [0 , 1 , 2 , 5 , 20 ]
2783- for endpoint in [True , False ]
2784- # NB: numpy's geomspace gives nonsense results on integer types
2785- for dtype in inexact_dtypes + [None ,]
2786- for rng_factory in [jtu .rand_default ]))
2802+ {
2803+ "testcase_name" : (
2804+ "_start_shape={}_stop_shape={}_num={}_endpoint={}_dtype={}"
2805+ ).format (start_shape , stop_shape , num , endpoint , dtype ),
2806+ "start_shape" : start_shape ,
2807+ "stop_shape" : stop_shape ,
2808+ "num" : num ,
2809+ "endpoint" : endpoint ,
2810+ "dtype" : dtype ,
2811+ "rng_factory" : rng_factory ,
2812+ }
2813+ for start_shape in [(), (2 ,), (2 , 2 )]
2814+ for stop_shape in [(), (2 ,), (2 , 2 )]
2815+ for num in [0 , 1 , 2 , 5 , 20 ]
2816+ for endpoint in [True , False ]
2817+ # NB: numpy's geomspace gives nonsense results on integer types
2818+ for dtype in (
2819+ (
2820+ float_dtypes
2821+ + [
2822+ None ,
2823+ ]
2824+ )
2825+ if onp .__version__ >= onp .lib .NumpyVersion ("2.0.0" )
2826+ else (
2827+ inexact_dtypes
2828+ + [
2829+ None ,
2830+ ]
2831+ )
2832+ )
2833+ for rng_factory in [jtu .rand_default ]
2834+ )
2835+ )
27872836 def testGeomspace (self , start_shape , stop_shape , num ,
27882837 endpoint , dtype , rng_factory ):
27892838 rng = rng_factory ()
0 commit comments