Skip to content

Commit

Permalink
tests: add unit test for propagating TracerProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Nov 13, 2024
1 parent 5cddcbe commit 0ddd907
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@
import mock
from google.cloud.spanner_v1 import DirectedReadOptions

hasOtelInstalled = False

try:
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.sampling import ALWAYS_ON
from opentelemetry import trace
hasOtelInstalled = True
except ImportError:
pass


def _make_credentials():
import google.auth.credentials
Expand Down Expand Up @@ -686,3 +700,37 @@ def test_list_instances_w_options(self):
retry=mock.ANY,
timeout=mock.ANY,
)

def test_observability_options(self):
if not hasOtelInstalled:
return

global_tracer_provider = TracerProvider(sampler=ALWAYS_ON)
trace.set_tracer_provider(global_tracer_provider)
global_trace_exporter = InMemorySpanExporter()
global_tracer_provider.add_span_processor(SimpleSpanProcessor(global_trace_exporter))

inject_tracer_provider = TracerProvider(sampler=ALWAYS_ON)
inject_trace_exporter = InMemorySpanExporter()
inject_tracer_provider.add_span_processor(SimpleSpanProcessor(inject_trace_exporter))
observability_options = dict(tracer_provder=inject_tracer_provider, enable_extended_tracing=True)
credentials = _make_credentials()
client = self._make_one(project=self.PROJECT, credentials=credentials, observability_options=observability_options)

instance = client.instance(
self.INSTANCE_ID,
self.CONFIGURATION_NAME,
display_name=self.DISPLAY_NAME,
node_count=self.NODE_COUNT,
labels=self.LABELS,
)

db = instance.database(self.DATABASE_ID, enable_interceptors_in_tests=True)
response = dict()
db.execute_sql.return_value = response
db.execute_sql('SELECT 1')

from_global_spans = global_trace_exporter.get_finished_spans()
from_inject_spans = inject_trace_exporter.get_finished_spans()
self.assertEqual(len(from_global_spans), 0, 'Expecting no spans from the global trace exporter')
self.assertEqual(len(from_global_spans) > 0, 'Expecting at least 1 span from the injected trace exporter')

0 comments on commit 0ddd907

Please sign in to comment.