Skip to content

Commit c563b2c

Browse files
committed
Handle Tuple return type
1 parent 4e6fed6 commit c563b2c

File tree

1 file changed

+127
-56
lines changed

1 file changed

+127
-56
lines changed

application_sdk/decorators/automation_activity.py

Lines changed: 127 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
import json
7-
from typing import Any, Callable, Dict, List, Union, get_args, get_origin
7+
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
88

99
from pydantic import BaseModel
1010
import requests
@@ -14,107 +14,180 @@
1414

1515
logger = get_logger(__name__)
1616

17-
1817
# Global registry to collect decorated activities
1918
ACTIVITY_SPECS: List[dict[str, Any]] = []
2019

2120

2221
def _type_to_json_schema(annotation: Any) -> Dict[str, Any]:
2322
"""Convert Python type annotation to JSON Schema using reflection."""
23+
# Pydantic models
2424
if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
2525
schema = annotation.model_json_schema(mode="serialization")
2626
if "$defs" in schema:
27-
defs = schema.pop("$defs")
28-
schema = {**schema, **defs}
27+
schema = {**schema, **schema.pop("$defs")}
2928
return schema
3029

3130
origin = get_origin(annotation)
32-
if origin is Union:
33-
args = get_args(annotation)
34-
if type(None) in args:
35-
non_none_types = [arg for arg in args if arg is not type(None)]
36-
if non_none_types:
37-
schema = _type_to_json_schema(non_none_types[0])
38-
schema["nullable"] = True
39-
return schema
31+
args = get_args(annotation)
32+
33+
# Optional types (Union with None)
34+
if origin is Union and type(None) in args:
35+
non_none = [a for a in args if a is not type(None)]
36+
if non_none:
37+
schema = _type_to_json_schema(non_none[0])
38+
schema["nullable"] = True
39+
return schema
4040

41+
# Dict types
4142
if origin is dict or annotation is dict:
42-
args = get_args(annotation)
43-
if args:
44-
return {
45-
"type": "object",
46-
"additionalProperties": _type_to_json_schema(args[1]),
47-
}
48-
return {"type": "object"}
43+
return {
44+
"type": "object",
45+
"additionalProperties": _type_to_json_schema(args[1]) if args else {},
46+
}
4947

48+
# List types
5049
if origin is list or annotation is list:
51-
args = get_args(annotation)
50+
return {"type": "array", "items": _type_to_json_schema(args[0]) if args else {}}
51+
52+
# Tuple types
53+
if origin is tuple or annotation is tuple:
5254
if args:
53-
return {"type": "array", "items": _type_to_json_schema(args[0])}
55+
item_schemas = [_type_to_json_schema(arg) for arg in args]
56+
return {
57+
"type": "array",
58+
"items": item_schemas,
59+
"minItems": len(item_schemas),
60+
"maxItems": len(item_schemas),
61+
}
5462
return {"type": "array"}
5563

56-
type_mapping = {
57-
str: {"type": "string"},
58-
int: {"type": "integer"},
59-
float: {"type": "number"},
60-
bool: {"type": "boolean"},
61-
Any: {},
62-
}
63-
64-
if annotation in type_mapping:
65-
return type_mapping[annotation]
64+
# Primitive types
65+
type_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
66+
if annotation in type_map:
67+
return {"type": type_map[annotation]}
6668

6769
return {}
6870

6971

