diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..8833cbd9 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: Run SSE Test + +on: + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r tests/requirements-dev.txt + + - name: Run test file + run: pytest tests/test_sse_client_server_hardened.py \ No newline at end of file diff --git a/tests/requirements-dev.txt b/tests/requirements-dev.txt new file mode 100644 index 00000000..43720b1a --- /dev/null +++ b/tests/requirements-dev.txt @@ -0,0 +1,5 @@ +fastapi +httpx +uvicorn +sse-starlette +anyio \ No newline at end of file diff --git a/tests/test.yml b/tests/test.yml new file mode 100644 index 00000000..106a4f02 --- /dev/null +++ b/tests/test.yml @@ -0,0 +1,32 @@ +name: Run Tests + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install main project dependencies + run: | + pip install -r requirements.txt || true + + - name: Install dev dependencies + run: | + pip install -r requirements-dev.txt + + - name: Run standalone SSE client-server test + run: | + python tests/test_sse_client_server_plain.py diff --git a/tests/test_sse_client_server.py b/tests/test_sse_client_server.py new file mode 100644 index 00000000..ff169adc --- /dev/null +++ b/tests/test_sse_client_server.py @@ -0,0 +1,45 @@ +import asyncio +from typing import AsyncGenerator, List +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + +app = FastAPI() + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\\n\\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + +async def run_sse_test() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages: List[str] = [] + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + print("Event received:", event.data) + messages.append(event.data) + if len(messages) == 3: + break + + if messages == ["Hello 1", "Hello 2", "Hello 3"]: + print("\\n Test passed!") + else: + print("\\n Test failed:", messages) + +if __name__ == "__main__": + asyncio.run(run_sse_test()) + diff --git a/tests/test_sse_client_server_cleaned.py b/tests/test_sse_client_server_cleaned.py new file mode 100644 index 00000000..fdc0c879 --- /dev/null +++ b/tests/test_sse_client_server_cleaned.py @@ -0,0 +1,43 @@ +import asyncio +from typing import AsyncGenerator, List + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + +# Required packages: fastapi, uvicorn, httpx, httpx-sse, sse-starlette, anyio + +app = FastAPI() + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\n\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + +async def test_aconnect_sse_server_response() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages: List[str] = [] + + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + print("Event received:", event.data) + messages.append(event.data) + if len(messages) == 3: + break + + assert messages == ["Hello 1", "Hello 2", "Hello 3"] + print("\n Test passed! SSE connection via aconnect_sse worked correctly.") diff --git a/tests/test_sse_client_server_hardened.py b/tests/test_sse_client_server_hardened.py new file mode 100644 index 00000000..356c41bb --- /dev/null +++ b/tests/test_sse_client_server_hardened.py @@ -0,0 +1,43 @@ +import asyncio +from typing import AsyncGenerator + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + + +app = FastAPI() + + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\n\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + + +async def test_aconnect_sse_server_response() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages = [] + + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + messages.append(event.data) + if len(messages) == 3: + break + + assert messages == ["Hello 1", "Hello 2", "Hello 3"] \ No newline at end of file diff --git a/tests/test_sse_client_server_plain.py b/tests/test_sse_client_server_plain.py new file mode 100644 index 00000000..e7982d49 --- /dev/null +++ b/tests/test_sse_client_server_plain.py @@ -0,0 +1,45 @@ +import asyncio +from typing import AsyncGenerator, List + +from fastapi import FastAPI +from starlette.responses import StreamingResponse +import uvicorn +from threading import Thread +import httpx +from mcp.client.sse import aconnect_sse + +app = FastAPI() + +@app.get("/sse") +async def sse_endpoint() -> StreamingResponse: + async def event_stream() -> AsyncGenerator[str, None]: + for i in range(3): + yield f"data: Hello {i+1}\n\n" + await asyncio.sleep(0.1) + return StreamingResponse(event_stream(), media_type="text/event-stream") + +def run_mock_server() -> None: + uvicorn.run(app, host="127.0.0.1", port=8012, log_level="warning") + +async def run_sse_test() -> None: + server_thread = Thread(target=run_mock_server, daemon=True) + server_thread.start() + await asyncio.sleep(1) + + messages: List[str] = [] + async with httpx.AsyncClient() as client: + async with aconnect_sse(client, "GET", "http://127.0.0.1:8012/sse") as event_source: + async for event in event_source.aiter_sse(): + if event.data: + print("Event received:", event.data) + messages.append(event.data) + if len(messages) == 3: + break + + if messages == ["Hello 1", "Hello 2", "Hello 3"]: + print("Test passed!") + else: + print("Test failed:", messages) + +if __name__ == "__main__": + asyncio.run(run_sse_test())