Skip to content

Commit

Permalink
Merge pull request #17 from modelcontextprotocol/davidsp/workflows
Browse files Browse the repository at this point in the history
Github workflows for ruff and pyright
  • Loading branch information
dsp-ant authored Oct 14, 2024
2 parents 9475815 + 211b5f0 commit db5ca59
Show file tree
Hide file tree
Showing 20 changed files with 823 additions and 107 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/check-format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: ruff

on:
push:
branches: [main]
pull_request:

jobs:
format:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version-file: ".python-version"

- name: Install the project
run: uv sync --frozen --all-extras --dev

- name: Run ruff format check
run: uv run --frozen ruff check .
29 changes: 29 additions & 0 deletions .github/workflows/check-types.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: typecheck

on:
push:
branches: [main]
pull_request:

jobs:
typecheck:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version-file: ".python-version"

- name: Install the project
run: uv sync --frozen --all-extras --dev

- name: Run pyright
run: uv run --frozen pyright
19 changes: 14 additions & 5 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@ jobs:

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version-file: ".python-version"

- name: Install the project
run: uv sync --frozen --all-extras --dev

- run: pip install .
- run: pip install -U pytest trio
- run: pytest
- name: Run pytest
run: uv run --frozen pytest
4 changes: 3 additions & 1 deletion mcp_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
ReadResourceResult,
Resource,
ResourceUpdatedNotification,
Role as SamplingRole,
SamplingMessage,
ServerCapabilities,
ServerNotification,
Expand All @@ -49,6 +48,9 @@
Tool,
UnsubscribeRequest,
)
from .types import (
Role as SamplingRole,
)

