Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: prototype MDX protocol #1264

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions google/api/field_behavior_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# type: ignore
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: google/api/field_behavior.proto
# isort: skip_file
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder

# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b"\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3"
)

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "google.api.field_behavior_pb2", _globals
)
if _descriptor._USE_C_DESCRIPTORS == False:
google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension(
field_behavior
)

DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI"
_globals["_FIELDBEHAVIOR"]._serialized_start = 82
_globals["_FIELDBEHAVIOR"]._serialized_end = 264
# @@protoc_insertion_point(module_scope)
8 changes: 7 additions & 1 deletion google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __init__(
self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT
else:
self._sqladmin_api_endpoint = sqladmin_api_endpoint
# asyncpg does not currently support using metadata exchange
# only use metadata exchange for sync drivers
self._use_metadata = False if driver == "asyncpg" else True
self._user_agent = user_agent

async def _get_metadata(
Expand Down Expand Up @@ -204,7 +207,10 @@ async def _get_ephemeral(

url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}:generateEphemeralCert"

data = {"public_key": pub_key}
data = {
"public_key": pub_key,
"use_metadata_exchange": self._use_metadata,
}

if enable_iam_auth:
# down-scope credentials with only IAM login scope (refreshes them too)
Expand Down
92 changes: 91 additions & 1 deletion google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import logging
import os
import socket
import struct
from threading import Thread
from types import TracebackType
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import google.auth
from google.auth.credentials import Credentials
Expand All @@ -44,11 +45,17 @@
from google.cloud.sql.connector.resolver import DnsResolver
from google.cloud.sql.connector.utils import format_database_user
from google.cloud.sql.connector.utils import generate_keys
import google.cloud.sql.proto.cloud_sql_metadata_exchange_pb2 as connectorspb

if TYPE_CHECKING:
import ssl

logger = logging.getLogger(name=__name__)

ASYNC_DRIVERS = ["asyncpg"]
SERVER_PROXY_PORT = 3307
# the maximum amount of time to wait before aborting a metadata exchange
IO_TIMEOUT = 30
_DEFAULT_SCHEME = "https://"
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
Expand Down Expand Up @@ -391,6 +398,9 @@ async def connect_async(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# Perform Metadata Exchange Protocol
metadata_partial = partial(self.metadata_exchange, sock)
sock = await self._loop.run_in_executor(None, metadata_partial)
# If this connection was opened using a domain name, then store it
# for later in case we need to forcibly close it on failover.
if conn_info.conn_name.domain_name:
Expand All @@ -409,6 +419,86 @@ async def connect_async(
await monitored_cache.force_refresh()
raise

def metadata_exchange(self, sock: ssl.SSLSocket) -> ssl.SSLSocket:
"""
Sends metadata about the connection prior to the database
protocol taking over.
The exchange consists of four steps:
1. Prepare a CloudSQLConnectRequest including the socket protocol and
the user agent.
2. Write the size of the message as a big endian uint32 (4 bytes) to
the server followed by the serialized message. The length does not
include the initial four bytes.
3. Read a big endian uint32 (4 bytes) from the server. This is the
CloudSQLConnectResponse message length and does not include the
initial four bytes.
4. Parse the response using the message length in step 3. If the
response is not OK, return the response's error. If there is no error,
the metadata exchange has succeeded and the connection is complete.
Args:
sock (ssl.SSLSocket): The mTLS/SSL socket to perform metadata
exchange on.
Returns:
sock (ssl.SSLSocket): mTLS/SSL socket connected to Cloud SQL Proxy
server.
"""
# form metadata exchange request
req = connectorspb.CloudSQLConnectRequest(
user_agent=f"{self._client._user_agent}", # type: ignore
protocol_type=connectorspb.CloudSQLConnectRequest.TCP,
)

# set I/O timeout
sock.settimeout(IO_TIMEOUT)

# pack big-endian unsigned integer (4 bytes)
packed_len = struct.pack(">I", req.ByteSize())

# send metadata message length and request message
sock.sendall(packed_len + req.SerializeToString())

# form metadata exchange response
resp = connectorspb.CloudSQLConnectResponse()

# read metadata message length (4 bytes)
message_len_buffer_size = struct.Struct(">I").size
message_len_buffer = b""
while message_len_buffer_size > 0:
chunk = sock.recv(message_len_buffer_size)
if not chunk:
raise RuntimeError(
"Connection closed while getting metadata exchange length!"
)
message_len_buffer += chunk
message_len_buffer_size -= len(chunk)

(message_len,) = struct.unpack(">I", message_len_buffer)

# read metadata exchange message
buffer = b""
while message_len > 0:
chunk = sock.recv(message_len)
if not chunk:
raise RuntimeError(
"Connection closed while performing metadata exchange!"
)
buffer += chunk
message_len -= len(chunk)

# parse metadata exchange response from buffer
resp.ParseFromString(buffer)

# reset socket back to blocking mode
sock.setblocking(True)

# validate metadata exchange response
if resp.response_code != connectorspb.CloudSQLConnectResponse.OK:
raise ValueError(
f"Metadata Exchange request has failed with error: {resp.error}"
)

return sock

async def _remove_cached(
self, instance_connection_string: str, enable_iam_auth: bool
) -> None:
Expand Down
56 changes: 56 additions & 0 deletions google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder

# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n:google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto\x12\x18google.cloud.sql.v1beta4\x1a\x1fgoogle/api/field_behavior.proto"\xc8\x01\n\x16\x43loudSQLConnectRequest\x12\x17\n\nuser_agent\x18\x01 \x01(\tB\x03\xe0\x41\x01\x12T\n\rprotocol_type\x18\x02 \x01(\x0e\x32=.google.cloud.sql.v1beta4.CloudSQLConnectRequest.ProtocolType"?\n\x0cProtocolType\x12\x1d\n\x19PROTOCOL_TYPE_UNSPECIFIED\x10\x00\x12\x07\n\x03TCP\x10\x01\x12\x07\n\x03UDS\x10\x02"\xc6\x01\n\x17\x43loudSQLConnectResponse\x12U\n\rresponse_code\x18\x01 \x01(\x0e\x32>.google.cloud.sql.v1beta4.CloudSQLConnectResponse.ResponseCode\x12\x12\n\x05\x65rror\x18\x02 \x01(\tB\x03\xe0\x41\x01"@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x62\x06proto3'
)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "google.cloud.sql.v1beta4.cloud_sql_metadata_exchange_pb2", globals()
)
if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
_CLOUDSQLCONNECTREQUEST.fields_by_name["user_agent"]._options = None
_CLOUDSQLCONNECTREQUEST.fields_by_name["user_agent"]._serialized_options = (
b"\340A\001"
)
_CLOUDSQLCONNECTRESPONSE.fields_by_name["error"]._options = None
_CLOUDSQLCONNECTRESPONSE.fields_by_name["error"]._serialized_options = b"\340A\001"
_CLOUDSQLCONNECTREQUEST._serialized_start = 122
_CLOUDSQLCONNECTREQUEST._serialized_end = 322
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_start = 259
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_end = 322
_CLOUDSQLCONNECTRESPONSE._serialized_start = 325
_CLOUDSQLCONNECTRESPONSE._serialized_end = 523
_CLOUDSQLCONNECTRESPONSE_RESPONSECODE._serialized_start = 459
_CLOUDSQLCONNECTRESPONSE_RESPONSECODE._serialized_end = 523
# @@protoc_insertion_point(module_scope)
67 changes: 67 additions & 0 deletions google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import ClassVar as _ClassVar
from typing import Optional as _Optional
from typing import Union as _Union

from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper

from google.api import field_behavior_pb2 as _field_behavior_pb2

DESCRIPTOR: _descriptor.FileDescriptor

class CloudSQLConnectRequest(_message.Message):
__slots__ = ["protocol_type", "user_agent"]

class ProtocolType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = [] # type: ignore

PROTOCOL_TYPE_FIELD_NUMBER: _ClassVar[int]
PROTOCOL_TYPE_UNSPECIFIED: CloudSQLConnectRequest.ProtocolType
TCP: CloudSQLConnectRequest.ProtocolType
UDS: CloudSQLConnectRequest.ProtocolType
USER_AGENT_FIELD_NUMBER: _ClassVar[int]
protocol_type: CloudSQLConnectRequest.ProtocolType
user_agent: str
def __init__(
self,
user_agent: _Optional[str] = ...,
protocol_type: _Optional[
_Union[CloudSQLConnectRequest.ProtocolType, str]
] = ...,
) -> None: ...

class CloudSQLConnectResponse(_message.Message):
__slots__ = ["error", "response_code"]

class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = [] # type: ignore

ERROR: CloudSQLConnectResponse.ResponseCode
ERROR_FIELD_NUMBER: _ClassVar[int]
OK: CloudSQLConnectResponse.ResponseCode
RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int]
RESPONSE_CODE_UNSPECIFIED: CloudSQLConnectResponse.ResponseCode
error: str
response_code: CloudSQLConnectResponse.ResponseCode
def __init__(
self,
response_code: _Optional[
_Union[CloudSQLConnectResponse.ResponseCode, str]
] = ...,
error: _Optional[str] = ...,
) -> None: ...
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"dnspython>=2.0.0",
"Requests",
"google-auth>=2.28.0",
"protobuf",
]
dynamic = ["version"]

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ cryptography==44.0.2
dnspython==2.7.0
Requests==2.32.3
google-auth==2.38.0
protobuf==6.30.0
Loading
Loading