Skip to content

Testing for telemetry #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: telemetry
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def get_telemetry_client(session_id_hex):
if session_id_hex in TelemetryClientFactory._clients:
return TelemetryClientFactory._clients[session_id_hex]
else:
logger.error(
logger.debug(
"Telemetry client not initialized for connection %s",
session_id_hex,
)
Expand Down
85 changes: 85 additions & 0 deletions tests/e2e/test_concurrent_telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import threading
from unittest.mock import patch
import pytest

from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory
from tests.e2e.test_driver import PySQLPytestTestCase

def run_in_threads(target, num_threads, pass_index=False):
"""Helper to run target function in multiple threads."""
threads = [
threading.Thread(target=target, args=(i,) if pass_index else ())
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()


class TestE2ETelemetry(PySQLPytestTestCase):

@pytest.fixture(autouse=True)
def telemetry_setup_teardown(self):
"""
This fixture ensures the TelemetryClientFactory is in a clean state
before each test and shuts it down afterward. Using a fixture makes
this robust and automatic.
"""
# --- SETUP ---
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._clients.clear()
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False

yield # This is where the test runs

# --- TEARDOWN ---
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False

def test_concurrent_queries_sends_telemetry(self):
"""
An E2E test where concurrent threads execute real queries against
the staging endpoint, while we capture and verify the generated telemetry.
"""
num_threads = 5
captured_telemetry = []
captured_telemetry_lock = threading.Lock()
captured_responses = []
captured_responses_lock = threading.Lock()

original_send_telemetry = TelemetryClient._send_telemetry
original_callback = TelemetryClient._telemetry_request_callback

def send_telemetry_wrapper(self_client, events):
with captured_telemetry_lock:
captured_telemetry.extend(events)
original_send_telemetry(self_client, events)

with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper):

def execute_query_worker(thread_id):
"""Each thread creates a connection and executes a query."""
with self.connection(extra_params={"enable_telemetry": True}) as conn:
with conn.cursor() as cursor:
cursor.execute(f"SELECT {thread_id}")
cursor.fetchall()

# Run the workers concurrently
run_in_threads(execute_query_worker, num_threads, pass_index=True)

if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)

# --- VERIFICATION ---
assert len(captured_telemetry) == num_threads * 3 # 4 events per thread (initial_telemetry_log, 2 latency_logs (execute, fetchall))

events_with_latency = [
e for e in captured_telemetry
if e.entry.sql_driver_log.operation_latency_ms is not None and e.entry.sql_driver_log.sql_statement_id is not None
]
assert len(events_with_latency) == num_threads * 2 # 2 events per thread (execute, fetchall)
179 changes: 177 additions & 2 deletions tests/unit/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import pytest
import requests
from unittest.mock import patch, MagicMock
import threading
import random
import time
from concurrent.futures import ThreadPoolExecutor

