Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(llmobs): avoid raising errors during llmobs integration span processing #10713

Merged
merged 10 commits into from
Sep 23, 2024
8 changes: 3 additions & 5 deletions ddtrace/_trace/trace_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,10 @@ def _on_botocore_patched_bedrock_api_call_started(ctx, request_params):
def _on_botocore_patched_bedrock_api_call_exception(ctx, exc_info):
span = ctx[ctx["call_key"]]
span.set_exc_info(*exc_info)
prompt = ctx["prompt"]
model_name = ctx["model_name"]
integration = ctx["bedrock_integration"]
if integration.is_pc_sampled_llmobs(span) and "embed" not in model_name:
integration.llmobs_set_tags(span, formatted_response=None, prompt=prompt, err=True)
if "embed" not in model_name:
integration.llmobs_set_tags(span, args=[], kwargs={"prompt": ctx["prompt"]})
span.finish()


Expand Down Expand Up @@ -721,8 +720,7 @@ def _on_botocore_bedrock_process_response(
span.set_tag_str(
"bedrock.response.choices.{}.finish_reason".format(i), str(formatted_response["finish_reason"][i])
)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, formatted_response=formatted_response, prompt=ctx["prompt"])
integration.llmobs_set_tags(span, args=[], kwargs={"prompt": ctx["prompt"]}, response=formatted_response)
span.finish()


Expand Down
9 changes: 1 addition & 8 deletions ddtrace/contrib/internal/anthropic/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,9 @@ def _process_finished_stream(integration, span, args, kwargs, streamed_chunks):
# builds the response message given streamed chunks and sets according span tags
try:
resp_message = _construct_message(streamed_chunks)

if integration.is_pc_sampled_span(span):
_tag_streamed_chat_completion_response(integration, span, resp_message)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
span=span,
resp=resp_message,
args=args,
kwargs=kwargs,
)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp_message)
except Exception:
log.warning("Error processing streamed completion/chat response.", exc_info=True)

Expand Down
6 changes: 2 additions & 4 deletions ddtrace/contrib/internal/anthropic/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def traced_chat_model_generate(anthropic, pin, func, instance, args, kwargs):
finally:
# we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=chat_completions)
span.finish()
return chat_completions

Expand Down Expand Up @@ -178,8 +177,7 @@ async def traced_async_chat_model_generate(anthropic, pin, func, instance, args,
finally:
# we don't want to finish the span if it is a stream as it will get finished once the iterator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span=span, resp=chat_completions, args=args, kwargs=kwargs)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=chat_completions)
span.finish()
return chat_completions

Expand Down
22 changes: 14 additions & 8 deletions ddtrace/contrib/internal/google_generativeai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ def __iter__(self):
else:
tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance)
finally:
if self._dd_integration.is_pc_sampled_llmobs(self._dd_span):
self._dd_integration.llmobs_set_tags(
self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__
)
self._kwargs["instance"] = self._model_instance
self._dd_integration.llmobs_set_tags(
self._dd_span,
args=self._args,
kwargs=self._kwargs,
response=self.__wrapped__,
)
self._dd_span.finish()


Expand All @@ -48,10 +51,13 @@ async def __aiter__(self):
else:
tag_response(self._dd_span, self.__wrapped__, self._dd_integration, self._model_instance)
finally:
if self._dd_integration.is_pc_sampled_llmobs(self._dd_span):
self._dd_integration.llmobs_set_tags(
self._dd_span, self._args, self._kwargs, self._model_instance, self.__wrapped__
)
self._kwargs["instance"] = self._model_instance
self._dd_integration.llmobs_set_tags(
self._dd_span,
args=self._args,
kwargs=self._kwargs,
response=self.__wrapped__,
)
self._dd_span.finish()


Expand Down
8 changes: 4 additions & 4 deletions ddtrace/contrib/internal/google_generativeai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def traced_generate(genai, pin, func, instance, args, kwargs):
finally:
# streamed spans will be finished separately once the stream generator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, args, kwargs, instance, generations)
kwargs["instance"] = instance
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations)
span.finish()
return generations

Expand Down Expand Up @@ -90,8 +90,8 @@ async def traced_agenerate(genai, pin, func, instance, args, kwargs):
finally:
# streamed spans will be finished separately once the stream generator is exhausted
if span.error or not stream:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(span, args, kwargs, instance, generations)
kwargs["instance"] = instance
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations)
span.finish()
return generations

