diff --git a/libs/community/langchain_community/callbacks/__init__.py b/libs/community/langchain_community/callbacks/__init__.py index 5d36b91f4..e99280f2f 100644 --- a/libs/community/langchain_community/callbacks/__init__.py +++ b/libs/community/langchain_community/callbacks/__init__.py @@ -23,6 +23,10 @@ from langchain_community.callbacks.arthur_callback import ( ArthurCallbackHandler, ) + from langchain_community.callbacks.bigquery_callback import ( + AsyncBigQueryCallbackHandler, + BigQueryCallbackHandler, + ) from langchain_community.callbacks.clearml_callback import ( ClearMLCallbackHandler, ) @@ -93,6 +97,8 @@ "ArgillaCallbackHandler": "langchain_community.callbacks.argilla_callback", "ArizeCallbackHandler": "langchain_community.callbacks.arize_callback", "ArthurCallbackHandler": "langchain_community.callbacks.arthur_callback", + "AsyncBigQueryCallbackHandler": "langchain_community.callbacks.bigquery_callback", + "BigQueryCallbackHandler": "langchain_community.callbacks.bigquery_callback", "ClearMLCallbackHandler": "langchain_community.callbacks.clearml_callback", "CometCallbackHandler": "langchain_community.callbacks.comet_ml_callback", "ContextCallbackHandler": "langchain_community.callbacks.context_callback", @@ -131,6 +137,8 @@ def __getattr__(name: str) -> Any: "ArgillaCallbackHandler", "ArizeCallbackHandler", "ArthurCallbackHandler", + "AsyncBigQueryCallbackHandler", + "BigQueryCallbackHandler", "ClearMLCallbackHandler", "CometCallbackHandler", "ContextCallbackHandler", diff --git a/libs/community/langchain_community/callbacks/bigquery_callback.py b/libs/community/langchain_community/callbacks/bigquery_callback.py new file mode 100644 index 000000000..d578e0cfe --- /dev/null +++ b/libs/community/langchain_community/callbacks/bigquery_callback.py @@ -0,0 +1,1342 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import threading +from datetime import UTC, datetime +from typing import Any, Dict, List, Optional, Union + +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.outputs import LLMResult +from langchain_core.utils import guard_import + +def _jsonify_safely(data: Any) -> Any: + """Recursively converts non-serializable objects to strings.""" + if isinstance(data, dict): + return {key: _jsonify_safely(value) for key, value in data.items()} + if isinstance(data, list): + return [_jsonify_safely(item) for item in data] + try: + json.dumps(data) + return data + except (TypeError, OverflowError): + return str(data) + + +def import_google_cloud_bigquery() -> Any: + """Import google-cloud-bigquery and its dependencies.""" + return ( + guard_import("google.cloud.bigquery"), + guard_import("google.auth", pip_name="google-auth"), + guard_import("google.api_core.gapic_v1.client_info"), + guard_import( + "google.cloud.bigquery_storage_v1.services.big_query_write.async_client" + ), + guard_import( + "google.cloud.bigquery_storage_v1.services.big_query_write.client" + ), + guard_import("google.cloud.bigquery_storage_v1"), + guard_import("pyarrow"), + ) + + +class AsyncBigQueryCallbackHandler(AsyncCallbackHandler): + """ + Callback Handler that logs to Google BigQuery. + + This handler captures key events during an agent's lifecycle—such as user + interactions, tool executions, LLM requests/responses, and errors—and + streams them to a BigQuery table for analysis and monitoring. + + It uses the BigQuery Write API for efficient, high-throughput streaming + ingestion. If the destination table does not exist, the handler will + attempt to create it based on a predefined schema. + """ + + def __init__( + self, + project_id: str, + dataset_id: str, + table_id: str = "agent_events", + max_content_length: int = 200 * 1024, + ): + """Initializes the BigQueryCallbackHandler. + + Args: + project_id: Google Cloud project ID. + dataset_id: BigQuery dataset ID. + table_id: BigQuery table ID for agent events. + max_content_length: The maximum length of content to log before truncating. + """ + super().__init__() + ( + self.bigquery, + self.google_auth, + self.gapic_client_info, + self.async_client, + self.sync_client, + self.bq_storage, + self.pa, + ) = import_google_cloud_bigquery() + self.BigQueryWriteAsyncClient = self.async_client.BigQueryWriteAsyncClient + self.BigQueryWriteClient = self.sync_client.BigQueryWriteClient + self._project_id, self._dataset_id, self._table_id = ( + project_id, + dataset_id, + table_id, + ) + self._max_content_length = max_content_length + self._bq_client = None + self._write_client = None + self._init_lock = asyncio.Lock() + self._arrow_schema = None + self._schema = [ + self.bigquery.SchemaField("timestamp", "TIMESTAMP"), + self.bigquery.SchemaField("event_type", "STRING"), + self.bigquery.SchemaField("run_id", "STRING"), + self.bigquery.SchemaField("parent_run_id", "STRING"), + self.bigquery.SchemaField("content", "STRING"), + self.bigquery.SchemaField("serialized", "STRING"), + self.bigquery.SchemaField("tags", "STRING"), + self.bigquery.SchemaField("metadata", "STRING"), + self.bigquery.SchemaField("error_message", "STRING"), + self.bigquery.SchemaField("is_truncated", "BOOLEAN"), + ] + self.action_records: list = [] + + async def _ensure_init(self) -> bool: + """Ensures BigQuery clients are initialized.""" + if self._write_client: + return True + async with self._init_lock: + if self._write_client: + return True + try: + creds, _ = await asyncio.to_thread( + self.google_auth.default, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + client_info = self.gapic_client_info.ClientInfo( + user_agent="langchain-bigquery-callback" + ) + self._bq_client = self.bigquery.Client( + project=self._project_id, credentials=creds, client_info=client_info + ) + + # Create dataset and table asynchronously + if self._bq_client: + # Run sync methods in a thread to avoid blocking the event loop. + await asyncio.to_thread( + self._bq_client.create_dataset, self._dataset_id, exists_ok=True + ) + table = self.bigquery.Table( + f"{self._project_id}.{self._dataset_id}.{self._table_id}", + schema=self._schema, + ) + await asyncio.to_thread( + self._bq_client.create_table, table, exists_ok=True + ) + + self._write_client = self.BigQueryWriteAsyncClient( + credentials=creds, # type: ignore + client_info=client_info, + ) + self._arrow_schema = self._bq_to_arrow_schema(self._schema) + return True + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("BQ Init Failed: %s", e) + return False + + def _bq_to_arrow_scalars(self, bq_scalar: str) -> Any: + """Converts a BigQuery scalar type string to a PyArrow data type.""" + _BQ_TO_ARROW_SCALARS = { + "BOOL": self.pa.bool_(), + "BOOLEAN": self.pa.bool_(), + "BYTES": self.pa.binary(), + "DATE": self.pa.date32(), + "DATETIME": self.pa.timestamp("us", tz=None), + "FLOAT": self.pa.float64(), + "FLOAT64": self.pa.float64(), + "GEOGRAPHY": self.pa.string(), + "INT64": self.pa.int64(), + "INTEGER": self.pa.int64(), + "JSON": self.pa.string(), + "NUMERIC": self.pa.decimal128(38, 9), + "BIGNUMERIC": self.pa.decimal256(76, 38), + "STRING": self.pa.string(), + "TIME": self.pa.time64("us"), + "TIMESTAMP": self.pa.timestamp("us", tz="UTC"), + } + return _BQ_TO_ARROW_SCALARS.get(bq_scalar) + + def _bq_to_arrow_data_type(self, field: Any) -> Any: + """Converts a BigQuery schema field to a PyArrow data type.""" + if field.mode == "REPEATED": + inner = self._bq_to_arrow_data_type( + self.bigquery.SchemaField( + field.name, + field.field_type, + fields=field.fields, + range_element_type=getattr(field, "range_element_type", None), + ) + ) + return self.pa.list_(inner) if inner else None + + field_type_upper = field.field_type.upper() if field.field_type else "" + if field_type_upper in ("RECORD", "STRUCT"): + arrow_fields = [ + self._bq_to_arrow_field(subfield) for subfield in field.fields + ] + return self.pa.struct(arrow_fields) + + constructor = self._bq_to_arrow_scalars(field_type_upper) + if constructor: + return constructor + else: + logging.warning( + "Failed to convert BigQuery field '%s': unsupported type '%s'.", + field.name, + field.field_type, + ) + return None + + def _bq_to_arrow_field(self, bq_field: Any) -> Any: + """Converts a BigQuery SchemaField to a PyArrow Field.""" + arrow_type = self._bq_to_arrow_data_type(bq_field) + if arrow_type: + return self.pa.field( + bq_field.name, + arrow_type, + nullable=(bq_field.mode != "REPEATED"), + ) + return None + + def _bq_to_arrow_schema(self, bq_schema_list: List[Any]) -> Any: + """Converts a list of BigQuery SchemaFields to a PyArrow Schema.""" + arrow_fields = [ + af for af in (self._bq_to_arrow_field(f) for f in bq_schema_list) if af + ] + return self.pa.schema(arrow_fields) + + def _truncate_content_safely(self, content: str) -> str: + """Truncates the content string if it exceeds the configured max length.""" + if len(content.encode("utf-8")) > self._max_content_length: + truncated_content = content.encode("utf-8")[: self._max_content_length].decode("utf-8", "ignore") + return f"{truncated_content} [TRUNCATED_MAX_BYTES:{self._max_content_length}]" + return content + + async def _log(self, data: dict) -> None: + """Schedules a log entry to be written.""" + row = { + "timestamp": datetime.now(UTC), + "event_type": None, + "run_id": None, + "parent_run_id": None, + "content": None, + "serialized": None, + "tags": None, + "metadata": None, + "error_message": None, + "is_truncated": False, + } + row.update(data) + + await self._perform_write(row) + + async def _perform_write(self, row: dict) -> None: + """Actual write operation.""" + try: + if not await self._ensure_init() or not self._write_client or not self._arrow_schema: + return + + pydict = {field.name: [row.get(field.name)] for field in self._arrow_schema} + batch = self.pa.RecordBatch.from_pydict(pydict, schema=self._arrow_schema) + + write_stream = f"projects/{self._project_id}/datasets/{self._dataset_id}/tables/{self._table_id}/_default" + request = self.bq_storage.types.AppendRowsRequest( + write_stream=write_stream, + ) + # Correctly attach Arrow data to the `arrow_rows` field. + request.arrow_rows.writer_schema.serialized_schema = ( + self._arrow_schema.serialize().to_pybytes() + ) + request.arrow_rows.rows.serialized_record_batch = ( + batch.serialize().to_pybytes() + ) + + # This is an async call + # Write with protection against immediate cancellation + async for resp in await asyncio.shield( + self._write_client.append_rows(iter([request])) + ): + if resp.error.code != 0: + logging.error("BQ Write Error: %s", resp.error.message) + + except RuntimeError as e: + if "Event loop is closed" not in str(e): + logging.exception("BQ Runtime Error: %s", e) + + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("BQ Write Failed: %s", e) + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM starts.""" + content_str = json.dumps({"prompts": prompts}) + data = { + "event_type": "LLM_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) + if serialized + else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Any] = None, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a new token is generated.""" + content_str = json.dumps({"token": token}) + data = { + "event_type": "LLM_NEW_TOKEN", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts.""" + message_dicts = [[msg.dict() for msg in m] for m in messages] + content_str = json.dumps({"messages": _jsonify_safely(message_dicts)}) + data = { + "event_type": "CHAT_MODEL_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) if serialized else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_llm_end( + self, + response: LLMResult, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM ends running.""" + metadata = kwargs.get("metadata") or {} + for generations in response.generations: + for generation in generations: + content_str = json.dumps({"response": generation.text}) + data = { + "event_type": "LLM_RESPONSE", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(metadata)), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_llm_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + data = { + "event_type": "LLM_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + await self._log(data) + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain starts running.""" + content_str = json.dumps({"inputs": _jsonify_safely(inputs)}) + data = { + "event_type": "CHAIN_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) if serialized else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_text( + self, + text: str, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on arbitrary text.""" + content_str = json.dumps({"text": text}) + data = { + "event_type": "TEXT", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a retriever starts.""" + content_str = json.dumps({"query": query}) + data = { + "event_type": "RETRIEVER_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) if serialized else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_retriever_end( + self, + documents: Any, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a retriever ends.""" + docs = [doc.dict() for doc in documents] + content_str = json.dumps({"documents": _jsonify_safely(docs)}) + data = { + "event_type": "RETRIEVER_END", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_retriever_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a retriever errors.""" + data = { + "event_type": "RETRIEVER_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + await self._log(data) + + async def close(self) -> None: + """ + Shuts down the callback handler, ensuring all logs are flushed and clients are + properly closed. This should be called before application exit. + + Once your Langchain application has completed its tasks, ensure that you call + the `close` method to finalize the logging process. + """ + logging.info("BQ Callback: Shutdown started.") + + # Use getattr for safe access in case transport is not present. + if self._write_client and hasattr(self._write_client, "close"): + try: + logging.info("BQ Callback: Closing write client.") + await self._write_client.close() + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning("BQ Callback: Error closing write client: %s", e) + if self._bq_client: + try: + self._bq_client.close() + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning("BQ Callback: Error closing BQ client: %s", e) + + self._write_client = None + self._bq_client = None + logging.info("BQ Callback: Shutdown complete.") + + async def on_chain_end( + self, + outputs: Union[Dict[str, Any], Any], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain ends running.""" + content_str = json.dumps({"outputs": _jsonify_safely(outputs)}) + data = { + "event_type": "CHAIN_END", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(kwargs.get("tags", []))), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_chain_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + data = { + "event_type": "CHAIN_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + await self._log(data) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool starts running.""" + content_str = json.dumps({"input": input_str}) + data = { + "event_type": "TOOL_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) if serialized else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_tool_end( + self, + output: Any, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool ends running.""" + content_str = json.dumps({"output": str(output)}) + data = { + "event_type": "TOOL_END", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_tool_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + data = { + "event_type": "TOOL_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + await self._log(data) + + async def on_agent_action( + self, + action: AgentAction, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + content_str = json.dumps({"tool": action.tool, "input": str(action.tool_input)}) + data = { + "event_type": "AGENT_ACTION", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + + async def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when agent ends running.""" + content_str = json.dumps({"output": _jsonify_safely(finish.return_values)}) + data = { + "event_type": "AGENT_FINISH", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + await self._log(data) + +class BigQueryCallbackHandler(BaseCallbackHandler): + """ + Callback Handler that logs to Google BigQuery. + + This handler captures key events during an agent's lifecycle—such as user + interactions, tool executions, LLM requests/responses, and errors—and + streams them to a BigQuery table for analysis and monitoring. + + It uses the BigQuery Write API for efficient, high-throughput streaming + ingestion. If the destination table does not exist, the handler will + attempt to create it based on a predefined schema. + """ + + def __init__( + self, + project_id: str, + dataset_id: str, + table_id: str = "agent_events", + max_content_length: int = 200 * 1024, + ): + """Initializes the BigQueryCallbackHandler. + + Args: + project_id: Google Cloud project ID. + dataset_id: BigQuery dataset ID. + table_id: BigQuery table ID for agent events. + max_content_length: The maximum length of content to log before truncating. + """ + super().__init__() + ( + self.bigquery, + self.google_auth, + self.gapic_client_info, + _, # async_client + self.sync_client, + self.bq_storage, + self.pa, + ) = import_google_cloud_bigquery() + self.BigQueryWriteClient = self.sync_client.BigQueryWriteClient + self._project_id, self._dataset_id, self._table_id = ( + project_id, + dataset_id, + table_id, + ) + self._max_content_length = max_content_length + self._bq_client = None + self._write_client = None + self._init_lock = threading.Lock() + self._arrow_schema = None + self._schema = [ + self.bigquery.SchemaField("timestamp", "TIMESTAMP"), + self.bigquery.SchemaField("event_type", "STRING"), + self.bigquery.SchemaField("run_id", "STRING"), + self.bigquery.SchemaField("parent_run_id", "STRING"), + self.bigquery.SchemaField("content", "STRING"), + self.bigquery.SchemaField("serialized", "STRING"), + self.bigquery.SchemaField("tags", "STRING"), + self.bigquery.SchemaField("metadata", "STRING"), + self.bigquery.SchemaField("error_message", "STRING"), + self.bigquery.SchemaField("is_truncated", "BOOLEAN"), + ] + self.action_records: list = [] + + def _ensure_init(self) -> bool: + """Ensures BigQuery clients are initialized.""" + if self._write_client: + return True + with self._init_lock: + if self._write_client: + return True + try: + creds, _ = self.google_auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + client_info = self.gapic_client_info.ClientInfo( + user_agent="langchain-bigquery-callback" + ) + self._bq_client = self.bigquery.Client( + project=self._project_id, credentials=creds, client_info=client_info + ) + + if self._bq_client: + self._bq_client.create_dataset(self._dataset_id, exists_ok=True) + table = self.bigquery.Table( + f"{self._project_id}.{self._dataset_id}.{self._table_id}", + schema=self._schema, + ) + self._bq_client.create_table(table, exists_ok=True) + + self._write_client = self.BigQueryWriteClient( + credentials=creds, # type: ignore + client_info=client_info, + ) + self._arrow_schema = self._bq_to_arrow_schema(self._schema) + return True + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("BQ Init Failed: %s", e) + return False + + def _bq_to_arrow_scalars(self, bq_scalar: str) -> Any: + """Converts a BigQuery scalar type string to a PyArrow data type.""" + _BQ_TO_ARROW_SCALARS = { + "BOOL": self.pa.bool_(), + "BOOLEAN": self.pa.bool_(), + "BYTES": self.pa.binary(), + "DATE": self.pa.date32(), + "DATETIME": self.pa.timestamp("us", tz=None), + "FLOAT": self.pa.float64(), + "FLOAT64": self.pa.float64(), + "GEOGRAPHY": self.pa.string(), + "INT64": self.pa.int64(), + "INTEGER": self.pa.int64(), + "JSON": self.pa.string(), + "NUMERIC": self.pa.decimal128(38, 9), + "BIGNUMERIC": self.pa.decimal256(76, 38), + "STRING": self.pa.string(), + "TIME": self.pa.time64("us"), + "TIMESTAMP": self.pa.timestamp("us", tz="UTC"), + } + return _BQ_TO_ARROW_SCALARS.get(bq_scalar) + + def _bq_to_arrow_data_type(self, field: Any) -> Any: + """Converts a BigQuery schema field to a PyArrow data type.""" + if field.mode == "REPEATED": + inner = self._bq_to_arrow_data_type( + self.bigquery.SchemaField( + field.name, + field.field_type, + fields=field.fields, + range_element_type=getattr(field, "range_element_type", None), + ) + ) + return self.pa.list_(inner) if inner else None + + field_type_upper = field.field_type.upper() if field.field_type else "" + if field_type_upper in ("RECORD", "STRUCT"): + arrow_fields = [ + self._bq_to_arrow_field(subfield) for subfield in field.fields + ] + return self.pa.struct(arrow_fields) + + constructor = self._bq_to_arrow_scalars(field_type_upper) + if constructor: + return constructor + else: + logging.warning( + "Failed to convert BigQuery field '%s': unsupported type '%s'.", + field.name, + field.field_type, + ) + return None + + def _bq_to_arrow_field(self, bq_field: Any) -> Any: + """Converts a BigQuery SchemaField to a PyArrow Field.""" + arrow_type = self._bq_to_arrow_data_type(bq_field) + if arrow_type: + return self.pa.field( + bq_field.name, + arrow_type, + nullable=(bq_field.mode != "REPEATED"), + ) + return None + + def _bq_to_arrow_schema(self, bq_schema_list: List[Any]) -> Any: + """Converts a list of BigQuery SchemaFields to a PyArrow Schema.""" + arrow_fields = [ + af for af in (self._bq_to_arrow_field(f) for f in bq_schema_list) if af + ] + return self.pa.schema(arrow_fields) + + def _truncate_content_safely(self, content: str) -> str: + """Truncates the content string if it exceeds the configured max length.""" + if len(content.encode("utf-8")) > self._max_content_length: + truncated_content = content.encode("utf-8")[: self._max_content_length].decode("utf-8", "ignore") + return f"{truncated_content} [TRUNCATED_MAX_BYTES:{self._max_content_length}]" + return content + + def _log(self, data: dict) -> None: + """Schedules a log entry to be written.""" + row = { + "timestamp": datetime.now(UTC), + "event_type": None, + "run_id": None, + "parent_run_id": None, + "content": None, + "serialized": None, + "tags": None, + "metadata": None, + "error_message": None, + "is_truncated": False, + } + row.update(data) + + self._perform_write(row) + + def _perform_write(self, row: dict) -> None: + """Actual write operation.""" + try: + if not self._ensure_init() or not self._write_client or not self._arrow_schema: + return + + pydict = {field.name: [row.get(field.name)] for field in self._arrow_schema} + batch = self.pa.RecordBatch.from_pydict(pydict, schema=self._arrow_schema) + + write_stream = f"projects/{self._project_id}/datasets/{self._dataset_id}/tables/{self._table_id}/_default" + request = self.bq_storage.types.AppendRowsRequest( + write_stream=write_stream, + ) + request.arrow_rows.writer_schema.serialized_schema = ( + self._arrow_schema.serialize().to_pybytes() + ) + request.arrow_rows.rows.serialized_record_batch = ( + batch.serialize().to_pybytes() + ) + + # This is a sync call + resp = self._write_client.append_rows(iter([request])) + for r in resp: + if r.error.code != 0: + logging.error("BQ Write Error: %s", r.error.message) + + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("BQ Write Failed: %s", e) + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM starts.""" + content_str = json.dumps({"prompts": prompts}) + data = { + "event_type": "LLM_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) + if serialized + else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts.""" + message_dicts = [[msg.dict() for msg in m] for m in messages] + content_str = json.dumps({"messages": _jsonify_safely(message_dicts)}) + data = { + "event_type": "CHAT_MODEL_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) if serialized else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM ends running.""" + metadata = kwargs.get("metadata") or {} + for generations in response.generations: + for generation in generations: + content_str = json.dumps({"response": generation.text}) + data = { + "event_type": "LLM_RESPONSE", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(metadata)), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Any] = None, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a new token is generated.""" + content_str = json.dumps({"token": token}) + data = { + "event_type": "LLM_NEW_TOKEN", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain starts running.""" + content_str = json.dumps({"inputs": _jsonify_safely(inputs)}) + data = { + "event_type": "CHAIN_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) + if serialized + else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_chain_end( + self, + outputs: Union[Dict[str, Any], Any], + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain ends running.""" + content_str = json.dumps({"outputs": _jsonify_safely(outputs)}) + data = { + "event_type": "CHAIN_END", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + data = { + "event_type": "CHAIN_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + self._log(data) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool starts running.""" + content_str = json.dumps({"input": input_str}) + data = { + "event_type": "TOOL_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) + if serialized + else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_tool_end( + self, + output: Any, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool ends running.""" + content_str = json.dumps({"output": str(output)}) + data = { + "event_type": "TOOL_END", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + data = { + "event_type": "TOOL_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + self._log(data) + + def on_text( + self, + text: str, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run on arbitrary text.""" + content_str = json.dumps({"text": text}) + data = { + "event_type": "TEXT", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + content_str = json.dumps({"tool": action.tool, "input": str(action.tool_input)}) + data = { + "event_type": "AGENT_ACTION", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when agent ends running.""" + content_str = json.dumps({"output": _jsonify_safely(finish.return_values)}) + data = { + "event_type": "AGENT_FINISH", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a retriever starts.""" + content_str = json.dumps({"query": query}) + data = { + "event_type": "RETRIEVER_START", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "serialized": json.dumps(_jsonify_safely(serialized)) + if serialized + else None, + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_retriever_end( + self, + documents: Any, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a retriever ends.""" + docs = [doc.dict() for doc in documents] + content_str = json.dumps({"documents": _jsonify_safely(docs)}) + data = { + "event_type": "RETRIEVER_END", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": self._truncate_content_safely(content_str), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + "is_truncated": len(content_str.encode("utf-8")) > self._max_content_length, + } + self._log(data) + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Run when a retriever errors.""" + data = { + "event_type": "RETRIEVER_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + self._log(data) + + def on_llm_error( + self, + error: BaseException, + *, + run_id: Any, + parent_run_id: Optional[Any] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + data = { + "event_type": "LLM_ERROR", + "run_id": str(run_id), + "parent_run_id": str(parent_run_id), + "content": None, + "error_message": str(error), + "metadata": json.dumps(_jsonify_safely(kwargs.get("metadata", {}))), + "tags": json.dumps(_jsonify_safely(tags or [])), + } + self._log(data) + + def close(self) -> None: + """ + Shuts down the callback handler, ensuring all logs are flushed and clients are + properly closed. This should be called before application exit. + + Once your Langchain application has completed its tasks, ensure that you call + the `close` method to finalize the logging process. + """ + logging.info("BQ Callback: Shutdown started.") + + if self._write_client and hasattr(self._write_client, "close"): + try: + logging.info("BQ Callback: Closing write client.") + self._write_client.close() + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning("BQ Callback: Error closing write client: %s", e) + if self._bq_client: + try: + self._bq_client.close() + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning("BQ Callback: Error closing BQ client: %s", e) + + self._write_client = None + self._bq_client = None + logging.info("BQ Callback: Shutdown complete.") diff --git a/libs/community/tests/unit_tests/callbacks/test_bigquery_callback.py b/libs/community/tests/unit_tests/callbacks/test_bigquery_callback.py new file mode 100644 index 000000000..62ed0d515 --- /dev/null +++ b/libs/community/tests/unit_tests/callbacks/test_bigquery_callback.py @@ -0,0 +1,565 @@ +"""Unit tests for BigQueryCallbackHandler.""" + +from typing import Any, Dict, Generator +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from langchain_core.agents import AgentAction, AgentFinish +from langchain_core.outputs import LLMResult +from langchain_core.messages import HumanMessage + +from langchain_community.callbacks.bigquery_callback import ( + AsyncBigQueryCallbackHandler, + BigQueryCallbackHandler, +) +from langchain_core.documents import Document + +@pytest.fixture +def mock_bigquery_clients() -> Generator[Dict[str, Any], None, None]: + """Mocks the BigQuery clients and dependencies.""" + with patch( + "langchain_community.callbacks.bigquery_callback.import_google_cloud_bigquery" + ) as mock_import: + mock_bigquery = MagicMock() + mock_google_auth = MagicMock() + mock_gapic_client_info = MagicMock() + mock_async_client_module = MagicMock() + mock_sync_client_module = MagicMock() + mock_bq_storage = MagicMock() + mock_pa = MagicMock() + + mock_import.return_value = ( + mock_bigquery, + mock_google_auth, + mock_gapic_client_info, + mock_async_client_module, + mock_sync_client_module, + mock_bq_storage, + mock_pa, + ) + + # Mock the async client instance + mock_async_write_client = AsyncMock() + mock_async_client_module.BigQueryWriteAsyncClient.return_value = ( + mock_async_write_client + ) + + # Mock the sync client instance + mock_sync_write_client = MagicMock() + mock_sync_client_module.BigQueryWriteClient.return_value = ( + mock_sync_write_client + ) + + # Mock the sync BigQuery client instance + mock_bq_client = MagicMock() + mock_bigquery.Client.return_value = mock_bq_client + + # Mock google auth to avoid real authentication + mock_google_auth.default = MagicMock(return_value=(None, "test-project")) + + yield { + "mock_bigquery": mock_bigquery, + "mock_google_auth": mock_google_auth, + "mock_async_write_client": mock_async_write_client, + "mock_sync_write_client": mock_sync_write_client, + "mock_bq_client": mock_bq_client, + "mock_pa": mock_pa, + } + + +@pytest.fixture +async def handler(mock_bigquery_clients: Dict[str, Any]) -> AsyncBigQueryCallbackHandler: + """ + Returns an initialized `AsyncBigQueryCallbackHandler` with mocked clients. + """ + handler = AsyncBigQueryCallbackHandler( + project_id="test-project", + dataset_id="test_dataset", + table_id="test_table", + ) + # Ensure initialization is run + await handler._ensure_init() + return handler + + +@pytest.fixture +def sync_handler( + mock_bigquery_clients: Dict[str, Any] +) -> BigQueryCallbackHandler: + """ + Returns an initialized `BigQueryCallbackHandler` with mocked clients. + """ + handler = BigQueryCallbackHandler( + project_id="test-project", + dataset_id="test_dataset", + table_id="test_table", + ) + # Ensure initialization is run + handler._ensure_init() + return handler + + +@pytest.mark.asyncio +async def test_async_on_llm_start( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_llm_start logs the correct event.""" + run_id = uuid4() + parent_run_id = uuid4() + await handler.on_llm_start( + serialized={"name": "test_llm"}, + prompts=["test prompt"], + run_id=run_id, + parent_run_id=parent_run_id, + ) + + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_llm_start( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_llm_start logs the correct event.""" + run_id = uuid4() + parent_run_id = uuid4() + sync_handler.on_llm_start( + serialized={"name": "test_llm"}, + prompts=["test prompt"], + run_id=run_id, + parent_run_id=parent_run_id, + ) + + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_llm_end( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_llm_end logs the correct event.""" + response = LLMResult(generations=[], llm_output={"model_name": "test_model"}) + await handler.on_llm_end(response, run_id=uuid4()) + + # on_llm_end might not log if there are no generations. Let's add one. + response.generations.append([MagicMock(text="test generation")]) + await handler.on_llm_end(response, run_id=uuid4()) + + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_llm_end( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_llm_end logs the correct event.""" + response = LLMResult(generations=[], llm_output={"model_name": "test_model"}) + sync_handler.on_llm_end(response, run_id=uuid4()) + + # on_llm_end might not log if there are no generations. Let's add one. + response.generations.append([MagicMock(text="test generation")]) + sync_handler.on_llm_end(response, run_id=uuid4()) + + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_chain_start( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_chain_start logs the correct event.""" + await handler.on_chain_start( + serialized={"name": "test_chain"}, inputs={"input": "test"}, run_id=uuid4() + ) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_chain_start( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_chain_start logs the correct event.""" + sync_handler.on_chain_start( + serialized={"name": "test_chain"}, inputs={"input": "test"}, run_id=uuid4() + ) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_chain_end( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_chain_end logs the correct event.""" + await handler.on_chain_end(outputs={"output": "test"}, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_chain_end( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_chain_end logs the correct event.""" + sync_handler.on_chain_end(outputs={"output": "test"}, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_tool_start( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_tool_start logs the correct event.""" + await handler.on_tool_start( + serialized={"name": "test_tool"}, input_str="test", run_id=uuid4() + ) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_tool_start( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_tool_start logs the correct event.""" + sync_handler.on_tool_start( + serialized={"name": "test_tool"}, input_str="test", run_id=uuid4() + ) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_agent_action( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_agent_action logs the correct event.""" + action = AgentAction(tool="test_tool", tool_input="test", log="test log") + await handler.on_agent_action(action, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + +def test_sync_on_agent_action( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_agent_action logs the correct event.""" + action = AgentAction(tool="test_tool", tool_input="test", log="test log") + sync_handler.on_agent_action(action, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_agent_finish( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_agent_finish logs the correct event.""" + finish = AgentFinish(return_values={"output": "test"}, log="test log") + await handler.on_agent_finish(finish, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + +def test_sync_on_agent_finish( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_agent_finish logs the correct event.""" + finish = AgentFinish(return_values={"output": "test"}, log="test log") + sync_handler.on_agent_finish(finish, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_llm_error( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_llm_error logs the correct event.""" + await handler.on_llm_error(Exception("test error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_llm_error( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_llm_error logs the correct event.""" + sync_handler.on_llm_error(Exception("test error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_chat_model_start( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_chat_model_start logs the correct event.""" + await handler.on_chat_model_start( + serialized={"name": "test_chat_model"}, + messages=[[HumanMessage(content="test")]], + run_id=uuid4(), + ) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_chat_model_start( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_chat_model_start logs the correct event.""" + sync_handler.on_chat_model_start( + serialized={"name": "test_chat_model"}, + messages=[[HumanMessage(content="test")]], + run_id=uuid4(), + ) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_retriever_end( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_retriever_end logs the correct event.""" + documents = [Document(page_content="test document")] + await handler.on_retriever_end(documents, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_retriever_end( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_retriever_end logs the correct event.""" + documents = [Document(page_content="test document")] + sync_handler.on_retriever_end(documents, run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_ensure_init_creates_dataset_and_table( + mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync _ensure_init creates dataset and table if they don't exist.""" + handler = BigQueryCallbackHandler( + project_id="test-project", + dataset_id="test_dataset", + table_id="test_table", + ) + initialized = handler._ensure_init() + + assert initialized is True + mock_bq_client = mock_bigquery_clients["mock_bq_client"] + mock_bq_client.create_dataset.assert_called_once_with("test_dataset", exists_ok=True) + mock_bq_client.create_table.assert_called_once() + + +def test_sync_init_failure(mock_bigquery_clients: Dict[str, Any]) -> None: + """Test that sync initialization failure is handled gracefully.""" + mock_bigquery_clients["mock_google_auth"].default.side_effect = Exception( + "Auth failed" + ) + handler = BigQueryCallbackHandler(project_id="test-project", dataset_id="test_dataset") + initialized = handler._ensure_init() + assert not initialized + + +@pytest.mark.asyncio +async def test_ensure_init_creates_dataset_and_table( + mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that _ensure_init creates dataset and table if they don't exist.""" + handler = AsyncBigQueryCallbackHandler( + project_id="test-project", + dataset_id="test_dataset", + table_id="test_table", + ) + await handler._ensure_init() + + mock_bq_client = mock_bigquery_clients["mock_bq_client"] + mock_bq_client.create_dataset.assert_called_once_with("test_dataset", exists_ok=True) + mock_bq_client.create_table.assert_called_once() + + +@pytest.mark.asyncio +async def test_init_failure(mock_bigquery_clients: Dict[str, Any]) -> None: + """Test that initialization failure is handled gracefully.""" + mock_bigquery_clients["mock_google_auth"].default.side_effect = Exception( + "Auth failed" + ) + handler = AsyncBigQueryCallbackHandler( + project_id="test-project", + dataset_id="test_dataset", + table_id="test_table", + ) + initialized = await handler._ensure_init() + assert not initialized + + # Verify that no write is attempted if init failed + await handler.on_llm_start( + serialized={"name": "test_llm"}, prompts=["test"], run_id=uuid4() + ) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + mock_write_client.append_rows.assert_not_called() + + +@pytest.mark.asyncio +async def test_async_close( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that the close method closes clients.""" + await handler.close() + mock_async_write_client = mock_bigquery_clients["mock_async_write_client"] + mock_async_write_client.close.assert_called_once() + mock_bq_client = mock_bigquery_clients["mock_bq_client"] + mock_bq_client.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_on_llm_new_token( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_llm_new_token logs the correct event.""" + await handler.on_llm_new_token(token="new", run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_llm_new_token( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_llm_new_token logs the correct event.""" + sync_handler.on_llm_new_token(token="new", run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_tool_end( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_tool_end logs the correct event.""" + await handler.on_tool_end(output="test output", run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_tool_end( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_tool_end logs the correct event.""" + sync_handler.on_tool_end(output="test output", run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_tool_error( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_tool_error logs the correct event.""" + await handler.on_tool_error(Exception("tool error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_tool_error( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_tool_error logs the correct event.""" + sync_handler.on_tool_error(Exception("tool error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_chain_error( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_chain_error logs the correct event.""" + await handler.on_chain_error(Exception("chain error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_chain_error( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_chain_error logs the correct event.""" + sync_handler.on_chain_error(Exception("chain error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_retriever_start( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_retriever_start logs the correct event.""" + await handler.on_retriever_start( + serialized={"name": "test_retriever"}, query="test query", run_id=uuid4() + ) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_retriever_start( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_retriever_start logs the correct event.""" + sync_handler.on_retriever_start( + serialized={"name": "test_retriever"}, query="test query", run_id=uuid4() + ) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_retriever_error( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_retriever_error logs the correct event.""" + await handler.on_retriever_error(Exception("retriever error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_retriever_error( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_retriever_error logs the correct event.""" + sync_handler.on_retriever_error(Exception("retriever error"), run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_on_text( + handler: AsyncBigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that on_text logs the correct event.""" + await handler.on_text("some text", run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_async_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_on_text( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that sync on_text logs the correct event.""" + sync_handler.on_text("some text", run_id=uuid4()) + mock_write_client = mock_bigquery_clients["mock_sync_write_client"] + assert mock_write_client.append_rows.call_count == 1 + + +def test_sync_close( + sync_handler: BigQueryCallbackHandler, mock_bigquery_clients: Dict[str, Any] +) -> None: + """Test that the sync close method closes clients.""" + sync_handler.close() + mock_sync_write_client = mock_bigquery_clients["mock_sync_write_client"] + mock_sync_write_client.close.assert_called_once() + mock_bq_client = mock_bigquery_clients["mock_bq_client"] + mock_bq_client.close.assert_called_once() + diff --git a/libs/community/tests/unit_tests/callbacks/test_imports.py b/libs/community/tests/unit_tests/callbacks/test_imports.py index 566099cbd..050a2e6d3 100644 --- a/libs/community/tests/unit_tests/callbacks/test_imports.py +++ b/libs/community/tests/unit_tests/callbacks/test_imports.py @@ -28,6 +28,8 @@ "UpTrainCallbackHandler", "UpstashRatelimitError", "UpstashRatelimitHandler", + "AsyncBigQueryCallbackHandler", + "BigQueryCallbackHandler", ]