Skip to content

Commit cea3f91

Browse files
committed
Handle Tuple return type
1 parent 4e6fed6 commit cea3f91

File tree

1 file changed

+159
-57
lines changed

1 file changed

+159
-57
lines changed

application_sdk/decorators/automation_activity.py

Lines changed: 159 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
import json
7-
from typing import Any, Callable, Dict, List, Union, get_args, get_origin
7+
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
88

99
from pydantic import BaseModel
1010
import requests
@@ -14,107 +14,211 @@
1414

1515
logger = get_logger(__name__)
1616

17-
1817
# Global registry to collect decorated activities
1918
ACTIVITY_SPECS: List[dict[str, Any]] = []
2019

2120

2221
def _type_to_json_schema(annotation: Any) -> Dict[str, Any]:
2322
"""Convert Python type annotation to JSON Schema using reflection."""
23+
# Pydantic models
2424
if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
2525
schema = annotation.model_json_schema(mode="serialization")
2626
if "$defs" in schema:
27-
defs = schema.pop("$defs")
28-
schema = {**schema, **defs}
27+
schema = {**schema, **schema.pop("$defs")}
2928
return schema
3029

3130
origin = get_origin(annotation)
32-
if origin is Union:
33-
args = get_args(annotation)
34-
if type(None) in args:
35-
non_none_types = [arg for arg in args if arg is not type(None)]
36-
if non_none_types:
37-
schema = _type_to_json_schema(non_none_types[0])
38-
schema["nullable"] = True
39-
return schema
31+
args = get_args(annotation)
32+
33+
# Optional types (Union with None)
34+
if origin is Union and type(None) in args:
35+
non_none = [a for a in args if a is not type(None)]
36+
if non_none:
37+
schema = _type_to_json_schema(non_none[0])
38+
schema["nullable"] = True
39+
return schema
4040

41+
# Dict types
4142
if origin is dict or annotation is dict:
42-
args = get_args(annotation)
43-
if args:
44-
return {
45-
"type": "object",
46-
"additionalProperties": _type_to_json_schema(args[1]),
47-
}
48-
return {"type": "object"}
43+
return {
44+
"type": "object",
45+
"additionalProperties": _type_to_json_schema(args[1]) if args else {},
46+
}
4947

48+
# List types
5049
if origin is list or annotation is list:
51-
args = get_args(annotation)
50+
return {"type": "array", "items": _type_to_json_schema(args[0]) if args else {}}
51+
52+
# Tuple types
53+
if origin is tuple or annotation is tuple:
5254
if args:
53-
return {"type": "array", "items": _type_to_json_schema(args[0])}
55+
item_schemas = [_type_to_json_schema(arg) for arg in args]
56+
return {
57+
"type": "array",
58+
"items": item_schemas,
59+
"minItems": len(item_schemas),
60+
"maxItems": len(item_schemas),
61+
}
5462
return {"type": "array"}
5563

56-
type_mapping = {
57-
str: {"type": "string"},
58-
int: {"type": "integer"},
59-
float: {"type": "number"},
60-
bool: {"type": "boolean"},
61-
Any: {},
62-
}
63-
64-
if annotation in type_mapping:
65-
return type_mapping[annotation]
64+
# Primitive types
65+
type_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
66+
if annotation in type_map:
67+
return {"type": type_map[annotation]}
6668

6769
return {}
6870

6971

