diff --git a/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py b/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py index 8c608b7655..5b3890714d 100644 --- a/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py @@ -137,6 +137,7 @@ _logger = logging.getLogger(__name__) _OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory" +_OTEL_SERVER_CURSOR_FACTORY_KEY = "_otel_orig_server_cursor_factory" class PsycopgInstrumentor(BaseInstrumentor): @@ -231,9 +232,15 @@ def instrument_connection(connection, tracer_provider=None): setattr( connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory ) + setattr( + connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, connection.server_cursor_factory + ) connection.cursor_factory = _new_cursor_factory( tracer_provider=tracer_provider ) + connection.server_cursor_factory = _new_cursor_factory( + tracer_provider=tracer_provider + ) connection._is_instrumented_by_opentelemetry = True else: _logger.warning( @@ -247,6 +254,9 @@ def uninstrument_connection(connection): connection.cursor_factory = getattr( connection, _OTEL_CURSOR_FACTORY_KEY, None ) + connection.server_cursor_factory = getattr( + connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, None + ) return connection @@ -267,6 +277,11 @@ def wrapped_connection( kwargs["cursor_factory"] = _new_cursor_factory(**new_factory_kwargs) connection = connect_method(*args, **kwargs) self.get_connection_attributes(connection) + + connection.server_cursor_factory = _new_cursor_factory( + db_api=self, base_factory=getattr(connection, "server_cursor_factory", None) + ) + return connection @@ -287,6 +302,10 @@ async def wrapped_connection( ) connection = await connect_method(*args, **kwargs) self.get_connection_attributes(connection) + + connection.server_cursor_factory = _new_cursor_async_factory( + db_api=self, base_factory=getattr(connection, "server_cursor_factory", None) + ) return connection diff --git a/instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py b/instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py index 4ddaad9174..353a0511db 100644 --- a/instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py +++ b/instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py @@ -13,6 +13,7 @@ # limitations under the License. import types +from typing import Optional from unittest import IsolatedAsyncioTestCase, mock import psycopg @@ -83,10 +84,14 @@ class MockConnection: def __init__(self, *args, **kwargs): self.cursor_factory = kwargs.pop("cursor_factory", None) + self.server_cursor_factory = None - def cursor(self): - if self.cursor_factory: + def cursor(self, name: Optional[str] = None): + if not name and self.cursor_factory: return self.cursor_factory(self) + + if name and self.server_cursor_factory: + return self.server_cursor_factory(self) return MockCursor() def get_dsn_parameters(self): # pylint: disable=no-self-use @@ -102,15 +107,18 @@ class MockAsyncConnection: def __init__(self, *args, **kwargs): self.cursor_factory = kwargs.pop("cursor_factory", None) + self.server_cursor_factory = None @staticmethod async def connect(*args, **kwargs): return MockAsyncConnection(**kwargs) - def cursor(self): - if self.cursor_factory: - cur = self.cursor_factory(self) - return cur + def cursor(self, name: Optional[str] = None): + if not name and self.cursor_factory: + return self.cursor_factory(self) + + if name and self.server_cursor_factory: + return self.server_cursor_factory(self) return MockAsyncCursor() def execute(self, query, params=None, *, prepare=None, binary=False): @@ -197,6 +205,36 @@ def test_instrumentor(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) + def test_instrumentor_with_named_cursor(self): + PsycopgInstrumentor().instrument() + + cnx = psycopg.connect(database="test") + + cursor = cnx.cursor(name="named_cursor") + + query = "SELECT * FROM test" + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationScope( + span, opentelemetry.instrumentation.psycopg + ) + + # check that no spans are generated after uninstrument + PsycopgInstrumentor().uninstrument() + + cnx = psycopg.connect(database="test") + cursor = cnx.cursor(name="named_cursor") + query = "SELECT * FROM test" + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + # pylint: disable=unused-argument def test_instrumentor_with_connection_class(self): PsycopgInstrumentor().instrument() @@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) + def test_instrumentor_with_connection_class_and_named_cursor(self): + PsycopgInstrumentor().instrument() + + cnx = psycopg.Connection.connect(database="test") + + cursor = cnx.cursor(name="named_cursor") + + query = "SELECT * FROM test" + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationScope( + span, opentelemetry.instrumentation.psycopg + ) + + # check that no spans are generated after uninstrument + PsycopgInstrumentor().uninstrument() + + cnx = psycopg.Connection.connect(database="test") + cursor = cnx.cursor(name="named_cursor") + query = "SELECT * FROM test" + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + def test_span_name(self): PsycopgInstrumentor().instrument() @@ -314,6 +382,23 @@ def test_instrument_connection(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) + # pylint: disable=unused-argument + def test_instrument_connection_with_named_cursor(self): + cnx = psycopg.connect(database="test") + query = "SELECT * FROM test" + cursor = cnx.cursor(name="named_cursor") + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 0) + + cnx = PsycopgInstrumentor().instrument_connection(cnx) + cursor = cnx.cursor(name="named_cursor") + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + # pylint: disable=unused-argument def test_instrument_connection_with_instrument(self): cnx = psycopg.connect(database="test") @@ -368,6 +453,23 @@ def test_uninstrument_connection_with_instrument_connection(self): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) + def test_uninstrument_connection_with_instrument_connection_and_named_cursor(self): + cnx = psycopg.connect(database="test") + PsycopgInstrumentor().instrument_connection(cnx) + query = "SELECT * FROM test" + cursor = cnx.cursor(name="named_cursor") + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + + cnx = PsycopgInstrumentor().uninstrument_connection(cnx) + cursor = cnx.cursor(name="named_cursor") + cursor.execute(query) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + @mock.patch("opentelemetry.instrumentation.dbapi.wrap_connect") def test_sqlcommenter_enabled(self, event_mocked): cnx = psycopg.connect(database="test") @@ -419,6 +521,33 @@ async def test_async_connection(): spans_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans_list), 1) + async def test_wrap_async_connection_class_with_named_cursor(self): + PsycopgInstrumentor().instrument() + + async def test_async_connection(): + acnx = await psycopg.AsyncConnection.connect("test") + async with acnx as cnx: + async with cnx.cursor(name="named_cursor") as cursor: + await cursor.execute("SELECT * FROM test") + + await test_async_connection() + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationScope( + span, opentelemetry.instrumentation.psycopg + ) + + # check that no spans are generated after uninstrument + PsycopgInstrumentor().uninstrument() + + await test_async_connection() + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + # pylint: disable=unused-argument async def test_instrumentor_with_async_connection_class(self): PsycopgInstrumentor().instrument()