Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,14 @@
_COMPONENT_OUTPUT_KEY = "haystack.component.output"
_COMPONENT_INPUT_KEY = "haystack.component.input"

# Context var used to keep track of tracing related info.
# This mainly useful for parents spans.
# External session metadata for trace correlation (Haystack system)
# Stores trace_id, user_id, session_id, tags, version for root trace creation
tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context")

# Internal span execution hierarchy for our tracer
# Manages parent-child relationships and prevents cross-request span interleaving
span_stack_var: ContextVar[Optional[List["LangfuseSpan"]]] = ContextVar("span_stack", default=None)


class LangfuseSpan(Span):
"""
Expand Down Expand Up @@ -265,6 +269,7 @@ def create_span(self, context: SpanContext) -> LangfuseSpan:
)
raise RuntimeError(message)

# Get external tracing context for root trace creation (correlation metadata)
tracing_ctx = tracing_context_var.get({})
if not context.parent_span:
# Create a new trace when there's no parent span
Expand Down Expand Up @@ -360,6 +365,7 @@ def __init__(
"before importing Haystack."
)
self._tracer = tracer
# Keep _context as deprecated shim to avoid AttributeError if anyone uses it
self._context: List[LangfuseSpan] = []
self._name = name
self._public = public
Expand Down Expand Up @@ -391,7 +397,12 @@ def trace(
# Create span using the handler
span = self._span_handler.create_span(span_context)

self._context.append(span)
# Build new span hierarchy: copy existing stack, add new span, save for restoration
prev_stack = span_stack_var.get()
new_stack = (prev_stack or []).copy()
new_stack.append(span)
token = span_stack_var.set(new_stack)

span.set_tags(tags)

try:
Expand All @@ -414,10 +425,8 @@ def trace(
cleanup_error=cleanup_error,
)
finally:
# CRITICAL: Always pop context to prevent corruption
# This is especially important for nested pipeline scenarios
if self._context and self._context[-1] == span:
self._context.pop()
# Restore previous span stack using saved token - ensures proper cleanup
span_stack_var.reset(token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we safely assume this line will never throw an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No mention of it in pydocs anywhere. Seems like it can't happen


if self.enforce_flush:
self.flush()
Expand All @@ -431,7 +440,9 @@ def current_span(self) -> Optional[Span]:

:return: The current span if available, else None.
"""
return self._context[-1] if self._context else None
# Get top of span stack (most recent span) from context-local storage
stack = span_stack_var.get()
return stack[-1] if stack else None

