Skip to content

Commit 644b2e8

Browse files
Integrate Anthropic client for dynamic plan generation (#421)
* feat: integrate Anthropic client for dynamic plan generation - Replace static plan generation in `plan_issue.py` with `ChatAnthropic`. - Implement `PlanModel` Pydantic model for structured output validation. - Inject repository file context into the prompt to improve plan accuracy. - Update `__init__.py` to export `plan_issue` and `execute_issue` from their respective modules. - Add unit tests for `plan_issue`. * fix: resolve type errors in langgraph_service and orchestrator - Consolidate GraphState type in types.py with all required fields - Add type: ignore comments for external package imports (pydantic, langchain) - Fix variable redefinition in execute_issue.py - Create separate create_pr.py node module - Update pyproject.toml with correct package name (multiplai) - Fix null vs undefined type mismatch in orchestrator.ts * fix(ci): update langgraph_service path to packages/api/langgraph_service * fix(ci): use pnpm for turborepo and fix Python lint issues - Switch CI from bun to pnpm for turborepo compatibility - Add missing newlines to Python __init__.py files - Shorten long comment line in execute_issue.py * fix(ci): remove pnpm version conflict and add missing newline * fix(lint): sort imports in test_plan_issue.py --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: Ronaldo Martins <ron@ldinho.com.br>
1 parent 61aa10d commit 644b2e8

15 files changed

Lines changed: 385 additions & 122 deletions

File tree

.github/workflows/ci.yml

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,38 @@ jobs:
1313
steps:
1414
- uses: actions/checkout@v4
1515

16-
- uses: oven-sh/setup-bun@v2
16+
- uses: pnpm/action-setup@v4
17+
18+
- uses: actions/setup-node@v4
1719
with:
18-
bun-version: latest
20+
node-version: "20"
21+
cache: "pnpm"
1922

2023
- name: Install dependencies
21-
run: bun install
24+
run: pnpm install
2225

2326
- name: Type check
24-
run: bun run typecheck
27+
run: pnpm run typecheck
2528

2629
test:
2730
name: Test
2831
runs-on: ubuntu-latest
2932
steps:
3033
- uses: actions/checkout@v4
3134

35+
- uses: pnpm/action-setup@v4
36+
37+
- uses: actions/setup-node@v4
38+
with:
39+
node-version: "20"
40+
cache: "pnpm"
41+
3242
- uses: oven-sh/setup-bun@v2
3343
with:
3444
bun-version: latest
3545

3646
- name: Install dependencies
37-
run: bun install
47+
run: pnpm install
3848

3949
- name: Run tests
4050
run: bun test || echo "No tests found"
@@ -44,7 +54,7 @@ jobs:
4454
runs-on: ubuntu-latest
4555
defaults:
4656
run:
47-
working-directory: langgraph_service
57+
working-directory: packages/api/langgraph_service
4858
steps:
4959
- uses: actions/checkout@v4
5060

@@ -69,7 +79,7 @@ jobs:
6979
runs-on: ubuntu-latest
7080
defaults:
7181
run:
72-
working-directory: langgraph_service
82+
working-directory: packages/api/langgraph_service
7383
steps:
7484
- uses: actions/checkout@v4
7585

packages/api/langgraph_service/pyproject.toml

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[project]
2-
name = "langgraph-service"
2+
name = "multiplai"
33
version = "0.1.0"
44
requires-python = ">=3.11"
55
dependencies = [
@@ -10,6 +10,7 @@ dependencies = [
1010
"fastapi>=0.115.0",
1111
"uvicorn>=0.32.0",
1212
"pydantic>=2.9.0",
13+
"pydantic-settings>=2.0.0",
1314
"httpx>=0.27.0",
1415
"python-dotenv>=1.0.0",
1516
"structlog>=24.0.0",
@@ -20,4 +21,23 @@ requires = ["hatchling"]
2021
build-backend = "hatchling.build"
2122

2223
[tool.hatch.build.targets.wheel]
23-
packages = ["src/langgraph_service"]
24+
packages = ["src/multiplai"]
25+
26+
[tool.pytest.ini_options]
27+
asyncio_mode = "auto"
28+
testpaths = ["tests"]
29+
pythonpath = ["src"]
30+
31+
[tool.mypy]
32+
python_version = "3.11"
33+
warn_return_any = true
34+
warn_unused_configs = true
35+
ignore_missing_imports = false
36+
37+
[tool.ruff]
38+
line-length = 100
39+
target-version = "py311"
40+
41+
[tool.ruff.lint]
42+
select = ["E", "F", "I", "W"]
43+
ignore = []
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""MultiplAI - Multi-agent AI system for automated software development."""
22

3-
__version__ = '0.1.0'
3+
__version__ = '0.1.0'

packages/api/langgraph_service/src/multiplai/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from functools import lru_cache
88

9-
from pydantic_settings import BaseSettings
9+
from pydantic_settings import BaseSettings # type: ignore[import-not-found]
1010

1111

1212
class Settings(BaseSettings):

packages/api/langgraph_service/src/multiplai/graph.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,13 @@
1313
import copy
1414
import inspect
1515
from dataclasses import dataclass
16-
from typing import Any, Awaitable, Callable, Dict, Optional, TypedDict
16+
from typing import Any, Awaitable, Callable, Dict, Optional
1717

18+
from multiplai.types import GraphState
1819

1920
END = "__end__"
2021

2122

22-
class GraphState(TypedDict, total=False):
23-
"""Shared graph state.
24-
25-
This is a pragmatic schema for tests and basic usage.
26-
"""
27-
28-
status: str
29-
context: Dict[str, Any]
30-
plan: Dict[str, Any]
31-
execution_result: Dict[str, Any]
32-
pr_url: str
33-
trace: list[str]
34-
35-
3623
class MemorySaver:
3724
"""A minimal in-memory checkpointer compatible with this module's graph."""
3825

@@ -47,7 +34,10 @@ def put(self, thread_id: str, state: GraphState) -> None:
4734
self._store[thread_id] = copy.deepcopy(state)
4835

4936

50-
NodeFn = Callable[[GraphState], Dict[str, Any] | GraphState | None | Awaitable[Dict[str, Any] | GraphState | None]]
37+
NodeFn = Callable[
38+
[GraphState],
39+
Dict[str, Any] | GraphState | None | Awaitable[Dict[str, Any] | GraphState | None],
40+
]
5141

5242

5343
@dataclass(frozen=True)
@@ -57,7 +47,9 @@ class _CompiledGraph:
5747
entry_point: str
5848
checkpointer: MemorySaver
5949

60-
async def ainvoke(self, state: GraphState, config: Optional[Dict[str, Any]] = None) -> GraphState:
50+
async def ainvoke(
51+
self, state: GraphState, config: Optional[Dict[str, Any]] = None
52+
) -> GraphState:
6153
current = self.entry_point
6254
thread_id = "default"
6355
if config is not None:
@@ -77,7 +69,8 @@ async def ainvoke(self, state: GraphState, config: Optional[Dict[str, Any]] = No
7769
result = await result
7870
if result:
7971
# Treat the node output as a partial state update.
80-
working_state.update(result) # type: ignore[arg-type]
72+
for key, value in result.items():
73+
working_state[key] = value # type: ignore[literal-required]
8174
self.checkpointer.put(thread_id, working_state)
8275

8376
current = self.edges.get(current, END)
@@ -92,7 +85,7 @@ class StateGraph:
9285
sufficient for wiring, compilation, and testing of node sequencing.
9386
"""
9487

95-
def __init__(self, state_schema: type[GraphState]) -> None:
88+
def __init__(self, state_schema: type) -> None:
9689
self._state_schema = state_schema
9790
self._nodes: Dict[str, NodeFn] = {}
9891
self._edges: Dict[str, str] = {}
@@ -138,7 +131,12 @@ async def plan_issue(state: GraphState) -> Dict[str, Any]:
138131
_append_trace(state, "plan_issue")
139132
return {
140133
"status": "planned",
141-
"plan": {"steps": ["execute", "create_pr"]},
134+
"plan": {
135+
"steps": ["execute", "create_pr"],
136+
"definition_of_done": [],
137+
"target_files": [],
138+
"estimated_complexity": "low",
139+
},
142140
}
143141

144142

@@ -187,4 +185,4 @@ def build_graph() -> _CompiledGraph:
187185
return workflow.compile(checkpointer=MemorySaver())
188186

189187

190-
graph = build_graph()
188+
graph = build_graph()

packages/api/langgraph_service/src/multiplai/nodes/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
This package exposes the node callables used by the graph runner.
44
"""
55

6-
from .load_context import create_pr, execute_issue, load_context, plan_issue
6+
from .create_pr import create_pr
7+
from .execute_issue import execute_issue
8+
from .load_context import load_context
9+
from .plan_issue import plan_issue
710

811
__all__ = [
912
"load_context",
1013
"plan_issue",
1114
"execute_issue",
1215
"create_pr",
13-
]
16+
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Create PR node for MultiplAI LangGraph workflows.
2+
3+
This module contains the `create_pr` node, responsible for creating a
4+
pull request from the generated diff.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from multiplai.types import GraphState
10+
11+
12+
async def create_pr(state: GraphState) -> GraphState:
13+
"""Create a pull request from the generated diff.
14+
15+
Args:
16+
state: Current graph state with diff.
17+
18+
Returns:
19+
Updated graph state with:
20+
- status set to 'pr_created'
21+
- pr_url containing the PR URL
22+
- pr_data containing PR metadata
23+
"""
24+
# Placeholder implementation
25+
# In the future, this will use GitHub API to create the PR
26+
27+
updated_state = GraphState(**state)
28+
updated_state["status"] = "pr_created"
29+
updated_state["pr_url"] = "https://github.com/example/repo/pull/1"
30+
updated_state["pr_data"] = {
31+
"number": 1,
32+
"html_url": "https://github.com/example/repo/pull/1",
33+
"state": "open",
34+
}
35+
return updated_state

packages/api/langgraph_service/src/multiplai/nodes/execute_issue.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
from __future__ import annotations
99

10-
from typing import Any, Dict, TypeAlias
11-
12-
from langchain_anthropic import ChatAnthropic
13-
from langchain_core.messages import HumanMessage, SystemMessage
10+
from langchain_anthropic import ChatAnthropic # type: ignore[import-not-found]
11+
from langchain_core.messages import ( # type: ignore[import-not-found]
12+
HumanMessage,
13+
SystemMessage,
14+
)
1415

1516
from multiplai.config import get_settings
16-
17-
GraphState: TypeAlias = Dict[str, Any]
17+
from multiplai.types import GraphState
1818

1919

2020
async def execute_issue(state: GraphState) -> GraphState:
@@ -41,10 +41,10 @@ async def execute_issue(state: GraphState) -> GraphState:
4141
target_files = state.get("target_files")
4242

4343
if not target_files:
44-
new_state: GraphState = dict(state)
45-
new_state["status"] = "error"
46-
new_state["error"] = "No target_files specified; unable to execute issue."
47-
return new_state
44+
error_state = GraphState(**state)
45+
error_state["status"] = "error"
46+
error_state["error"] = "No target_files specified; unable to execute issue."
47+
return error_state
4848

4949
# Read the content of the target files
5050
file_contents = {}
@@ -53,20 +53,18 @@ async def execute_issue(state: GraphState) -> GraphState:
5353
with open(file_path, "r", encoding="utf-8") as f:
5454
file_contents[file_path] = f.read()
5555
except FileNotFoundError:
56-
new_state: GraphState = dict(state)
57-
new_state["status"] = "error"
58-
new_state["error"] = f"File not found: {file_path}"
59-
return new_state
56+
error_state = GraphState(**state)
57+
error_state["status"] = "error"
58+
error_state["error"] = f"File not found: {file_path}"
59+
return error_state
6060
except Exception as e:
61-
new_state: GraphState = dict(state)
62-
new_state["status"] = "error"
63-
new_state["error"] = f"Error reading file {file_path}: {e}"
64-
return new_state
61+
error_state = GraphState(**state)
62+
error_state["status"] = "error"
63+
error_state["error"] = f"Error reading file {file_path}: {e}"
64+
return error_state
6565

6666
settings = get_settings()
67-
llm = ChatAnthropic(
68-
api_key=settings.anthropic_api_key, model="claude-3-5-sonnet-20240620"
69-
)
67+
llm = ChatAnthropic(api_key=settings.anthropic_api_key, model="claude-3-5-sonnet-20240620")
7068

7169
system_prompt = (
7270
"You are an expert software engineer. Your task is to generate a unified diff "
@@ -88,7 +86,7 @@ async def execute_issue(state: GraphState) -> GraphState:
8886
response = await llm.ainvoke(messages)
8987
unified_diff = response.content
9088
if isinstance(unified_diff, list):
91-
# Handle case where content might be list of blocks (though unlikely for text model without tools)
89+
# Handle case where content might be list of blocks
9290
unified_diff = "\n".join(
9391
[block.text for block in unified_diff if hasattr(block, "text")]
9492
)
@@ -105,12 +103,12 @@ async def execute_issue(state: GraphState) -> GraphState:
105103
unified_diff = unified_diff.strip()
106104

107105
except Exception as e:
108-
new_state: GraphState = dict(state)
109-
new_state["status"] = "error"
110-
new_state["error"] = f"Error generating patch: {e}"
111-
return new_state
112-
113-
new_state: GraphState = dict(state)
114-
new_state["diff"] = unified_diff
115-
new_state["status"] = "executed"
116-
return new_state
106+
error_state = GraphState(**state)
107+
error_state["status"] = "error"
108+
error_state["error"] = f"Error generating patch: {e}"
109+
return error_state
110+
111+
success_state = GraphState(**state)
112+
success_state["diff"] = unified_diff
113+
success_state["status"] = "executed"
114+
return success_state

0 commit comments

Comments
 (0)