diff --git a/server.py b/server.py index 6e592d8..41f0e9b 100644 --- a/server.py +++ b/server.py @@ -3,7 +3,7 @@ import numpy as np import uvicorn from pydantic_settings import BaseSettings -from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect +from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect, HTTPException from datetime import datetime import msgpack import asyncio @@ -15,6 +15,7 @@ class Settings(BaseSettings): redis_url: str = "redis://localhost:6379/0" ttl: int = 60 * 60 # 1 hour + max_payload_size: int = 16 * 1024 * 1024 # 16MB max payload def build_app(settings: Settings): @@ -45,21 +46,20 @@ async def create(): await redis_client.setnx(f"seq_num:{node_id}", 0) return {"node_id": node_id} - @app.delete("/upload/{node_id}", status_code=204) - async def close(node_id): - "Declare that a dataset is done streaming." - - await redis_client.delete(f"seq_num:{node_id}") - # TODO: Shorten TTL on all extant data for this node. - return None - @app.post("/upload/{node_id}") async def append(node_id, request: Request): "Append data to a dataset." + # Check request body size limit + # Tell good-faith clients that their request is too big. + # Fix for: test_large_data_resource.py::test_large_data_resource_limits + headers = request.headers + content_length = headers.get("content-length") + if content_length and int(content_length) > settings.max_payload_size: + raise HTTPException(status_code=413, detail="Payload too large") + # get data from request body binary_data = await request.body() - headers = request.headers metadata = { "timestamp": datetime.now().isoformat(), } @@ -85,16 +85,23 @@ async def append(node_id, request: Request): # TODO: Implement two-way communication with subscribe, unsubscribe, flow control. # @app.websocket("/stream/many") - @app.post("/close/{node_id}") + @app.delete("/close/{node_id}") async def close_connection(node_id: str, request: Request): - # Parse the JSON body - body = await request.json() headers = request.headers - reason = body.get("reason", None) - - metadata = {"timestamp": datetime.now().isoformat(), "reason": reason} + # Check the node status. + # ttl returns -2 if the key does not exist. + # ttl returns -1 if the key exists but has no associated expire. + # ttl greater than 0 means that it is marked to expire. + node_ttl = await redis_client.ttl(f"seq_num:{node_id}") + if node_ttl > 0: + raise HTTPException(status_code=404, detail=f"Node expiring in {node_ttl} seconds") + if node_ttl == -2: + raise HTTPException(status_code=404, detail="Node not found") + + metadata = {"timestamp": datetime.now().isoformat()} metadata.setdefault("Content-Type", headers.get("Content-Type")) + # Increment the counter for this node. seq_num = await redis_client.incr(f"seq_num:{node_id}") @@ -109,12 +116,12 @@ async def close_connection(node_id: str, request: Request): }, ) pipeline.expire(f"data:{node_id}:{seq_num}", settings.ttl) + pipeline.expire(f"seq_num:{node_id}", settings.ttl) pipeline.publish(f"notify:{node_id}", seq_num) await pipeline.execute() return { "status": f"Connection for node {node_id} is now closed.", - "reason": reason, } @app.websocket("/stream/single/{node_id}") # one-way communcation @@ -124,6 +131,10 @@ async def websocket_endpoint( envelope_format: str = "json", seq_num: Optional[int] = None, ): + # Check if the node is streamable before accepting the websocket connection + if not await redis_client.exists(f"seq_num:{node_id}"): + raise HTTPException(status_code=404, detail="Node not found") + await websocket.accept( headers=[(b"x-server-host", socket.gethostname().encode())] ) diff --git a/tests/conftest.py b/tests/conftest.py index 91b9111..9432719 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,6 @@ def client(): """Fixture providing TestClient following ws-tests pattern.""" settings = Settings(redis_url="redis://localhost:6379/0", ttl=60 * 60) app = build_app(settings) - + with TestClient(app) as client: yield client - diff --git a/tests/test_close_endpoint.py b/tests/test_close_endpoint.py new file mode 100644 index 0000000..1b6926a --- /dev/null +++ b/tests/test_close_endpoint.py @@ -0,0 +1,37 @@ +""" +Tests for the close endpoint. +""" + +def test_close_connection_success(client): + """Test successful close of an existing connection.""" + # First create a node + response = client.post("/upload") + assert response.status_code == 200 + node_id = response.json()["node_id"] + + # Upload some data + response = client.post( + f"/upload/{node_id}", + content=b"test data", + headers={"Content-Type": "application/octet-stream"}, + ) + assert response.status_code == 200 + + # Now close the connection + response = client.delete(f"/close/{node_id}") + assert response.status_code == 200 + assert response.json()["status"] == f"Connection for node {node_id} is now closed." + + # Now close the connection again. + response = client.delete(f"/close/{node_id}") + assert response.status_code == 404 + + +def test_close_connection_not_found(client): + """Test close endpoint returns 404 for non-existent node.""" + non_existent_node_id = "definitely_non_existent_node_99999999" + + response = client.delete(f"/close/{non_existent_node_id}") + assert response.status_code == 404 + assert response.json()["detail"] == "Node not found" + diff --git a/tests/test_large_data_resource.py b/tests/test_large_data_resource.py new file mode 100644 index 0000000..5c0986c --- /dev/null +++ b/tests/test_large_data_resource.py @@ -0,0 +1,22 @@ +""" +Tests for large data handling and resource limit bugs. +""" + + +def test_large_data_resource_limits(client): + """Server should handle large data with proper resource limits.""" + + # Test: Huge payload (20MB) - should be rejected as too large + response = client.post("/upload") + assert response.status_code == 200 + node_id = response.json()["node_id"] + + huge_payload = b"\x00" * (20 * 1024 * 1024) # 20MB (exceeds 16MB limit) + response = client.post( + f"/upload/{node_id}", + content=huge_payload, + headers={"Content-Type": "application/octet-stream"}, + ) + # Should be rejected with 413 Payload Too Large due to size limits + assert response.status_code == 413 + assert "Payload too large" in response.json()["detail"] diff --git a/tests/test_websocket_timing.py b/tests/test_websocket_timing.py index ea6a9fb..5969a81 100644 --- a/tests/test_websocket_timing.py +++ b/tests/test_websocket_timing.py @@ -2,6 +2,16 @@ import numpy as np +def test_websocket_connection_to_non_existent_node(client): + """Test websocket connection to non-existent node returns 404.""" + non_existent_node_id = "definitely_non_existent_websocket_node_99999999" + + # Try to connect to websocket for non-existent node + # This should result in an HTTP 404 response during the handshake + response = client.get(f"/stream/single/{non_existent_node_id}") + assert response.status_code == 404 + + def test_subscribe_immediately_after_creation_websockets(client): """Client that subscribes immediately after node creation sees all updates in order.""" # Create node