def get_trace_url(self) -> str:
"""
Expand Down
127 changes: 105 additions & 22 deletions integrations/langfuse/tests/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,57 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import datetime
import logging
import sys
import json
from typing import Optional
from unittest.mock import MagicMock, Mock, patch

import pytest
from haystack import Pipeline, component
from haystack.dataclasses import ChatMessage, ToolCall

from haystack_integrations.tracing.langfuse.tracer import (
_COMPONENT_OUTPUT_KEY, DefaultSpanHandler, LangfuseSpan, LangfuseTracer,
SpanContext)
_COMPONENT_OUTPUT_KEY,
DefaultSpanHandler,
LangfuseSpan,
LangfuseTracer,
SpanContext,
)
from haystack_integrations.components.connectors.langfuse import LangfuseConnector


class MockSpan:
def __init__(self):
def __init__(self, name="mock_span"):
self._data = {}
self._span = self
self.operation_name = "operation_name"
self.operation_name = name
self._name = name

def raw_span(self):
return self

def span(self, name=None):
# assert correct operation name passed to the span
assert name == "operation_name"
return self
# Return a new mock span for child spans
return MockSpan(name=name or "child_span")

def update(self, **kwargs):
self._data.update(kwargs)

def generation(self, name=None):
return self
# Return a new mock span for generation spans
return MockSpan(name=name or "generation_span")

def end(self):
pass


class MockTracer:

def trace(self, name, **kwargs):
return MockSpan()
# Return a unique mock span for each trace call
return MockSpan(name=name)

def flush(self):
pass
Expand All @@ -59,7 +68,6 @@ def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None:


class TestLangfuseSpan:

# LangfuseSpan can be initialized with a span object
def test_initialized_with_span_object(self):
mock_span = Mock()
Expand Down Expand Up @@ -232,7 +240,8 @@ def test_initialization(self):
langfuse_instance = Mock()
tracer = LangfuseTracer(tracer=langfuse_instance, name="Haystack", public=True)
assert tracer._tracer == langfuse_instance
assert tracer._context == []
# Check behavioral state instead of internal _context list
assert tracer.current_span() is None
assert tracer._name == "Haystack"
assert tracer._public

Expand All @@ -255,13 +264,14 @@ def test_create_new_span(self):

# check that the trace method is called on the tracer instance with the provided operation name and tags
with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span:
assert len(tracer._context) == 1, "The trace span should have been added to the the root context span"
# Check that there is a current active span during tracing
assert tracer.current_span() is not None, "There should be an active span during tracing"
assert tracer.current_span() == span, "The current span should be the active span"
assert span.raw_span().operation_name == "operation_name"
assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"}

assert (
len(tracer._context) == 0
), "The trace span should have been popped, and the root span is closed as well"
# Check that the span is cleaned up after tracing
assert tracer.current_span() is None, "There should be no active span after tracing completes"

# check that update method is called on the span instance with the provided key value pairs
def test_update_span_with_pipeline_input_output_data(self):
Expand Down Expand Up @@ -324,12 +334,12 @@ def test_handle_tool_invoker(self):
assert mock_span.update.call_count >= 1
name_update_call = None
for call in mock_span.update.call_args_list:
if 'name' in call[1]:
if "name" in call[1]:
name_update_call = call
break

assert name_update_call is not None, "No call to update the span name was made"
updated_name = name_update_call[1]['name']
updated_name = name_update_call[1]["name"]

# verify the format of the updated span name to be: `original_component_name - [list_of_tool_names]`
assert updated_name != "tool_invoker", f"Expected 'tool_invoker` to be upddated with tool names"
Expand Down Expand Up @@ -369,8 +379,7 @@ def test_update_span_flush_disable(self, monkeypatch):
monkeypatch.setenv("HAYSTACK_LANGFUSE_ENFORCE_FLUSH", "false")
tracer_mock = Mock()

from haystack_integrations.tracing.langfuse.tracer import \
LangfuseTracer
from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer

tracer = LangfuseTracer(tracer=tracer_mock, name="Haystack", public=False)
with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span:
Expand All @@ -385,11 +394,12 @@ def test_context_is_empty_after_tracing(self):
with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span:
pass

assert tracer._context == []
# Check behavioral state instead of internal _context list
assert tracer.current_span() is None

def test_init_with_tracing_disabled(self, monkeypatch, caplog):
# Clear haystack modules because ProxyTracer is initialized whenever haystack is imported
modules_to_clear = [name for name in sys.modules if name.startswith('haystack')]
modules_to_clear = [name for name in sys.modules if name.startswith("haystack")]
for name in modules_to_clear:
sys.modules.pop(name, None)

Expand All @@ -400,3 +410,76 @@ def test_init_with_tracing_disabled(self, monkeypatch, caplog):

LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False)
assert "tracing is disabled" in caplog.text

def test_async_concurrency_span_isolation(self):
"""
Test that concurrent async traces maintain isolated span contexts.

This test verifies that the context-local span stack prevents cross-request
span interleaving in concurrent environments like FastAPI servers.
"""
tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False)

# Track spans from each task for verification
task1_spans = []
task2_spans = []

async def trace_task(task_id: str, spans_list: list):
"""Simulate a request with nested tracing operations"""
with tracer.trace(f"outer_operation_{task_id}") as outer_span:
spans_list.append(("outer", outer_span, tracer.current_span()))

# Simulate some async work
await asyncio.sleep(0.01)

with tracer.trace(f"inner_operation_{task_id}") as inner_span:
spans_list.append(("inner", inner_span, tracer.current_span()))

# Simulate more async work
await asyncio.sleep(0.01)

# Verify nested relationship within this task
assert tracer.current_span() == inner_span

# After inner span, outer should be current again
spans_list.append(("after_inner", None, tracer.current_span()))
assert tracer.current_span() == outer_span

# After all spans, should be None
spans_list.append(("after_outer", None, tracer.current_span()))
assert tracer.current_span() is None

async def run_concurrent_traces():
"""Run two concurrent tracing tasks"""
await asyncio.gather(trace_task("task1", task1_spans), trace_task("task2", task2_spans))

# Run the concurrent test
asyncio.run(run_concurrent_traces())

# Verify both tasks completed successfully
assert len(task1_spans) == 4
assert len(task2_spans) == 4

# Verify each task had proper span isolation
# Task 1 spans should be different from Task 2 spans
task1_outer = task1_spans[0][1] # outer span from task1
task2_outer = task2_spans[0][1] # outer span from task2
assert task1_outer != task2_outer

task1_inner = task1_spans[1][1] # inner span from task1
task2_inner = task2_spans[1][1] # inner span from task2
assert task1_inner != task2_inner

# Verify proper nesting within each task
# Task 1: outer -> inner -> outer -> None
assert task1_spans[0][2] == task1_outer # current_span during outer
assert task1_spans[1][2] == task1_inner # current_span during inner
assert task1_spans[2][2] == task1_outer # current_span after inner
assert task1_spans[3][2] is None # current_span after outer

# Task 2: outer -> inner -> outer -> None
assert task2_spans[0][2] == task2_outer # current_span during outer
assert task2_spans[1][2] == task2_inner # current_span during inner
assert task2_spans[2][2] == task2_outer # current_span after inner
assert task2_spans[3][2] is None # current_span after outer

Loading