Skip to content

Commit 405cfdc

Browse files
feat: add structural reform support for AI agent (#67)
* feat: add simulation_modifier field to Policy model Stores Python code for structural reforms (custom variable formulas). Includes database migration. * feat: add structural reform support for AI agent - Add simulation_modifier field to Policy model (stores Python code) - Update Modal functions to apply modifiers to simulations - Add agent upload endpoints for results (deciles, poverty, inequality) - Update agent system prompt with structural reform workflow - Database migration for policies.simulation_modifier column The agent can now write custom variable formulas via simulation modifiers and run them on UK/US economy simulations. * feat: add execute_python tool for agent to test modifier code - Added execute_python tool that runs Python code in a sandbox - Updated system prompt to instruct agent to test modifiers before submitting - Agent can now validate syntax and logic before running expensive simulations * test: add tests for structural reform and execute_python functionality Tests for: - Creating policies with simulation_modifier - execute_python_code function (basic, error handling, modifier validation) - PolicyRead including simulation_modifier field * test: add integration tests for agent execute_python tool - Added dotenv loading to conftest.py for API keys - Added test_agent_uses_execute_python_tool - verifies agent can execute Python - Added test_agent_validates_modifier_code - verifies agent can validate modifiers
1 parent 75315cc commit 405cfdc

File tree

15 files changed

+859
-53
lines changed

15 files changed

+859
-53
lines changed

scripts/seed.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer
223223
seen_names.add(p.name)
224224

225225
filter_msg = f" Filtered to {len(parameters_to_add)} user-facing parameters"
226-
filter_msg += f" (from {len(model_version.parameters)} total, deduplicated by name)"
226+
filter_msg += (
227+
f" (from {len(model_version.parameters)} total, deduplicated by name)"
228+
)
227229
if lite and skipped_state_params > 0:
228230
filter_msg += f", skipped {skipped_state_params} state params (lite mode)"
229231
console.print(filter_msg)
@@ -626,7 +628,9 @@ def main():
626628

627629
with logfire.span("database_seeding"):
628630
mode_str = " (lite mode)" if args.lite else ""
629-
console.print(f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n")
631+
console.print(
632+
f"[bold green]PolicyEngine database seeding{mode_str}[/bold green]\n"
633+
)
630634

631635
with next(get_quiet_session()) as session:
632636
# Seed UK model

src/policyengine_api/agent_sandbox.py

Lines changed: 142 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,56 @@ def configure_logfire(traceparent: str | None = None):
7676
- POST /analysis/economic-impact with tax_benefit_model_name, policy_id and dataset_id
7777
- GET /analysis/economic-impact/{report_id} for results (includes decile_impacts and program_statistics)
7878
79+
4. **Structural reforms** (custom variable formulas):
80+
For reforms that can't be expressed as parameter changes (e.g., new benefits, eligibility changes):
81+
82+
**IMPORTANT: Always test your modifier code first using the execute_python tool!**
83+
84+
Steps:
85+
a. Write your simulation_modifier code
86+
b. Test it with execute_python to check for syntax errors and basic logic
87+
c. POST /agent/results/policy-with-modifier to create the policy
88+
d. Use the policy_id in /analysis/economic-impact as normal
89+
90+
Example test with execute_python:
91+
```python
92+
# Test the modifier code compiles and basic logic works
93+
from numpy import where
94+
95+
def modify(simulation):
96+
# Your modifier code here
97+
pass
98+
99+
# Test the function exists and is callable
100+
print(f"modify function defined: {callable(modify)}")
101+
102+
# Test any helper logic
103+
income = 15000
104+
benefit = 1000 if income < 20000 else 0
105+
print(f"Test case: income={income} -> benefit={benefit}")
106+
```
107+
108+
Example simulation_modifier for a new benefit:
109+
```python
110+
def modify(simulation):
111+
from policyengine_core.variables import Variable
112+
from policyengine_core.periods import YEAR
113+
from numpy import where
114+
115+
Person = simulation.tax_benefit_system.entities_by_name()["person"]
116+
117+
@simulation.tax_benefit_system.variable("my_new_benefit")
118+
class my_new_benefit(Variable):
119+
value_type = float
120+
entity = Person
121+
definition_period = YEAR
122+
label = "My new benefit"
123+
124+
def formula(person, period, parameters):
125+
income = person("employment_income", period)
126+
return where(income < 20000, 1000, 0)
127+
```
128+
79129
## Response formatting
80130
81131
Follow PolicyEngine's writing style:
@@ -124,6 +174,66 @@ def configure_logfire(traceparent: str | None = None):
124174
},
125175
}
126176