__all__ = [
"CallToolRequest",
Expand Down
3 changes: 2 additions & 1 deletion mcp_python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ async def initialize(self) -> InitializeResult:

if result.protocolVersion != SUPPORTED_PROTOCOL_VERSION:
raise RuntimeError(
f"Unsupported protocol version from the server: {result.protocolVersion}"
"Unsupported protocol version from the server: "
f"{result.protocolVersion}"
)

await self.send_notification(
Expand Down
24 changes: 19 additions & 5 deletions mcp_python/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ def remove_request_params(url: str) -> str:


@asynccontextmanager
async def sse_client(url: str, headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5):
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
Expand Down Expand Up @@ -67,7 +73,10 @@ async def sse_reader(
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)

Expand Down Expand Up @@ -104,11 +113,16 @@ async def post_writer(endpoint_url: str):
logger.debug(f"Sending client message: {message}")
response = await client.post(
endpoint_url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
json=message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
f"Client message sent successfully: {response.status_code}"
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
Expand Down
3 changes: 2 additions & 1 deletion mcp_python/client/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class StdioServerParameters(BaseModel):
@asynccontextmanager
async def stdio_client(server: StdioServerParameters):
"""
Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout.
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
Expand Down
57 changes: 33 additions & 24 deletions mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def __init__(self, name: str):

def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""

def pkg_version(package: str) -> str:
try:
from importlib.metadata import version

return version(package)
except Exception:
return "unknown"
Expand All @@ -69,16 +71,17 @@ def pkg_version(package: str) -> str:
)

def get_capabilities(self) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None
"""Convert existing handlers to a ServerCapabilities object."""

return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest)
)
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None

return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest),
)

@property
def request_context(self) -> RequestContext:
Expand All @@ -87,7 +90,7 @@ def request_context(self) -> RequestContext:

def list_prompts(self):
def decorator(func: Callable[[], Awaitable[list[Prompt]]]):
logger.debug(f"Registering handler for PromptListRequest")
logger.debug("Registering handler for PromptListRequest")

async def handler(_: Any):
prompts = await func()
Expand All @@ -103,17 +106,19 @@ def get_prompt(self):
GetPromptRequest,
GetPromptResult,
ImageContent,
Role as Role,
SamplingMessage,
TextContent,
)
from mcp_python.types import (
Role as Role,
)

def decorator(
func: Callable[
[str, dict[str, str] | None], Awaitable[types.PromptResponse]
],
):
logger.debug(f"Registering handler for GetPromptRequest")
logger.debug("Registering handler for GetPromptRequest")

async def handler(req: GetPromptRequest):
prompt_get = await func(req.params.name, req.params.arguments)
Expand Down Expand Up @@ -149,7 +154,7 @@ async def handler(req: GetPromptRequest):

def list_resources(self):
def decorator(func: Callable[[], Awaitable[list[Resource]]]):
logger.debug(f"Registering handler for ListResourcesRequest")
logger.debug("Registering handler for ListResourcesRequest")

async def handler(_: Any):
resources = await func()
Expand All @@ -169,7 +174,7 @@ def read_resource(self):
)

def decorator(func: Callable[[AnyUrl], Awaitable[str | bytes]]):
logger.debug(f"Registering handler for ReadResourceRequest")
logger.debug("Registering handler for ReadResourceRequest")

async def handler(req: ReadResourceRequest):
result = await func(req.params.uri)
Expand Down Expand Up @@ -204,7 +209,7 @@ def set_logging_level(self):
from mcp_python.types import EmptyResult

def decorator(func: Callable[[LoggingLevel], Awaitable[None]]):
logger.debug(f"Registering handler for SetLevelRequest")
logger.debug("Registering handler for SetLevelRequest")

async def handler(req: SetLevelRequest):
await func(req.params.level)
Expand All @@ -219,7 +224,7 @@ def subscribe_resource(self):
from mcp_python.types import EmptyResult

def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug(f"Registering handler for SubscribeRequest")
logger.debug("Registering handler for SubscribeRequest")

async def handler(req: SubscribeRequest):
await func(req.params.uri)
Expand All @@ -234,7 +239,7 @@ def unsubscribe_resource(self):
from mcp_python.types import EmptyResult

def decorator(func: Callable[[AnyUrl], Awaitable[None]]):
logger.debug(f"Registering handler for UnsubscribeRequest")
logger.debug("Registering handler for UnsubscribeRequest")

async def handler(req: UnsubscribeRequest):
await func(req.params.uri)
Expand All @@ -249,7 +254,7 @@ def call_tool(self):
from mcp_python.types import CallToolResult

def decorator(func: Callable[..., Awaitable[Any]]):
logger.debug(f"Registering handler for CallToolRequest")
logger.debug("Registering handler for CallToolRequest")

async def handler(req: CallToolRequest):
result = await func(req.params.name, **(req.params.arguments or {}))
Expand All @@ -264,7 +269,7 @@ def progress_notification(self):
def decorator(
func: Callable[[str | int, float, float | None], Awaitable[None]],
):
logger.debug(f"Registering handler for ProgressNotification")
logger.debug("Registering handler for ProgressNotification")

async def handler(req: ProgressNotification):
await func(
Expand All @@ -286,7 +291,7 @@ def decorator(
Awaitable[Completion | None],
],
):
logger.debug(f"Registering handler for CompleteRequest")
logger.debug("Registering handler for CompleteRequest")

async def handler(req: CompleteRequest):
completion = await func(req.params.ref, req.params.argument)
Expand All @@ -307,10 +312,12 @@ async def run(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions
initialization_options: types.InitializationOptions,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(read_stream, write_stream, initialization_options) as session:
async with ServerSession(
read_stream, write_stream, initialization_options
) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")

Expand Down Expand Up @@ -359,14 +366,16 @@ async def run(

handler = self.notification_handlers[type(notify)]
logger.debug(
f"Dispatching notification of type {type(notify).__name__}"
f"Dispatching notification of type "
f"{type(notify).__name__}"
)

try:
await handler(notify)
except Exception as err:
logger.error(
f"Uncaught exception in notification handler: {err}"
f"Uncaught exception in notification handler: "
f"{err}"
)

for warning in w:
Expand Down
18 changes: 15 additions & 3 deletions mcp_python/server/__main__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import importlib.metadata
import logging
import sys
import importlib.metadata

import anyio

from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.server.stdio import stdio_server
from mcp_python.server.types import InitializationOptions
from mcp_python.types import ServerCapabilities

if not sys.warnoptions:
Expand All @@ -30,7 +31,18 @@ async def receive_loop(session: ServerSession):
async def main():
version = importlib.metadata.version("mcp_python")
async with stdio_server() as (read_stream, write_stream):
async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
async with (
ServerSession(
read_stream,
write_stream,
InitializationOptions(
server_name="mcp_python",
server_version=version,
capabilities=ServerCapabilities(),
),
) as session,
write_stream,
):
await receive_loop(session)


Expand Down
Loading

0 comments on commit db5ca59

Please sign in to comment.