1
1
import builtins
2
2
import keyword
3
3
import os
4
+ from typing import Any
4
5
5
6
import yaml
6
7
from black import FileMode , format_str
9
10
predefined = ["description" , "default" , "data_type" , "required" , "alias" , "units" ]
10
11
11
12
12
- def _find_alias (all_data : dict , head : list | None = None ):
13
+ def _find_alias (all_data : dict , head : list | None = None ) -> dict [ str , str ] :
13
14
"""
14
15
Find all aliases in the data structure.
15
16
@@ -31,13 +32,13 @@ def _find_alias(all_data: dict, head: list | None = None):
31
32
return results
32
33
33
34
34
- def _replace_alias (all_data : dict ):
35
+ def _replace_alias (all_data : dict ) -> dict :
35
36
for key , value in _find_alias (all_data ).items ():
36
37
_set (all_data , key , _get (all_data , value ))
37
38
return all_data
38
39
39
40
40
- def _get (obj : dict , path : str , sep : str = "/" ):
41
+ def _get (obj : dict , path : str , sep : str = "/" ) -> Any :
41
42
"""
42
43
Get a value from a nested dictionary.
43
44
@@ -69,13 +70,13 @@ def _set(obj: dict, path: str, value):
69
70
obj [last ] = value
70
71
71
72
72
- def _get_safe_parameter_name (name : str ):
73
+ def _get_safe_parameter_name (name : str ) -> str :
73
74
if keyword .iskeyword (name ) or name in dir (builtins ):
74
75
name = name + "_"
75
76
return name
76
77
77
78
78
- def _get_docstring_line (data : dict , key : str ):
79
+ def _get_docstring_line (data : dict , key : str ) -> str :
79
80
"""
80
81
Get a single line for the docstring.
81
82
@@ -108,7 +109,12 @@ def _get_docstring_line(data: dict, key: str):
108
109
return line
109
110
110
111
111
- def _get_docstring (all_data , description = None , indent = indent , predefined = predefined ):
112
+ def _get_docstring (
113
+ all_data : dict ,
114
+ description : str | None = None ,
115
+ indent : str = indent ,
116
+ predefined : list [str ] = predefined ,
117
+ ) -> list [str ]:
112
118
txt = [indent + '"""' ]
113
119
if description is not None :
114
120
txt .append (f"{ indent } { description } \n " )
@@ -123,7 +129,7 @@ def _get_docstring(all_data, description=None, indent=indent, predefined=predefi
123
129
return txt
124
130
125
131
126
- def _get_input_arg (key , entry , indent = indent ):
132
+ def _get_input_arg (key : str , entry : dict , indent : str = indent ) -> str :
127
133
t = entry .get ("data_type" , "dict" )
128
134
units = "" .join (entry .get ("units" , "" ).split ())
129
135
if not entry .get ("required" , False ) and units != "" :
@@ -136,7 +142,7 @@ def _get_input_arg(key, entry, indent=indent):
136
142
return t
137
143
138
144
139
- def _rename_keys (data ) :
145
+ def _rename_keys (data : dict ) -> dict :
140
146
d_1 = {_get_safe_parameter_name (key ): value for key , value in data .items ()}
141
147
d_2 = {
142
148
key : d
@@ -148,11 +154,11 @@ def _rename_keys(data):
148
154
149
155
150
156
def _get_function (
151
- data ,
157
+ data : dict ,
152
158
function_name : list [str ],
153
- predefined = predefined ,
154
- is_kwarg = False ,
155
- ):
159
+ predefined : list [ str ] = predefined ,
160
+ is_kwarg : bool = False ,
161
+ ) -> str :
156
162
d = _rename_keys (data )
157
163
func = []
158
164
if is_kwarg :
@@ -187,7 +193,9 @@ def _get_function(
187
193
return "\n " .join (result )
188
194
189
195
190
- def _get_all_function_names (all_data , head = "" , predefined = predefined ):
196
+ def _get_all_function_names (
197
+ all_data : dict , head : str = "" , predefined : list [str ] = predefined
198
+ ) -> list [str ]:
191
199
key_lst = []
192
200
for tag , data in all_data .items ():
193
201
if tag not in predefined and data .get ("data_type" , "dict" ) == "dict" :
@@ -196,7 +204,7 @@ def _get_all_function_names(all_data, head="", predefined=predefined):
196
204
return key_lst
197
205
198
206
199
- def _get_class (all_data ) :
207
+ def _get_class (all_data : dict ) -> str :
200
208
fnames = _get_all_function_names (all_data )
201
209
txt = ""
202
210
for name in fnames :
@@ -216,7 +224,7 @@ def _get_class(all_data):
216
224
return txt
217
225
218
226
219
- def _get_file_content (yml_file_name = "input_data.yml" ):
227
+ def _get_file_content (yml_file_name : str = "input_data.yml" ) -> str :
220
228
file_location = os .path .join (os .path .dirname (__file__ ), yml_file_name )
221
229
with open (file_location , "r" ) as f :
222
230
file_content = f .read ()
@@ -246,7 +254,7 @@ def _get_file_content(yml_file_name="input_data.yml"):
246
254
return file_content
247
255
248
256
249
- def export_class (yml_file_name = "input_data.yml" , py_file_name = "input.py" ):
257
+ def export_class (yml_file_name : str = "input_data.yml" , py_file_name : str = "input.py" ):
250
258
file_content = _get_file_content (yml_file_name )
251
259
with open (os .path .join (os .path .dirname (__file__ ), ".." , py_file_name ), "w" ) as f :
252
260
f .write (file_content )
0 commit comments