177+
# Python execution tool for testing code
178+
EXECUTE_PYTHON_TOOL = {
179+
"name": "execute_python",
180+
"description": "Execute Python code and return the output. Use this to test simulation modifier code before submitting it. The code runs in a sandboxed environment with numpy available. Returns stdout/stderr and any exceptions.",
181+
"input_schema": {
182+
"type": "object",
183+
"properties": {
184+
"code": {
185+
"type": "string",
186+
"description": "Python code to execute. Should include print statements to show results.",
187+
}
188+
},
189+
"required": ["code"],
190+
},
191+
}
192+
193+
194+
def execute_python_code(code: str) -> str:
195+
"""Execute Python code in a restricted environment and return output."""
196+
import io
197+
import sys
198+
import traceback
199+
200+
# Capture stdout/stderr
201+
old_stdout = sys.stdout
202+
old_stderr = sys.stderr
203+
sys.stdout = captured_out = io.StringIO()
204+
sys.stderr = captured_err = io.StringIO()
205+
206+
result = ""
207+
try:
208+
# Create a restricted namespace with common imports available
209+
namespace = {
210+
"__builtins__": __builtins__,
211+
}
212+
213+
# Execute the code
214+
exec(code, namespace)
215+
216+
stdout_val = captured_out.getvalue()
217+
stderr_val = captured_err.getvalue()
218+
219+
if stdout_val:
220+
result += f"Output:\n{stdout_val}"
221+
if stderr_val:
222+
result += f"\nStderr:\n{stderr_val}"
223+
if not stdout_val and not stderr_val:
224+
result = "Code executed successfully (no output)"
225+
226+
except Exception as e:
227+
result = (
228+
f"Error: {type(e).__name__}: {e}\n\nTraceback:\n{traceback.format_exc()}"
229+
)
230+
231+
finally:
232+
sys.stdout = old_stdout
233+
sys.stderr = old_stderr
234+
235+
return result[:5000] # Limit output length
236+
127237

128238
def fetch_openapi_spec(api_base_url: str) -> dict:
129239
"""Fetch and cache OpenAPI spec."""
@@ -235,8 +345,7 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]:
235345

236346
prop = schema_to_json_schema(spec, param_schema)
237347
prop["description"] = (
238-
param.get("description", "")
239-
+ f" (in: {param_in})"
348+
param.get("description", "") + f" (in: {param_in})"
240349
)
241350
properties[param_name] = prop
242351

@@ -268,16 +377,18 @@ def openapi_to_claude_tools(spec: dict) -> list[dict]:
268377
if required:
269378
input_schema["required"] = list(set(required))
270379

271-
tools.append({
272-
"name": tool_name,
273-
"description": full_desc[:1024], # Claude has limits
274-
"input_schema": input_schema,
275-
"_meta": {
276-
"path": path,
277-
"method": method,
278-
"parameters": operation.get("parameters", []),
279-
},
280-
})
380+
tools.append(
381+
{
382+
"name": tool_name,
383+
"description": full_desc[:1024], # Claude has limits
384+
"input_schema": input_schema,
385+
"_meta": {
386+
"path": path,
387+
"method": method,
388+
"parameters": operation.get("parameters", []),
389+
},
390+
}
391+
)
281392

282393
return tools
283394

@@ -347,7 +458,9 @@ def execute_api_tool(
347458
url, params=query_params, json=body_data, headers=headers, timeout=60
348459
)
349460
elif method == "delete":
350-
resp = requests.delete(url, params=query_params, headers=headers, timeout=60)
461+
resp = requests.delete(
462+
url, params=query_params, headers=headers, timeout=60
463+
)
351464
else:
352465
return f"Unsupported method: {method}"
353466

@@ -415,11 +528,10 @@ def log(msg: str) -> None:
415528
tool_lookup = {t["name"]: t for t in tools}
416529

