Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Optional, TypedDict

from integration_tests.subroutes import di_subrouter, static_router, sub_router
from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, WebSocketDisconnect, jsonify, serve_file, serve_html
from robyn import Headers, Request, Response, Robyn, SSEMessage, SSEResponse, StreamingResponse, WebSocketDisconnect, jsonify, serve_file, serve_html
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
from robyn.robyn import QueryParams, Url
from robyn.templating import JinjaTemplate
Expand Down Expand Up @@ -1646,6 +1646,64 @@ def sync_pydantic_return_list(user: UserCreate) -> list[UserCreate]:
async def async_pydantic_return_list(user: UserCreate) -> list[UserCreate]:
return [user, user]

# --- Binary streaming endpoints ---


@app.get("/stream/bytes")
def stream_bytes(request):
"""Stream binary data using bytes chunks"""

def bytes_generator():
# Generate 3 chunks of known binary data
for i in range(3):
yield bytes([i] * 1024) # 1KB chunks filled with the chunk index

return StreamingResponse(
content=bytes_generator(),
media_type="application/octet-stream",
headers=Headers({"Content-Type": "application/octet-stream"}),
)


@app.get("/stream/bytes_file")
def stream_bytes_file(request):
"""Stream a file in binary mode using yield from"""
test_file = os.path.join(current_file_path, "build", "index.html")

def file_generator():
with open(test_file, "rb") as f:
while True:
chunk = f.read(512)
if not chunk:
break
yield chunk

return StreamingResponse(
content=file_generator(),
media_type="application/octet-stream",
headers=Headers(
{
"Content-Type": "application/octet-stream",
"Content-Disposition": "attachment; filename=index.html",
}
),
)


@app.get("/stream/mixed_text")
def stream_mixed_text(request):
"""Stream text data using string chunks (ensures str still works)"""

def text_generator():
for i in range(3):
yield f"text chunk {i}\n"

return StreamingResponse(
content=text_generator(),
media_type="text/plain",
headers=Headers({"Content-Type": "text/plain"}),
)


def main():
app.set_response_header("server", "robyn")
Expand Down
79 changes: 79 additions & 0 deletions integration_tests/test_binary_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os

import pytest
import requests

from integration_tests.helpers.http_methods_helpers import BASE_URL


@pytest.mark.benchmark
def test_stream_bytes_basic(session):
"""Test that binary bytes can be streamed without error"""
response = requests.get(f"{BASE_URL}/stream/bytes", stream=True, timeout=5)
assert response.status_code == 200
assert response.headers.get("Content-Type") == "application/octet-stream"

# Collect all streamed data
data = b""
for chunk in response.iter_content(chunk_size=None):
if chunk:
data += chunk

# We expect 3 chunks of 1024 bytes each
assert len(data) == 3 * 1024

# Verify chunk contents: chunk i is filled with byte value i
for i in range(3):
chunk = data[i * 1024 : (i + 1) * 1024]
assert chunk == bytes([i] * 1024), f"Chunk {i} has unexpected content"


@pytest.mark.benchmark
def test_stream_bytes_no_sse_headers(session):
"""Test that binary streaming responses do NOT include SSE-specific headers"""
response = requests.get(f"{BASE_URL}/stream/bytes", stream=True, timeout=5)
assert response.status_code == 200

# SSE-specific headers should NOT be present for binary streams
assert response.headers.get("X-Accel-Buffering") is None
assert response.headers.get("Pragma") is None
assert response.headers.get("Expires") is None


@pytest.mark.benchmark
def test_stream_bytes_file(session):
"""Test streaming a file in binary mode"""
response = requests.get(f"{BASE_URL}/stream/bytes_file", stream=True, timeout=5)
assert response.status_code == 200
assert response.headers.get("Content-Type") == "application/octet-stream"
assert "attachment" in response.headers.get("Content-Disposition", "")

# Collect all streamed data
streamed_data = b""
for chunk in response.iter_content(chunk_size=None):
if chunk:
streamed_data += chunk

# Read the original file to compare
test_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "build", "index.html")
with open(test_file, "rb") as f:
original_data = f.read()

assert streamed_data == original_data, "Streamed file content does not match original"


@pytest.mark.benchmark
def test_stream_text_still_works(session):
"""Test that string-based streaming still works after the bytes change"""
response = requests.get(f"{BASE_URL}/stream/mixed_text", stream=True, timeout=5)
assert response.status_code == 200
assert response.headers.get("Content-Type") == "text/plain"

content = b""
for chunk in response.iter_content(chunk_size=None):
if chunk:
content += chunk

text = content.decode("utf-8")
for i in range(3):
assert f"text chunk {i}" in text
12 changes: 9 additions & 3 deletions robyn/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def serve_file(file_path: str, file_name: Optional[str] = None) -> FileResponse:
class AsyncGeneratorWrapper:
"""Optimized true-streaming wrapper for async generators"""

