Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions aiokafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,12 @@ def send(self, request, expect_response=True):
f"No connection to broker at {self._host}:{self._port}"
)

if self._writer.is_closing():
self.close(reason=CloseReason.CONNECTION_BROKEN)
raise Errors.KafkaConnectionError(
f"Connection at {self._host}:{self._port} is closing"
)

correlation_id = self._next_correlation_id()
header = request.build_request_header(
correlation_id=correlation_id, client_id=self._client_id
Expand Down
1 change: 1 addition & 0 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Pygments==2.18.0
gssapi==1.9.0
async-timeout==4.0.3
cramjam==2.9.0
uvloop==0.21.0
89 changes: 84 additions & 5 deletions tests/test_conn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import gc
import socket
import struct
import sys
from collections.abc import AsyncIterable, Iterable
from typing import Any
from unittest import mock

import pytest
import pytest_asyncio

from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn
from aiokafka.errors import (
Expand Down Expand Up @@ -144,7 +148,7 @@
with self.assertRaises(KafkaConnectionError):
await conn.send(request)

conn._writer = mock.MagicMock()
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
conn._writer.write.side_effect = OSError("mocked writer is closed")

with self.assertRaises(KafkaConnectionError):
Expand Down Expand Up @@ -173,7 +177,7 @@
return resp

reader.readexactly.side_effect = [first_resp(), second_resp()]
writer = mock.MagicMock()
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))

conn._reader = reader
conn._writer = writer
Expand Down Expand Up @@ -208,7 +212,7 @@
return resp

reader.readexactly.side_effect = [first_resp(), second_resp()]
writer = mock.MagicMock()
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))

conn._reader = reader
conn._writer = writer
Expand Down Expand Up @@ -237,7 +241,7 @@
# setup reader
reader = mock.MagicMock()
reader.readexactly.return_value = invoke_osserror()
writer = mock.MagicMock()
writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))

conn._reader = reader
conn._writer = writer
Expand Down Expand Up @@ -394,7 +398,7 @@
# setup connection with mocked transport and protocol
conn = AIOKafkaConnection(host="", port=9999)
conn.close = mock.MagicMock()
conn._writer = mock.MagicMock()
conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False))
out_buffer = []
conn._writer.write = mock.Mock(side_effect=out_buffer.append)
conn._reader = mock.MagicMock()
Expand Down Expand Up @@ -424,3 +428,78 @@
conn._send_sasl_token(b"Super data")
# We don't need to close 2ce
self.assertEqual(conn.close.call_count, 1)


@pytest.mark.skipif(sys.platform == "win32", reason="Uvloop doesn't support Windows")
class TestClosedSocket:
@pytest.fixture(
params=(
pytest.param("asyncio", id="asyncio"),
pytest.param("uvloop", id="uvloop"),
),
)
def event_loop_policy(
self, request: pytest.FixtureRequest
) -> Iterable[asyncio.AbstractEventLoopPolicy]:
if request.param == "asyncio":
policy = asyncio.DefaultEventLoopPolicy()
elif request.param == "uvloop":
import uvloop

policy = uvloop.EventLoopPolicy()
else:
raise ValueError(f"loop {request.param} is not supported")

Check warning on line 451 in tests/test_conn.py

View check run for this annotation

Codecov / codecov/patch

tests/test_conn.py#L451

Added line #L451 was not covered by tests

yield policy

@pytest.fixture()
def server(self, unused_tcp_port: int) -> Iterable[tuple[str, int, socket.socket]]:
host = "localhost"
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((host, unused_tcp_port))
sock.listen(8)
sock.setblocking(False)

yield host, unused_tcp_port, sock

sock.close()

@pytest_asyncio.fixture()
async def conn(
self, server: tuple[str, int, socket.socket]
) -> AsyncIterable[AIOKafkaConnection]:
host, port, _ = server

conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000)
conn._create_reader_task = mock.Mock()

yield conn

fut = conn.close()
if fut:
await fut

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.

@pytest.mark.asyncio
async def test_send_to_closed_socket(
self, server: tuple[str, int, socket.socket], conn: AIOKafkaConnection
) -> None:
host, port, sock = server

request = MetadataRequest([])

with pytest.raises(
KafkaConnectionError,
match=f"KafkaConnectionError: No connection to broker at {host}:{port}",
):
await conn.send(request)

await conn.connect()

sock.close()
await asyncio.sleep(0.1)

with pytest.raises(
KafkaConnectionError,
match=f"KafkaConnectionError: Connection at {host}:{port} is closing",
):
await conn.send(request)
Loading