diff --git a/README.md b/README.md index dd29393..4ab55ce 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ Implemented dataclasses (in `types.py`): - `GenAI` - base class - `LLMInvocation` - `EmbeddingInvocation` +- `RetrievalInvocation` - `Workflow` - `AgentInvocation` - `Step` diff --git a/util/opentelemetry-util-genai/CHANGELOG.md b/util/opentelemetry-util-genai/CHANGELOG.md index 0f94723..03e0065 100644 --- a/util/opentelemetry-util-genai/CHANGELOG.md +++ b/util/opentelemetry-util-genai/CHANGELOG.md @@ -2,6 +2,10 @@ All notable changes to this repository are documented in this file. +## Unreleased + +- Added `RetrievalInvocation` type to support retrieval/search operations in GenAI workflows + ## Version 0.1.4 - 2025-11-07 - Initial 0.1.4 release of splunk-otel-util-genai diff --git a/util/opentelemetry-util-genai/examples/retrievals_example.py b/util/opentelemetry-util-genai/examples/retrievals_example.py new file mode 100644 index 0000000..fbbbe5c --- /dev/null +++ b/util/opentelemetry-util-genai/examples/retrievals_example.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +"""Example demonstrating OpenTelemetry GenAI telemetry for retrieval operations. + +This example shows: +1. Basic retrieval invocation lifecycle +2. Retrieval with vector search +3. Retrieval with text query and metadata +4. Retrieval with custom attributes +5. Error handling for retrieval operations +6. Retrieval with agent context +7. Metrics and span emission for retrievals +""" + +import time + +from opentelemetry import _logs as logs +from opentelemetry import trace +from opentelemetry.sdk._logs import LoggerProvider +from opentelemetry.sdk._logs.export import ( + ConsoleLogExporter, + SimpleLogRecordProcessor, +) +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import ( + ConsoleMetricExporter, + PeriodicExportingMetricReader, +) +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + ConsoleSpanExporter, + SimpleSpanProcessor, +) +from opentelemetry.util.genai.handler import get_telemetry_handler +from opentelemetry.util.genai.types import Error, RetrievalInvocation + + +def setup_telemetry(): + """Set up OpenTelemetry providers for tracing, metrics, and logging.""" + # Set up tracing + trace_provider = TracerProvider() + trace_provider.add_span_processor( + SimpleSpanProcessor(ConsoleSpanExporter()) + ) + trace.set_tracer_provider(trace_provider) + + # Set up metrics + metric_reader = PeriodicExportingMetricReader( + ConsoleMetricExporter(), export_interval_millis=5000 + ) + meter_provider = MeterProvider(metric_readers=[metric_reader]) + + # Set up logging (for events) + logger_provider = LoggerProvider() + logger_provider.add_log_record_processor( + SimpleLogRecordProcessor(ConsoleLogExporter()) + ) + logs.set_logger_provider(logger_provider) + + return trace_provider, meter_provider, logger_provider + + +def example_basic_retrieval(): + """Example 1: Basic retrieval invocation with text query.""" + print("\n" + "=" * 60) + print("Example 1: Basic Retrieval Invocation") + print("=" * 60) + + handler = get_telemetry_handler() + + # Create retrieval invocation + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="What is OpenTelemetry?", + top_k=5, + retriever_type="vector_store", + provider="pinecone", + ) + + # Start the retrieval operation + handler.start_retrieval(retrieval) + time.sleep(0.05) # Simulate API call + + # Simulate response - populate results + retrieval.documents_retrieved = 5 + retrieval.results = [ + {"id": "doc1", "score": 0.95, "content": "OpenTelemetry is..."}, + {"id": "doc2", "score": 0.89, "content": "OTEL provides..."}, + {"id": "doc3", "score": 0.85, "content": "Observability with..."}, + {"id": "doc4", "score": 0.82, "content": "Tracing and metrics..."}, + {"id": "doc5", "score": 0.78, "content": "Distributed tracing..."}, + ] + + # Finish the retrieval operation + handler.stop_retrieval(retrieval) + + print("✓ Completed retrieval for text query") + print(f" Query: {retrieval.query}") + print(f" Documents retrieved: {retrieval.documents_retrieved}") + print(f" Provider: {retrieval.provider}") + + +def example_vector_search(): + """Example 2: Retrieval with vector search.""" + print("\n" + "=" * 60) + print("Example 2: Vector Search Retrieval") + print("=" * 60) + + handler = get_telemetry_handler() + + # Create retrieval with query vector + query_vector = [0.1, 0.2, 0.3, 0.4, 0.5] * 100 # 500-dim vector + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query_vector=query_vector, + top_k=10, + retriever_type="vector_store", + provider="chroma", + framework="langchain", + ) + + # Start the retrieval operation + handler.start_retrieval(retrieval) + time.sleep(0.08) # Simulate API call + + # Simulate response + retrieval.documents_retrieved = 10 + retrieval.results = [ + {"id": f"doc{i}", "score": 0.95 - i * 0.05} for i in range(10) + ] + + # Finish the retrieval operation + handler.stop_retrieval(retrieval) + + print("✓ Completed vector search retrieval") + print(f" Vector dimensions: {len(query_vector)}") + print(f" Documents retrieved: {retrieval.documents_retrieved}") + print(f" Framework: {retrieval.framework}") + + +def example_retrieval_with_metadata(): + """Example 3: Retrieval with result metadata.""" + print("\n" + "=" * 60) + print("Example 3: Retrieval with Metadata") + print("=" * 60) + + handler = get_telemetry_handler() + + # Create retrieval + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="machine learning tutorials", + top_k=3, + retriever_type="hybrid_search", + provider="weaviate", + framework="langchain", + ) + + # Start the retrieval operation + handler.start_retrieval(retrieval) + time.sleep(0.06) # Simulate API call + + # Simulate response with metadata + retrieval.documents_retrieved = 3 + retrieval.results = [ + { + "id": "tut1", + "score": 0.92, + "content": "Intro to ML", + "metadata": {"category": "tutorial", "difficulty": "beginner"}, + }, + { + "id": "tut2", + "score": 0.88, + "content": "Python ML basics", + "metadata": {"category": "tutorial", "difficulty": "beginner"}, + }, + { + "id": "tut3", + "score": 0.85, + "content": "Getting started with ML", + "metadata": {"category": "tutorial", "difficulty": "beginner"}, + }, + ] + + # Finish the retrieval operation + handler.stop_retrieval(retrieval) + + print("✓ Completed retrieval with metadata") + print(f" Query: {retrieval.query}") + print(f" Retriever type: {retrieval.retriever_type}") + print(f" Documents retrieved: {retrieval.documents_retrieved}") + + +def example_retrieval_with_custom_attributes(): + """Example 4: Retrieval with custom attributes.""" + print("\n" + "=" * 60) + print("Example 4: Retrieval with Custom Attributes") + print("=" * 60) + + handler = get_telemetry_handler() + + # Create retrieval with custom attributes + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="customer support documentation", + top_k=5, + retriever_type="semantic_search", + provider="qdrant", + attributes={ + "collection_name": "support_docs", + "user_id": "user-789", + "session_id": "session-456", + "search_type": "semantic", + }, + ) + + # Start the retrieval operation + handler.start_retrieval(retrieval) + time.sleep(0.07) # Simulate API call + + # Simulate response + retrieval.documents_retrieved = 5 + + # Finish the retrieval operation + handler.stop_retrieval(retrieval) + + print("✓ Completed retrieval with custom attributes") + print(f" Query: {retrieval.query}") + print(f" Custom attributes: {retrieval.attributes}") + + +def example_retrieval_with_agent_context(): + """Example 5: Retrieval within an agent context.""" + print("\n" + "=" * 60) + print("Example 5: Retrieval with Agent Context") + print("=" * 60) + + handler = get_telemetry_handler() + + # Create retrieval with agent context + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="latest product updates", + top_k=7, + retriever_type="vector_store", + provider="milvus", + framework="langchain", + agent_name="product_assistant", + agent_id="agent-123", + ) + + # Start the retrieval operation + handler.start_retrieval(retrieval) + time.sleep(0.05) # Simulate API call + + # Simulate response + retrieval.documents_retrieved = 7 + + # Finish the retrieval operation + handler.stop_retrieval(retrieval) + + print("✓ Completed retrieval with agent context") + print(f" Agent: {retrieval.agent_name} (ID: {retrieval.agent_id})") + print(f" Query: {retrieval.query}") + print(f" Documents retrieved: {retrieval.documents_retrieved}") + + +def example_retrieval_error(): + """Example 6: Handling retrieval errors.""" + print("\n" + "=" * 60) + print("Example 6: Retrieval Error Handling") + print("=" * 60) + + handler = get_telemetry_handler() + + # Create retrieval invocation + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="test query", + top_k=5, + retriever_type="vector_store", + provider="pinecone", + ) + + # Start the retrieval operation + handler.start_retrieval(retrieval) + time.sleep(0.03) # Simulate API call + + # Simulate an error + error = Error( + message="Connection timeout to vector store", + type=TimeoutError, + ) + + # Fail the retrieval operation + handler.fail_retrieval(retrieval, error) + + print("✗ Retrieval failed with error") + print(f" Error: {error.message}") + print(f" Provider: {retrieval.provider}") + + +def example_multiple_retrievals(): + """Example 7: Multiple sequential retrievals.""" + print("\n" + "=" * 60) + print("Example 7: Multiple Sequential Retrievals") + print("=" * 60) + + handler = get_telemetry_handler() + + queries = [ + "What is machine learning?", + "How does deep learning work?", + "Explain neural networks", + ] + + for idx, query_text in enumerate(queries, 1): + retrieval = RetrievalInvocation( + operation_name="retrieval", + query=query_text, + top_k=5, + retriever_type="vector_store", + provider="pinecone", + attributes={"query_index": idx}, + ) + + handler.start_retrieval(retrieval) + time.sleep(0.04) # Simulate API call + + # Simulate response + retrieval.documents_retrieved = 5 + + handler.stop_retrieval(retrieval) + print(f" ✓ Completed retrieval {idx}/{len(queries)}") + + print(f"✓ Completed all {len(queries)} retrievals") + + +def example_hybrid_retrieval(): + """Example 8: Hybrid retrieval combining text and vector search.""" + print("\n" + "=" * 60) + print("Example 8: Hybrid Retrieval") + print("=" * 60) + + handler = get_telemetry_handler() + + # Create hybrid retrieval with both query and vector + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="artificial intelligence applications", + query_vector=[0.2] * 768, # 768-dim vector + top_k=8, + retriever_type="hybrid_search", + provider="elasticsearch", + framework="langchain", + attributes={ + "alpha": 0.5, # Balance between text and vector search + "boost_query": True, + }, + ) + + # Start the retrieval operation + handler.start_retrieval(retrieval) + time.sleep(0.09) # Simulate API call + + # Simulate response + retrieval.documents_retrieved = 8 + retrieval.results = [ + {"id": f"doc{i}", "score": 0.9 - i * 0.05, "type": "hybrid"} + for i in range(8) + ] + + # Finish the retrieval operation + handler.stop_retrieval(retrieval) + + print("✓ Completed hybrid retrieval") + print(f" Query: {retrieval.query}") + print(f" Vector dimensions: {len(retrieval.query_vector)}") + print(f" Retriever type: {retrieval.retriever_type}") + print(f" Documents retrieved: {retrieval.documents_retrieved}") + + +def main(): + """Run all retrieval examples.""" + print("\n" + "=" * 60) + print("OpenTelemetry GenAI Retrievals Examples") + print("=" * 60) + + # Set up telemetry + trace_provider, meter_provider, logger_provider = setup_telemetry() + + # Run examples + example_basic_retrieval() + example_vector_search() + example_retrieval_with_metadata() + example_retrieval_with_custom_attributes() + example_retrieval_with_agent_context() + example_retrieval_error() + example_multiple_retrievals() + example_hybrid_retrieval() + + # Force flush to ensure all telemetry is exported + print("\n" + "=" * 60) + print("Flushing telemetry data...") + print("=" * 60) + trace_provider.force_flush() + meter_provider.force_flush() + logger_provider.force_flush() + + print("\n✓ All examples completed successfully!") + print("Check the console output above for spans, metrics, and events.\n") + + +if __name__ == "__main__": + main() diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/attributes.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/attributes.py index 3c64da6..a496582 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/attributes.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/attributes.py @@ -55,6 +55,13 @@ GEN_AI_EMBEDDINGS_INPUT_TEXTS = "gen_ai.embeddings.input.texts" GEN_AI_REQUEST_ENCODING_FORMATS = "gen_ai.request.encoding_formats" +# Retrieval attributes +GEN_AI_RETRIEVAL_TYPE = "gen_ai.retrieval.type" +GEN_AI_RETRIEVAL_QUERY_TEXT = "gen_ai.retrieval.query.text" +GEN_AI_RETRIEVAL_TOP_K = "gen_ai.retrieval.top_k" +GEN_AI_RETRIEVAL_DOCUMENTS_RETRIEVED = "gen_ai.retrieval.documents_retrieved" +GEN_AI_RETRIEVAL_DOCUMENTS = "gen_ai.retrieval.documents" + # Server attributes (from semantic conventions) SERVER_ADDRESS = "server.address" SERVER_PORT = "server.port" diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py index f5291c1..c2f7ceb 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/metrics.py @@ -18,6 +18,7 @@ EmbeddingInvocation, Error, LLMInvocation, + RetrievalInvocation, ToolCall, Workflow, ) @@ -50,6 +51,9 @@ def __init__(self, meter: Optional[Meter] = None): self._agent_duration_histogram: Histogram = ( instruments.agent_duration_histogram ) + self._retrieval_duration_histogram: Histogram = ( + instruments.retrieval_duration_histogram + ) def on_start(self, obj: Any) -> None: # no-op for metrics return None @@ -146,6 +150,9 @@ def on_end(self, obj: Any) -> None: span=getattr(embedding_invocation, "span", None), ) + if isinstance(obj, RetrievalInvocation): + self._record_retrieval_metrics(obj) + def on_error(self, error: Error, obj: Any) -> None: # Handle new agentic types if isinstance(obj, Workflow): @@ -242,6 +249,9 @@ def on_error(self, error: Error, obj: Any) -> None: span=getattr(embedding_invocation, "span", None), ) + if isinstance(obj, RetrievalInvocation): + self._record_retrieval_metrics(obj, error) + def handles(self, obj: Any) -> bool: return isinstance( obj, @@ -251,6 +261,7 @@ def handles(self, obj: Any) -> bool: Workflow, AgentInvocation, EmbeddingInvocation, + RetrievalInvocation, ), ) @@ -306,3 +317,40 @@ def _record_agent_metrics(self, agent: AgentInvocation) -> None: self._agent_duration_histogram.record( duration, attributes=metric_attrs, context=context ) + + def _record_retrieval_metrics( + self, retrieval: RetrievalInvocation, error: Optional[Error] = None + ) -> None: + """Record metrics for a retrieval operation.""" + if retrieval.end_time is None: + return + duration = retrieval.end_time - retrieval.start_time + metric_attrs = { + GenAI.GEN_AI_OPERATION_NAME: retrieval.operation_name, + } + if retrieval.retriever_type: + metric_attrs["gen_ai.retrieval.type"] = retrieval.retriever_type + if retrieval.framework: + metric_attrs["gen_ai.framework"] = retrieval.framework + if retrieval.provider: + metric_attrs[GenAI.GEN_AI_PROVIDER_NAME] = retrieval.provider + # Add agent context if available + if retrieval.agent_name: + metric_attrs[GenAI.GEN_AI_AGENT_NAME] = retrieval.agent_name + if retrieval.agent_id: + metric_attrs[GenAI.GEN_AI_AGENT_ID] = retrieval.agent_id + # Add error type if present + if error is not None and getattr(error, "type", None) is not None: + metric_attrs[ErrorAttributes.ERROR_TYPE] = error.type.__qualname__ + + context = None + span = getattr(retrieval, "span", None) + if span is not None: + try: + context = trace.set_span_in_context(span) + except (ValueError, RuntimeError): # pragma: no cover - defensive + context = None + + self._retrieval_duration_histogram.record( + duration, attributes=metric_attrs, context=context + ) diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/span.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/span.py index 5d1eab0..d69a9b3 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/span.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/emitters/span.py @@ -26,6 +26,9 @@ GEN_AI_OUTPUT_MESSAGES, GEN_AI_PROVIDER_NAME, GEN_AI_REQUEST_ENCODING_FORMATS, + GEN_AI_RETRIEVAL_DOCUMENTS_RETRIEVED, + GEN_AI_RETRIEVAL_QUERY_TEXT, + GEN_AI_RETRIEVAL_TOP_K, GEN_AI_STEP_ASSIGNED_AGENT, GEN_AI_STEP_NAME, GEN_AI_STEP_OBJECTIVE, @@ -47,6 +50,7 @@ EmbeddingInvocation, Error, LLMInvocation, + RetrievalInvocation, Step, ToolCall, Workflow, @@ -201,9 +205,10 @@ def _apply_start_attrs(self, invocation: GenAIType): provider = getattr(invocation, "provider", None) if provider: span.set_attribute(GEN_AI_PROVIDER_NAME, provider) - # framework (named field) - if isinstance(invocation, LLMInvocation) and invocation.framework: - span.set_attribute("gen_ai.framework", invocation.framework) + # framework (named field) - applies to all invocation types + framework = getattr(invocation, "framework", None) + if framework: + span.set_attribute("gen_ai.framework", framework) # function definitions (semantic conv derived from structured list) if isinstance(invocation, LLMInvocation): _apply_function_definitions(span, invocation.request_functions) @@ -302,6 +307,8 @@ def on_start( self._apply_start_attrs(invocation) elif isinstance(invocation, EmbeddingInvocation): self._start_embedding(invocation) + elif isinstance(invocation, RetrievalInvocation): + self._start_retrieval(invocation) else: # Use operation field for span name (defaults to "chat") operation = getattr(invocation, "operation", "chat") @@ -335,6 +342,8 @@ def on_end(self, invocation: LLMInvocation | EmbeddingInvocation) -> None: self._finish_step(invocation) elif isinstance(invocation, EmbeddingInvocation): self._finish_embedding(invocation) + elif isinstance(invocation, RetrievalInvocation): + self._finish_retrieval(invocation) else: span = getattr(invocation, "span", None) if span is None: @@ -359,6 +368,8 @@ def on_error( self._error_step(error, invocation) elif isinstance(invocation, EmbeddingInvocation): self._error_embedding(error, invocation) + elif isinstance(invocation, RetrievalInvocation): + self._error_retrieval(error, invocation) else: span = getattr(invocation, "span", None) if span is None: @@ -771,3 +782,79 @@ def _error_embedding( token.__exit__(None, None, None) # type: ignore[misc] except Exception: pass + + # ---- Retrieval lifecycle --------------------------------------------- + def _start_retrieval(self, retrieval: RetrievalInvocation) -> None: + """Start a retrieval span.""" + span_name = f"{retrieval.operation_name}" + if retrieval.provider: + span_name = f"{retrieval.operation_name} {retrieval.provider}" + parent_span = getattr(retrieval, "parent_span", None) + parent_ctx = ( + trace.set_span_in_context(parent_span) + if parent_span is not None + else None + ) + cm = self._tracer.start_as_current_span( + span_name, + kind=SpanKind.CLIENT, + end_on_exit=False, + context=parent_ctx, + ) + span = cm.__enter__() + self._attach_span(retrieval, span, cm) + self._apply_start_attrs(retrieval) + + # Set retrieval-specific start attributes + if retrieval.server_address: + span.set_attribute(SERVER_ADDRESS, retrieval.server_address) + if retrieval.server_port: + span.set_attribute(SERVER_PORT, retrieval.server_port) + if retrieval.top_k is not None: + span.set_attribute(GEN_AI_RETRIEVAL_TOP_K, retrieval.top_k) + if self._capture_content and retrieval.query: + span.set_attribute(GEN_AI_RETRIEVAL_QUERY_TEXT, retrieval.query) + + def _finish_retrieval(self, retrieval: RetrievalInvocation) -> None: + """Finish a retrieval span.""" + span = retrieval.span + if span is None: + return + # Apply finish-time semantic conventions + if retrieval.documents_retrieved is not None: + span.set_attribute( + GEN_AI_RETRIEVAL_DOCUMENTS_RETRIEVED, + retrieval.documents_retrieved, + ) + token = retrieval.context_token + if token is not None and hasattr(token, "__exit__"): + try: + token.__exit__(None, None, None) # type: ignore[misc] + except Exception: + pass + span.end() + + def _error_retrieval( + self, error: Error, retrieval: RetrievalInvocation + ) -> None: + """Fail a retrieval span with error status.""" + span = retrieval.span + if span is None: + return + span.set_status(Status(StatusCode.ERROR, error.message)) + if span.is_recording(): + span.set_attribute( + ErrorAttributes.ERROR_TYPE, error.type.__qualname__ + ) + # Set error type from invocation if available + if retrieval.error_type: + span.set_attribute( + ErrorAttributes.ERROR_TYPE, retrieval.error_type + ) + token = retrieval.context_token + if token is not None and hasattr(token, "__exit__"): + try: + token.__exit__(None, None, None) # type: ignore[misc] + except Exception: + pass + span.end() diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py index b71b1ac..544593e 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/handler.py @@ -91,6 +91,7 @@ def genai_debug_log(*_args: Any, **_kwargs: Any) -> None: # type: ignore EvaluationResult, GenAI, LLMInvocation, + RetrievalInvocation, Step, ToolCall, Workflow, @@ -475,6 +476,70 @@ def fail_embedding( pass return invocation + def start_retrieval( + self, invocation: RetrievalInvocation + ) -> RetrievalInvocation: + """Start a retrieval invocation and create a pending span entry.""" + self._refresh_capture_content() + if ( + not invocation.agent_name or not invocation.agent_id + ) and self._agent_context_stack: + top_name, top_id = self._agent_context_stack[-1] + if not invocation.agent_name: + invocation.agent_name = top_name + if not invocation.agent_id: + invocation.agent_id = top_id + invocation.start_time = time.time() + self._emitter.on_start(invocation) + span = getattr(invocation, "span", None) + if span is not None: + self._span_registry[str(invocation.run_id)] = span + self._entity_registry[str(invocation.run_id)] = invocation + return invocation + + def stop_retrieval( + self, invocation: RetrievalInvocation + ) -> RetrievalInvocation: + """Finalize a retrieval invocation successfully and end its span.""" + invocation.end_time = time.time() + + # Determine if this invocation should be sampled for evaluation + invocation.sample_for_evaluation = self._should_sample_for_evaluation( + invocation.trace_id + ) + + self._emitter.on_end(invocation) + self._notify_completion(invocation) + self._entity_registry.pop(str(invocation.run_id), None) + # Force flush metrics if a custom provider with force_flush is present + if ( + hasattr(self, "_meter_provider") + and self._meter_provider is not None + ): + try: # pragma: no cover + self._meter_provider.force_flush() # type: ignore[attr-defined] + except Exception: + pass + return invocation + + def fail_retrieval( + self, invocation: RetrievalInvocation, error: Error + ) -> RetrievalInvocation: + """Fail a retrieval invocation and end its span with error status.""" + invocation.end_time = time.time() + self._emitter.on_error(error, invocation) + self._notify_completion(invocation) + self._entity_registry.pop(str(invocation.run_id), None) + if ( + hasattr(self, "_meter_provider") + and self._meter_provider is not None + ): + try: # pragma: no cover + self._meter_provider.force_flush() # type: ignore[attr-defined] + except Exception: + pass + return invocation + # ToolCall lifecycle -------------------------------------------------- def start_tool_call(self, invocation: ToolCall) -> ToolCall: """Start a tool call invocation and create a pending span entry.""" @@ -880,6 +945,8 @@ def start(self, obj: Any) -> Any: return self.start_llm(obj) if isinstance(obj, EmbeddingInvocation): return self.start_embedding(obj) + if isinstance(obj, RetrievalInvocation): + return self.start_retrieval(obj) if isinstance(obj, ToolCall): return self.start_tool_call(obj) return obj @@ -960,6 +1027,8 @@ def finish(self, obj: Any) -> Any: return self.stop_llm(obj) if isinstance(obj, EmbeddingInvocation): return self.stop_embedding(obj) + if isinstance(obj, RetrievalInvocation): + return self.stop_retrieval(obj) if isinstance(obj, ToolCall): return self.stop_tool_call(obj) return obj @@ -976,6 +1045,8 @@ def fail(self, obj: Any, error: Error) -> Any: return self.fail_llm(obj, error) if isinstance(obj, EmbeddingInvocation): return self.fail_embedding(obj, error) + if isinstance(obj, RetrievalInvocation): + return self.fail_retrieval(obj, error) if isinstance(obj, ToolCall): return self.fail_tool_call(obj, error) return obj diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/instruments.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/instruments.py index fd7381c..9d88d62 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/instruments.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/instruments.py @@ -42,3 +42,8 @@ def __init__(self, meter: Meter): unit="s", description="Duration of agent operations", ) + self.retrieval_duration_histogram: Histogram = meter.create_histogram( + name="gen_ai.retrieval.duration", + unit="s", + description="Duration of retrieval operations", + ) diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/types.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/types.py index 12424b7..8e6b8a9 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/types.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/types.py @@ -318,6 +318,51 @@ class EmbeddingInvocation(GenAI): error_type: Optional[str] = None +@dataclass +class RetrievalInvocation(GenAI): + """Represents a single retrieval/search invocation.""" + + # Required attribute + operation_name: str = field( + default="retrieval", + metadata={"semconv": GenAIAttributes.GEN_AI_OPERATION_NAME}, + ) + + # Recommended attributes + retriever_type: Optional[str] = field( + default=None, + metadata={"semconv": "gen_ai.retrieval.type"}, + ) + request_model: Optional[str] = field( + default=None, + metadata={"semconv": GenAIAttributes.GEN_AI_REQUEST_MODEL}, + ) + query: Optional[str] = field( + default=None, + metadata={"semconv": "gen_ai.retrieval.query.text"}, + ) + top_k: Optional[int] = field( + default=None, + metadata={"semconv": "gen_ai.retrieval.top_k"}, + ) + documents_retrieved: Optional[int] = field( + default=None, + metadata={"semconv": "gen_ai.retrieval.documents_retrieved"}, + ) + + # Opt-in attribute + results: list[dict[str, Any]] = field( + default_factory=list, + metadata={"semconv": "gen_ai.retrieval.documents"}, + ) + + # Additional utility fields (not in semantic conventions) + query_vector: Optional[list[float]] = None + server_port: Optional[int] = None + server_address: Optional[str] = None + error_type: Optional[str] = None + + @dataclass class Workflow(GenAI): """Represents a workflow orchestrating multiple agents and steps. @@ -429,6 +474,7 @@ class Step(GenAI): "GenAI", "LLMInvocation", "EmbeddingInvocation", + "RetrievalInvocation", "Error", "EvaluationResult", # agentic AI types diff --git a/util/opentelemetry-util-genai/tests/test_retrieval_invocation.py b/util/opentelemetry-util-genai/tests/test_retrieval_invocation.py new file mode 100644 index 0000000..96341b0 --- /dev/null +++ b/util/opentelemetry-util-genai/tests/test_retrieval_invocation.py @@ -0,0 +1,483 @@ +"""Tests for RetrievalInvocation lifecycle and telemetry.""" + +import pytest + +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) +from opentelemetry.semconv._incubating.attributes import ( + gen_ai_attributes as GenAI, +) +from opentelemetry.util.genai.attributes import ( + GEN_AI_RETRIEVAL_DOCUMENTS_RETRIEVED, + GEN_AI_RETRIEVAL_TOP_K, + GEN_AI_RETRIEVAL_TYPE, +) +from opentelemetry.util.genai.handler import get_telemetry_handler +from opentelemetry.util.genai.types import Error, RetrievalInvocation + + +def test_retrieval_invocation_basic_lifecycle(): + """Test basic start/stop lifecycle for retrieval invocation.""" + handler = get_telemetry_handler() + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="test query", + top_k=5, + retriever_type="vector_store", + provider="pinecone", + ) + + # Start should assign span + result = handler.start_retrieval(retrieval) + assert result is retrieval + assert retrieval.span is not None + assert retrieval.start_time is not None + + # Stop should set end_time and end span + retrieval.documents_retrieved = 5 + handler.stop_retrieval(retrieval) + assert retrieval.end_time is not None + assert retrieval.end_time >= retrieval.start_time + + +def test_retrieval_invocation_with_error(): + """Test error handling for retrieval invocation.""" + handler = get_telemetry_handler() + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="failing query", + top_k=10, + retriever_type="vector_store", + provider="chroma", + ) + + handler.start_retrieval(retrieval) + assert retrieval.span is not None + + # Fail the retrieval + error = Error(message="Connection timeout", type=TimeoutError) + handler.fail_retrieval(retrieval, error) + assert retrieval.end_time is not None + + +def test_retrieval_invocation_creates_span_with_attributes(): + """Test that retrieval invocation creates span with correct attributes.""" + # Set up in-memory span exporter + span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + + handler = get_telemetry_handler() + span_emitters = list(handler._emitter.emitters_for("span")) + if span_emitters: + span_emitters[0]._tracer = tracer_provider.get_tracer(__name__) + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="OpenTelemetry documentation", + top_k=7, + retriever_type="semantic_search", + provider="weaviate", + framework="langchain", + ) + + handler.start_retrieval(retrieval) + retrieval.documents_retrieved = 7 + handler.stop_retrieval(retrieval) + + # Get exported spans + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + + span = spans[0] + attrs = span.attributes + + # Check required attributes + assert attrs[GenAI.GEN_AI_OPERATION_NAME] == "retrieval" + + # Check recommended attributes + assert attrs[GEN_AI_RETRIEVAL_TYPE] == "semantic_search" + assert attrs[GEN_AI_RETRIEVAL_TOP_K] == 7 + assert attrs[GEN_AI_RETRIEVAL_DOCUMENTS_RETRIEVED] == 7 + + # Check provider and framework + assert attrs[GenAI.GEN_AI_PROVIDER_NAME] == "weaviate" + assert attrs.get("gen_ai.framework") == "langchain" + + +def test_retrieval_invocation_with_vector_search(): + """Test retrieval with query vector.""" + handler = get_telemetry_handler() + query_vector = [0.1, 0.2, 0.3] * 256 # 768-dim vector + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query_vector=query_vector, + top_k=10, + retriever_type="vector_store", + provider="pinecone", + ) + + handler.start_retrieval(retrieval) + assert retrieval.span is not None + assert retrieval.query_vector == query_vector + + retrieval.documents_retrieved = 10 + handler.stop_retrieval(retrieval) + assert retrieval.end_time is not None + + +def test_retrieval_invocation_with_hybrid_search(): + """Test retrieval with both text query and vector.""" + handler = get_telemetry_handler() + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="machine learning", + query_vector=[0.5] * 384, + top_k=15, + retriever_type="hybrid_search", + provider="elasticsearch", + ) + + handler.start_retrieval(retrieval) + assert retrieval.span is not None + assert retrieval.query == "machine learning" + assert len(retrieval.query_vector) == 384 + + retrieval.documents_retrieved = 15 + handler.stop_retrieval(retrieval) + + +def test_retrieval_invocation_with_agent_context(): + """Test retrieval within agent context.""" + handler = get_telemetry_handler() + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="product information", + top_k=5, + retriever_type="vector_store", + provider="milvus", + agent_name="product_assistant", + agent_id="agent-123", + ) + + handler.start_retrieval(retrieval) + assert retrieval.span is not None + assert retrieval.agent_name == "product_assistant" + assert retrieval.agent_id == "agent-123" + + retrieval.documents_retrieved = 5 + handler.stop_retrieval(retrieval) + + +def test_retrieval_invocation_with_custom_attributes(): + """Test retrieval with custom attributes.""" + handler = get_telemetry_handler() + + custom_attrs = { + "collection_name": "docs", + "user_id": "user-456", + "session_id": "session-789", + } + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="custom search", + top_k=3, + retriever_type="vector_store", + provider="qdrant", + attributes=custom_attrs, + ) + + handler.start_retrieval(retrieval) + assert retrieval.span is not None + assert retrieval.attributes == custom_attrs + + retrieval.documents_retrieved = 3 + handler.stop_retrieval(retrieval) + + +def test_retrieval_invocation_with_results(): + """Test retrieval with result documents.""" + handler = get_telemetry_handler() + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="test", + top_k=2, + retriever_type="vector_store", + provider="pinecone", + ) + + handler.start_retrieval(retrieval) + + # Populate results + retrieval.documents_retrieved = 2 + retrieval.results = [ + {"id": "doc1", "score": 0.95, "content": "First document"}, + {"id": "doc2", "score": 0.87, "content": "Second document"}, + ] + + handler.stop_retrieval(retrieval) + assert len(retrieval.results) == 2 + assert retrieval.results[0]["score"] == 0.95 + + +def test_retrieval_invocation_semantic_convention_attributes(): + """Test that semantic convention attributes are properly extracted.""" + retrieval = RetrievalInvocation( + operation_name="retrieval", + request_model="text-embedding-ada-002", + query="semantic test", + top_k=5, + retriever_type="vector_store", + provider="test_provider", + ) + + semconv_attrs = retrieval.semantic_convention_attributes() + + # Check that semantic convention attributes are present + assert GenAI.GEN_AI_OPERATION_NAME in semconv_attrs + assert semconv_attrs[GenAI.GEN_AI_OPERATION_NAME] == "retrieval" + assert GenAI.GEN_AI_REQUEST_MODEL in semconv_attrs + assert ( + semconv_attrs[GenAI.GEN_AI_REQUEST_MODEL] == "text-embedding-ada-002" + ) + assert "gen_ai.retrieval.type" in semconv_attrs + assert semconv_attrs["gen_ai.retrieval.type"] == "vector_store" + assert "gen_ai.retrieval.query.text" in semconv_attrs + assert semconv_attrs["gen_ai.retrieval.query.text"] == "semantic test" + assert "gen_ai.retrieval.top_k" in semconv_attrs + assert semconv_attrs["gen_ai.retrieval.top_k"] == 5 + + +def test_retrieval_invocation_minimal_required_fields(): + """Test retrieval with only required fields.""" + handler = get_telemetry_handler() + + # Only operation_name is required + retrieval = RetrievalInvocation( + operation_name="retrieval", + ) + + handler.start_retrieval(retrieval) + assert retrieval.span is not None + + handler.stop_retrieval(retrieval) + assert retrieval.end_time is not None + + +def test_retrieval_invocation_multiple_sequential(): + """Test multiple sequential retrieval invocations.""" + handler = get_telemetry_handler() + + queries = ["query1", "query2", "query3"] + retrievals = [] + + for query in queries: + retrieval = RetrievalInvocation( + operation_name="retrieval", + query=query, + top_k=5, + retriever_type="vector_store", + provider="pinecone", + ) + handler.start_retrieval(retrieval) + retrieval.documents_retrieved = 5 + handler.stop_retrieval(retrieval) + retrievals.append(retrieval) + + # All should have completed successfully + assert len(retrievals) == 3 + for retrieval in retrievals: + assert retrieval.span is not None + assert retrieval.end_time is not None + + +def test_generic_start_finish_for_retrieval(): + """Test generic handler methods route to retrieval lifecycle.""" + handler = get_telemetry_handler() + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="generic test", + top_k=5, + retriever_type="vector_store", + provider="test", + ) + + # Generic methods should route to retrieval lifecycle + handler.start(retrieval) + assert retrieval.span is not None + + handler.finish(retrieval) + assert retrieval.end_time is not None + + # Test fail path + retrieval2 = RetrievalInvocation( + operation_name="retrieval", + query="fail test", + top_k=3, + ) + handler.start(retrieval2) + handler.fail(retrieval2, Error(message="test error", type=RuntimeError)) + assert retrieval2.end_time is not None + + +def test_retrieval_invocation_span_name(): + """Test that span name is correctly formatted.""" + span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + + handler = get_telemetry_handler() + span_emitters = list(handler._emitter.emitters_for("span")) + if span_emitters: + span_emitters[0]._tracer = tracer_provider.get_tracer(__name__) + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="test", + provider="pinecone", + ) + + handler.start_retrieval(retrieval) + handler.stop_retrieval(retrieval) + + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + # Span name should be "retrieval pinecone" + assert spans[0].name == "retrieval pinecone" + + +def test_retrieval_invocation_without_provider(): + """Test retrieval without provider (span name should be just operation).""" + span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + + handler = get_telemetry_handler() + span_emitters = list(handler._emitter.emitters_for("span")) + if span_emitters: + span_emitters[0]._tracer = tracer_provider.get_tracer(__name__) + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="test", + ) + + handler.start_retrieval(retrieval) + handler.stop_retrieval(retrieval) + + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + # Span name should be just "retrieval" + assert spans[0].name == "retrieval" + + +@pytest.mark.parametrize( + "retriever_type,provider", + [ + ("vector_store", "pinecone"), + ("semantic_search", "weaviate"), + ("hybrid_search", "elasticsearch"), + ("keyword_search", "opensearch"), + ], +) +def test_retrieval_invocation_different_types(retriever_type, provider): + """Test retrieval with different retriever types and providers.""" + handler = get_telemetry_handler() + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query=f"test {retriever_type}", + top_k=5, + retriever_type=retriever_type, + provider=provider, + ) + + handler.start_retrieval(retrieval) + assert retrieval.span is not None + assert retrieval.retriever_type == retriever_type + assert retrieval.provider == provider + + retrieval.documents_retrieved = 5 + handler.stop_retrieval(retrieval) + assert retrieval.end_time is not None + + +def test_retrieval_invocation_with_server_and_model_attributes(): + """Test retrieval with server address, port, and model attributes.""" + span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + + handler = get_telemetry_handler() + span_emitters = list(handler._emitter.emitters_for("span")) + if span_emitters: + span_emitters[0]._tracer = tracer_provider.get_tracer(__name__) + + retrieval = RetrievalInvocation( + operation_name="retrieval", + request_model="text-embedding-ada-002", + query="test query", + top_k=5, + retriever_type="vector_store", + provider="weaviate", + server_address="localhost", + server_port=8080, + ) + + handler.start_retrieval(retrieval) + retrieval.documents_retrieved = 5 + handler.stop_retrieval(retrieval) + + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + + span = spans[0] + attrs = span.attributes + + # Check new attributes + assert attrs[GenAI.GEN_AI_REQUEST_MODEL] == "text-embedding-ada-002" + assert attrs["server.address"] == "localhost" + assert attrs["server.port"] == 8080 + + +def test_retrieval_invocation_with_error_type(): + """Test retrieval with error_type attribute.""" + span_exporter = InMemorySpanExporter() + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + + handler = get_telemetry_handler() + span_emitters = list(handler._emitter.emitters_for("span")) + if span_emitters: + span_emitters[0]._tracer = tracer_provider.get_tracer(__name__) + + retrieval = RetrievalInvocation( + operation_name="retrieval", + query="test query", + top_k=5, + retriever_type="vector_store", + provider="pinecone", + error_type="ConnectionError", + ) + + handler.start_retrieval(retrieval) + error = Error(message="Connection failed", type=ConnectionError) + handler.fail_retrieval(retrieval, error) + + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + + span = spans[0] + attrs = span.attributes + + # Check error type attribute (should be set from invocation.error_type) + assert attrs["error.type"] == "ConnectionError"