Skip to content

Commit eb753ac

Browse files
authored
Merge branch 'master' into fix-wb-ids-download
2 parents 9b36235 + 21999e9 commit eb753ac

19 files changed

Lines changed: 925 additions & 81 deletions

File tree

import-automation/workflow/ingestion-helper/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,16 @@ Updates the version of an import, records version history, and updates the statu
6161
Initializes the Spanner database by creating all necessary tables and uploading proto descriptors.
6262

6363
* This action requires no payload parameters. It automatically reads `schema.sql` and `storage.pb` from the container directory to provision the database schema and proto descriptors.
64+
* `enableEmbeddings` (Optional): Boolean to enable creation of embedding tables and models.
6465
* **Note on Protos**: The `storage.pb` file is generated during the Docker build process. The `Dockerfile` fetches `storage.proto` from the `datacommonsorg/import` GitHub repository and compiles it into `storage.pb`.
6566

67+
#### `embedding_ingestion`
68+
Triggers the generation of embeddings for updated nodes in Spanner. It fetches nodes of specific types (e.g., `StatisticalVariable`, `Topic`) that have been updated, generates embeddings using a remote ML model in Spanner, and stores the results in the `NodeEmbeddings` table.
69+
70+
* `enableEmbeddings` (Optional): Boolean to override the default setting for enabling embeddings. If false or missing and default is false, it skips embedding generation.
71+
* **Flags**:
72+
- `--node_types`: A comma-separated list of node types to process (default: `StatisticalVariable,Topic`). This is a command-line flag for the service, not a request parameter.
73+
6674
## Local Development and Testing
6775

