Skip to content

Commit a396d10

Browse files
authored
Fix request model creation and MCP execution for pipeline wrappers with only run_api_async implemented (#125)
* remove unused imports * Fix requset_model creation for pipelines with only run_api_async implemented * run correctly pipeline as MCP tools when they have only run_api_async implemented
1 parent dd6a405 commit a396d10

File tree

6 files changed

+98
-31
lines changed

6 files changed

+98
-31
lines changed

src/hayhooks/cli/mcp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import typer
32
import uvicorn
43
import sys

src/hayhooks/server/utils/deploy_utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,28 @@ def add_pipeline_to_registry(
423423
clog.debug("Running setup()")
424424
pipeline_wrapper.setup()
425425

426-
docstring = docstring_parser.parse(inspect.getdoc(pipeline_wrapper.run_api) or "")
426+
# Determine which run_api method to use for creating request model (prefer async if available)
427+
if pipeline_wrapper._is_run_api_async_implemented:
428+
run_method_to_inspect = pipeline_wrapper.run_api_async
429+
clog.debug("Using `run_api_async` for metadata creation.")
430+
elif pipeline_wrapper._is_run_api_implemented:
431+
run_method_to_inspect = pipeline_wrapper.run_api
432+
clog.debug("Using `run_api` for metadata creation.")
433+
else:
434+
# If neither run_api nor run_api_async is implemented, skip creating request model
435+
run_method_to_inspect = None
436+
clog.debug("No run_api method implemented, skipping request model creation.")
437+
438+
if run_method_to_inspect:
439+
docstring = docstring_parser.parse(inspect.getdoc(run_method_to_inspect) or "")
440+
request_model = create_request_model_from_callable(run_method_to_inspect, f'{pipeline_name}Run', docstring)
441+
else:
442+
docstring = docstring_parser.Docstring()
443+
request_model = None
444+
427445
metadata = {
428446
"description": docstring.short_description or "",
429-
"request_model": create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run', docstring),
447+
"request_model": request_model,
430448
"skip_mcp": pipeline_wrapper.skip_mcp,
431449
}
432450

src/hayhooks/server/utils/mcp_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import traceback
23
from enum import Enum
34
from pathlib import Path
45
from typing import TYPE_CHECKING, List, Union
@@ -13,8 +14,10 @@
1314
)
1415
from hayhooks.settings import settings
1516
from hayhooks.server.pipelines import registry
17+
from hayhooks.server.pipelines.registry import PipelineType
1618
from haystack.lazy_imports import LazyImport
1719
from hayhooks.server.routers.deploy import PipelineFilesRequest
20+
from fastapi.concurrency import run_in_threadpool
1821

1922

2023
with LazyImport("Run 'pip install \"mcp\"' to install MCP.") as mcp_import:
@@ -133,12 +136,21 @@ async def run_pipeline_as_tool(
133136
mcp_import.check()
134137

135138
log.debug(f"Calling pipeline as tool '{name}' with arguments: {arguments}")
136-
pipeline_wrapper: Union[BasePipelineWrapper, None] = registry.get(name)
139+
pipeline: Union[PipelineType, None] = registry.get(name)
137140

138-
if not pipeline_wrapper:
141+
if not pipeline:
139142
raise ValueError(f"Pipeline '{name}' not found")
140143

141-
result = await asyncio.to_thread(pipeline_wrapper.run_api, **arguments)
144+
# Only BasePipelineWrapper instances support run_api/run_api_async methods
145+
if not isinstance(pipeline, BasePipelineWrapper):
146+
raise ValueError(f"Pipeline '{name}' is not a BasePipelineWrapper and cannot be used as an MCP tool")
147+
148+
# Use the same async/sync pattern as in deploy_utils.py
149+
if pipeline._is_run_api_async_implemented:
150+
result = await pipeline.run_api_async(**arguments)
151+
else:
152+
result = await run_in_threadpool(pipeline.run_api, **arguments)
153+
142154
log.trace(f"Pipeline '{name}' returned result: {result}")
143155

144156
return [TextContent(text=result, type="text")]
@@ -206,6 +218,8 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent | ImageConte
206218

207219
except Exception as exc:
208220
msg = f"General unhandled error in call_tool for tool '{name}': {exc}"
221+
if settings.show_tracebacks:
222+
msg += f"\n{traceback.format_exc()}"
209223
log.error(msg)
210224
raise Exception(msg) from exc
211225

tests/test_deploy_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
create_pipeline_wrapper_instance,
1616
deploy_pipeline_files,
1717
undeploy_pipeline,
18+
add_pipeline_to_registry,
1819
)
1920
from hayhooks.server.exceptions import (
2021
PipelineFilesError,
@@ -378,3 +379,29 @@ def test_undeploy_pipeline_without_app(test_settings):
378379

379380
# 4. Assert pipeline files are deleted
380381
assert not pipeline_dir.exists()
382+
383+
384+
def test_add_pipeline_to_registry_with_async_run_api():
385+
pipeline_name = "async_question_answer"
386+
pipeline_wrapper_path = Path("tests/test_files/files/async_question_answer/pipeline_wrapper.py")
387+
pipeline_yml_path = Path("tests/test_files/files/async_question_answer/question_answer.yml")
388+
files = {
389+
"pipeline_wrapper.py": pipeline_wrapper_path.read_text(),
390+
"question_answer.yml": pipeline_yml_path.read_text(),
391+
}
392+
393+
pipeline_wrapper = add_pipeline_to_registry(pipeline_name=pipeline_name, files=files, save_files=False)
394+
assert registry.get(pipeline_name) == pipeline_wrapper
395+
396+
metadata = registry.get_metadata(pipeline_name)
397+
assert metadata is not None
398+
assert "request_model" in metadata
399+
assert metadata["request_model"] is not None
400+
401+
assert pipeline_wrapper._is_run_api_async_implemented is True
402+
assert pipeline_wrapper._is_run_api_implemented is False
403+
404+
request_model = metadata["request_model"]
405+
schema = request_model.model_json_schema()
406+
assert "question" in schema["properties"]
407+
assert schema["properties"]["question"]["type"] == "string"

tests/test_files/files/async_question_answer/pipeline_wrapper.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,15 @@ def setup(self) -> None:
2121
async def run_api_async(self, question: str) -> str:
2222
log.trace(f"Running pipeline with question: {question}")
2323

24-
result = await self.pipeline.run_async(
25-
{
26-
"prompt_builder": {
27-
"template": [
28-
ChatMessage.from_system(SYSTEM_MESSAGE),
29-
ChatMessage.from_user(question),
30-
]
31-
}
32-
}
33-
)
34-
return result["llm"]["replies"][0].text
24+
return "This is a mock response from the pipeline"
3525

3626
async def run_chat_completion_async(self, model: str, messages: List[dict], body: dict) -> AsyncGenerator:
3727
log.trace(f"Running pipeline with model: {model}, messages: {messages}, body: {body}")
3828

39-
question = get_last_user_message(messages)
40-
log.trace(f"Question: {question}")
41-
42-
return async_streaming_generator(
43-
pipeline=self.pipeline,
44-
pipeline_run_args={
45-
"prompt_builder": {
46-
"template": [
47-
ChatMessage.from_system(SYSTEM_MESSAGE),
48-
ChatMessage.from_user(question),
49-
]
50-
},
51-
},
52-
)
29+
mock_response = "This is a mock response from the pipeline"
30+
31+
async def mock_generator():
32+
for word in mock_response.split():
33+
yield word + " "
34+
35+
return mock_generator()

tests/test_it_mcp_server.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@ def deploy_chat_with_website_mcp_pipeline():
4343
return pipeline_name
4444

4545

46+
@pytest.fixture
47+
def deploy_async_question_answer_mcp_pipeline():
48+
pipeline_name = "async_question_answer"
49+
pipeline_wrapper_path = Path("tests/test_files/files/async_question_answer/pipeline_wrapper.py")
50+
pipeline_yml_path = Path("tests/test_files/files/async_question_answer/question_answer.yml")
51+
files = {
52+
"pipeline_wrapper.py": pipeline_wrapper_path.read_text(),
53+
"question_answer.yml": pipeline_yml_path.read_text(),
54+
}
55+
add_pipeline_to_registry(pipeline_name=pipeline_name, files=files)
56+
return pipeline_name
57+
58+
4659
@pytest.fixture
4760
def mcp_server_instance() -> "Server":
4861
return create_mcp_server()
@@ -105,6 +118,19 @@ async def test_call_pipeline_as_tool(mcp_server_instance, deploy_chat_with_websi
105118
assert result.content == [TextContent(type="text", text="This is a mock response from the pipeline")]
106119

107120

121+
@pytest.mark.asyncio
122+
async def test_call_async_pipeline_as_tool(mcp_server_instance, deploy_async_question_answer_mcp_pipeline):
123+
async with client_session(mcp_server_instance) as client:
124+
result = await client.call_tool(
125+
deploy_async_question_answer_mcp_pipeline, {"question": "What is the capital of France?"}
126+
)
127+
128+
assert isinstance(result, CallToolResult)
129+
130+
# In the deployed pipeline, the response is mocked
131+
assert result.content == [TextContent(type="text", text="This is a mock response from the pipeline")]
132+
133+
108134
@pytest.mark.asyncio
109135
async def test_call_pipeline_as_tool_with_invalid_arguments(mcp_server_instance, deploy_chat_with_website_mcp_pipeline):
110136
async with client_session(mcp_server_instance) as client:

0 commit comments

Comments
 (0)