@@ -44,9 +44,20 @@ def get_configs(path: Path) -> List[str]:
44
44
default = 'configs/config_example.json' ,
45
45
help = 'The path to a configuration file or '
46
46
'a directory that contains configuration files' )
47
+ parser .add_argument ('--device' , '--devices' , default = 'host cpu gpu none' , type = str , nargs = '+' ,
48
+ choices = ('host' , 'cpu' , 'gpu' , 'none' ),
49
+ help = 'Availible execution context devices. '
50
+ 'This parameter only marks devices as available, '
51
+ 'make sure to add the device to the config file '
52
+ 'to run it on a specific device' )
47
53
parser .add_argument ('--dummy-run' , default = False , action = 'store_true' ,
48
54
help = 'Run configuration parser and datasets generation '
49
55
'without benchmarks running' )
56
+ parser .add_argument ('--dtype' , '--dtypes' , type = str , default = "float32 float64" , nargs = '+' ,
57
+ choices = ("float32" , "float64" ),
58
+ help = 'Available floating point data types'
59
+ 'This parameter only marks dtype as available, '
60
+ 'make sure to add the dtype parameter to the config file ' )
50
61
parser .add_argument ('--no-intel-optimized' , default = False , action = 'store_true' ,
51
62
help = 'Use Scikit-learn without Intel optimizations' )
52
63
parser .add_argument ('--output-file' , default = 'results.json' ,
@@ -93,6 +104,28 @@ def get_configs(path: Path) -> List[str]:
93
104
for params_set in config ['cases' ]:
94
105
params = common_params .copy ()
95
106
params .update (params_set .copy ())
107
+
108
+ device = []
109
+ if 'device' not in params :
110
+ if 'sklearn' in params ['lib' ]:
111
+ logging .info ('The device parameter value is not defined in config, '
112
+ 'none is used' )
113
+ device = ['none' ]
114
+ elif not isinstance (params ['device' ], list ):
115
+ device = [params ['device' ]]
116
+ else :
117
+ device = params ['device' ]
118
+ params ["device" ] = [dv for dv in device if dv in args .device ]
119
+
120
+ dtype = []
121
+ if 'dtype' not in params :
122
+ dtype = ['float64' ]
123
+ elif not isinstance (params ['dtype' ], list ):
124
+ dtype = [params ['dtype' ]]
125
+ else :
126
+ dtype = params ['dtype' ]
127
+ params ['dtype' ] = [dt for dt in dtype if dt in args .dtype ]
128
+
96
129
algorithm = params ['algorithm' ]
97
130
libs = params ['lib' ]
98
131
if not isinstance (libs , list ):
0 commit comments