diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 2227bcbd..dab57113 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -400,7 +400,7 @@ def get_telemetry_client(session_id_hex): if session_id_hex in TelemetryClientFactory._clients: return TelemetryClientFactory._clients[session_id_hex] else: - logger.error( + logger.debug( "Telemetry client not initialized for connection %s", session_id_hex, ) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py new file mode 100644 index 00000000..5593588f --- /dev/null +++ b/tests/e2e/test_concurrent_telemetry.py @@ -0,0 +1,85 @@ +import threading +from unittest.mock import patch +import pytest + +from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory +from tests.e2e.test_driver import PySQLPytestTestCase + +def run_in_threads(target, num_threads, pass_index=False): + """Helper to run target function in multiple threads.""" + threads = [ + threading.Thread(target=target, args=(i,) if pass_index else ()) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + +class TestE2ETelemetry(PySQLPytestTestCase): + + @pytest.fixture(autouse=True) + def telemetry_setup_teardown(self): + """ + This fixture ensures the TelemetryClientFactory is in a clean state + before each test and shuts it down afterward. Using a fixture makes + this robust and automatic. + """ + # --- SETUP --- + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + yield # This is where the test runs + + # --- TEARDOWN --- + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + def test_concurrent_queries_sends_telemetry(self): + """ + An E2E test where concurrent threads execute real queries against + the staging endpoint, while we capture and verify the generated telemetry. + """ + num_threads = 5 + captured_telemetry = [] + captured_telemetry_lock = threading.Lock() + captured_responses = [] + captured_responses_lock = threading.Lock() + + original_send_telemetry = TelemetryClient._send_telemetry + original_callback = TelemetryClient._telemetry_request_callback + + def send_telemetry_wrapper(self_client, events): + with captured_telemetry_lock: + captured_telemetry.extend(events) + original_send_telemetry(self_client, events) + + with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper): + + def execute_query_worker(thread_id): + """Each thread creates a connection and executes a query.""" + with self.connection(extra_params={"enable_telemetry": True}) as conn: + with conn.cursor() as cursor: + cursor.execute(f"SELECT {thread_id}") + cursor.fetchall() + + # Run the workers concurrently + run_in_threads(execute_query_worker, num_threads, pass_index=True) + + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + + # --- VERIFICATION --- + assert len(captured_telemetry) == num_threads * 3 # 4 events per thread (initial_telemetry_log, 2 latency_logs (execute, fetchall)) + + events_with_latency = [ + e for e in captured_telemetry + if e.entry.sql_driver_log.operation_latency_ms is not None and e.entry.sql_driver_log.sql_statement_id is not None + ] + assert len(events_with_latency) == num_threads * 2 # 2 events per thread (execute, fetchall) \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index fcf3fa70..f5a4b37f 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -2,13 +2,16 @@ import pytest import requests from unittest.mock import patch, MagicMock +import threading +import random +import time +from concurrent.futures import ThreadPoolExecutor from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, NoopTelemetryClient, TelemetryClientFactory, TelemetryHelper, - BaseTelemetryClient ) from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow from databricks.sql.auth.authenticators import ( @@ -283,4 +286,176 @@ def test_factory_shutdown_flow(self, telemetry_system_reset): # Close second client - factory should shut down TelemetryClientFactory.close(session2) assert TelemetryClientFactory._initialized is False - assert TelemetryClientFactory._executor is None \ No newline at end of file + assert TelemetryClientFactory._executor is None + + +# A helper function to run a target in multiple threads and wait for them. +def run_in_threads(target, num_threads, pass_index=False): + """Creates, starts, and joins a specified number of threads. + + Args: + target: The function to run in each thread + num_threads: Number of threads to create + pass_index: If True, passes the thread index (0, 1, 2, ...) as first argument + """ + threads = [ + threading.Thread(target=target, args=(i,) if pass_index else ()) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + +class TestTelemetryRaceConditions: + """Tests for race conditions in multithreaded scenarios.""" + + @pytest.fixture(autouse=True) + def clean_factory(self): + """A fixture to automatically reset the factory's state before each test.""" + # Clean up at the start of each test + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + yield + + # Clean up at the end of each test + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._clients.clear() + TelemetryClientFactory._executor = None + TelemetryClientFactory._initialized = False + + def test_factory_concurrent_initialization_of_DIFFERENT_clients(self): + """ + Tests that multiple threads creating DIFFERENT clients concurrently + share a single ThreadPoolExecutor and all clients are created successfully. + """ + num_threads = 20 + + def create_client(thread_id): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=f"session_{thread_id}", + auth_provider=None, + host_url="test-host", + ) + + run_in_threads(create_client, 20, pass_index=True) + + # ASSERT: The factory was properly initialized + assert TelemetryClientFactory._initialized is True + assert TelemetryClientFactory._executor is not None + assert isinstance(TelemetryClientFactory._executor, ThreadPoolExecutor) + + # ASSERT: All clients were successfully created + assert len(TelemetryClientFactory._clients) == num_threads + + # ASSERT: All TelemetryClient instances share the same executor + telemetry_clients = [ + client for client in TelemetryClientFactory._clients.values() + if isinstance(client, TelemetryClient) + ] + assert len(telemetry_clients) == num_threads + + shared_executor = TelemetryClientFactory._executor + for client in telemetry_clients: + assert client._executor is shared_executor + + def test_factory_concurrent_initialization_of_SAME_client(self): + """ + Tests that multiple threads trying to initialize the SAME client + result in only one client instance being created. + """ + session_id = "shared-session" + num_threads = 20 + + def create_same_client(): + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test-host", + ) + + run_in_threads(create_same_client, num_threads) + + # ASSERT: Only one client was created in the factory. + assert len(TelemetryClientFactory._clients) == 1 + client = TelemetryClientFactory.get_telemetry_client(session_id) + assert isinstance(client, TelemetryClient) + + def test_client_concurrent_event_export(self): + """ + Tests that no events are lost when multiple threads call _export_event + on the same client instance concurrently. + """ + client = TelemetryClient(True, "session-1", None, "host", MagicMock()) + # Mock _flush to prevent auto-flushing when batch size threshold is reached + original_flush = client._flush + client._flush = MagicMock() + + num_threads = 5 + events_per_thread = 10 + + def add_events(): + for i in range(events_per_thread): + client._export_event(f"event-{i}") + + run_in_threads(add_events, num_threads) + + # ASSERT: The batch contains all events from all threads, none were lost. + total_expected_events = num_threads * events_per_thread + assert len(client._events_batch) == total_expected_events + + # Restore original flush method for cleanup + client._flush = original_flush + + def test_client_concurrent_flush(self): + """ + Tests that if multiple threads trigger _flush at the same time, + the underlying send operation is only called once for the batch. + """ + client = TelemetryClient(True, "session-1", None, "host", MagicMock()) + client._send_telemetry = MagicMock() + + # Pre-fill the batch so there's something to flush + client._events_batch = ["event"] * 5 + + def call_flush(): + client._flush() + + run_in_threads(call_flush, 10) + + # ASSERT: The send operation was called exactly once. + # This proves the lock prevents multiple threads from sending the same batch. + client._send_telemetry.assert_called_once() + # ASSERT: The event batch is now empty. + assert len(client._events_batch) == 0 + + def test_factory_concurrent_create_and_close(self): + """ + Tests that concurrently creating and closing different clients + doesn't corrupt the factory state and correctly shuts down the executor. + """ + num_ops = 50 + + def create_and_close_client(i): + session_id = f"session_{i}" + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="host" + ) + # Small sleep to increase chance of interleaving operations + time.sleep(random.uniform(0, 0.01)) + TelemetryClientFactory.close(session_id) + + run_in_threads(create_and_close_client, num_ops, pass_index=True) + + # ASSERT: After all operations, the factory should be empty and reset. + assert not TelemetryClientFactory._clients + assert TelemetryClientFactory._executor is None + assert not TelemetryClientFactory._initialized \ No newline at end of file