72+
def _create_object_schema(
73+
properties: Dict[str, Any], required: Optional[List[str]] = None
74+
) -> Dict[str, Any]:
75+
"""Create a JSON Schema object with properties and required fields."""
76+
schema = {"type": "object", "properties": properties}
77+
if required:
78+
schema["required"] = required
79+
return schema
80+
81+
7082
def _generate_input_schema(func: Any) -> Dict[str, Any]:
7183
"""Generate JSON Schema for function inputs using reflection."""
72-
sig = inspect.signature(func)
73-
properties = {}
74-
required: list[str] = []
75-
76-
for param_name, param in sig.parameters.items():
77-
if param_name == "self":
84+
properties: Dict[str, Any] = {}
85+
required: List[str] = []
86+
for name, param in inspect.signature(func).parameters.items():
87+
if name == "self":
7888
continue
79-
80-
param_schema = (
89+
schema = (
8190
_type_to_json_schema(param.annotation)
8291
if param.annotation != inspect.Parameter.empty
8392
else {}
8493
)
85-
8694
if param.default != inspect.Parameter.empty:
8795
if isinstance(param.default, (str, int, float, bool, type(None))):
88-
param_schema["default"] = param.default
96+
schema["default"] = param.default
8997
else:
90-
required.append(param_name)
98+
required.append(name)
99+
properties[name] = schema
100+
return _create_object_schema(properties, required if required else None)
91101

92-
properties[param_name] = param_schema
93102

94-
schema = {"type": "object", "properties": properties}
95-
if required:
96-
schema["required"] = required
97-
98-
return schema
103+
def _validate_output_field_names(
104+
func: Any, output_field_names: List[str], return_annotation: Any
105+
) -> None:
106+
"""Validate output_field_names against function return type."""
107+
origin = get_origin(return_annotation)
108+
is_tuple: bool = origin is tuple or return_annotation is tuple
109+
110+
if is_tuple:
111+
args = get_args(return_annotation)
112+
if not args:
113+
raise ValueError(
114+
f"output_field_names provided for '{func.__name__}', but tuple has no type arguments."
115+
)
116+
if len(output_field_names) != len(args):
117+
raise ValueError(
118+
f"output_field_names length ({len(output_field_names)}) doesn't match "
119+
f"tuple length ({len(args)}) for '{func.__name__}'. Expected {len(args)} field names."
120+
)
121+
else:
122+
# For single return types (Pydantic models, primitives, etc.), require exactly 1 field name
123+
if len(output_field_names) != 1:
124+
raise ValueError(
125+
f"output_field_names provided for '{func.__name__}' with single return type, "
126+
f"but length ({len(output_field_names)}) is not 1. "
127+
"For single return types, provide exactly one field name."
128+
)
99129

100130

101-
def _generate_output_schema(func: Any) -> Dict[str, Any]:
131+
def _generate_output_schema(
132+
func: Any, output_field_names: Optional[List[str]] = None
133+
) -> Dict[str, Any]:
102134
"""Generate JSON Schema for function outputs using reflection."""
103135
return_annotation = inspect.signature(func).return_annotation
104-
if return_annotation == inspect.Signature.empty:
136+
has_return = return_annotation != inspect.Signature.empty
137+
138+
if output_field_names:
139+
if not has_return:
140+
raise ValueError(
141+
f"output_field_names provided for '{func.__name__}', but function has no return annotation."
142+
)
143+
_validate_output_field_names(func, output_field_names, return_annotation)
144+
145+
if not has_return:
105146
return {}
106-
return _type_to_json_schema(return_annotation)
147+
148+
schema: dict[str, Any] = _type_to_json_schema(return_annotation)
149+
150+
# Wrap in object with named properties if output_field_names provided
151+
if output_field_names:
152+
# Handle tuple conversion to object
153+
if schema.get("type") == "array" and "items" in schema:
154+
items = schema.get("items", [])
155+
if isinstance(items, list) and len(items) == len(output_field_names):
156+
properties = {
157+
name: item for name, item in zip(output_field_names, items)
158+
}
159+
return _create_object_schema(properties, output_field_names)
160+
# Handle single return types (Pydantic models, primitives, etc.) - wrap in object
161+
elif len(output_field_names) == 1:
162+
properties = {output_field_names[0]: schema}
163+
return _create_object_schema(properties, output_field_names)
164+
165+
return schema
107166

108167

109168
def automation_activity(
110-
name: str,
111169
description: str,
170+
output_field_names: Optional[List[str]] = None,
112171
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
113-
"""Decorator to mark an activity for automatic registration."""
172+
"""Decorator to mark an activity for automatic registration.
173+
174+
Args:
175+
description: Description of what the activity does.
176+
output_field_names: Optional list of field names for return types.
177+
- For tuple return types: Length must match the number of tuple elements.
178+
Converts tuple outputs to objects with named properties.
179+
- For single return types (Pydantic models, primitives like str/int, etc.):
180+
Must have exactly one field name. Wraps the return value in an object
181+
with the given field name.
182+
183+
Returns:
184+
Decorator function.
185+
186+
Raises:
187+
ValueError: If output_field_names is provided but function has no return annotation,
188+
or if lengths don't match (tuple length for tuples, exactly 1 for single types).
189+
190+
Example:
191+
>>> # Tuple return type
192+
>>> @automation_activity(
193+
... description="Fetch entities by DSL",
194+
... output_field_names=["entities_path", "total_count", "chunk_count"]
195+
... )
196+
... def fetch_entities(self, dsl_query: dict) -> tuple[str, int, int]:
197+
... return ("path/to/entities", 1000, 10)
198+
...
199+
>>> # Pydantic model return type
200+
>>> @automation_activity(
201+
... description="Update metadata",
202+
... output_field_names=["update_metadata_output"]
203+
... )
204+
... def update_metadata(self, ...) -> UpdateMetadataOutput:
205+
... return UpdateMetadataOutput(...)
206+
...
207+
>>> # Primitive return type
208+
>>> @automation_activity(
209+
... description="Get count",
210+
... output_field_names=["count"]
211+
... )
212+
... def get_count(self) -> int:
213+
... return 42
214+
"""
114215

115216
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
217+
name: str = func.__name__
116218
input_schema: dict[str, Any] = _generate_input_schema(func)
117-
output_schema: dict[str, Any] = _generate_output_schema(func)
219+
output_schema: dict[str, Any] = _generate_output_schema(
220+
func, output_field_names
221+
)
118222

119223
logger.info(f"Collected automation activity: {name}")
120224
ACTIVITY_SPECS.append(
@@ -161,9 +265,7 @@ def flush_activity_registrations(
161265
)
162266
return
163267

164-
logger.info(
165-
f"Registering {len(ACTIVITY_SPECS)} activities with automation engine"
166-
)
268+
logger.info(f"Registering {len(ACTIVITY_SPECS)} activities with automation engine")
167269

168270
# Generate app qualified name
169271
app_qualified_name: str = f"default/apps/{app_name}"

0 commit comments

Comments
 (0)