from databricks.sql.telemetry.telemetry_client import (
TelemetryClient,
NoopTelemetryClient,
TelemetryClientFactory,
TelemetryHelper,
BaseTelemetryClient
)
from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow
from databricks.sql.auth.authenticators import (
Expand Down Expand Up @@ -283,4 +286,176 @@ def test_factory_shutdown_flow(self, telemetry_system_reset):
# Close second client - factory should shut down
TelemetryClientFactory.close(session2)
assert TelemetryClientFactory._initialized is False
assert TelemetryClientFactory._executor is None
assert TelemetryClientFactory._executor is None


# A helper function to run a target in multiple threads and wait for them.
def run_in_threads(target, num_threads, pass_index=False):
"""Creates, starts, and joins a specified number of threads.

Args:
target: The function to run in each thread
num_threads: Number of threads to create
pass_index: If True, passes the thread index (0, 1, 2, ...) as first argument
"""
threads = [
threading.Thread(target=target, args=(i,) if pass_index else ())
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add these in a separate file? @jprakash-db what's the sop in python?

class TestTelemetryRaceConditions:
"""Tests for race conditions in multithreaded scenarios."""

@pytest.fixture(autouse=True)
def clean_factory(self):
"""A fixture to automatically reset the factory's state before each test."""
# Clean up at the start of each test
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._clients.clear()
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False

yield

# Clean up at the end of each test
if TelemetryClientFactory._executor:
TelemetryClientFactory._executor.shutdown(wait=True)
TelemetryClientFactory._clients.clear()
TelemetryClientFactory._executor = None
TelemetryClientFactory._initialized = False

def test_factory_concurrent_initialization_of_DIFFERENT_clients(self):
"""
Tests that multiple threads creating DIFFERENT clients concurrently
share a single ThreadPoolExecutor and all clients are created successfully.
"""
num_threads = 20

def create_client(thread_id):
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True,
session_id_hex=f"session_{thread_id}",
auth_provider=None,
host_url="test-host",
)

run_in_threads(create_client, 20, pass_index=True)

# ASSERT: The factory was properly initialized
assert TelemetryClientFactory._initialized is True
assert TelemetryClientFactory._executor is not None
assert isinstance(TelemetryClientFactory._executor, ThreadPoolExecutor)

# ASSERT: All clients were successfully created
assert len(TelemetryClientFactory._clients) == num_threads

# ASSERT: All TelemetryClient instances share the same executor
telemetry_clients = [
client for client in TelemetryClientFactory._clients.values()
if isinstance(client, TelemetryClient)
]
assert len(telemetry_clients) == num_threads

shared_executor = TelemetryClientFactory._executor
for client in telemetry_clients:
assert client._executor is shared_executor

def test_factory_concurrent_initialization_of_SAME_client(self):
"""
Tests that multiple threads trying to initialize the SAME client
result in only one client instance being created.
"""
session_id = "shared-session"
num_threads = 20

def create_same_client():
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True,
session_id_hex=session_id,
auth_provider=None,
host_url="test-host",
)

run_in_threads(create_same_client, num_threads)

# ASSERT: Only one client was created in the factory.
assert len(TelemetryClientFactory._clients) == 1
client = TelemetryClientFactory.get_telemetry_client(session_id)
assert isinstance(client, TelemetryClient)

def test_client_concurrent_event_export(self):
"""
Tests that no events are lost when multiple threads call _export_event
on the same client instance concurrently.
"""
client = TelemetryClient(True, "session-1", None, "host", MagicMock())
# Mock _flush to prevent auto-flushing when batch size threshold is reached
original_flush = client._flush
client._flush = MagicMock()

num_threads = 5
events_per_thread = 10

def add_events():
for i in range(events_per_thread):
client._export_event(f"event-{i}")

run_in_threads(add_events, num_threads)

# ASSERT: The batch contains all events from all threads, none were lost.
total_expected_events = num_threads * events_per_thread
assert len(client._events_batch) == total_expected_events

# Restore original flush method for cleanup
client._flush = original_flush

def test_client_concurrent_flush(self):
"""
Tests that if multiple threads trigger _flush at the same time,
the underlying send operation is only called once for the batch.
"""
client = TelemetryClient(True, "session-1", None, "host", MagicMock())
client._send_telemetry = MagicMock()

# Pre-fill the batch so there's something to flush
client._events_batch = ["event"] * 5

def call_flush():
client._flush()

run_in_threads(call_flush, 10)

# ASSERT: The send operation was called exactly once.
# This proves the lock prevents multiple threads from sending the same batch.
client._send_telemetry.assert_called_once()
# ASSERT: The event batch is now empty.
assert len(client._events_batch) == 0

def test_factory_concurrent_create_and_close(self):
"""
Tests that concurrently creating and closing different clients
doesn't corrupt the factory state and correctly shuts down the executor.
"""
num_ops = 50

def create_and_close_client(i):
session_id = f"session_{i}"
TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="host"
)
# Small sleep to increase chance of interleaving operations
time.sleep(random.uniform(0, 0.01))
TelemetryClientFactory.close(session_id)

run_in_threads(create_and_close_client, num_ops, pass_index=True)

# ASSERT: After all operations, the factory should be empty and reset.
assert not TelemetryClientFactory._clients
assert TelemetryClientFactory._executor is None
assert not TelemetryClientFactory._initialized
Loading