417530
# Strip _meta from tools before sending to Claude (it doesn't need it)
418-
claude_tools = [
419-
{k: v for k, v in t.items() if k != "_meta"} for t in tools
420-
]
421-
# Add the sleep tool
531+
claude_tools = [{k: v for k, v in t.items() if k != "_meta"} for t in tools]
532+
# Add built-in tools
422533
claude_tools.append(SLEEP_TOOL)
534+
claude_tools.append(EXECUTE_PYTHON_TOOL)
423535

424536
client = anthropic.Anthropic()
425537

@@ -466,6 +578,12 @@ def log(msg: str) -> None:
466578
log(f"[SLEEP] Waiting {seconds} seconds...")
467579
time.sleep(seconds)
468580
result = f"Slept for {seconds} seconds"
581+
elif block.name == "execute_python":
582+
# Handle Python execution tool
583+
code = block.input.get("code", "")
584+
log(f"[PYTHON] Executing code ({len(code)} chars)...")
585+
result = execute_python_code(code)
586+
log(f"[PYTHON] Result: {result[:200]}")
469587
else:
470588
tool = tool_lookup.get(block.name)
471589
if tool:
@@ -477,11 +595,13 @@ def log(msg: str) -> None:
477595

478596
log(f"[TOOL_RESULT] {result[:300]}")
479597

480-
tool_results.append({
481-
"type": "tool_result",
482-
"tool_use_id": block.id,
483-
"content": result,
484-
})
598+
tool_results.append(
599+
{
600+
"type": "tool_result",
601+
"tool_use_id": block.id,
602+
"content": result,
603+
}
604+
)
485605

486606
messages.append({"role": "assistant", "content": assistant_content})
487607

src/policyengine_api/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from . import (
66
agent,
7+
agent_results,
78
analysis,
89
change_aggregates,
910
datasets,
@@ -35,5 +36,6 @@
3536
api_router.include_router(household.router)
3637
api_router.include_router(analysis.router)
3738
api_router.include_router(agent.router)
39+
api_router.include_router(agent_results.router)
3840

3941
__all__ = ["api_router"]

src/policyengine_api/api/agent.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def get_traceparent() -> str | None:
2424
TraceContextTextMapPropagator().inject(carrier)
2525
return carrier.get("traceparent")
2626

27+
2728
router = APIRouter(prefix="/agent", tags=["agent"])
2829

2930

@@ -93,7 +94,9 @@ def _run_local_agent(
9394
from policyengine_api.agent_sandbox import _run_agent_impl
9495

9596
try:
96-
history_dicts = [{"role": m.role, "content": m.content} for m in (history or [])]
97+
history_dicts = [
98+
{"role": m.role, "content": m.content} for m in (history or [])
99+
]
97100
result = _run_agent_impl(question, api_base_url, call_id, history_dicts)
98101
_calls[call_id]["status"] = result.get("status", "completed")
99102
_calls[call_id]["result"] = result
@@ -136,9 +139,15 @@ async def run_agent(request: RunRequest) -> RunResponse:
136139

137140
traceparent = get_traceparent()
138141
run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent")
139-
history_dicts = [{"role": m.role, "content": m.content} for m in request.history]
142+
history_dicts = [
143+
{"role": m.role, "content": m.content} for m in request.history
144+
]
140145
call = run_fn.spawn(
141-
request.question, api_base_url, call_id, history_dicts, traceparent=traceparent
146+
request.question,
147+
api_base_url,
148+
call_id,
149+
history_dicts,
150+
traceparent=traceparent,
142151
)
143152

144153
_calls[call_id] = {
@@ -166,7 +175,12 @@ async def run_agent(request: RunRequest) -> RunResponse:
166175
# Run in background using asyncio
167176
loop = asyncio.get_event_loop()
168177
loop.run_in_executor(
169-
None, _run_local_agent, call_id, request.question, api_base_url, request.history
178+
None,
179+
_run_local_agent,
180+
call_id,
181+
request.question,
182+
api_base_url,
183+
request.history,
170184
)
171185

172186
return RunResponse(call_id=call_id, status="running")

0 commit comments

Comments
 (0)