|
1 | 1 | import functools |
| 2 | +import inspect |
2 | 3 |
|
3 | 4 | from typing import List, Union |
4 | 5 |
|
5 | 6 | from domaintools.docstring_patcher import DocstringPatcher |
| 7 | +from domaintools.request_validator import RequestValidator |
6 | 8 |
|
7 | 9 |
|
8 | 10 | def api_endpoint(spec_name: str, path: str, methods: Union[str, List[str]]): |
9 | 11 | """ |
10 | | - Decorator to tag a method as an API endpoint. |
| 12 | + Decorator to tag a method as an API endpoint AND validate inputs. |
11 | 13 |
|
12 | 14 | Args: |
13 | 15 | spec_name: The key for the spec in api_instance.specs |
14 | 16 | path: The API path (e.g., "/users") |
15 | 17 | methods: A single method ("get") or list of methods (["get", "post"]) |
16 | | - that this function handles. |
17 | 18 | """ |
18 | 19 |
|
19 | 20 | def decorator(func): |
20 | 21 | func._api_spec_name = spec_name |
21 | 22 | func._api_path = path |
22 | 23 |
|
23 | | - # Always store the methods as a list |
24 | | - if isinstance(methods, str): |
25 | | - func._api_methods = [methods] |
26 | | - else: |
27 | | - func._api_methods = methods |
| 24 | + # Normalize methods to a list |
| 25 | + normalized_methods = [methods] if isinstance(methods, str) else methods |
| 26 | + func._api_methods = normalized_methods |
| 27 | + |
| 28 | + # Get the signature of the original function ONCE |
| 29 | + sig = inspect.signature(func) |
28 | 30 |
|
29 | 31 | @functools.wraps(func) |
30 | 32 | def wrapper(self, *args, **kwargs): |
| 33 | + |
| 34 | + try: |
| 35 | + bound_args = sig.bind(*args, **kwargs) |
| 36 | + except TypeError: |
| 37 | + # If arguments don't match signature, let the actual func raise the error |
| 38 | + return func(*args, **kwargs) |
| 39 | + |
| 40 | + arguments = bound_args.arguments |
| 41 | + |
| 42 | + # Robustly find 'self' (it's usually the first argument in bound_args) |
| 43 | + # We look for the first value in arguments, or try to get 'self' explicitly. |
| 44 | + instance = arguments.pop("self", None) |
| 45 | + if not instance and args: |
| 46 | + instance = args[0] |
| 47 | + |
| 48 | + # Retrieve the Spec from the instance |
| 49 | + # We assume 'self' has a .specs attribute (like DocstringPatcher expects) |
| 50 | + spec = getattr(self, "specs", {}).get(spec_name) |
| 51 | + |
| 52 | + if spec: |
| 53 | + # Determine which HTTP method is currently being executed. |
| 54 | + # If the function allows dynamic methods (e.g. method="POST"), use that. |
| 55 | + # Otherwise, default to the first method defined in the decorator. |
| 56 | + current_method = kwargs.get("method", normalized_methods[0]) |
| 57 | + |
| 58 | + # Run Validation |
| 59 | + # This will raise a ValueError and stop execution if validation fails. |
| 60 | + try: |
| 61 | + RequestValidator.validate( |
| 62 | + spec=spec, |
| 63 | + path=path, |
| 64 | + method=current_method, |
| 65 | + parameters=arguments, |
| 66 | + ) |
| 67 | + except ValueError as e: |
| 68 | + print(f"[Validation Error] {e}") |
| 69 | + raise e |
| 70 | + |
| 71 | + # Proceed with the original function call |
31 | 72 | return func(*args, **kwargs) |
32 | 73 |
|
33 | | - # Copy all tags to the wrapper |
| 74 | + # Copy tags to wrapper for the DocstringPatcher to find |
34 | 75 | wrapper._api_spec_name = func._api_spec_name |
35 | 76 | wrapper._api_path = func._api_path |
36 | 77 | wrapper._api_methods = func._api_methods |
| 78 | + |
37 | 79 | return wrapper |
38 | 80 |
|
39 | 81 | return decorator |
|
0 commit comments