Skip to content

Commit c9cda4e

Browse files
authored
Merge pull request #48 from pyiron/refactoring
add type hints
2 parents c55b9ff + 4fa3515 commit c9cda4e

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

sphinx_parser/src/generator.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import builtins
22
import keyword
33
import os
4+
from typing import Any
45

56
import yaml
67
from black import FileMode, format_str
@@ -9,7 +10,7 @@
910
predefined = ["description", "default", "data_type", "required", "alias", "units"]
1011

1112

12-
def _find_alias(all_data: dict, head: list | None = None):
13+
def _find_alias(all_data: dict, head: list | None = None) -> dict[str, str]:
1314
"""
1415
Find all aliases in the data structure.
1516
@@ -31,13 +32,13 @@ def _find_alias(all_data: dict, head: list | None = None):
3132
return results
3233

3334

34-
def _replace_alias(all_data: dict):
35+
def _replace_alias(all_data: dict) -> dict:
3536
for key, value in _find_alias(all_data).items():
3637
_set(all_data, key, _get(all_data, value))
3738
return all_data
3839

3940

40-
def _get(obj: dict, path: str, sep: str = "/"):
41+
def _get(obj: dict, path: str, sep: str = "/") -> Any:
4142
"""
4243
Get a value from a nested dictionary.
4344
@@ -69,13 +70,13 @@ def _set(obj: dict, path: str, value):
6970
obj[last] = value
7071

7172

72-
def _get_safe_parameter_name(name: str):
73+
def _get_safe_parameter_name(name: str) -> str:
7374
if keyword.iskeyword(name) or name in dir(builtins):
7475
name = name + "_"
7576
return name
7677

7778

78-
def _get_docstring_line(data: dict, key: str):
79+
def _get_docstring_line(data: dict, key: str) -> str:
7980
"""
8081
Get a single line for the docstring.
8182
@@ -108,7 +109,12 @@ def _get_docstring_line(data: dict, key: str):
108109
return line
109110

110111

111-
def _get_docstring(all_data, description=None, indent=indent, predefined=predefined):
112+
def _get_docstring(
113+
all_data: dict,
114+
description: str | None = None,
115+
indent: str = indent,
116+
predefined: list[str] = predefined,
117+
) -> list[str]:
112118
txt = [indent + '"""']
113119
if description is not None:
114120
txt.append(f"{indent}{description}\n")
@@ -123,7 +129,7 @@ def _get_docstring(all_data, description=None, indent=indent, predefined=predefi
123129
return txt
124130

125131

126-
def _get_input_arg(key, entry, indent=indent):
132+
def _get_input_arg(key: str, entry: dict, indent: str = indent) -> str:
127133
t = entry.get("data_type", "dict")
128134
units = "".join(entry.get("units", "").split())
129135
if not entry.get("required", False) and units != "":
@@ -136,7 +142,7 @@ def _get_input_arg(key, entry, indent=indent):
136142
return t
137143

138144

139-
def _rename_keys(data):
145+
def _rename_keys(data: dict) -> dict:
140146
d_1 = {_get_safe_parameter_name(key): value for key, value in data.items()}
141147
d_2 = {
142148
key: d
@@ -148,11 +154,11 @@ def _rename_keys(data):
148154

149155

150156
def _get_function(
151-
data,
157+
data: dict,
152158
function_name: list[str],
153-
predefined=predefined,
154-
is_kwarg=False,
155-
):
159+
predefined: list[str] = predefined,
160+
is_kwarg: bool = False,
161+
) -> str:
156162
d = _rename_keys(data)
157163
func = []
158164
if is_kwarg:
@@ -187,7 +193,9 @@ def _get_function(
187193
return "\n".join(result)
188194

189195

190-
def _get_all_function_names(all_data, head="", predefined=predefined):
196+
def _get_all_function_names(
197+
all_data: dict, head: str = "", predefined: list[str] = predefined
198+
) -> list[str]:
191199
key_lst = []
192200
for tag, data in all_data.items():
193201
if tag not in predefined and data.get("data_type", "dict") == "dict":
@@ -196,7 +204,7 @@ def _get_all_function_names(all_data, head="", predefined=predefined):
196204
return key_lst
197205

198206

199-
def _get_class(all_data):
207+
def _get_class(all_data: dict) -> str:
200208
fnames = _get_all_function_names(all_data)
201209
txt = ""
202210
for name in fnames:
@@ -216,7 +224,7 @@ def _get_class(all_data):
216224
return txt
217225

218226

219-
def _get_file_content(yml_file_name="input_data.yml"):
227+
def _get_file_content(yml_file_name: str = "input_data.yml") -> str:
220228
file_location = os.path.join(os.path.dirname(__file__), yml_file_name)
221229
with open(file_location, "r") as f:
222230
file_content = f.read()
@@ -246,7 +254,7 @@ def _get_file_content(yml_file_name="input_data.yml"):
246254
return file_content
247255

248256

249-
def export_class(yml_file_name="input_data.yml", py_file_name="input.py"):
257+
def export_class(yml_file_name: str = "input_data.yml", py_file_name: str = "input.py"):
250258
file_content = _get_file_content(yml_file_name)
251259
with open(os.path.join(os.path.dirname(__file__), "..", py_file_name), "w") as f:
252260
f.write(file_content)

0 commit comments

Comments
 (0)