Skip to content

Commit

Permalink
Alloy language connector (#34156)
Browse files Browse the repository at this point in the history
* Add AlloyDB language connector support.

* Add test.

* Trigger test.

* Add link to WriteToJdbc.

---------

Co-authored-by: Claude <[email protected]>
  • Loading branch information
claudevdm and Claude authored Mar 5, 2025
1 parent 9da7006 commit 0733fa5
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 4
"modification": 5
}
129 changes: 128 additions & 1 deletion sdks/python/apache_beam/ml/rag/ingestion/alloydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import logging
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Callable
from typing import Dict
Expand All @@ -37,6 +38,73 @@
_LOGGER = logging.getLogger(__name__)


@dataclass
class AlloyDBLanguageConnectorConfig:
"""Configuration options for AlloyDB Java language connector.
Contains all parameters needed to configure a connection using the AlloyDB
Java connector via JDBC. For details see
https://github.com/GoogleCloudPlatform/alloydb-java-connector/blob/main/docs/jdbc.md
Attributes:
database_name: Name of the database to connect to.
instance_name: Fullly qualified instance. Format:
'projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances
/<INSTANCE>'
ip_type: IP type to use for connection. Either 'PRIVATE' (default),
'PUBLIC' 'PSC.
enable_iam_auth: Whether to enable IAM authentication. Default is False
target_principal: Optional service account to impersonate for
connection.
delegates: Optional comma-separated list of service accounts for
delegated impersonation.
admin_service_endpoint: Optional custom API service endpoint.
quota_project: Optional project ID for quota and billing.
"""
database_name: str
instance_name: str
ip_type: str = "PRIVATE"
enable_iam_auth: bool = False
target_principal: Optional[str] = None
delegates: Optional[List[str]] = None
admin_service_endpoint: Optional[str] = None
quota_project: Optional[str] = None

def to_jdbc_url(self) -> str:
"""Convert options to a properly formatted JDBC URL.
Returns:
JDBC URL string configured with all options.
"""
# Base URL with database name
url = f"jdbc:postgresql:///{self.database_name}?"

# Add required properties
properties = {
"socketFactory": "com.google.cloud.alloydb.SocketFactory",
"alloydbInstanceName": self.instance_name,
"alloydbIpType": self.ip_type
}

if self.enable_iam_auth:
properties["alloydbEnableIAMAuth"] = "true"

if self.target_principal:
properties["alloydbTargetPrincipal"] = self.target_principal

if self.delegates:
properties["alloydbDelegates"] = ",".join(self.delegates)

if self.admin_service_endpoint:
properties["alloydbAdminServiceEndpoint"] = self.admin_service_endpoint

if self.quota_project:
properties["alloydbQuotaProject"] = self.quota_project

property_string = "&".join(f"{k}={v}" for k, v in properties.items())
return url + property_string


