From 2ed2e18a079b3c4b0b977662ecf609a01169429f Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:42:53 -0800 Subject: [PATCH 01/10] Updated node insert to use native cloud spanner api. Fixed bug where namespace prefixes were inserted into the database. --- .../datacommons_api/api_cli.py | 32 ++- .../datacommons_api/core/config.py | 2 +- .../endpoints/routers/node_router.py | 6 +- .../datacommons_api/services/graph_service.py | 256 ++++++++++++++---- .../services/graph_service_test.py | 127 +++++++++ .../datacommons_db/models/edge.py | 8 +- .../datacommons_db/models/node.py | 3 +- .../datacommons_db/models/observation.py | 3 +- packages/datacommons-db/pyproject.toml | 1 + uv.lock | 2 + 10 files changed, 382 insertions(+), 58 deletions(-) create mode 100644 packages/datacommons-api/datacommons_api/services/graph_service_test.py diff --git a/packages/datacommons-api/datacommons_api/api_cli.py b/packages/datacommons-api/datacommons_api/api_cli.py index 3b50566..2db660a 100644 --- a/packages/datacommons-api/datacommons_api/api_cli.py +++ b/packages/datacommons-api/datacommons_api/api_cli.py @@ -16,9 +16,10 @@ import uvicorn from datacommons_api.app import app -from datacommons_api.core.config import initialize_config +from datacommons_api.core.config import get_config, initialize_config from datacommons_api.core.logging import get_logger, setup_logging -from datacommons_db.session import initialize_db +from datacommons_db.session import get_session, initialize_db +from datacommons_api.services.graph_service import GraphService setup_logging() logger = get_logger(__name__) @@ -72,3 +73,30 @@ def start( port=port, reload=reload, ) + + +@api.command() +@click.option("--gcp-project-id", help="GCP project id.", required=True) +@click.option("--gcp-spanner-instance-id", help="GCP Spanner instance id.", required=True) +@click.option("--gcp-spanner-database-name", help="GCP Spanner database name.", required=True) +def drop_tables( + gcp_project_id: str, + gcp_spanner_instance_id: str, + gcp_spanner_database_name: str, +): + """Drop Node and Edge tables from the graph database.""" + logger.info("Dropping Node and Edge tables from the graph database") + initialize_config( + gcp_project_id=gcp_project_id, + gcp_spanner_instance_id=gcp_spanner_instance_id, + gcp_spanner_database_name=gcp_spanner_database_name, + ) + config = get_config() + db = get_session( + config.GCP_PROJECT_ID, + config.GCP_SPANNER_INSTANCE_ID, + config.GCP_SPANNER_DATABASE_NAME, + ) + graph_service = GraphService(db) + graph_service.drop_tables() + logger.info("Successfully dropped Node and Edge tables") \ No newline at end of file diff --git a/packages/datacommons-api/datacommons_api/core/config.py b/packages/datacommons-api/datacommons_api/core/config.py index a73810b..3fbba81 100644 --- a/packages/datacommons-api/datacommons_api/core/config.py +++ b/packages/datacommons-api/datacommons_api/core/config.py @@ -65,7 +65,7 @@ def validate_config_or_exit(config: Config) -> None: # Ensure GCP Spanner is configured for var in REQUIRED_ENV_VARS: if not getattr(config, var): - logger.error("Environment variable %s must be set", var) + logger.error("Config variable %s must be set", var) sys.exit(1) diff --git a/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py b/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py index 41f036c..324ff72 100644 --- a/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py +++ b/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py @@ -29,7 +29,8 @@ # JSON-LD endpoint -@router.get("/nodes/", response_model=JSONLDDocument, response_model_exclude_none=True) +@router.get("/nodes", response_model=JSONLDDocument, response_model_exclude_none=True) +@router.get("/nodes/", response_model=JSONLDDocument, response_model_exclude_none=True, include_in_schema=False) def get_nodes( limit: int = DEFAULT_NODE_FETCH_LIMIT, type_filter: Annotated[ @@ -44,7 +45,8 @@ def get_nodes( return graph_service.get_graph_nodes(limit=limit, type_filter=type_filter) -@router.post("/nodes/", response_model=UpdateResponse, response_model_exclude_none=True) +@router.post("/nodes", response_model=UpdateResponse, response_model_exclude_none=True) +@router.post("/nodes/", response_model=UpdateResponse, response_model_exclude_none=True, include_in_schema=False) def insert_nodes( jsonld: JSONLDDocument, graph_service: Annotated[GraphService, Depends(with_graph_service)] = None, diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index 941f2e6..a0920df 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -13,17 +13,19 @@ # limitations under the License. # Standard library imports +import base64 import logging - +import traceback +from google.cloud import spanner +from google.cloud.spanner_v1 import database from sqlalchemy import text from sqlalchemy.orm import Session, joinedload -# Third-party imports +from datacommons_api.core.config import get_config from datacommons_api.core.constants import DEFAULT_NODE_FETCH_LIMIT -from datacommons_db.models.edge import EdgeModel - -# Local application imports -from datacommons_db.models.node import NodeModel +from datacommons_db.models.edge import EdgeModel, EDGE_TABLE_NAME +from datacommons_db.models.node import NodeModel, NODE_TABLE_NAME +from datacommons_db.models.edge import OBJECT_VALUE_MAX_LENGTH from datacommons_schema.models.jsonld import ( GraphNode, GraphNodePropertyValue, @@ -33,6 +35,9 @@ # Configure logging logger = logging.getLogger(__name__) +# Silence OpenTelemetry warnings/errors (Spanner client integration triggers these) +logging.getLogger("opentelemetry.metrics._internal").setLevel(logging.ERROR) +logging.getLogger("opentelemetry.sdk.metrics._internal.export").setLevel(logging.CRITICAL) class GraphServiceError(Exception): """ @@ -69,13 +74,21 @@ def create_node_model(graph_node: GraphNode) -> NodeModel: types = graph_node.type if not isinstance(types, list): types = [types] - types = [t for t in types if t is not None] + types_with_namespaces = [t for t in types if t is not None] + # Remove all CURIE namespaces before storing the node id + subject_id = strip_namespace(graph_node.id) + types = [strip_namespace(t) for t in types] return NodeModel( - subject_id=graph_node.id, + subject_id=subject_id, types=types, ) +def strip_namespace(id: str) -> str: + """ + Strip all CURIE namespaces from an id. + """ + return id.split(":")[-1] def create_edge_model( subject_id: str, @@ -90,7 +103,8 @@ def create_edge_model( Args: subject_id: The ID of the source node predicate: The edge predicate - value_data: The edge value + object_id: The ID of the target node + object_value: The edge value - A string literal - A GraphNode provenance: The ID of a node that is the provenance of the edge @@ -101,17 +115,17 @@ def create_edge_model( """ # Handle lists of values by creating multiple edges edge = EdgeModel( - object_id=object_id, - predicate=predicate, - subject_id=subject_id, + object_id=strip_namespace(object_id), + predicate=strip_namespace(predicate), + subject_id=strip_namespace(subject_id), ) if provenance: - edge.provenance = provenance + edge.provenance = strip_namespace(provenance) if object_value: - edge.object_value = object_value + edge.object_value = strip_namespace(object_value) if object_id else object_value if object_value and not object_id: # If the edge value is a string, use the subject id as the object id - edge.object_id = subject_id + edge.object_id = strip_namespace(subject_id) if not object_id and not object_value: message = f"Missing object_id or object_value for edge {subject_id} {predicate}" raise GraphServiceError(message) @@ -182,11 +196,15 @@ def node_model_to_graph_node(node: NodeModel) -> GraphNode: edge_groups[edge.predicate] = [] property_value = {} - # If the edge has a literal value, add it to the property value - if edge.object_value: + + if edge.object_bytes: + # If the edge has bytes, decode them and add them to the property value + property_value["@value"] = base64.b64decode(edge.object_bytes).decode("utf-8") + elif edge.object_value: + # If the edge has a literal value, add it to the property value property_value["@value"] = edge.object_value - # If the edge has an object id, add it to the property value else: + # If the edge has an object id, add it to the property value property_value["@id"] = edge.object_id # If the edge has provenance, add it to the property value @@ -203,6 +221,123 @@ def node_model_to_graph_node(node: NodeModel) -> GraphNode: return GraphNode(**graph_node_properties) +def get_edge_val(e: EdgeModel, col: str) -> str | None: + """ + Helper function to get the value of an edge column, with support for Spanner index key length limits. + + Args: + e: The EdgeModel instance + col: The column name + + Returns: + The value of the edge column, with support for Spanner index key length limits + """ + if col in ("object_value", "object_bytes"): + val = getattr(e, "object_value") + if not val: + return None + val_bytes = val.encode("utf-8") + + # A Spanner index key incorporates both the indexed columns AND the Primary Key. + # Max index key length is 8192 bytes total. The Primary Keys can swallow up to 4096 bytes easily. + # So we must restrict object_value to 4096 bytes to guarantee the total key size is < 8192 bytes. + if col == "object_value": + if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: + # Slice to exactly OBJECT_VALUE_MAX_LENGTH bytes, dropping fragmented chars gracefully + val_truncated = val_bytes[:OBJECT_VALUE_MAX_LENGTH].decode("utf-8", errors="ignore") + return val_truncated + return val + elif col == "object_bytes": + if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: + import base64 + return base64.b64encode(val_bytes).decode("utf-8") + return None + return getattr(e, col) + +def get_node_models(jsonld: JSONLDDocument) -> list[NodeModel]: + """ + Converts a JSON-LD document into a list of NodeModel instances with their outgoing edges loaded. + """ + node_models = [] + for graph_node in jsonld.graph: + node_model = create_node_model(graph_node) + node_model.outgoing_edges = extract_edges_from_node(graph_node) + node_models.append(node_model) + return node_models + +def get_node_model_batches(node_models: list[NodeModel], batch_size: int = 1000) -> list[list[NodeModel]]: + """ + Splits a list of NodeModel instances into batches of nodes and edges. + + Args: + node_models: List of NodeModel instances + batch_size: Maximum number of nodes and edges per batch + + Returns: + List of batches of nodes and edges + """ + node_batches: list[list[NodeModel]] = [] + current_batch: list[NodeModel] = [] + current_batch_len = 0 + for node_model in node_models: + # Add node and its edges to the current batch + node_len = len(node_model.outgoing_edges) + 1 + if current_batch_len + node_len < batch_size: + current_batch.append(node_model) + current_batch_len += node_len + else: + # If the current batch is full, add it to the list of batches + node_batches.append(current_batch) + current_batch = [] + current_batch_len = 0 + # Add the last batch if it's not empty + if current_batch: + node_batches.append(current_batch) + return node_batches + +def insert_node_models_batch(node_models: list[NodeModel], spanner_batch: database.BatchCheckout): + """ + Inserts a batch of NodeModel instances into the database using Spanner API. + + Args: + node_models: List of NodeModel instances + spanner_batch: Spanner batch to insert into + + Returns: + None + """ + # Get the column names from the NodeModel and EdgeModel + node_columns = tuple(c.name for c in NodeModel.__table__.columns) + edge_columns = tuple(c.name for c in EdgeModel.__table__.columns if c.name != "object_value_tokenlist") + + # Insert nodes into the database + spanner_batch.insert_or_update( + table=NODE_TABLE_NAME, + columns=node_columns, + values=[tuple(getattr(n, col) for col in node_columns) for n in node_models], + ) + + # Delete existing edges for these nodes using a KeyRange prefix + keyset = spanner.KeySet( + ranges=[ + spanner.KeyRange(start_closed=[n.subject_id], end_closed=[n.subject_id]) + for n in node_models + ] + ) + spanner_batch.delete(table=EDGE_TABLE_NAME, keyset=keyset) + + # Insert the new edges + for node_model in node_models: + # Skip if there are no edges to avoid empty insert errors + if not node_model.outgoing_edges: + continue + spanner_batch.insert_or_update( + table=EDGE_TABLE_NAME, + columns=edge_columns, + values=[tuple(get_edge_val(e, col) for col in edge_columns) + for e in node_model.outgoing_edges], + ) + class GraphService: """ Service for managing graph database operations. @@ -219,7 +354,14 @@ def __init__(self, session: Session): session: SQLAlchemy session for database operations """ self.session = session - logger.info("Initialized GraphService with new session") + + config = get_config() + spanner_client = spanner.Client(project=config.GCP_PROJECT_ID) + instance = spanner_client.instance(config.GCP_SPANNER_INSTANCE_ID) + self.spanner_database = instance.database(config.GCP_SPANNER_DATABASE_NAME) + + # Silence Spanner client INFO logs + self.spanner_database.logger.setLevel(logging.WARNING) def get_graph_nodes( self, @@ -290,40 +432,56 @@ def _get_nodes_with_outgoing_edges( logger.debug("Retrieved %d nodes with outgoing edges", len(nodes)) return nodes - def insert_graph_nodes(self, jsonld: JSONLDDocument) -> None: + def insert_graph_nodes(self, jsonld: JSONLDDocument, batch_size: int = 1000) -> None: """ - Insert nodes and edges from a JSON-LD document into the database. - - Raises an exception if the node already exists. + Inserts nodes and edges from a JSON-LD document into the database using Spanner API. - This method processes the JSON-LD document, creating NodeModel and EdgeModel - instances for each node and its edges. It handles both literal values and - references to other nodes, preserving provenance information. + Updates the nodes and edges if they already exist. Args: jsonld: The JSON-LD document containing nodes and edges to insert """ - nodes: list[NodeModel] = [] - edges: list[EdgeModel] = [] - - logger.info("Inserting %d nodes from JSON-LD document", len(jsonld.graph)) - - # Process each node in the graph - for graph_node in jsonld.graph: - # Create node model - node_model = create_node_model(graph_node) - nodes.append(node_model) - - # Extract and create edge models - node_edges = extract_edges_from_node(graph_node) - edges.extend(node_edges) - logger.info("Inserting %d nodes and %d edges", len(nodes), len(edges)) - - # Add all nodes and edges to the session - self.session.add_all(nodes) - self.session.add_all(edges) - - # Commit the transaction - self.session.commit() - logger.info("Successfully committed all nodes and edges to database") + # Convert JSON-LD to NodeModels + node_models = get_node_models(jsonld) + node_model_batches = get_node_model_batches(node_models, batch_size) + total_edges = sum(len(node_model.outgoing_edges) for node_model in node_models) + + logger.info("Inserting %d nodes and %d edges in %d batch(es) to Spanner", len(node_models), total_edges, len(node_model_batches)) + + # Insert nodes and edges in batches + success_count = 0 + try: + for node_model_batch in node_model_batches: + with self.spanner_database.batch() as spanner_batch: + insert_node_models_batch(node_model_batch, spanner_batch) + success_count += len(node_model_batch) + except Exception as e: + error_message = f"Failed to insert nodes and edges to Spanner after {success_count}/{len(node_models)} nodes inserted" + logger.error(error_message + ": %s", e) + traceback.print_exc() + raise GraphServiceError(error_message) + + logger.info("Successfully committed %d nodes and %d edges to Spanner", success_count, total_edges) + + def drop_tables(self) -> None: + """ + Delete Node and Edge tables from the graph database. + """ + logger.info("Dropping Node and Edge tables from the graph database") + logger.info("Are you sure you want to continue? (yes/no)") + if input() == "yes": + logger.info("Dropping index EdgeByObjectValue") + query = "DROP INDEX EdgeByObjectValue" + self.session.execute(text(query)) + logger.info("Dropping table %s", EDGE_TABLE_NAME) + query = f"DROP TABLE {EDGE_TABLE_NAME}" + self.session.execute(text(query)) + logger.info("Dropping table %s", NODE_TABLE_NAME) + query = f"DROP TABLE {NODE_TABLE_NAME}" + self.session.execute(text(query)) + self.session.commit() + logger.info("Successfully dropped Node and Edge tables") + else: + logger.info("Quitting. Did not drop tables") + \ No newline at end of file diff --git a/packages/datacommons-api/datacommons_api/services/graph_service_test.py b/packages/datacommons-api/datacommons_api/services/graph_service_test.py new file mode 100644 index 0000000..ea08607 --- /dev/null +++ b/packages/datacommons-api/datacommons_api/services/graph_service_test.py @@ -0,0 +1,127 @@ +import pytest +from unittest.mock import MagicMock, patch, call + +from sqlalchemy.orm import Session +from google.cloud import spanner +from datacommons_api.core.config import Config +from datacommons_api.services.graph_service import GraphService, GraphServiceError +from datacommons_db.models.node import NodeModel +from datacommons_db.models.edge import EdgeModel +from datacommons_schema.models.jsonld import JSONLDDocument, GraphNode + +@pytest.fixture +def mock_session(): + return MagicMock(spec=Session) + +@pytest.fixture +def mock_config(): + with patch("datacommons_api.services.graph_service.get_config") as mock: + mock_config_instance = MagicMock(spec=Config) + mock_config_instance.GCP_PROJECT_ID = "test-project" + mock_config_instance.GCP_SPANNER_INSTANCE_ID = "test-instance" + mock_config_instance.GCP_SPANNER_DATABASE_NAME = "test-db" + mock.return_value = mock_config_instance + yield mock + +@pytest.fixture +def mock_spanner_client(): + with patch("datacommons_api.services.graph_service.spanner.Client") as mock: + mock_client_instance = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + + mock_client_instance.instance.return_value = mock_instance + mock_instance.database.return_value = mock_database + + mock.return_value = mock_client_instance + yield mock_client_instance + +@pytest.fixture +def graph_service(mock_session, mock_config, mock_spanner_client): + return GraphService(session=mock_session) + +def test_init(mock_session, mock_config, mock_spanner_client): + service = GraphService(session=mock_session) + assert service.session == mock_session + mock_spanner_client.instance.assert_called_once_with("test-instance") + mock_spanner_client.instance.return_value.database.assert_called_once_with("test-db") + +def test_get_graph_nodes(graph_service, mock_session): + # Setup mock data + mock_node = NodeModel(subject_id="test_node", types=["TestType"]) + mock_edge = EdgeModel( + subject_id="test_node", + predicate="test_predicate", + object_id="test_target" + ) + mock_node.outgoing_edges = [mock_edge] + + # Mock the query chain + mock_query = MagicMock() + mock_query.options.return_value.limit.return_value.all.return_value = [mock_node] + # Handle type filter + mock_query.filter.return_value.params.return_value.options.return_value.limit.return_value.all.return_value = [mock_node] + mock_session.query.return_value = mock_query + + # Test without filter + result = graph_service.get_graph_nodes(limit=10) + + # Verify + assert isinstance(result, JSONLDDocument) + assert len(result.graph) == 1 + assert result.graph[0].id == "test_node" + assert result.graph[0].type == ["TestType"] + assert result.graph[0].model_dump(by_alias=True, exclude_none=True)["test_predicate"] == {"@id": "test_target"} + + # Test with filter + result = graph_service.get_graph_nodes(limit=10, type_filter=["TestType"]) + assert isinstance(result, JSONLDDocument) + assert len(result.graph) == 1 + +def test_insert_graph_nodes(graph_service, mock_session, mock_spanner_client): + # Setup mock data for JSONLD + graph_node = GraphNode(**{ + "@id": "test_node", + "@type": ["TestType"], + "test_predicate": {"@id": "test_target"} + }) + mock_jsonld = JSONLDDocument( + context={"test": "http://test.com/"}, + graph=[graph_node] + ) + + mock_batch = MagicMock() + mock_database = mock_spanner_client.instance.return_value.database.return_value + mock_database.batch.return_value.__enter__.return_value = mock_batch + + # Test + graph_service.insert_graph_nodes(mock_jsonld) + + # Verify + assert mock_batch.insert_or_update.call_count == 2 + mock_batch.delete.assert_called_once() + +def test_insert_graph_nodes_error(graph_service, mock_spanner_client): + # Setup mock data that triggers an error + mock_jsonld = JSONLDDocument(context={}, graph=[GraphNode(**{"@id": "n1", "@type": "t1"})]) + + mock_database = mock_spanner_client.instance.return_value.database.return_value + mock_database.batch.side_effect = Exception("Spanner Error") + + with pytest.raises(GraphServiceError) as exc_info: + graph_service.insert_graph_nodes(mock_jsonld) + + assert "Failed to insert nodes and edges to Spanner" in str(exc_info.value) + +def test_drop_tables(graph_service, mock_session): + with patch("builtins.input", return_value="yes"): + graph_service.drop_tables() + assert mock_session.execute.call_count == 3 + mock_session.commit.assert_called_once() + + mock_session.reset_mock() + + with patch("builtins.input", return_value="no"): + graph_service.drop_tables() + assert mock_session.execute.call_count == 0 + assert mock_session.commit.call_count == 0 diff --git a/packages/datacommons-db/datacommons_db/models/edge.py b/packages/datacommons-db/datacommons_db/models/edge.py index 25e85eb..3549a59 100644 --- a/packages/datacommons-db/datacommons_db/models/edge.py +++ b/packages/datacommons-db/datacommons_db/models/edge.py @@ -20,18 +20,22 @@ from datacommons_db.models.base import Base +EDGE_TABLE_NAME = "Edge" +OBJECT_VALUE_MAX_LENGTH = 4096 + class EdgeModel(Base): """ Represents an edge in the graph. """ - __tablename__ = "Edge" + __tablename__ = EDGE_TABLE_NAME subject_id = sa.Column( String(1024), sa.ForeignKey("Node.subject_id"), primary_key=True ) predicate = sa.Column(String(1024), primary_key=True) object_id = sa.Column(String(1024), primary_key=True) - object_value = sa.Column(Text(), nullable=True) + object_value = sa.Column(String(OBJECT_VALUE_MAX_LENGTH), nullable=True) + object_bytes = sa.Column(sa.LargeBinary(), nullable=True) object_hash = sa.Column(String(64), primary_key=True, nullable=True) provenance = sa.Column(String(1024), primary_key=True, nullable=True) # Use deferred to avoid loading the node data into memory diff --git a/packages/datacommons-db/datacommons_db/models/node.py b/packages/datacommons-db/datacommons_db/models/node.py index 053367d..c6e0e0c 100644 --- a/packages/datacommons-db/datacommons_db/models/node.py +++ b/packages/datacommons-db/datacommons_db/models/node.py @@ -19,13 +19,14 @@ from datacommons_db.models.base import Base +NODE_TABLE_NAME = "Node" class NodeModel(Base): """ Represents a node in the graph. """ - __tablename__ = "Node" + __tablename__ = NODE_TABLE_NAME subject_id = sa.Column(String(1024), primary_key=True, autoincrement=False) name = sa.Column(Text(), nullable=True) types = sa.Column(ARRAY(String(1024)), nullable=True) diff --git a/packages/datacommons-db/datacommons_db/models/observation.py b/packages/datacommons-db/datacommons_db/models/observation.py index 9bfec2f..8b4dd00 100644 --- a/packages/datacommons-db/datacommons_db/models/observation.py +++ b/packages/datacommons-db/datacommons_db/models/observation.py @@ -18,13 +18,14 @@ from datacommons_db.models.base import Base +OBSERVATION_TABLE_NAME = "Observation" class ObservationModel(Base): """ Represents a statistical observation of a variable. """ - __tablename__ = "Observation" + __tablename__ = OBSERVATION_TABLE_NAME variable_measured = sa.Column(String(1024), nullable=False, primary_key=True) observation_about = sa.Column(String(1024), nullable=False, primary_key=True) diff --git a/packages/datacommons-db/pyproject.toml b/packages/datacommons-db/pyproject.toml index b81dd58..9a9aa07 100644 --- a/packages/datacommons-db/pyproject.toml +++ b/packages/datacommons-db/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "sqlalchemy", "sqlalchemy-spanner", + "google-cloud-spanner", "setuptools<=80.0.0", # Pin version to <=80 to avoid https://stackoverflow.com/questions/76043689/pkg-resources-is-deprecated-as-an-api ] diff --git a/uv.lock b/uv.lock index 2dca153..4a33497 100644 --- a/uv.lock +++ b/uv.lock @@ -453,6 +453,7 @@ requires-dist = [ name = "datacommons-db" source = { editable = "packages/datacommons-db" } dependencies = [ + { name = "google-cloud-spanner" }, { name = "setuptools" }, { name = "sqlalchemy" }, { name = "sqlalchemy-spanner" }, @@ -460,6 +461,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "google-cloud-spanner" }, { name = "setuptools", specifier = "<=80.0.0" }, { name = "sqlalchemy" }, { name = "sqlalchemy-spanner" }, From 374ed5a995e7af669eee8fd46fe6a049eee84979 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:43:34 -0800 Subject: [PATCH 02/10] formatting --- .../datacommons_api/api_cli.py | 10 ++- .../endpoints/routers/node_router.py | 14 +++- .../datacommons_api/services/graph_service.py | 66 ++++++++++++----- .../services/graph_service_test.py | 72 +++++++++++-------- .../datacommons_db/models/edge.py | 1 + .../datacommons_db/models/node.py | 1 + .../datacommons_db/models/observation.py | 1 + 7 files changed, 116 insertions(+), 49 deletions(-) diff --git a/packages/datacommons-api/datacommons_api/api_cli.py b/packages/datacommons-api/datacommons_api/api_cli.py index 2db660a..2ffa9b0 100644 --- a/packages/datacommons-api/datacommons_api/api_cli.py +++ b/packages/datacommons-api/datacommons_api/api_cli.py @@ -77,8 +77,12 @@ def start( @api.command() @click.option("--gcp-project-id", help="GCP project id.", required=True) -@click.option("--gcp-spanner-instance-id", help="GCP Spanner instance id.", required=True) -@click.option("--gcp-spanner-database-name", help="GCP Spanner database name.", required=True) +@click.option( + "--gcp-spanner-instance-id", help="GCP Spanner instance id.", required=True +) +@click.option( + "--gcp-spanner-database-name", help="GCP Spanner database name.", required=True +) def drop_tables( gcp_project_id: str, gcp_spanner_instance_id: str, @@ -99,4 +103,4 @@ def drop_tables( ) graph_service = GraphService(db) graph_service.drop_tables() - logger.info("Successfully dropped Node and Edge tables") \ No newline at end of file + logger.info("Successfully dropped Node and Edge tables") diff --git a/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py b/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py index 324ff72..02d12f8 100644 --- a/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py +++ b/packages/datacommons-api/datacommons_api/endpoints/routers/node_router.py @@ -30,7 +30,12 @@ # JSON-LD endpoint @router.get("/nodes", response_model=JSONLDDocument, response_model_exclude_none=True) -@router.get("/nodes/", response_model=JSONLDDocument, response_model_exclude_none=True, include_in_schema=False) +@router.get( + "/nodes/", + response_model=JSONLDDocument, + response_model_exclude_none=True, + include_in_schema=False, +) def get_nodes( limit: int = DEFAULT_NODE_FETCH_LIMIT, type_filter: Annotated[ @@ -46,7 +51,12 @@ def get_nodes( @router.post("/nodes", response_model=UpdateResponse, response_model_exclude_none=True) -@router.post("/nodes/", response_model=UpdateResponse, response_model_exclude_none=True, include_in_schema=False) +@router.post( + "/nodes/", + response_model=UpdateResponse, + response_model_exclude_none=True, + include_in_schema=False, +) def insert_nodes( jsonld: JSONLDDocument, graph_service: Annotated[GraphService, Depends(with_graph_service)] = None, diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index a0920df..0bf451f 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -37,7 +37,10 @@ # Silence OpenTelemetry warnings/errors (Spanner client integration triggers these) logging.getLogger("opentelemetry.metrics._internal").setLevel(logging.ERROR) -logging.getLogger("opentelemetry.sdk.metrics._internal.export").setLevel(logging.CRITICAL) +logging.getLogger("opentelemetry.sdk.metrics._internal.export").setLevel( + logging.CRITICAL +) + class GraphServiceError(Exception): """ @@ -84,12 +87,14 @@ def create_node_model(graph_node: GraphNode) -> NodeModel: types=types, ) + def strip_namespace(id: str) -> str: """ Strip all CURIE namespaces from an id. """ return id.split(":")[-1] + def create_edge_model( subject_id: str, predicate: str, @@ -196,10 +201,12 @@ def node_model_to_graph_node(node: NodeModel) -> GraphNode: edge_groups[edge.predicate] = [] property_value = {} - + if edge.object_bytes: # If the edge has bytes, decode them and add them to the property value - property_value["@value"] = base64.b64decode(edge.object_bytes).decode("utf-8") + property_value["@value"] = base64.b64decode(edge.object_bytes).decode( + "utf-8" + ) elif edge.object_value: # If the edge has a literal value, add it to the property value property_value["@value"] = edge.object_value @@ -237,23 +244,27 @@ def get_edge_val(e: EdgeModel, col: str) -> str | None: if not val: return None val_bytes = val.encode("utf-8") - + # A Spanner index key incorporates both the indexed columns AND the Primary Key. # Max index key length is 8192 bytes total. The Primary Keys can swallow up to 4096 bytes easily. # So we must restrict object_value to 4096 bytes to guarantee the total key size is < 8192 bytes. if col == "object_value": if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: # Slice to exactly OBJECT_VALUE_MAX_LENGTH bytes, dropping fragmented chars gracefully - val_truncated = val_bytes[:OBJECT_VALUE_MAX_LENGTH].decode("utf-8", errors="ignore") + val_truncated = val_bytes[:OBJECT_VALUE_MAX_LENGTH].decode( + "utf-8", errors="ignore" + ) return val_truncated return val elif col == "object_bytes": if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: import base64 + return base64.b64encode(val_bytes).decode("utf-8") return None return getattr(e, col) + def get_node_models(jsonld: JSONLDDocument) -> list[NodeModel]: """ Converts a JSON-LD document into a list of NodeModel instances with their outgoing edges loaded. @@ -265,7 +276,10 @@ def get_node_models(jsonld: JSONLDDocument) -> list[NodeModel]: node_models.append(node_model) return node_models -def get_node_model_batches(node_models: list[NodeModel], batch_size: int = 1000) -> list[list[NodeModel]]: + +def get_node_model_batches( + node_models: list[NodeModel], batch_size: int = 1000 +) -> list[list[NodeModel]]: """ Splits a list of NodeModel instances into batches of nodes and edges. @@ -295,7 +309,10 @@ def get_node_model_batches(node_models: list[NodeModel], batch_size: int = 1000) node_batches.append(current_batch) return node_batches -def insert_node_models_batch(node_models: list[NodeModel], spanner_batch: database.BatchCheckout): + +def insert_node_models_batch( + node_models: list[NodeModel], spanner_batch: database.BatchCheckout +): """ Inserts a batch of NodeModel instances into the database using Spanner API. @@ -308,7 +325,11 @@ def insert_node_models_batch(node_models: list[NodeModel], spanner_batch: databa """ # Get the column names from the NodeModel and EdgeModel node_columns = tuple(c.name for c in NodeModel.__table__.columns) - edge_columns = tuple(c.name for c in EdgeModel.__table__.columns if c.name != "object_value_tokenlist") + edge_columns = tuple( + c.name + for c in EdgeModel.__table__.columns + if c.name != "object_value_tokenlist" + ) # Insert nodes into the database spanner_batch.insert_or_update( @@ -334,10 +355,13 @@ def insert_node_models_batch(node_models: list[NodeModel], spanner_batch: databa spanner_batch.insert_or_update( table=EDGE_TABLE_NAME, columns=edge_columns, - values=[tuple(get_edge_val(e, col) for col in edge_columns) - for e in node_model.outgoing_edges], + values=[ + tuple(get_edge_val(e, col) for col in edge_columns) + for e in node_model.outgoing_edges + ], ) + class GraphService: """ Service for managing graph database operations. @@ -432,7 +456,9 @@ def _get_nodes_with_outgoing_edges( logger.debug("Retrieved %d nodes with outgoing edges", len(nodes)) return nodes - def insert_graph_nodes(self, jsonld: JSONLDDocument, batch_size: int = 1000) -> None: + def insert_graph_nodes( + self, jsonld: JSONLDDocument, batch_size: int = 1000 + ) -> None: """ Inserts nodes and edges from a JSON-LD document into the database using Spanner API. @@ -447,8 +473,13 @@ def insert_graph_nodes(self, jsonld: JSONLDDocument, batch_size: int = 1000) -> node_model_batches = get_node_model_batches(node_models, batch_size) total_edges = sum(len(node_model.outgoing_edges) for node_model in node_models) - logger.info("Inserting %d nodes and %d edges in %d batch(es) to Spanner", len(node_models), total_edges, len(node_model_batches)) - + logger.info( + "Inserting %d nodes and %d edges in %d batch(es) to Spanner", + len(node_models), + total_edges, + len(node_model_batches), + ) + # Insert nodes and edges in batches success_count = 0 try: @@ -461,8 +492,12 @@ def insert_graph_nodes(self, jsonld: JSONLDDocument, batch_size: int = 1000) -> logger.error(error_message + ": %s", e) traceback.print_exc() raise GraphServiceError(error_message) - - logger.info("Successfully committed %d nodes and %d edges to Spanner", success_count, total_edges) + + logger.info( + "Successfully committed %d nodes and %d edges to Spanner", + success_count, + total_edges, + ) def drop_tables(self) -> None: """ @@ -484,4 +519,3 @@ def drop_tables(self) -> None: logger.info("Successfully dropped Node and Edge tables") else: logger.info("Quitting. Did not drop tables") - \ No newline at end of file diff --git a/packages/datacommons-api/datacommons_api/services/graph_service_test.py b/packages/datacommons-api/datacommons_api/services/graph_service_test.py index ea08607..d824fa8 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service_test.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service_test.py @@ -9,10 +9,12 @@ from datacommons_db.models.edge import EdgeModel from datacommons_schema.models.jsonld import JSONLDDocument, GraphNode + @pytest.fixture def mock_session(): return MagicMock(spec=Session) + @pytest.fixture def mock_config(): with patch("datacommons_api.services.graph_service.get_config") as mock: @@ -23,104 +25,118 @@ def mock_config(): mock.return_value = mock_config_instance yield mock + @pytest.fixture def mock_spanner_client(): with patch("datacommons_api.services.graph_service.spanner.Client") as mock: mock_client_instance = MagicMock() mock_instance = MagicMock() mock_database = MagicMock() - + mock_client_instance.instance.return_value = mock_instance mock_instance.database.return_value = mock_database - + mock.return_value = mock_client_instance yield mock_client_instance + @pytest.fixture def graph_service(mock_session, mock_config, mock_spanner_client): return GraphService(session=mock_session) + def test_init(mock_session, mock_config, mock_spanner_client): service = GraphService(session=mock_session) assert service.session == mock_session mock_spanner_client.instance.assert_called_once_with("test-instance") - mock_spanner_client.instance.return_value.database.assert_called_once_with("test-db") + mock_spanner_client.instance.return_value.database.assert_called_once_with( + "test-db" + ) + def test_get_graph_nodes(graph_service, mock_session): # Setup mock data mock_node = NodeModel(subject_id="test_node", types=["TestType"]) mock_edge = EdgeModel( - subject_id="test_node", - predicate="test_predicate", - object_id="test_target" + subject_id="test_node", predicate="test_predicate", object_id="test_target" ) mock_node.outgoing_edges = [mock_edge] - + # Mock the query chain mock_query = MagicMock() mock_query.options.return_value.limit.return_value.all.return_value = [mock_node] # Handle type filter - mock_query.filter.return_value.params.return_value.options.return_value.limit.return_value.all.return_value = [mock_node] + mock_query.filter.return_value.params.return_value.options.return_value.limit.return_value.all.return_value = [ + mock_node + ] mock_session.query.return_value = mock_query - + # Test without filter result = graph_service.get_graph_nodes(limit=10) - + # Verify assert isinstance(result, JSONLDDocument) assert len(result.graph) == 1 assert result.graph[0].id == "test_node" assert result.graph[0].type == ["TestType"] - assert result.graph[0].model_dump(by_alias=True, exclude_none=True)["test_predicate"] == {"@id": "test_target"} + assert result.graph[0].model_dump(by_alias=True, exclude_none=True)[ + "test_predicate" + ] == {"@id": "test_target"} # Test with filter result = graph_service.get_graph_nodes(limit=10, type_filter=["TestType"]) assert isinstance(result, JSONLDDocument) assert len(result.graph) == 1 + def test_insert_graph_nodes(graph_service, mock_session, mock_spanner_client): # Setup mock data for JSONLD - graph_node = GraphNode(**{ - "@id": "test_node", - "@type": ["TestType"], - "test_predicate": {"@id": "test_target"} - }) + graph_node = GraphNode( + **{ + "@id": "test_node", + "@type": ["TestType"], + "test_predicate": {"@id": "test_target"}, + } + ) mock_jsonld = JSONLDDocument( - context={"test": "http://test.com/"}, - graph=[graph_node] + context={"test": "http://test.com/"}, graph=[graph_node] ) - + mock_batch = MagicMock() mock_database = mock_spanner_client.instance.return_value.database.return_value mock_database.batch.return_value.__enter__.return_value = mock_batch - + # Test graph_service.insert_graph_nodes(mock_jsonld) - + # Verify assert mock_batch.insert_or_update.call_count == 2 mock_batch.delete.assert_called_once() - + + def test_insert_graph_nodes_error(graph_service, mock_spanner_client): # Setup mock data that triggers an error - mock_jsonld = JSONLDDocument(context={}, graph=[GraphNode(**{"@id": "n1", "@type": "t1"})]) - + mock_jsonld = JSONLDDocument( + context={}, graph=[GraphNode(**{"@id": "n1", "@type": "t1"})] + ) + mock_database = mock_spanner_client.instance.return_value.database.return_value mock_database.batch.side_effect = Exception("Spanner Error") - + with pytest.raises(GraphServiceError) as exc_info: graph_service.insert_graph_nodes(mock_jsonld) - + assert "Failed to insert nodes and edges to Spanner" in str(exc_info.value) + def test_drop_tables(graph_service, mock_session): with patch("builtins.input", return_value="yes"): graph_service.drop_tables() assert mock_session.execute.call_count == 3 mock_session.commit.assert_called_once() - + mock_session.reset_mock() - + with patch("builtins.input", return_value="no"): graph_service.drop_tables() assert mock_session.execute.call_count == 0 diff --git a/packages/datacommons-db/datacommons_db/models/edge.py b/packages/datacommons-db/datacommons_db/models/edge.py index 3549a59..9aeedb5 100644 --- a/packages/datacommons-db/datacommons_db/models/edge.py +++ b/packages/datacommons-db/datacommons_db/models/edge.py @@ -23,6 +23,7 @@ EDGE_TABLE_NAME = "Edge" OBJECT_VALUE_MAX_LENGTH = 4096 + class EdgeModel(Base): """ Represents an edge in the graph. diff --git a/packages/datacommons-db/datacommons_db/models/node.py b/packages/datacommons-db/datacommons_db/models/node.py index c6e0e0c..ece650e 100644 --- a/packages/datacommons-db/datacommons_db/models/node.py +++ b/packages/datacommons-db/datacommons_db/models/node.py @@ -21,6 +21,7 @@ NODE_TABLE_NAME = "Node" + class NodeModel(Base): """ Represents a node in the graph. diff --git a/packages/datacommons-db/datacommons_db/models/observation.py b/packages/datacommons-db/datacommons_db/models/observation.py index 8b4dd00..3a6981a 100644 --- a/packages/datacommons-db/datacommons_db/models/observation.py +++ b/packages/datacommons-db/datacommons_db/models/observation.py @@ -20,6 +20,7 @@ OBSERVATION_TABLE_NAME = "Observation" + class ObservationModel(Base): """ Represents a statistical observation of a variable. From d5bc582b1e81b2f19639a3b5f0d455e33e2a7397 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:51:55 -0800 Subject: [PATCH 03/10] Update packages/datacommons-api/datacommons_api/services/graph_service.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../datacommons-api/datacommons_api/services/graph_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index 0bf451f..dff2b70 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -243,7 +243,7 @@ def get_edge_val(e: EdgeModel, col: str) -> str | None: val = getattr(e, "object_value") if not val: return None - val_bytes = val.encode("utf-8") + val_bytes = str(val).encode("utf-8") # A Spanner index key incorporates both the indexed columns AND the Primary Key. # Max index key length is 8192 bytes total. The Primary Keys can swallow up to 4096 bytes easily. From 2d9fd0da5ff9a6fce35ddc20cdf752133d0748e8 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:52:07 -0800 Subject: [PATCH 04/10] Update packages/datacommons-api/datacommons_api/services/graph_service.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../datacommons-api/datacommons_api/services/graph_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index dff2b70..60ac814 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -77,7 +77,7 @@ def create_node_model(graph_node: GraphNode) -> NodeModel: types = graph_node.type if not isinstance(types, list): types = [types] - types_with_namespaces = [t for t in types if t is not None] + types = [t for t in types if t is not None] # Remove all CURIE namespaces before storing the node id subject_id = strip_namespace(graph_node.id) From a6ef0df8842c8f8e6221b647945aba008168b7a2 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:56:19 -0800 Subject: [PATCH 05/10] pr fedback --- .../datacommons_api/api_cli.py | 7 +++ .../datacommons_api/services/graph_service.py | 46 +++++++++-------- .../services/graph_service_test.py | 51 ++++++++++++++----- packages/datacommons-api/test_node_batches.py | 32 ++++++++++++ 4 files changed, 104 insertions(+), 32 deletions(-) create mode 100644 packages/datacommons-api/test_node_batches.py diff --git a/packages/datacommons-api/datacommons_api/api_cli.py b/packages/datacommons-api/datacommons_api/api_cli.py index 2ffa9b0..f02d726 100644 --- a/packages/datacommons-api/datacommons_api/api_cli.py +++ b/packages/datacommons-api/datacommons_api/api_cli.py @@ -83,12 +83,19 @@ def start( @click.option( "--gcp-spanner-database-name", help="GCP Spanner database name.", required=True ) +@click.option( + "--yes", is_flag=True, help="Skip confirmation prompt." +) def drop_tables( gcp_project_id: str, gcp_spanner_instance_id: str, gcp_spanner_database_name: str, + yes: bool, ): """Drop Node and Edge tables from the graph database.""" + if not yes: + click.confirm("Are you sure you want to drop the Node and Edge tables?", abort=True) + logger.info("Dropping Node and Edge tables from the graph database") initialize_config( gcp_project_id=gcp_project_id, diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index 0bf451f..440720d 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -294,16 +294,27 @@ def get_node_model_batches( current_batch: list[NodeModel] = [] current_batch_len = 0 for node_model in node_models: - # Add node and its edges to the current batch node_len = len(node_model.outgoing_edges) + 1 - if current_batch_len + node_len < batch_size: + + # If the node itself is larger than the batch_size, add it as its own batch + if node_len >= batch_size: + if current_batch: + node_batches.append(current_batch) + current_batch = [] + current_batch_len = 0 + node_batches.append([node_model]) + continue + + # Add node and its edges to the current batch + if current_batch_len + node_len <= batch_size: current_batch.append(node_model) current_batch_len += node_len else: # If the current batch is full, add it to the list of batches node_batches.append(current_batch) - current_batch = [] - current_batch_len = 0 + current_batch = [node_model] + current_batch_len = node_len + # Add the last batch if it's not empty if current_batch: node_batches.append(current_batch) @@ -503,19 +514,14 @@ def drop_tables(self) -> None: """ Delete Node and Edge tables from the graph database. """ - logger.info("Dropping Node and Edge tables from the graph database") - logger.info("Are you sure you want to continue? (yes/no)") - if input() == "yes": - logger.info("Dropping index EdgeByObjectValue") - query = "DROP INDEX EdgeByObjectValue" - self.session.execute(text(query)) - logger.info("Dropping table %s", EDGE_TABLE_NAME) - query = f"DROP TABLE {EDGE_TABLE_NAME}" - self.session.execute(text(query)) - logger.info("Dropping table %s", NODE_TABLE_NAME) - query = f"DROP TABLE {NODE_TABLE_NAME}" - self.session.execute(text(query)) - self.session.commit() - logger.info("Successfully dropped Node and Edge tables") - else: - logger.info("Quitting. Did not drop tables") + logger.info("Dropping index EdgeByObjectValue") + query = "DROP INDEX EdgeByObjectValue" + self.session.execute(text(query)) + logger.info("Dropping table %s", EDGE_TABLE_NAME) + query = f"DROP TABLE {EDGE_TABLE_NAME}" + self.session.execute(text(query)) + logger.info("Dropping table %s", NODE_TABLE_NAME) + query = f"DROP TABLE {NODE_TABLE_NAME}" + self.session.execute(text(query)) + self.session.commit() + logger.info("Successfully dropped Node and Edge tables") diff --git a/packages/datacommons-api/datacommons_api/services/graph_service_test.py b/packages/datacommons-api/datacommons_api/services/graph_service_test.py index d824fa8..b1183c2 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service_test.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service_test.py @@ -4,12 +4,47 @@ from sqlalchemy.orm import Session from google.cloud import spanner from datacommons_api.core.config import Config -from datacommons_api.services.graph_service import GraphService, GraphServiceError +from datacommons_api.services.graph_service import ( + GraphService, + GraphServiceError, + get_node_model_batches, +) from datacommons_db.models.node import NodeModel from datacommons_db.models.edge import EdgeModel from datacommons_schema.models.jsonld import JSONLDDocument, GraphNode +def test_get_node_model_batches(): + node1 = NodeModel(subject_id="n1", types=["T1"]) + node1.outgoing_edges = [EdgeModel(subject_id="n1", predicate="p", object_id=f"o{i}") for i in range(5)] + + node2 = NodeModel(subject_id="n2", types=["T1"]) + node2.outgoing_edges = [EdgeModel(subject_id="n2", predicate="p", object_id=f"o{i}") for i in range(5)] + + node3 = NodeModel(subject_id="n3", types=["T1"]) + node3.outgoing_edges = [EdgeModel(subject_id="n3", predicate="p", object_id=f"o{i}") for i in range(5)] + + # 6 items per node + # batch size 10 means 10 items max. n1 = 6 items -> batch 0. n2 = 6 items -> batch 1. n3 = 6 items -> batch 2. + batches = get_node_model_batches([node1, node2, node3], batch_size=10) + assert len(batches) == 3 + assert batches[0] == [node1] + assert batches[1] == [node2] + assert batches[2] == [node3] + + # test a node larger than the batch size (6 items > batch size 5) + batches = get_node_model_batches([node1, node2], batch_size=5) + assert len(batches) == 2 + assert batches[0] == [node1] + assert batches[1] == [node2] + + # Test batch size 12. n1 + n2 = 12 items -> batch 0. n3 = 6 items -> batch 1. + batches = get_node_model_batches([node1, node2, node3], batch_size=12) + assert len(batches) == 2 + assert batches[0] == [node1, node2] + assert batches[1] == [node3] + + @pytest.fixture def mock_session(): return MagicMock(spec=Session) @@ -130,14 +165,6 @@ def test_insert_graph_nodes_error(graph_service, mock_spanner_client): def test_drop_tables(graph_service, mock_session): - with patch("builtins.input", return_value="yes"): - graph_service.drop_tables() - assert mock_session.execute.call_count == 3 - mock_session.commit.assert_called_once() - - mock_session.reset_mock() - - with patch("builtins.input", return_value="no"): - graph_service.drop_tables() - assert mock_session.execute.call_count == 0 - assert mock_session.commit.call_count == 0 + graph_service.drop_tables() + assert mock_session.execute.call_count == 3 + mock_session.commit.assert_called_once() diff --git a/packages/datacommons-api/test_node_batches.py b/packages/datacommons-api/test_node_batches.py new file mode 100644 index 0000000..982c1a6 --- /dev/null +++ b/packages/datacommons-api/test_node_batches.py @@ -0,0 +1,32 @@ +from datacommons_db.models.node import NodeModel +from datacommons_db.models.edge import EdgeModel +from datacommons_api.services.graph_service import get_node_model_batches + +def test_get_node_model_batches_bug(): + node1 = NodeModel(subject_id="node1", types=["TypeA"]) + # 5 edges + 1 node = 6 items + node1.outgoing_edges = [EdgeModel(subject_id="node1", predicate="p", object_id=f"obj{i}") for i in range(5)] + + node2 = NodeModel(subject_id="node2", types=["TypeA"]) + # 5 edges + 1 node = 6 items + node2.outgoing_edges = [EdgeModel(subject_id="node2", predicate="p", object_id=f"obj{i}") for i in range(5)] + + node3 = NodeModel(subject_id="node3", types=["TypeA"]) + # 5 edges + 1 node = 6 items + node3.outgoing_edges = [EdgeModel(subject_id="node3", predicate="p", object_id=f"obj{i}") for i in range(5)] + + # Total items = 18. Let's set batch size to 10. + # Node 1 (6 items) -> Batch 1 + # Node 2 (6 items) -> 6 + 6 = 12 > 10. So it hits the else block. Node 2 is skipped. + batches = get_node_model_batches([node1, node2, node3], batch_size=10) + + print(f"Number of batches: {len(batches)}") + for i, batch in enumerate(batches): + print(f"Batch {i}: {[n.subject_id for n in batch]}") + + all_nodes_in_batches = [n for batch in batches for n in batch] + print(f"Total nodes returned: {len(all_nodes_in_batches)}") + print(f"Expected: 3, Actual: {len(all_nodes_in_batches)}") + +if __name__ == "__main__": + test_get_node_model_batches_bug() From dd62e2d2a33fe259d71cac71fbd2856e4f7e82a9 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:56:34 -0800 Subject: [PATCH 06/10] formatting --- .../datacommons_api/api_cli.py | 8 ++--- .../datacommons_api/services/graph_service.py | 4 +-- .../services/graph_service_test.py | 22 +++++++++----- packages/datacommons-api/test_node_batches.py | 29 +++++++++++++------ 4 files changed, 40 insertions(+), 23 deletions(-) diff --git a/packages/datacommons-api/datacommons_api/api_cli.py b/packages/datacommons-api/datacommons_api/api_cli.py index f02d726..f2ae048 100644 --- a/packages/datacommons-api/datacommons_api/api_cli.py +++ b/packages/datacommons-api/datacommons_api/api_cli.py @@ -83,9 +83,7 @@ def start( @click.option( "--gcp-spanner-database-name", help="GCP Spanner database name.", required=True ) -@click.option( - "--yes", is_flag=True, help="Skip confirmation prompt." -) +@click.option("--yes", is_flag=True, help="Skip confirmation prompt.") def drop_tables( gcp_project_id: str, gcp_spanner_instance_id: str, @@ -94,7 +92,9 @@ def drop_tables( ): """Drop Node and Edge tables from the graph database.""" if not yes: - click.confirm("Are you sure you want to drop the Node and Edge tables?", abort=True) + click.confirm( + "Are you sure you want to drop the Node and Edge tables?", abort=True + ) logger.info("Dropping Node and Edge tables from the graph database") initialize_config( diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index 440720d..97f0205 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -295,7 +295,7 @@ def get_node_model_batches( current_batch_len = 0 for node_model in node_models: node_len = len(node_model.outgoing_edges) + 1 - + # If the node itself is larger than the batch_size, add it as its own batch if node_len >= batch_size: if current_batch: @@ -314,7 +314,7 @@ def get_node_model_batches( node_batches.append(current_batch) current_batch = [node_model] current_batch_len = node_len - + # Add the last batch if it's not empty if current_batch: node_batches.append(current_batch) diff --git a/packages/datacommons-api/datacommons_api/services/graph_service_test.py b/packages/datacommons-api/datacommons_api/services/graph_service_test.py index b1183c2..fa9ed8f 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service_test.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service_test.py @@ -16,14 +16,20 @@ def test_get_node_model_batches(): node1 = NodeModel(subject_id="n1", types=["T1"]) - node1.outgoing_edges = [EdgeModel(subject_id="n1", predicate="p", object_id=f"o{i}") for i in range(5)] - + node1.outgoing_edges = [ + EdgeModel(subject_id="n1", predicate="p", object_id=f"o{i}") for i in range(5) + ] + node2 = NodeModel(subject_id="n2", types=["T1"]) - node2.outgoing_edges = [EdgeModel(subject_id="n2", predicate="p", object_id=f"o{i}") for i in range(5)] - + node2.outgoing_edges = [ + EdgeModel(subject_id="n2", predicate="p", object_id=f"o{i}") for i in range(5) + ] + node3 = NodeModel(subject_id="n3", types=["T1"]) - node3.outgoing_edges = [EdgeModel(subject_id="n3", predicate="p", object_id=f"o{i}") for i in range(5)] - + node3.outgoing_edges = [ + EdgeModel(subject_id="n3", predicate="p", object_id=f"o{i}") for i in range(5) + ] + # 6 items per node # batch size 10 means 10 items max. n1 = 6 items -> batch 0. n2 = 6 items -> batch 1. n3 = 6 items -> batch 2. batches = get_node_model_batches([node1, node2, node3], batch_size=10) @@ -31,13 +37,13 @@ def test_get_node_model_batches(): assert batches[0] == [node1] assert batches[1] == [node2] assert batches[2] == [node3] - + # test a node larger than the batch size (6 items > batch size 5) batches = get_node_model_batches([node1, node2], batch_size=5) assert len(batches) == 2 assert batches[0] == [node1] assert batches[1] == [node2] - + # Test batch size 12. n1 + n2 = 12 items -> batch 0. n3 = 6 items -> batch 1. batches = get_node_model_batches([node1, node2, node3], batch_size=12) assert len(batches) == 2 diff --git a/packages/datacommons-api/test_node_batches.py b/packages/datacommons-api/test_node_batches.py index 982c1a6..240b1e8 100644 --- a/packages/datacommons-api/test_node_batches.py +++ b/packages/datacommons-api/test_node_batches.py @@ -2,31 +2,42 @@ from datacommons_db.models.edge import EdgeModel from datacommons_api.services.graph_service import get_node_model_batches + def test_get_node_model_batches_bug(): node1 = NodeModel(subject_id="node1", types=["TypeA"]) # 5 edges + 1 node = 6 items - node1.outgoing_edges = [EdgeModel(subject_id="node1", predicate="p", object_id=f"obj{i}") for i in range(5)] - + node1.outgoing_edges = [ + EdgeModel(subject_id="node1", predicate="p", object_id=f"obj{i}") + for i in range(5) + ] + node2 = NodeModel(subject_id="node2", types=["TypeA"]) # 5 edges + 1 node = 6 items - node2.outgoing_edges = [EdgeModel(subject_id="node2", predicate="p", object_id=f"obj{i}") for i in range(5)] - + node2.outgoing_edges = [ + EdgeModel(subject_id="node2", predicate="p", object_id=f"obj{i}") + for i in range(5) + ] + node3 = NodeModel(subject_id="node3", types=["TypeA"]) # 5 edges + 1 node = 6 items - node3.outgoing_edges = [EdgeModel(subject_id="node3", predicate="p", object_id=f"obj{i}") for i in range(5)] - + node3.outgoing_edges = [ + EdgeModel(subject_id="node3", predicate="p", object_id=f"obj{i}") + for i in range(5) + ] + # Total items = 18. Let's set batch size to 10. # Node 1 (6 items) -> Batch 1 # Node 2 (6 items) -> 6 + 6 = 12 > 10. So it hits the else block. Node 2 is skipped. batches = get_node_model_batches([node1, node2, node3], batch_size=10) - + print(f"Number of batches: {len(batches)}") for i, batch in enumerate(batches): print(f"Batch {i}: {[n.subject_id for n in batch]}") - + all_nodes_in_batches = [n for batch in batches for n in batch] print(f"Total nodes returned: {len(all_nodes_in_batches)}") print(f"Expected: 3, Actual: {len(all_nodes_in_batches)}") - + + if __name__ == "__main__": test_get_node_model_batches_bug() From 696e643dd0c7a3415f119f160aca1f78090d1da8 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:58:06 -0800 Subject: [PATCH 07/10] Update packages/datacommons-api/datacommons_api/services/graph_service.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../datacommons_api/services/graph_service.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index eebe7b4..38e3e46 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -500,9 +500,8 @@ def insert_graph_nodes( success_count += len(node_model_batch) except Exception as e: error_message = f"Failed to insert nodes and edges to Spanner after {success_count}/{len(node_models)} nodes inserted" - logger.error(error_message + ": %s", e) - traceback.print_exc() - raise GraphServiceError(error_message) + logger.exception(error_message) + raise GraphServiceError(error_message) from e logger.info( "Successfully committed %d nodes and %d edges to Spanner", From 68c473d21ee31d31a43ebfc07764be72c9cbe213 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Thu, 5 Mar 2026 23:58:35 -0800 Subject: [PATCH 08/10] removed extra line --- .../datacommons-api/datacommons_api/services/graph_service.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index eebe7b4..d398ce3 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -258,8 +258,6 @@ def get_edge_val(e: EdgeModel, col: str) -> str | None: return val elif col == "object_bytes": if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: - import base64 - return base64.b64encode(val_bytes).decode("utf-8") return None return getattr(e, col) From f4d2a75d0c8699bc850634dbdd296ee496ab5493 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Fri, 6 Mar 2026 17:21:41 -0800 Subject: [PATCH 09/10] pr feedback --- .../datacommons_api/api_cli.py | 1 + .../datacommons_api/services/graph_service.py | 64 ++++++++++--------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/packages/datacommons-api/datacommons_api/api_cli.py b/packages/datacommons-api/datacommons_api/api_cli.py index f2ae048..7e2a7a5 100644 --- a/packages/datacommons-api/datacommons_api/api_cli.py +++ b/packages/datacommons-api/datacommons_api/api_cli.py @@ -91,6 +91,7 @@ def drop_tables( yes: bool, ): """Drop Node and Edge tables from the graph database.""" + # TODO: Refactor this method to only drop the data from the tables, not the tables themselves. if not yes: click.confirm( "Are you sure you want to drop the Node and Edge tables?", abort=True diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index 3b58bc5..10befa6 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -228,39 +228,40 @@ def node_model_to_graph_node(node: NodeModel) -> GraphNode: return GraphNode(**graph_node_properties) -def get_edge_val(e: EdgeModel, col: str) -> str | None: +def coerce_edge_val_for_db_write(e: EdgeModel, col: str) -> str | None: """ - Helper function to get the value of an edge column, with support for Spanner index key length limits. - + Coerces and truncates edge values to comply with Spanner index limits. Args: - e: The EdgeModel instance - col: The column name - + e: The EdgeModel instance containing raw data. + col: The target database column name. Returns: - The value of the edge column, with support for Spanner index key length limits + - For 'object_value': A UTF-8 string truncated to 4096 bytes (safe-decoded). + - For 'object_bytes': A Base64-encoded representation of the model's 'object_value'. + - For other columns: The raw attribute value from the model. """ - if col in ("object_value", "object_bytes"): - val = getattr(e, "object_value") - if not val: - return None - val_bytes = str(val).encode("utf-8") - - # A Spanner index key incorporates both the indexed columns AND the Primary Key. - # Max index key length is 8192 bytes total. The Primary Keys can swallow up to 4096 bytes easily. - # So we must restrict object_value to 4096 bytes to guarantee the total key size is < 8192 bytes. - if col == "object_value": - if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: - # Slice to exactly OBJECT_VALUE_MAX_LENGTH bytes, dropping fragmented chars gracefully - val_truncated = val_bytes[:OBJECT_VALUE_MAX_LENGTH].decode( - "utf-8", errors="ignore" - ) - return val_truncated - return val - elif col == "object_bytes": - if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: - return base64.b64encode(val_bytes).decode("utf-8") - return None - return getattr(e, col) + if col not in ("object_value", "object_bytes"): + return getattr(e, col) + + val = getattr(e, "object_value") + if not val: + return None + val_bytes = str(val).encode("utf-8") + + # A Spanner index key incorporates both the indexed columns AND the Primary Key. + # Max index key length is 8192 bytes total. The Primary Keys can swallow up to 4096 bytes easily. + # So we must restrict object_value to 4096 bytes to guarantee the total key size is < 8192 bytes. + if col == "object_value": + if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: + # Slice to exactly OBJECT_VALUE_MAX_LENGTH bytes, dropping fragmented chars gracefully + val_truncated = val_bytes[:OBJECT_VALUE_MAX_LENGTH].decode( + "utf-8", errors="ignore" + ) + return val_truncated + return val + elif col == "object_bytes": + if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: + return base64.b64encode(val_bytes).decode("utf-8") + return None def get_node_models(jsonld: JSONLDDocument) -> list[NodeModel]: @@ -365,7 +366,7 @@ def insert_node_models_batch( table=EDGE_TABLE_NAME, columns=edge_columns, values=[ - tuple(get_edge_val(e, col) for col in edge_columns) + tuple(coerce_edge_val_for_db_write(e, col) for col in edge_columns) for e in node_model.outgoing_edges ], ) @@ -490,6 +491,9 @@ def insert_graph_nodes( ) # Insert nodes and edges in batches + # TODO(dwnoble): this insert may fail if a node in an earlier batch references a node in a later batch. + # Also may fail if a node references a node that is in a remote knowledge graph + # Possible solution: Insert all nodes first, then insert all edges in a second pass. success_count = 0 try: for node_model_batch in node_model_batches: From 457195c338c9f7e7348a2114333c6e8e3db6a3f4 Mon Sep 17 00:00:00 2001 From: Dan Noble Date: Fri, 6 Mar 2026 17:24:05 -0800 Subject: [PATCH 10/10] Added todo --- .../datacommons-api/datacommons_api/services/graph_service.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/datacommons-api/datacommons_api/services/graph_service.py b/packages/datacommons-api/datacommons_api/services/graph_service.py index 10befa6..0b0bf88 100644 --- a/packages/datacommons-api/datacommons_api/services/graph_service.py +++ b/packages/datacommons-api/datacommons_api/services/graph_service.py @@ -253,6 +253,8 @@ def coerce_edge_val_for_db_write(e: EdgeModel, col: str) -> str | None: if col == "object_value": if len(val_bytes) > OBJECT_VALUE_MAX_LENGTH: # Slice to exactly OBJECT_VALUE_MAX_LENGTH bytes, dropping fragmented chars gracefully + # TODO: To avoid hash index collisions, we should use a deterministic hash of the object_value + # and store that along with the truncated value. val_truncated = val_bytes[:OBJECT_VALUE_MAX_LENGTH].decode( "utf-8", errors="ignore" )