|
4 | 4 |
|
5 | 5 | import inspect |
6 | 6 | import json |
7 | | -from typing import Any, Callable, Dict, List, Union, get_args, get_origin |
| 7 | +from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin |
8 | 8 |
|
9 | 9 | from pydantic import BaseModel |
10 | 10 | import requests |
|
14 | 14 |
|
15 | 15 | logger = get_logger(__name__) |
16 | 16 |
|
17 | | - |
18 | 17 | # Global registry to collect decorated activities |
19 | 18 | ACTIVITY_SPECS: List[dict[str, Any]] = [] |
20 | 19 |
|
21 | 20 |
|
22 | 21 | def _type_to_json_schema(annotation: Any) -> Dict[str, Any]: |
23 | 22 | """Convert Python type annotation to JSON Schema using reflection.""" |
| 23 | + # Pydantic models |
24 | 24 | if inspect.isclass(annotation) and issubclass(annotation, BaseModel): |
25 | 25 | schema = annotation.model_json_schema(mode="serialization") |
26 | 26 | if "$defs" in schema: |
27 | | - defs = schema.pop("$defs") |
28 | | - schema = {**schema, **defs} |
| 27 | + schema = {**schema, **schema.pop("$defs")} |
29 | 28 | return schema |
30 | 29 |
|
31 | 30 | origin = get_origin(annotation) |
32 | | - if origin is Union: |
33 | | - args = get_args(annotation) |
34 | | - if type(None) in args: |
35 | | - non_none_types = [arg for arg in args if arg is not type(None)] |
36 | | - if non_none_types: |
37 | | - schema = _type_to_json_schema(non_none_types[0]) |
38 | | - schema["nullable"] = True |
39 | | - return schema |
| 31 | + args = get_args(annotation) |
| 32 | + |
| 33 | + # Optional types (Union with None) |
| 34 | + if origin is Union and type(None) in args: |
| 35 | + non_none = [a for a in args if a is not type(None)] |
| 36 | + if non_none: |
| 37 | + schema = _type_to_json_schema(non_none[0]) |
| 38 | + schema["nullable"] = True |
| 39 | + return schema |
40 | 40 |
|
| 41 | + # Dict types |
41 | 42 | if origin is dict or annotation is dict: |
42 | | - args = get_args(annotation) |
43 | | - if args: |
44 | | - return { |
45 | | - "type": "object", |
46 | | - "additionalProperties": _type_to_json_schema(args[1]), |
47 | | - } |
48 | | - return {"type": "object"} |
| 43 | + return { |
| 44 | + "type": "object", |
| 45 | + "additionalProperties": _type_to_json_schema(args[1]) if args else {}, |
| 46 | + } |
49 | 47 |
|
| 48 | + # List types |
50 | 49 | if origin is list or annotation is list: |
51 | | - args = get_args(annotation) |
| 50 | + return {"type": "array", "items": _type_to_json_schema(args[0]) if args else {}} |
| 51 | + |
| 52 | + # Tuple types |
| 53 | + if origin is tuple or annotation is tuple: |
52 | 54 | if args: |
53 | | - return {"type": "array", "items": _type_to_json_schema(args[0])} |
| 55 | + item_schemas = [_type_to_json_schema(arg) for arg in args] |
| 56 | + return { |
| 57 | + "type": "array", |
| 58 | + "items": item_schemas, |
| 59 | + "minItems": len(item_schemas), |
| 60 | + "maxItems": len(item_schemas), |
| 61 | + } |
54 | 62 | return {"type": "array"} |
55 | 63 |
|
56 | | - type_mapping = { |
57 | | - str: {"type": "string"}, |
58 | | - int: {"type": "integer"}, |
59 | | - float: {"type": "number"}, |
60 | | - bool: {"type": "boolean"}, |
61 | | - Any: {}, |
62 | | - } |
63 | | - |
64 | | - if annotation in type_mapping: |
65 | | - return type_mapping[annotation] |
| 64 | + # Primitive types |
| 65 | + type_map = {str: "string", int: "integer", float: "number", bool: "boolean"} |
| 66 | + if annotation in type_map: |
| 67 | + return {"type": type_map[annotation]} |
66 | 68 |
|
67 | 69 | return {} |
68 | 70 |
|
69 | 71 |
|
| 72 | +def _create_object_schema( |
| 73 | + properties: Dict[str, Any], required: Optional[List[str]] = None |
| 74 | +) -> Dict[str, Any]: |
| 75 | + """Create a JSON Schema object with properties and required fields.""" |
| 76 | + schema = {"type": "object", "properties": properties} |
| 77 | + if required: |
| 78 | + schema["required"] = required |
| 79 | + return schema |
| 80 | + |
| 81 | + |
70 | 82 | def _generate_input_schema(func: Any) -> Dict[str, Any]: |
71 | 83 | """Generate JSON Schema for function inputs using reflection.""" |
72 | | - sig = inspect.signature(func) |
73 | | - properties = {} |
74 | | - required: list[str] = [] |
75 | | - |
76 | | - for param_name, param in sig.parameters.items(): |
77 | | - if param_name == "self": |
| 84 | + properties: Dict[str, Any] = {} |
| 85 | + required: List[str] = [] |
| 86 | + for name, param in inspect.signature(func).parameters.items(): |
| 87 | + if name == "self": |
78 | 88 | continue |
79 | | - |
80 | | - param_schema = ( |
| 89 | + schema = ( |
81 | 90 | _type_to_json_schema(param.annotation) |
82 | 91 | if param.annotation != inspect.Parameter.empty |
83 | 92 | else {} |
84 | 93 | ) |
85 | | - |
86 | 94 | if param.default != inspect.Parameter.empty: |
87 | 95 | if isinstance(param.default, (str, int, float, bool, type(None))): |
88 | | - param_schema["default"] = param.default |
| 96 | + schema["default"] = param.default |
89 | 97 | else: |
90 | | - required.append(param_name) |
91 | | - |
92 | | - properties[param_name] = param_schema |
| 98 | + required.append(name) |
| 99 | + properties[name] = schema |
| 100 | + return _create_object_schema(properties, required if required else None) |
93 | 101 |
|
94 | | - schema = {"type": "object", "properties": properties} |
95 | | - if required: |
96 | | - schema["required"] = required |
97 | 102 |
|
98 | | - return schema |
| 103 | +def _validate_output_field_names( |
| 104 | + func: Any, output_field_names: List[str], return_annotation: Any |
| 105 | +) -> None: |
| 106 | + """Validate output_field_names against function return type.""" |
| 107 | + origin = get_origin(return_annotation) |
| 108 | + is_tuple: bool = origin is tuple or return_annotation is tuple |
| 109 | + if not is_tuple: |
| 110 | + raise ValueError( |
| 111 | + f"output_field_names provided for '{func.__name__}', " |
| 112 | + f"but return type is not a tuple (got {return_annotation})." |
| 113 | + ) |
| 114 | + args = get_args(return_annotation) |
| 115 | + if not args: |
| 116 | + raise ValueError( |
| 117 | + f"output_field_names provided for '{func.__name__}', but tuple has no type arguments." |
| 118 | + ) |
| 119 | + if len(output_field_names) != len(args): |
| 120 | + raise ValueError( |
| 121 | + f"output_field_names length ({len(output_field_names)}) doesn't match " |
| 122 | + f"tuple length ({len(args)}) for '{func.__name__}'. Expected: {output_field_names}" |
| 123 | + ) |
99 | 124 |
|
100 | 125 |
|
101 | | -def _generate_output_schema(func: Any) -> Dict[str, Any]: |
| 126 | +def _generate_output_schema( |
| 127 | + func: Any, output_field_names: Optional[List[str]] = None |
| 128 | +) -> Dict[str, Any]: |
102 | 129 | """Generate JSON Schema for function outputs using reflection.""" |
103 | 130 | return_annotation = inspect.signature(func).return_annotation |
104 | | - if return_annotation == inspect.Signature.empty: |
| 131 | + has_return = return_annotation != inspect.Signature.empty |
| 132 | + |
| 133 | + if output_field_names: |
| 134 | + if not has_return: |
| 135 | + raise ValueError( |
| 136 | + f"output_field_names provided for '{func.__name__}', but function has no return annotation." |
| 137 | + ) |
| 138 | + _validate_output_field_names(func, output_field_names, return_annotation) |
| 139 | + |
| 140 | + if not has_return: |
105 | 141 | return {} |
106 | | - return _type_to_json_schema(return_annotation) |
| 142 | + |
| 143 | + schema: dict[str, Any] = _type_to_json_schema(return_annotation) |
| 144 | + |
| 145 | + # Convert tuple to object if field names provided |
| 146 | + if output_field_names and schema.get("type") == "array" and "items" in schema: |
| 147 | + items = schema.get("items", []) |
| 148 | + if isinstance(items, list) and len(items) == len(output_field_names): |
| 149 | + properties = {name: item for name, item in zip(output_field_names, items)} |
| 150 | + return _create_object_schema(properties, output_field_names) |
| 151 | + |
| 152 | + return schema |
107 | 153 |
|
108 | 154 |
|
109 | 155 | def automation_activity( |
110 | 156 | name: str, |
111 | 157 | description: str, |
| 158 | + output_field_names: Optional[List[str]] = None, |
112 | 159 | ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: |
113 | | - """Decorator to mark an activity for automatic registration.""" |
| 160 | + """Decorator to mark an activity for automatic registration. |
| 161 | +
|
| 162 | + Args: |
| 163 | + name: Name of the activity/tool. |
| 164 | + description: Description of what the activity does. |
| 165 | + output_field_names: Optional list of field names for tuple return types. |
| 166 | + If provided, tuple outputs will be converted to objects with named properties. |
| 167 | + Length must match the number of elements in the tuple return type. |
| 168 | +
|
| 169 | + Returns: |
| 170 | + Decorator function. |
| 171 | +
|
| 172 | + Raises: |
| 173 | + ValueError: If output_field_names is provided but function has no return annotation, |
| 174 | + or if return type is not a tuple, or if lengths don't match. |
| 175 | +
|
| 176 | + Example: |
| 177 | + >>> @automation_activity( |
| 178 | + ... name="fetch_entities", |
| 179 | + ... description="Fetch entities by DSL", |
| 180 | + ... output_field_names=["entities_path", "total_count", "chunk_count"] |
| 181 | + ... ) |
| 182 | + ... def fetch_entities(self, dsl_query: dict) -> tuple[str, int, int]: |
| 183 | + ... return ("path/to/entities", 1000, 10) |
| 184 | + """ |
114 | 185 |
|
115 | 186 | def decorator(func: Callable[..., Any]) -> Callable[..., Any]: |
116 | 187 | input_schema: dict[str, Any] = _generate_input_schema(func) |
117 | | - output_schema: dict[str, Any] = _generate_output_schema(func) |
| 188 | + output_schema: dict[str, Any] = _generate_output_schema( |
| 189 | + func, output_field_names |
| 190 | + ) |
118 | 191 |
|
119 | 192 | logger.info(f"Collected automation activity: {name}") |
120 | 193 | ACTIVITY_SPECS.append( |
@@ -161,9 +234,7 @@ def flush_activity_registrations( |
161 | 234 | ) |
162 | 235 | return |
163 | 236 |
|
164 | | - logger.info( |
165 | | - f"Registering {len(ACTIVITY_SPECS)} activities with automation engine" |
166 | | - ) |
| 237 | + logger.info(f"Registering {len(ACTIVITY_SPECS)} activities with automation engine") |
167 | 238 |
|
168 | 239 | # Generate app qualified name |
169 | 240 | app_qualified_name: str = f"default/apps/{app_name}" |
|
0 commit comments