@dataclass
class AlloyDBConnectionConfig:
"""Configuration for AlloyDB database connection.
Expand All @@ -58,6 +126,10 @@ class AlloyDBConnectionConfig:
max_connections: Optional number of connections in the pool.
Use negative for no limit.
write_batch_size: Optional write batch size for bulk operations.
additional_jdbc_args: Additional arguments that will be passed to
WriteToJdbc. These may include 'driver_jars', 'expansion_service',
'classpath', etc. See full set of args at
:class:`~apache_beam.io.jdbc.WriteToJdbc`
Example:
>>> config = AlloyDBConnectionConfig(
Expand All @@ -76,6 +148,60 @@ class AlloyDBConnectionConfig:
autosharding: Optional[bool] = None
max_connections: Optional[int] = None
write_batch_size: Optional[int] = None
additional_jdbc_args: Dict[str, Any] = field(default_factory=dict)

@classmethod
def with_language_connector(
cls,
connector_options: AlloyDBLanguageConnectorConfig,
username: str,
password: str,
connection_properties: Optional[Dict[str, str]] = None,
connection_init_sqls: Optional[List[str]] = None,
autosharding: Optional[bool] = None,
max_connections: Optional[int] = None,
write_batch_size: Optional[int] = None) -> 'AlloyDBConnectionConfig':
"""Create AlloyDBConnectionConfig using the AlloyDB language connector.
Args:
connector_options: AlloyDB language connector configuration options.
username: Database username. For IAM auth, this should be the IAM
user email.
password: Database password. Can be empty string when using IAM
auth.
connection_properties: Additional JDBC connection properties.
connection_init_sqls: SQL statements to execute on connection.
autosharding: Enable autosharding.
max_connections: Max connections in pool.
write_batch_size: Write batch size.
Returns:
Configured AlloyDBConnectionConfig instance.
Example:
>>> options = AlloyDBLanguageConnectorConfig(
... database_name="mydb",
... instance_name="projects/my-project/locations/us-central1\
.... /clusters/my-cluster/instances/my-instance",
... ip_type="PUBLIC",
... enable_iam_auth=True
... )
"""
return cls(
jdbc_url=connector_options.to_jdbc_url(),
username=username,
password=password,
connection_properties=connection_properties,
connection_init_sqls=connection_init_sqls,
autosharding=autosharding,
max_connections=max_connections,
write_batch_size=write_batch_size,
additional_jdbc_args={
'classpath': [
"org.postgresql:postgresql:42.2.16",
"com.google.cloud:alloydb-jdbc-connector:1.2.0"
]
})


@dataclass
Expand Down Expand Up @@ -713,4 +839,5 @@ def expand(self, pcoll: beam.PCollection[Chunk]):
connection_init_sqls,
autosharding=self.config.connection_config.autosharding,
max_connections=self.config.connection_config.max_connections,
write_batch_size=self.config.connection_config.write_batch_size))
write_batch_size=self.config.connection_config.write_batch_size,
**self.config.connection_config.additional_jdbc_args))
88 changes: 88 additions & 0 deletions sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from apache_beam.coders.row_coder import RowCoder
from apache_beam.io.jdbc import ReadFromJdbc
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBConnectionConfig
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBLanguageConnectorConfig
from apache_beam.ml.rag.ingestion.alloydb import AlloyDBVectorWriterConfig
from apache_beam.ml.rag.ingestion.alloydb import ColumnSpec
from apache_beam.ml.rag.ingestion.alloydb import ColumnSpecsBuilder
Expand Down Expand Up @@ -328,6 +329,93 @@ def test_default_schema(self):
equal_to([expected_last_n]),
label=f"last_{sample_size}_check")

def test_language_connector(self):
"""Test language connector."""
self.skip_if_dataflow_runner()

connector_options = AlloyDBLanguageConnectorConfig(
database_name=self.database,
instance_name="projects/apache-beam-testing/locations/us-central1/\
clusters/testing-psc/instances/testing-psc-1",
ip_type="PSC")
connection_config = AlloyDBConnectionConfig.with_language_connector(
connector_options=connector_options,
username=self.username,
password=self.password)
config = AlloyDBVectorWriterConfig(
connection_config=connection_config, table_name=self.default_table_name)

# Create test chunks
num_records = 150
sample_size = min(500, num_records // 2)
chunks = ChunkTestUtils.get_expected_values(0, num_records)

self.write_test_pipeline.not_use_test_runner_api = True

with self.write_test_pipeline as p:
_ = (p | beam.Create(chunks) | config.create_write_transform())

self.read_test_pipeline.not_use_test_runner_api = True
read_query = f"""
SELECT
CAST(id AS VARCHAR(255)),
CAST(content AS VARCHAR(255)),
CAST(embedding AS text),
CAST(metadata AS text)
FROM {self.default_table_name}
"""

with self.read_test_pipeline as p:
rows = (
p
| ReadFromJdbc(
table_name=self.default_table_name,
driver_class_name="org.postgresql.Driver",
jdbc_url=connector_options.to_jdbc_url(),
username=self.username,
password=self.password,
query=read_query,
classpath=[
"org.postgresql:postgresql:42.2.16",
"com.google.cloud:alloydb-jdbc-connector:1.2.0"
]))

count_result = rows | "Count All" >> beam.combiners.Count.Globally()
assert_that(count_result, equal_to([num_records]), label='count_check')

chunks = (rows | "To Chunks" >> beam.Map(row_to_chunk))
chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally(HashingFn())
assert_that(
chunk_hashes,
equal_to([generate_expected_hash(num_records)]),
label='hash_check')

# Sample validation
first_n = (
chunks
| "Key on Index" >> beam.Map(key_on_id)
| f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of(
sample_size, key=lambda x: x[0], reverse=True)
| "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs]))
expected_first_n = ChunkTestUtils.get_expected_values(0, sample_size)
assert_that(
first_n,
equal_to([expected_first_n]),
label=f"first_{sample_size}_check")

last_n = (
chunks
| "Key on Index 2" >> beam.Map(key_on_id)
| f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of(
sample_size, key=lambda x: x[0])
| "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs]))
expected_last_n = ChunkTestUtils.get_expected_values(
num_records - sample_size, num_records)[::-1]
assert_that(
last_n,
equal_to([expected_last_n]),
label=f"last_{sample_size}_check")

def test_custom_specs(self):
"""Test custom specifications for ID, embedding, and content."""
self.skip_if_dataflow_runner()
Expand Down

0 comments on commit 0733fa5

Please sign in to comment.