Expand Down
94 changes: 14 additions & 80 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,7 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"llm",
span,
prompts,
completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=completions, operation="llm")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -322,14 +315,7 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"llm",
span,
prompts,
completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=completions, operation="llm")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -438,14 +424,7 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"chat",
span,
chat_messages,
chat_completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=chat_completions, operation="chat")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -570,14 +549,7 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"chat",
span,
chat_messages,
chat_completions,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=chat_completions, operation="chat")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -662,14 +634,7 @@ def traced_embedding(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"embedding",
span,
input_texts,
embeddings,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=embeddings, operation="embedding")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -717,8 +682,7 @@ def traced_chain_call(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_outputs, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -774,8 +738,7 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_outputs, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_outputs, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -847,8 +810,7 @@ def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_output, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
return final_output
Expand Down Expand Up @@ -894,8 +856,7 @@ async def traced_lcel_runnable_sequence_async(langchain, pin, func, instance, ar
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chain", span, inputs, final_output, error=bool(span.error))
integration.llmobs_set_tags(span, args=[], kwargs=inputs, response=final_output, operation="chain")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
return final_output
Expand Down Expand Up @@ -953,14 +914,7 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs):
integration.metric(span, "incr", "request.error", 1)
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"retrieval",
span,
query,
documents,
error=bool(span.error),
)
integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=documents, operation="retrieval")
span.finish()
integration.metric(span, "dist", "request.duration", span.duration_ns)
if integration.is_pc_sampled_log(span):
Expand Down Expand Up @@ -1024,18 +978,8 @@ def traced_base_tool_invoke(langchain, pin, func, instance, args, kwargs):
span.set_exc_info(*sys.exc_info())
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"tool",
span,
{
"input": tool_input,
"config": config if config else {},
"info": tool_info if tool_info else {},
},
tool_output,
error=bool(span.error),
)
tool_inputs = {"input": tool_input, "config": config or {}, "info": tool_info or {}}
integration.llmobs_set_tags(span, args=[], kwargs=tool_inputs, response=tool_output, operation="tool")
span.finish()
return tool_output

Expand Down Expand Up @@ -1085,18 +1029,8 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
span.set_exc_info(*sys.exc_info())
raise
finally:
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"tool",
span,
{
"input": tool_input,
"config": config if config else {},
"info": tool_info if tool_info else {},
},
tool_output,
error=bool(span.error),
)
tool_inputs = {"input": tool_input, "config": config or {}, "info": tool_info or {}}
integration.llmobs_set_tags(span, args=[], kwargs=tool_inputs, response=tool_output, operation="tool")
span.finish()
return tool_output

Expand Down
14 changes: 3 additions & 11 deletions ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ddtrace.contrib.internal.openai.utils import _process_finished_stream
from ddtrace.contrib.internal.openai.utils import _tag_tool_calls
from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs._constants import SPAN_KIND


API_VERSION = "v1"
Expand Down Expand Up @@ -189,8 +188,6 @@ class _CompletionHook(_BaseCompletionHook):

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
if integration.is_pc_sampled_llmobs(span):
span.set_tag_str(SPAN_KIND, "llm")
if integration.is_pc_sampled_span(span):
prompt = kwargs.get("prompt", "")
if isinstance(prompt, str):
Expand All @@ -212,8 +209,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
integration.log(
span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict
)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("completion", resp, span, kwargs, err=error)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="completion")
if not resp:
return
for choice in resp.choices:
Expand Down Expand Up @@ -247,8 +243,6 @@ class _ChatCompletionHook(_BaseCompletionHook):

def _record_request(self, pin, integration, span, args, kwargs):
super()._record_request(pin, integration, span, args, kwargs)
if integration.is_pc_sampled_llmobs(span):
span.set_tag_str(SPAN_KIND, "llm")
for idx, m in enumerate(kwargs.get("messages", [])):
role = getattr(m, "role", "")
name = getattr(m, "name", "")
Expand All @@ -274,8 +268,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
integration.log(
span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict
)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("chat", resp, span, kwargs, err=error)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="chat")
if not resp:
return
for choice in resp.choices:
Expand Down Expand Up @@ -319,8 +312,7 @@ def _record_request(self, pin, integration, span, args, kwargs):

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags("embedding", resp, span, kwargs, err=error)
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="embedding")
if not resp:
return
span.set_metric("openai.response.embeddings_count", len(resp.data))
Expand Down
6 changes: 2 additions & 4 deletions ddtrace/contrib/internal/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,8 @@ def _process_finished_stream(integration, span, kwargs, streamed_chunks, is_comp
if integration.is_pc_sampled_span(span):
_tag_streamed_response(integration, span, formatted_completions)
_set_token_metrics(span, integration, formatted_completions, prompts, request_messages, kwargs)
if integration.is_pc_sampled_llmobs(span):
integration.llmobs_set_tags(
"completion" if is_completion else "chat", None, span, kwargs, formatted_completions, None
)
operation = "completion" if is_completion else "chat"
integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=formatted_completions, operation=operation)
except Exception:
log.warning("Error processing streamed completion/chat response.", exc_info=True)

Expand Down
Loading
Loading