99from modelspec .base_types import print_
1010from modelspec .base_types import EvaluableExpression
1111
12+ from random import Random
13+ from typing import Union
14+
1215verbose = False
1316
1417
15- def load_json (filename ):
18+ def load_json (filename : str ):
1619 """
1720 Load a generic JSON file
21+
22+ Args:
23+ filename: The name of the JSON file to load
1824 """
1925
2026 with open (filename ) as f :
@@ -23,19 +29,25 @@ def load_json(filename):
2329 return data
2430
2531
26- def load_yaml (filename ):
32+ def load_yaml (filename : str ):
2733 """
2834 Load a generic YAML file
35+
36+ Args:
37+ filename: The name of the YAML file to load
2938 """
3039 with open (filename ) as f :
3140 data = yaml .load (f , Loader = yaml .SafeLoader )
3241
3342 return data
3443
3544
36- def load_bson (filename ):
45+ def load_bson (filename : str ):
3746 """
3847 Load a generic BSON file
48+
49+ Args:
50+ filename: The name of the BSON file to load
3951 """
4052 with open (filename , "rb" ) as infile :
4153 data_encoded = infile .read ()
@@ -211,11 +223,26 @@ def _params_info(parameters, multiline=False):
211223FORMAT_TENSORFLOW = "tensorflow"
212224
213225
214- def evaluate (expr , parameters = {}, rng = None , array_format = FORMAT_NUMPY , verbose = False ):
226+ def evaluate (
227+ expr : Union [int , float , str , list , dict ],
228+ parameters : dict = {},
229+ rng : Random = None ,
230+ array_format : str = FORMAT_NUMPY ,
231+ verbose : bool = False ,
232+ cast_to_int : bool = False ,
233+ ):
215234 """
216235 Evaluate a general string like expression (e.g. "2 * weight") using a dict
217236 of parameters (e.g. {'weight':10}). Returns floats, ints, etc. if that's what's
218237 given in expr
238+
239+ Args:
240+ expr: The expression to convert
241+ parameters: A dict of the parameters which can be substituted in to the expression
242+ rng: The random number generator to use
243+ array_format: numpy or tensorflow
244+ verbose: Print the calculations
245+ cast_to_int: return an int for float/string values if castable
219246 """
220247
221248 if array_format == FORMAT_TENSORFLOW :
@@ -233,7 +260,7 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
233260 expr
234261 ] # replace with the value in parameters & check whether it's float/int...
235262 if verbose :
236- print_ ("Using for that param: %s" % _val_info (expr ), verbose )
263+ print_ (" Using for that param: %s" % _val_info (expr ), verbose )
237264
238265 if type (expr ) == str :
239266 try :
@@ -242,26 +269,28 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
242269 else :
243270 expr = int (expr )
244271 except :
245- pass
246- try :
247- if array_format == FORMAT_TENSORFLOW :
248- expr = tf .constant (float (expr ))
249- else :
250- expr = float (expr )
251- except :
252- pass
272+
273+ try :
274+ if array_format == FORMAT_TENSORFLOW :
275+ expr = tf .constant (float (expr ))
276+ else :
277+ expr = float (expr )
278+ except :
279+ pass
253280
254281 if type (expr ) == list :
255282 if verbose :
256- print_ ("Returning a list in format: %s" % array_format , verbose )
283+ print_ (" Returning a list in format: %s" % array_format , verbose )
257284 if array_format == FORMAT_TENSORFLOW :
258285 return tf .constant (expr , dtype = tf .float64 )
259286 else :
260287 return np .array (expr )
261288
262289 if type (expr ) == np .ndarray :
263290 if verbose :
264- print_ ("Returning a numpy array in format: %s" % array_format , verbose )
291+ print_ (
292+ " Returning a numpy array in format: %s" % array_format , verbose
293+ )
265294 if array_format == FORMAT_TENSORFLOW :
266295 return tf .convert_to_tensor (expr , dtype = tf .float64 )
267296 else :
@@ -270,22 +299,22 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
270299 if "Tensor" in type (expr ).__name__ :
271300 if verbose :
272301 print_ (
273- "Returning a tensorflow Tensor in format: %s" % array_format ,
302+ " Returning a tensorflow Tensor in format: %s" % array_format ,
274303 verbose ,
275304 )
276305 if array_format == FORMAT_NUMPY :
277306 return expr .numpy ()
278307 else :
279308 return expr
280309
281- if int (expr ) == expr :
310+ if int (expr ) == expr and cast_to_int :
282311 if verbose :
283- print_ ("Returning int: %s" % int (expr ), verbose )
312+ print_ (" Returning int: %s" % int (expr ), verbose )
284313 return int (expr )
285314 else : # will have failed if not number
286315 if verbose :
287- print_ ("Returning float: %s" % expr , verbose )
288- return float ( expr )
316+ print_ (" Returning {}: {}" . format ( type ( expr ), expr ) , verbose )
317+ return expr
289318 except :
290319 try :
291320 if rng :
@@ -299,7 +328,7 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
299328
300329 if verbose :
301330 print_ (
302- "Trying to eval [%s] with Python using %s..."
331+ " Trying to eval [%s] with Python using %s..."
303332 % (expr , parameters .keys ()),
304333 verbose ,
305334 )
@@ -308,13 +337,14 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
308337
309338 if verbose :
310339 print_ (
311- "Evaluated with Python: {} = {}" .format (expr , _val_info (v )), verbose
340+ " Evaluated with Python: {} = {}" .format (expr , _val_info (v )),
341+ verbose ,
312342 )
313343
314344 if (type (v ) == float or type (v ) == str ) and int (v ) == v :
315345
316346 if verbose :
317- print_ ("Returning int: %s" % int (v ), verbose )
347+ print_ (" Returning int: %s" % int (v ), verbose )
318348
319349 if array_format == FORMAT_TENSORFLOW :
320350 return tf .constant (int (v ))
@@ -323,7 +353,7 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
323353 return v
324354 except Exception as e :
325355 if verbose :
326- print_ (f"Returning without altering: { expr } (error: { e } )" , verbose )
356+ print_ (f" Returning without altering: { expr } (error: { e } )" , verbose )
327357 return expr
328358
329359
0 commit comments