6876
To run the helper service locally and test its functionality:
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helper utilities for embedding workflows."""
16+
17+
import itertools
18+
import logging
19+
import time
20+
from datetime import datetime
21+
from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, StructField
22+
23+
24+
_BATCH_SIZE = 1000
25+
26+
def get_latest_lock_timestamp(database):
27+
"""Gets the latest AcquiredTimestamp from IngestionLock table.
28+
29+
Args:
30+
database: google.cloud.spanner.Database object.
31+
32+
Returns:
33+
The latest AcquiredTimestamp as a datetime object, or None if no entries exist.
34+
"""
35+
time_lock_sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock"
36+
try:
37+
with database.snapshot() as snapshot:
38+
results = snapshot.execute_sql(time_lock_sql)
39+
for row in results:
40+
return row[0]
41+
except Exception as e:
42+
logging.error(f"Error fetching latest lock timestamp: {e}")
43+
raise
44+
return None
45+
46+
def get_updated_nodes(database, timestamp, node_types):
47+
"""Gets subject_ids and names from Node table where update_timestamp > timestamp.
48+
Yields results to avoid loading all into memory.
49+
50+
Args:
51+
database: google.cloud.spanner.Database object.
52+
timestamp: datetime object to filter by.
53+
node_types: A list of strings representing the node types to filter by.
54+
55+
Yields:
56+
Dictionaries containing subject_id and name.
57+
"""
58+
timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE"
59+
60+
updated_node_sql = f"""
61+
SELECT subject_id, name, types FROM Node
62+
WHERE name IS NOT NULL
63+
AND {timestamp_condition}
64+
AND EXISTS (
65+
SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types)
66+
)
67+
"""
68+
69+
params = {"node_types": node_types}
70+
param_types = {"node_types": Array(STRING)}
71+
72+
if timestamp:
73+
logging.info(f"Filtering valid nodes updated after {timestamp}")
74+
params["timestamp"] = timestamp
75+
param_types["timestamp"] = TIMESTAMP
76+
else:
77+
logging.info("No timestamp provided, reading all valid nodes.")
78+
79+
try:
80+
with database.snapshot() as snapshot:
81+
results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types, timeout=300)
82+
fields = None
83+
for row in results:
84+
if fields is None:
85+
fields = [field.name for field in results.fields]
86+
yield dict(zip(fields, row))
87+
except Exception as e:
88+
logging.error(f"Error fetching updated nodes: {e}")
89+
raise
90+
91+
92+
def filter_and_convert_nodes(nodes_generator):
93+
"""Filters out nodes without a name and converts dictionaries to tuples.
94+
Reads from a generator and yields results.
95+
96+
Args:
97+
nodes_generator: A generator yielding dictionaries containing subject_id, name, and types.
98+
99+
Yields:
100+
Tuples (subject_id, embedding_content, types).
101+
"""
102+
for node in nodes_generator:
103+
if node.get("name"):
104+
yield (node.get("subject_id"), node.get("name"), node.get("types"))
105+
106+
107+
def generate_embeddings_partitioned(database, nodes_generator):
108+
"""Generates embeddings in batches using standard transactions.
109+
Processes nodes in chunks of 500 to avoid transaction size limits.
110+
Accepts a generator to avoid loading all nodes into memory.
111+
112+
Args:
113+
database: google.cloud.spanner.Database object.
114+
nodes_generator: A generator yielding tuples containing (subject_id, embedding_content).
115+
116+
Returns:
117+
The number of affected rows.
118+
"""
119+
global _BATCH_SIZE
120+
total_rows_affected = 0
121+
122+
logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.")
123+
124+
embeddings_sql = """
125+
INSERT OR UPDATE INTO NodeEmbedding (subject_id, embedding_content, embeddings, types)
126+
SELECT subject_id, content, embeddings.values, types
127+
FROM ML.PREDICT(
128+
MODEL NodeEmbeddingModel,
129+
(SELECT subject_id, embedding_content AS content, types, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes))
130+
)
131+
"""
132+
133+
struct_type = Struct([
134+
StructField("subject_id", STRING),
135+
StructField("embedding_content", STRING),
136+
StructField("types", Array(STRING))
137+
])
138+
139+
def chunked(iterable, n):
140+
it = iter(iterable)
141+
while True:
142+
chunk = list(itertools.islice(it, n))
143+
if not chunk:
144+
break
145+
yield chunk
146+
147+
for batch in chunked(nodes_generator, _BATCH_SIZE):
148+
params = {"nodes": batch}
149+
param_types = {"nodes": Array(struct_type)}
150+
151+
def _execute_dml(transaction):
152+
return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300)
153+
154+
try:
155+
row_count = database.run_in_transaction(_execute_dml)
156+
total_rows_affected += row_count
157+
logging.info(f"Processed batch of {len(batch)} nodes. Affected total {total_rows_affected} rows.")
158+
time.sleep(0.5)
159+
except Exception as e:
160+
logging.error(f"Error executing batch transaction: {e}")
161+
raise
162+
163+
logging.info(f"Completed batch processing. Total affected rows: {total_rows_affected}")
164+
return total_rows_affected
165+
166+
167+
168+
169+
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest.mock import MagicMock, patch
17+
from datetime import datetime
18+
19+
from embedding_utils import (
20+
get_latest_lock_timestamp,
21+
get_updated_nodes,
22+
filter_and_convert_nodes,
23+
generate_embeddings_partitioned
24+
)
25+
26+
class TestEmbeddingUtils(unittest.TestCase):
27+
28+
def test_get_latest_lock_timestamp(self):
29+
mock_database = MagicMock()
30+
mock_snapshot = MagicMock()
31+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
32+
expected_timestamp = datetime(2026, 4, 20, 12, 0, 0)
33+
mock_snapshot.execute_sql.return_value = [(expected_timestamp,)]
34+
35+
timestamp = get_latest_lock_timestamp(mock_database)
36+
self.assertEqual(timestamp, expected_timestamp)
37+
38+
def test_get_updated_nodes(self):
39+
mock_database = MagicMock()
40+
mock_snapshot = MagicMock()
41+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
42+
43+
class MockField:
44+
def __init__(self, name):
45+
self.name = name
46+
47+
class MockResults:
48+
def __init__(self, rows, field_names):
49+
self.rows = rows
50+
self.fields = [MockField(name) for name in field_names]
51+
52+
def __iter__(self):
53+
return iter(self.rows)
54+
55+
mock_snapshot.execute_sql.return_value = MockResults(
56+
rows=[("dc/1", "Node 1", ["Topic"])],
57+
field_names=["subject_id", "name", "types"]
58+
)
59+
60+
nodes = list(get_updated_nodes(mock_database, None, ["Topic"]))
61+
62+
# Verify Spanner call
63+
mock_snapshot.execute_sql.assert_called_once()
64+
args, kwargs = mock_snapshot.execute_sql.call_args
65+
query = args[0]
66+
self.assertIn("SELECT subject_id, name, types FROM Node", query)
67+
self.assertIn("TRUE", query)
68+
self.assertEqual(kwargs["params"], {"node_types": ["Topic"]})
69+
70+
self.assertEqual(len(nodes), 1)
71+
self.assertEqual(nodes[0]["subject_id"], "dc/1")
72+
self.assertEqual(nodes[0]["name"], "Node 1")
73+
self.assertEqual(nodes[0]["types"], ["Topic"])
74+
75+
def test_get_updated_nodes_with_timestamp(self):
76+
mock_database = MagicMock()
77+
mock_snapshot = MagicMock()
78+
mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
79+
80+
class MockField:
81+
def __init__(self, name):
82+
self.name = name
83+
84+
class MockResults:
85+
def __init__(self, rows, field_names):
86+
self.rows = rows
87+
self.fields = [MockField(name) for name in field_names]
88+
89+
def __iter__(self):
90+
return iter(self.rows)
91+
92+
mock_snapshot.execute_sql.return_value = MockResults(
93+
rows=[("dc/2", "Node 2", ["Topic"])],
94+
field_names=["subject_id", "name", "types"]
95+
)
96+
97+
test_timestamp = datetime(2026, 4, 25, 0, 0, 0)
98+
nodes = list(get_updated_nodes(mock_database, test_timestamp, ["Topic"]))
99+
100+
# Verify Spanner call
101+
mock_snapshot.execute_sql.assert_called_once()
102+
args, kwargs = mock_snapshot.execute_sql.call_args
103+
query = args[0]
104+
self.assertIn("SELECT subject_id, name, types FROM Node", query)
105+
self.assertIn("update_timestamp > @timestamp", query)
106+
self.assertEqual(kwargs["params"], {"node_types": ["Topic"], "timestamp": test_timestamp})
107+
108+
self.assertEqual(len(nodes), 1)
109+
self.assertEqual(nodes[0]["subject_id"], "dc/2")
110+
111+
def test_filter_and_convert_nodes(self):
112+
nodes = [
113+
{"subject_id": "dc/1", "name": "Node 1", "types": ["Topic"]},
114+
{"subject_id": "dc/2", "name": None, "types": ["StatisticalVariable"]},
115+
{"subject_id": "dc/3", "name": "Node 3", "types": ["Topic", "StatisticalVariable"]},
116+
{"subject_id": "dc/4", "name": "", "types": ["StatisticalVariable"]}
117+
]
118+
119+
converted = list(filter_and_convert_nodes(nodes))
120+
self.assertEqual(len(converted), 2)
121+
self.assertEqual(converted[0], ("dc/1", "Node 1", ["Topic"]))
122+
self.assertEqual(converted[1], ("dc/3", "Node 3", ["Topic", "StatisticalVariable"]))
123+
124+
@patch('embedding_utils._BATCH_SIZE', 2)
125+
def test_generate_embeddings_partitioned(self):
126+
mock_database = MagicMock()
127+
128+
nodes = [
129+
("dc/1", "Node 1", ["Topic"]),
130+
("dc/2", "Node 2", ["Topic"]),
131+
("dc/3", "Node 3", ["Topic"]),
132+
("dc/4", "Node 4", ["Topic"]),
133+
("dc/5", "Node 5", ["Topic"]),
134+
("dc/6", "Node 6", ["Topic"]),
135+
("dc/7", "Node 7", ["Topic"]),
136+
("dc/8", "Node 8", ["Topic"])
137+
]
138+
139+
transactions = []
140+
def side_effect(func):
141+
mock_transaction = MagicMock()
142+
mock_transaction.execute_update.return_value = 2
143+
transactions.append(mock_transaction)
144+
return func(mock_transaction)
145+
146+
mock_database.run_in_transaction.side_effect = side_effect
147+
148+
affected_rows = generate_embeddings_partitioned(mock_database, nodes)
149+
self.assertEqual(affected_rows, 8)
150+
self.assertEqual(mock_database.run_in_transaction.call_count, 4)
151+
152+
# Verify execute_update calls
153+
self.assertEqual(len(transactions), 4)
154+
for i, tx in enumerate(transactions):
155+
tx.execute_update.assert_called_once()
156+
args, kwargs = tx.execute_update.call_args
157+
self.assertIn("INSERT OR UPDATE INTO NodeEmbeddings", args[0])
158+
159+
# Verify batch content
160+
batch = kwargs["params"]["nodes"]
161+
self.assertEqual(len(batch), 2)
162+
self.assertEqual(batch[0][0], f"dc/{i*2 + 1}")
163+
self.assertEqual(batch[1][0], f"dc/{i*2 + 2}")
164+
165+
if __name__ == '__main__':
166+
unittest.main()

import-automation/workflow/ingestion-helper/main.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functions_framework
22
from spanner_client import SpannerClient
33
from storage_client import StorageClient
4+
from embedding_utils import get_latest_lock_timestamp, get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned
45
import logging
56
import os
67
from absl import flags
@@ -33,6 +34,9 @@
3334
'enable_embeddings',
3435
os.environ.get('ENABLE_EMBEDDINGS', 'false').lower() == 'true',
3536
'Enable embeddings')
37+
flags.DEFINE_list(
38+
'node_types', ['StatisticalVariable', 'Topic'],
39+
'Node types to generate embeddings for')
3640

3741
if not FLAGS.is_parsed():
3842
FLAGS(['ingestion_helper'])
@@ -214,5 +218,24 @@ def ingestion_helper(request):
214218
FLAGS.enable_embeddings)
215219
spanner.initialize_database(enable_embeddings=enable_embeddings)
216220
return ('OK', 200)
221+
elif actionType == 'embedding_ingestion':
222+
logging.info("Action: embedding_ingestion")
223+
enable_embeddings = request_json.get('enableEmbeddings',
224+
FLAGS.enable_embeddings)
225+
if not enable_embeddings:
226+
logging.info("Embeddings not enabled, skipping.")
227+
return ('Invalid request on embedding ingestion.', 400)
228+
229+
node_types = FLAGS.node_types
230+
try:
231+
logging.info(f"Job started. Fetching all nodes for types: {node_types}")
232+
timestamp = get_latest_lock_timestamp(spanner.database)
233+
nodes = get_updated_nodes(spanner.database, timestamp, node_types)
234+
converted_nodes = filter_and_convert_nodes(nodes)
235+
affected_rows = generate_embeddings_partitioned(spanner.database, converted_nodes)
236+
return (f"OK [Affected rows: {affected_rows}]", 200)
237+
except Exception as e:
238+
logging.error(f"Embedding ingestion failed: {e}")
239+
return (f"Error: {e}", 500)
217240
else:
218241
return (f'Unknown actionType: {actionType}', 400)

0 commit comments

Comments
 (0)