72+
def _create_object_schema(
73+
properties: Dict[str, Any], required: Optional[List[str]] = None
74+
) -> Dict[str, Any]:
75+
"""Create a JSON Schema object with properties and required fields."""
76+
schema = {"type": "object", "properties": properties}
77+
if required:
78+
schema["required"] = required
79+
return schema
80+
81+
7082
def _generate_input_schema(func: Any) -> Dict[str, Any]:
7183
"""Generate JSON Schema for function inputs using reflection."""
72-
sig = inspect.signature(func)
73-
properties = {}
74-
required: list[str] = []
75-
76-
for param_name, param in sig.parameters.items():
77-
if param_name == "self":
84+
properties: Dict[str, Any] = {}
85+
required: List[str] = []
86+
for name, param in inspect.signature(func).parameters.items():
87+
if name == "self":
7888
continue
79-
80-
param_schema = (
89+
schema = (
8190
_type_to_json_schema(param.annotation)
8291
if param.annotation != inspect.Parameter.empty
8392
else {}
8493
)
85-
8694
if param.default != inspect.Parameter.empty:
8795
if isinstance(param.default, (str, int, float, bool, type(None))):
88-
param_schema["default"] = param.default
96+
schema["default"] = param.default
8997
else:
90-
required.append(param_name)
91-
92-
properties[param_name] = param_schema
98+
required.append(name)
99+
properties[name] = schema
100+
return _create_object_schema(properties, required if required else None)
93101

94-
schema = {"type": "object", "properties": properties}
95-
if required:
96-
schema["required"] = required
97102

98-
return schema
103+
def _validate_output_field_names(
104+
func: Any, output_field_names: List[str], return_annotation: Any
105+
) -> None:
106+
"""Validate output_field_names against function return type."""
107+
origin = get_origin(return_annotation)
108+
is_tuple: bool = origin is tuple or return_annotation is tuple
109+
if not is_tuple:
110+
raise ValueError(
111+
f"output_field_names provided for '{func.__name__}', "
112+
f"but return type is not a tuple (got {return_annotation})."
113+
)
114+
args = get_args(return_annotation)
115+
if not args:
116+
raise ValueError(
117+
f"output_field_names provided for '{func.__name__}', but tuple has no type arguments."
118+
)
119+
if len(output_field_names) != len(args):
120+
raise ValueError(
121+
f"output_field_names length ({len(output_field_names)}) doesn't match "
122+
f"tuple length ({len(args)}) for '{func.__name__}'. Expected: {output_field_names}"
123+
)
99124

100125

101-
def _generate_output_schema(func: Any) -> Dict[str, Any]:
126+
def _generate_output_schema(
127+
func: Any, output_field_names: Optional[List[str]] = None
128+
) -> Dict[str, Any]:
102129
"""Generate JSON Schema for function outputs using reflection."""
103130
return_annotation = inspect.signature(func).return_annotation
104-
if return_annotation == inspect.Signature.empty:
131+
has_return = return_annotation != inspect.Signature.empty
132+
133+
if output_field_names:
134+
if not has_return:
135+
raise ValueError(
136+
f"output_field_names provided for '{func.__name__}', but function has no return annotation."
137+
)
138+
_validate_output_field_names(func, output_field_names, return_annotation)
139+
140+
if not has_return:
105141
return {}
106-
return _type_to_json_schema(return_annotation)
142+
143+
schema: dict[str, Any] = _type_to_json_schema(return_annotation)
144+
145+
# Convert tuple to object if field names provided
146+
if output_field_names and schema.get("type") == "array" and "items" in schema:
147+
items = schema.get("items", [])
148+
if isinstance(items, list) and len(items) == len(output_field_names):
149+
properties = {name: item for name, item in zip(output_field_names, items)}
150+
return _create_object_schema(properties, output_field_names)
151+
152+
return schema
107153

108154

109155
def automation_activity(
110156
name: str,
111157
description: str,
158+
output_field_names: Optional[List[str]] = None,
112159
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
113-
"""Decorator to mark an activity for automatic registration."""
160+
"""Decorator to mark an activity for automatic registration.
161+
162+
Args:
163+
name: Name of the activity/tool.
164+
description: Description of what the activity does.
165+
output_field_names: Optional list of field names for tuple return types.
166+
If provided, tuple outputs will be converted to objects with named properties.
167+
Length must match the number of elements in the tuple return type.
168+
169+
Returns:
170+
Decorator function.
171+
172+
Raises:
173+
ValueError: If output_field_names is provided but function has no return annotation,
174+
or if return type is not a tuple, or if lengths don't match.
175+
176+
Example:
177+
>>> @automation_activity(
178+
... name="fetch_entities",
179+
... description="Fetch entities by DSL",
180+
... output_field_names=["entities_path", "total_count", "chunk_count"]
181+
... )
182+
... def fetch_entities(self, dsl_query: dict) -> tuple[str, int, int]:
183+
... return ("path/to/entities", 1000, 10)
184+
"""
114185

115186
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
116187
input_schema: dict[str, Any] = _generate_input_schema(func)
117-
output_schema: dict[str, Any] = _generate_output_schema(func)
188+
output_schema: dict[str, Any] = _generate_output_schema(
189+
func, output_field_names
190+
)
118191

119192
logger.info(f"Collected automation activity: {name}")
120193
ACTIVITY_SPECS.append(
@@ -161,9 +234,7 @@ def flush_activity_registrations(
161234
)
162235
return
163236

164-
logger.info(
165-
f"Registering {len(ACTIVITY_SPECS)} activities with automation engine"
166-
)
237+
logger.info(f"Registering {len(ACTIVITY_SPECS)} activities with automation engine")
167238

168239
# Generate app qualified name
169240
app_qualified_name: str = f"default/apps/{app_name}"

0 commit comments

Comments
 (0)