def __init__(self, async_gen: AsyncGenerator[str, None]):
def __init__(self, async_gen: AsyncGenerator[Union[str, bytes], None]):
self.async_gen = async_gen
self._loop = None
self._iterator = None
Expand Down Expand Up @@ -124,7 +124,10 @@ async def get_next():
class StreamingResponse:
def __init__(
self,
content: Union[Generator[str, None, None], AsyncGenerator[str, None]],
content: Union[
Generator[Union[str, bytes], None, None],
AsyncGenerator[Union[str, bytes], None],
],
status_code: Optional[int] = None,
headers: Optional[Headers] = None,
media_type: str = "text/event-stream",
Expand All @@ -149,7 +152,10 @@ def __init__(


def SSEResponse(
content: Union[Generator[str, None, None], AsyncGenerator[str, None]],
content: Union[
Generator[Union[str, bytes], None, None],
AsyncGenerator[Union[str, bytes], None],
],
status_code: Optional[int] = None,
headers: Optional[Headers] = None,
) -> StreamingResponse:
Expand Down
65 changes: 44 additions & 21 deletions src/types/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct StreamingResponse {
pub status_code: u16,
pub headers: Headers,
pub content_generator: Py<PyAny>,
pub media_type: String,
}

#[derive(Debug)]
Expand Down Expand Up @@ -85,11 +86,17 @@ impl Responder for Response {
}

impl StreamingResponse {
pub fn new(status_code: u16, headers: Headers, content_generator: Py<PyAny>) -> Self {
pub fn new(
status_code: u16,
headers: Headers,
content_generator: Py<PyAny>,
media_type: String,
) -> Self {
Self {
status_code,
headers,
content_generator,
media_type,
}
}
}
Expand All @@ -104,13 +111,25 @@ impl Responder for StreamingResponse {

apply_hashmap_headers(&mut response_builder, &self.headers);

// Optimized headers for SSE streaming
response_builder
.append_header(("Connection", "keep-alive"))
.append_header(("X-Accel-Buffering", "no")) // Disable nginx buffering
.append_header(("Cache-Control", "no-cache, no-store, must-revalidate"))
.append_header(("Pragma", "no-cache"))
.append_header(("Expires", "0"));
// Only add SSE-specific headers for event-stream responses if not already present
if self.media_type == "text/event-stream" {
if !self.headers.contains("Connection".to_string()) {
response_builder.append_header(("Connection", "keep-alive"));
}
if !self.headers.contains("X-Accel-Buffering".to_string()) {
response_builder.append_header(("X-Accel-Buffering", "no")); // Disable nginx buffering
}
if !self.headers.contains("Cache-Control".to_string()) {
response_builder
.append_header(("Cache-Control", "no-cache, no-store, must-revalidate"));
}
if !self.headers.contains("Pragma".to_string()) {
response_builder.append_header(("Pragma", "no-cache"));
}
if !self.headers.contains("Expires".to_string()) {
response_builder.append_header(("Expires", "0"));
}
}

// Create the optimized stream from the Python generator
let stream = create_python_stream(self.content_generator);
Expand All @@ -129,7 +148,15 @@ fn create_python_stream(
let gen = generator.bind(py);

match gen.call_method0("__next__") {
Ok(value) => value.extract::<String>().ok().map(|s| (s, generator)),
Ok(value) => {
if let Ok(py_bytes) = value.downcast::<PyBytes>() {
Some((py_bytes.as_bytes().to_vec(), generator))
} else if let Ok(s) = value.extract::<String>() {
Some((s.into_bytes(), generator))
} else {
None
}
}
Err(e) => {
if !e.is_instance_of::<pyo3::exceptions::PyStopIteration>(py) {
log::error!("Generator error: {}", e);
Expand All @@ -141,7 +168,7 @@ fn create_python_stream(
})
.await
{
Ok(Some((string_value, generator))) => Some((Ok(Bytes::from(string_value)), generator)),
Ok(Some((data, generator))) => Some((Ok(Bytes::from(data)), generator)),
_ => None,
}
}))
Expand Down Expand Up @@ -282,7 +309,6 @@ impl PyStreamingResponse {
let mut headers = Headers::new(None);
if media_type == "text/event-stream" {
headers.set("Content-Type".to_string(), "text/event-stream".to_string());
headers.set("Cache-Control".to_string(), "no-cache".to_string());
headers.set("Connection".to_string(), "keep-alive".to_string());
} else {
// For non-SSE streaming responses, still set appropriate headers
Expand Down Expand Up @@ -443,18 +469,15 @@ impl FromPyObject<'_, '_> for StreamingResponse {
.and_then(|a| a.extract())
.unwrap_or_else(|_| "text/event-stream".to_string());

if media_type == "text/event-stream" {
headers.set("Content-Type".to_string(), "text/event-stream".to_string());
if headers.get("Cache-Control".to_string()).is_none() {
headers.set("Cache-Control".to_string(), "no-cache".to_string());
}
if headers.get("Connection".to_string()).is_none() {
headers.set("Connection".to_string(), "keep-alive".to_string());
}
}
headers.set("Content-Type".to_string(), media_type.clone());

let content: pyo3::Py<PyAny> = obj.getattr("content")?.unbind();

Ok(StreamingResponse::new(status_code, headers, content))
Ok(StreamingResponse::new(
status_code,
headers,
content,
media_type,
))
}
}
Loading