Skip to content

Commit 90cd24f

Browse files
quantumsheepMonitob
authored andcommitted
feat: add generator
1 parent 337d5e0 commit 90cd24f

File tree

4 files changed

+678
-0
lines changed

4 files changed

+678
-0
lines changed

generator.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import importlib
2+
import inspect
3+
from dataclasses import fields
4+
from datetime import datetime
5+
from enum import Enum
6+
from inspect import Parameter, isclass
7+
import pkgutil
8+
from types import FunctionType, NoneType
9+
from typing import (
10+
Dict,
11+
List,
12+
Optional,
13+
Type,
14+
Union,
15+
get_args,
16+
get_origin,
17+
get_type_hints,
18+
)
19+
20+
from jinja2 import Environment, FileSystemLoader, select_autoescape
21+
22+
from scaleway import ALL_REGIONS, WaitForOptions
23+
24+
import scaleway
25+
26+
27+
class FieldTypeDescriptor:
28+
name: str
29+
choices: List[str]
30+
31+
def __init__(self, field_type: type):
32+
self.name = "???"
33+
34+
if get_origin(field_type) == Union:
35+
generic_args = get_args(field_type)
36+
field_type = generic_args[0]
37+
38+
self.name = ""
39+
if field_type == str:
40+
self.name = "str"
41+
elif field_type == int:
42+
self.name = "int"
43+
elif field_type == bool:
44+
self.name = "bool"
45+
elif get_origin(field_type) is dict:
46+
self.name = "dict"
47+
elif get_origin(field_type) is list:
48+
self.name = "list"
49+
elif field_type == float:
50+
self.name = "float"
51+
elif field_type is datetime:
52+
self.name = "str"
53+
elif isclass(field_type):
54+
if issubclass(field_type, Enum):
55+
self.name = "str"
56+
self.choices = [str(choice) for choice in field_type]
57+
else:
58+
self.name = "dict"
59+
60+
if self.name == "":
61+
raise Exception(f"Unknown type: {field_type}")
62+
63+
64+
class FieldDescriptor:
65+
name: str
66+
type: FieldTypeDescriptor
67+
required: bool
68+
description: str
69+
70+
def __init__(self, name: str, parameter: Parameter):
71+
self.name = name
72+
73+
if isinstance(parameter, Parameter):
74+
self.type = FieldTypeDescriptor(parameter.annotation)
75+
self.required = parameter.default == inspect.Parameter.empty
76+
else:
77+
self.type = FieldTypeDescriptor(parameter)
78+
self.required = True
79+
80+
self.description = ""
81+
82+
if name == "region":
83+
self.type.choices = ALL_REGIONS
84+
85+
86+
class MethodDescriptor:
87+
name: str
88+
request_fields: List[FieldDescriptor]
89+
response_fields: List[FieldDescriptor]
90+
91+
def __init__(self, method: FunctionType):
92+
self.name = method.__name__
93+
self.request_fields: List[FieldDescriptor] = []
94+
self.response_fields: List[FieldDescriptor] = []
95+
96+
signature = inspect.signature(method)
97+
for name, parameter in signature.parameters.items():
98+
if name == "self":
99+
continue
100+
101+
if get_origin(parameter.annotation) is Union:
102+
first_arg = get_args(parameter.annotation)[0]
103+
if get_origin(first_arg) is WaitForOptions:
104+
continue
105+
106+
self.request_fields.append(FieldDescriptor(name, parameter))
107+
108+
return_type = method.__annotations__["return"]
109+
if return_type != NoneType:
110+
origin = get_origin(return_type)
111+
if origin == list or origin == Union:
112+
return_type = get_args(return_type)[0]
113+
114+
hints = get_type_hints(return_type)
115+
116+
for field in fields(return_type):
117+
self.response_fields.append(
118+
FieldDescriptor(field.name, hints[field.name])
119+
)
120+
121+
@property
122+
def required_request_fields(self):
123+
return [field for field in self.request_fields if field.required]
124+
125+
def has_request_field(self, name: str) -> bool:
126+
for field in self.request_fields:
127+
if field.name == name:
128+
return True
129+
130+
return False
131+
132+
def has_response_field(self, name: str) -> bool:
133+
for field in self.response_fields:
134+
if field.name == name:
135+
return True
136+
137+
return False
138+
139+
140+
class APIDescriptor:
141+
namespace: str
142+
group: str
143+
144+
api_class: Type[object]
145+
146+
name: str
147+
method_create: Optional[MethodDescriptor]
148+
method_get: Optional[MethodDescriptor]
149+
method_update: Optional[MethodDescriptor]
150+
method_delete: Optional[MethodDescriptor]
151+
method_list: Optional[MethodDescriptor]
152+
method_wait_for: Optional[MethodDescriptor]
153+
154+
request_id_field: Optional[FieldDescriptor]
155+
response_id_field: Optional[FieldDescriptor]
156+
157+
def __init__(
158+
self,
159+
api_class: Type[object],
160+
namespace: str,
161+
group: str,
162+
methods: List[MethodDescriptor],
163+
):
164+
self.api_class = api_class
165+
self.namespace = namespace
166+
self.group = group
167+
self.name = f"{namespace}_{group}" if group != "" else namespace
168+
169+
self.method_create = None
170+
self.method_get = None
171+
self.method_update = None
172+
self.method_delete = None
173+
self.method_list = None
174+
self.method_wait_for = None
175+
176+
for method in methods:
177+
if method.name.startswith("create_"):
178+
self.method_create = method
179+
elif method.name.startswith("get_"):
180+
self.method_get = method
181+
elif method.name.startswith("update_"):
182+
self.method_update = method
183+
elif method.name.startswith("delete_"):
184+
self.method_delete = method
185+
elif method.name.startswith("list_"):
186+
self.method_list = method
187+
elif method.name.startswith("wait_for_"):
188+
self.method_wait_for = method
189+
190+
self.request_id_field = None
191+
self.response_id_field = None
192+
193+
method_with_id_field = (
194+
self.method_get or self.method_update or self.method_update
195+
)
196+
if method_with_id_field is None:
197+
raise Exception(f"Unable to find method with ID field for {self.name}")
198+
199+
for field in method_with_id_field.request_fields:
200+
if f"{group}_id" in field.name:
201+
self.request_id_field = field
202+
break
203+
204+
for field in method_with_id_field.response_fields:
205+
if f"id" in field.name:
206+
self.response_id_field = field
207+
break
208+
209+
if self.request_id_field is None:
210+
if len(method_with_id_field.request_fields) == 0:
211+
raise Exception(
212+
f"Unable to find request ID field for {self.name} (no request fields)"
213+
)
214+
215+
self.request_id_field = method_with_id_field.request_fields[0]
216+
217+
if self.response_id_field is None:
218+
if len(method_with_id_field.response_fields) == 0:
219+
raise Exception(
220+
f"Unable to find response ID field for {self.name} (no response fields)"
221+
)
222+
223+
self.response_id_field = method_with_id_field.response_fields[0]
224+
225+
@property
226+
def class_import_path(self) -> str:
227+
return ".".join(self.api_class.__module__.split(".")[:-1])
228+
229+
@property
230+
def class_name(self) -> str:
231+
return self.api_class.__name__
232+
233+
234+
def get_api_descriptors(namespace: str, api_class: Type[object]) -> List[APIDescriptor]:
235+
apis: List[APIDescriptor] = []
236+
237+
prefixes = ["create_", "get_", "update_", "delete_", "list_", "wait_for_"]
238+
239+
groups: Dict[str, List[MethodDescriptor]] = {}
240+
241+
for name, method in inspect.getmembers(api_class, predicate=inspect.isfunction):
242+
for prefix in prefixes:
243+
if not name.startswith(prefix):
244+
continue
245+
246+
parts = name.split("_", 1)
247+
group = parts[1]
248+
249+
if prefix == "list_":
250+
if not group.endswith("_all"):
251+
continue
252+
253+
group = group.removesuffix("s_all")
254+
255+
try:
256+
method_descriptor = MethodDescriptor(method)
257+
258+
if group not in groups:
259+
groups[group] = []
260+
groups[group].append(method_descriptor)
261+
except Exception as e:
262+
print(f"Error processing method {name}: {e}")
263+
264+
for group, methods in groups.items():
265+
try:
266+
api = APIDescriptor(api_class, namespace, group, methods)
267+
if (
268+
api.method_create is None
269+
or api.method_get is None
270+
or api.method_delete is None
271+
):
272+
continue
273+
apis.append(api)
274+
except Exception as e:
275+
print(f"Error processing API {namespace}.{group}: {e}")
276+
277+
return apis
278+
279+
280+
def main() -> None:
281+
modules = pkgutil.iter_modules(scaleway.__path__)
282+
apis: Dict[str, Type[object]] = {}
283+
284+
for _, product, _ in modules:
285+
module = importlib.import_module(f"scaleway.{product}")
286+
versions = pkgutil.iter_modules(module.__path__)
287+
288+
for _, version, _ in versions:
289+
module = importlib.import_module(f"scaleway.{product}.{version}")
290+
291+
for name, api in inspect.getmembers(module, isclass):
292+
if name.endswith("API"):
293+
apis[f"{product}_{version}"] = api
294+
295+
env = Environment(
296+
loader=FileSystemLoader("templates"),
297+
autoescape=select_autoescape(),
298+
extensions=["jinja2.ext.loopcontrols"],
299+
)
300+
301+
module_names: List[str] = []
302+
303+
for name, api in apis.items():
304+
descriptors = get_api_descriptors(name, api)
305+
306+
for descriptor in descriptors:
307+
module_code = env.get_template("module.py.jinja").render(api=descriptor)
308+
with open(f"plugins/modules/scaleway_{descriptor.name}.py", "w") as f:
309+
f.write(module_code)
310+
311+
module_names.append(descriptor.name)
312+
313+
with open(f"meta/runtime.yml", "w") as f:
314+
content = env.get_template("runtime.yml.jinja").render(module_names=module_names)
315+
f.write(content)
316+
317+
if __name__ == "__main__":
318+
main()

0 commit comments

Comments
 (0)