From 11bc165e8b9c586f887b4a046493d3cf0842db02 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 03:57:07 +0000 Subject: [PATCH 01/15] preliminary (robust) SEA HTTP Client Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 335 +++++++++++++----- 1 file changed, 237 insertions(+), 98 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..25811462a 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,11 +1,19 @@ import json import logging -import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +import ssl +import urllib.parse +import urllib.request +from typing import Dict, Any, Optional, List, Tuple, Union from urllib.parse import urljoin +from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager +from urllib3.exceptions import HTTPError, MaxRetryError +from urllib3.util import make_headers + from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy from databricks.sql.types import SSLOptions +from databricks.sql.exc import RequestError logger = logging.getLogger(__name__) @@ -14,8 +22,8 @@ class SeaHttpClient: """ HTTP client for Statement Execution API (SEA). - This client handles the HTTP communication with the SEA endpoints, - including authentication, request formatting, and response parsing. + This client uses urllib3 for robust HTTP communication with retry policies + and connection pooling, similar to the Thrift HTTP client but simplified. """ def __init__( @@ -38,48 +46,136 @@ def __init__( http_headers: List of HTTP headers to include in requests auth_provider: Authentication provider ssl_options: SSL configuration options - **kwargs: Additional keyword arguments + **kwargs: Additional keyword arguments including retry policy settings """ self.server_hostname = server_hostname - self.port = port + self.port = port or 443 self.http_path = http_path self.auth_provider = auth_provider self.ssl_options = ssl_options - self.base_url = f"https://{server_hostname}:{port}" + # Build base URL + self.base_url = f"https://{server_hostname}:{self.port}" + + # Parse URL for proxy handling + parsed_url = urllib.parse.urlparse(self.base_url) + self.scheme = parsed_url.scheme + self.host = parsed_url.hostname + self.port = parsed_url.port or (443 if self.scheme == "https" else 80) + # Setup headers self.headers: Dict[str, str] = dict(http_headers) self.headers.update({"Content-Type": "application/json"}) - self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) - - # Create a session for connection pooling - self.session = requests.Session() + # Extract retry policy settings + self._retry_delay_min = kwargs.get("_retry_delay_min", 1.0) + self._retry_delay_max = kwargs.get("_retry_delay_max", 60.0) + self._retry_stop_after_attempts_count = kwargs.get( + "_retry_stop_after_attempts_count", 30 + ) + self._retry_stop_after_attempts_duration = kwargs.get( + "_retry_stop_after_attempts_duration", 900.0 + ) + self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0) + self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + + # Connection pooling settings + self.max_connections = kwargs.get("max_connections", 10) + + # Setup retry policy + self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) + + if self.enable_v3_retries: + self.retry_policy = DatabricksRetryPolicy( + delay_min=self._retry_delay_min, + delay_max=self._retry_delay_max, + stop_after_attempts_count=self._retry_stop_after_attempts_count, + stop_after_attempts_duration=self._retry_stop_after_attempts_duration, + delay_default=self._retry_delay_default, + force_dangerous_codes=self.force_dangerous_codes, + urllib3_kwargs={"allowed_methods": ["GET", "POST", "DELETE"]}, + ) + else: + # Legacy behavior - no automatic retries + self.retry_policy = 0 - # Configure SSL verification - if ssl_options.tls_verify: - self.session.verify = ssl_options.tls_trusted_ca_file or True + # Handle proxy settings + try: + proxy = urllib.request.getproxies().get(self.scheme) + except (KeyError, AttributeError): + proxy = None + else: + if urllib.request.proxy_bypass(self.host): + proxy = None + + if proxy: + parsed_proxy = urllib.parse.urlparse(proxy) + self.realhost = self.host + self.realport = self.port + self.proxy_uri = proxy + self.host = parsed_proxy.hostname + self.port = parsed_proxy.port + self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy) else: - self.session.verify = False - - # Configure client certificates if provided - if ssl_options.tls_client_cert_file: - client_cert = ssl_options.tls_client_cert_file - client_key = ssl_options.tls_client_cert_key_file - client_key_password = ssl_options.tls_client_cert_key_password - - if client_key: - self.session.cert = (client_cert, client_key) - else: - self.session.cert = client_cert - - if client_key_password: - # Note: requests doesn't directly support key passwords - # This would require more complex handling with libraries like pyOpenSSL - logger.warning( - "Client key password provided but not supported by requests library" - ) + self.realhost = self.realport = self.proxy_auth = self.proxy_uri = None + + # Initialize connection pool + self._pool = None + self._open() + + def _basic_proxy_auth_headers(self, proxy_parsed) -> Optional[Dict[str, str]]: + """Create basic auth headers for proxy if credentials are provided.""" + if proxy_parsed is None or not proxy_parsed.username: + return None + ap = f"{urllib.parse.unquote(proxy_parsed.username)}:{urllib.parse.unquote(proxy_parsed.password)}" + return make_headers(proxy_basic_auth=ap) + + def _open(self): + """Initialize the connection pool.""" + pool_kwargs = {"maxsize": self.max_connections} + + if self.scheme == "http": + pool_class = HTTPConnectionPool + else: # https + pool_class = HTTPSConnectionPool + pool_kwargs.update({ + "cert_reqs": ssl.CERT_REQUIRED if self.ssl_options.tls_verify else ssl.CERT_NONE, + "ca_certs": self.ssl_options.tls_trusted_ca_file, + "cert_file": self.ssl_options.tls_client_cert_file, + "key_file": self.ssl_options.tls_client_cert_key_file, + "key_password": self.ssl_options.tls_client_cert_key_password, + }) + + if self.proxy_uri: + proxy_manager = ProxyManager( + self.proxy_uri, + num_pools=1, + proxy_headers=self.proxy_auth, + ) + self._pool = proxy_manager.connection_from_host( + host=self.realhost, + port=self.realport, + scheme=self.scheme, + pool_kwargs=pool_kwargs, + ) + else: + self._pool = pool_class(self.host, self.port, **pool_kwargs) + + def close(self): + """Close the connection pool.""" + if self._pool: + self._pool.clear() + + def set_retry_command_type(self, command_type: CommandType): + """Set the command type for retry policy decision making.""" + if isinstance(self.retry_policy, DatabricksRetryPolicy): + self.retry_policy.command_type = command_type + + def start_retry_timer(self): + """Start the retry timer for duration-based retry limits.""" + if isinstance(self.retry_policy, DatabricksRetryPolicy): + self.retry_policy.start_retry_timer() def _get_auth_headers(self) -> Dict[str, str]: """Get authentication headers from the auth provider.""" @@ -87,17 +183,6 @@ def _get_auth_headers(self) -> Dict[str, str]: self.auth_provider.add_headers(headers) return headers - def _get_call(self, method: str) -> Callable: - """Get the appropriate HTTP method function.""" - method = method.upper() - if method == "GET": - return self.session.get - if method == "POST": - return self.session.post - if method == "DELETE": - return self.session.delete - raise ValueError(f"Unsupported HTTP method: {method}") - def _make_request( self, method: str, @@ -118,69 +203,123 @@ def _make_request( Dict[str, Any]: Response data parsed from JSON Raises: - RequestError: If the request fails + RequestError: If the request fails after retries """ - url = urljoin(self.base_url, path) - headers: Dict[str, str] = {**self.headers, **self._get_auth_headers()} + # Build full URL + if path.startswith("/"): + url = path + else: + url = f"/{path.lstrip('/')}" + + # Prepare headers + headers = {**self.headers, **self._get_auth_headers()} - logger.debug(f"making {method} request to {url}") + # Prepare request body + body = json.dumps(data).encode("utf-8") if data else b"" + if body: + headers["Content-Length"] = str(len(body)) + + # Set command type for retry policy + command_type = self._get_command_type_from_path(path, method) + self.set_retry_command_type(command_type) + self.start_retry_timer() + + logger.debug(f"Making {method} request to {url}") try: - call = self._get_call(method) - response = call( + response = self._pool.request( + method=method.upper(), url=url, + body=body, headers=headers, - json=data, - params=params, + preload_content=True, + retries=self.retry_policy, ) - # Check for HTTP errors - response.raise_for_status() - - # Log response details - logger.debug(f"Response status: {response.status_code}") - - # Parse JSON response - if response.content: - result = response.json() - # Log response content (but limit it for large responses) - content_str = json.dumps(result) - if len(content_str) > 1000: - logger.debug( - f"Response content (truncated): {content_str[:1000]}..." - ) + logger.debug(f"Response status: {response.status}") + + # Handle successful responses + if 200 <= response.status < 300: + if response.data: + try: + result = json.loads(response.data.decode("utf-8")) + logger.debug("Successfully parsed JSON response") + return result + except (json.JSONDecodeError, UnicodeDecodeError) as e: + logger.error(f"Failed to parse JSON response: {e}") + raise RequestError(f"Invalid JSON response: {e}", e) + return {} + + # Handle error responses + error_message = f"SEA HTTP request failed with status {response.status}" + + try: + if response.data: + error_details = json.loads(response.data.decode("utf-8")) + if isinstance(error_details, dict) and "message" in error_details: + error_message = f"{error_message}: {error_details['message']}" + logger.error(f"Request failed: {error_details}") + except (json.JSONDecodeError, UnicodeDecodeError): + # Log raw response if we can't parse JSON + content = response.data.decode("utf-8", errors="replace") if response.data else "" + logger.error(f"Request failed with non-JSON response: {content}") + + raise RequestError(error_message, None) + + except MaxRetryError as e: + # Extract the most recent error from the retry history + error_message = f"SEA request failed after retries: {str(e)}" + + if hasattr(e, "reason") and e.reason: + if hasattr(e.reason, "response"): + # Extract status code and body from the final failed response + response = e.reason.response + error_message = f"SEA request failed after retries (status {response.status})" + try: + if response.data: + error_details = json.loads(response.data.decode("utf-8")) + if isinstance(error_details, dict) and "message" in error_details: + error_message = f"{error_message}: {error_details['message']}" + except (json.JSONDecodeError, UnicodeDecodeError): + pass else: - logger.debug(f"Response content: {content_str}") - return result - return {} - - except requests.exceptions.RequestException as e: - # Handle request errors and extract details from response if available - error_message = f"SEA HTTP request failed: {str(e)}" - - if hasattr(e, "response") and e.response is not None: - status_code = e.response.status_code - try: - error_details = e.response.json() - error_message = ( - f"{error_message}: {error_details.get('message', '')}" - ) - logger.error( - f"Request failed (status {status_code}): {error_details}" - ) - except (ValueError, KeyError): - # If we can't parse JSON, log raw content - content = ( - e.response.content.decode("utf-8", errors="replace") - if isinstance(e.response.content, bytes) - else str(e.response.content) - ) - logger.error(f"Request failed (status {status_code}): {content}") - else: - logger.error(error_message) - - # Re-raise as a RequestError - from databricks.sql.exc import RequestError + error_message = f"SEA request failed after retries: {str(e.reason)}" + + logger.error(error_message) + raise RequestError(error_message, e) + except HTTPError as e: + error_message = f"SEA HTTP error: {str(e)}" + logger.error(error_message) raise RequestError(error_message, e) + + except Exception as e: + error_message = f"Unexpected error in SEA request: {str(e)}" + logger.error(error_message) + raise RequestError(error_message, e) + + def _get_command_type_from_path(self, path: str, method: str) -> CommandType: + """ + Determine the command type based on the API path and method. + + This helps the retry policy make appropriate decisions for different + types of SEA operations. + """ + path = path.lower() + method = method.upper() + + if "/statements" in path: + if method == "POST" and path.endswith("/statements"): + return CommandType.EXECUTE_STATEMENT + elif "/cancel" in path: + return CommandType.OTHER # Cancel operation + elif method == "DELETE": + return CommandType.CLOSE_OPERATION + elif method == "GET": + return CommandType.GET_OPERATION_STATUS + elif "/sessions" in path: + if method == "DELETE": + return CommandType.CLOSE_SESSION + + return CommandType.OTHER From d389316ca9334c78c9957e440d93ee0e3ec57c60 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 04:00:12 +0000 Subject: [PATCH 02/15] prevent catching of MaxRetryError and HttpError in client Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 110 +++++++----------- 1 file changed, 39 insertions(+), 71 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 25811462a..054d57c00 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -7,7 +7,6 @@ from urllib.parse import urljoin from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager -from urllib3.exceptions import HTTPError, MaxRetryError from urllib3.util import make_headers from databricks.sql.auth.authenticators import AuthProvider @@ -227,77 +226,46 @@ def _make_request( logger.debug(f"Making {method} request to {url}") - try: - response = self._pool.request( - method=method.upper(), - url=url, - body=body, - headers=headers, - preload_content=True, - retries=self.retry_policy, - ) + # When v3 retries are enabled, urllib3 handles retries internally via DatabricksRetryPolicy + # When disabled, we let exceptions bubble up (similar to Thrift backend approach) + response = self._pool.request( + method=method.upper(), + url=url, + body=body, + headers=headers, + preload_content=True, + retries=self.retry_policy, + ) - logger.debug(f"Response status: {response.status}") - - # Handle successful responses - if 200 <= response.status < 300: - if response.data: - try: - result = json.loads(response.data.decode("utf-8")) - logger.debug("Successfully parsed JSON response") - return result - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error(f"Failed to parse JSON response: {e}") - raise RequestError(f"Invalid JSON response: {e}", e) - return {} - - # Handle error responses - error_message = f"SEA HTTP request failed with status {response.status}" - - try: - if response.data: - error_details = json.loads(response.data.decode("utf-8")) - if isinstance(error_details, dict) and "message" in error_details: - error_message = f"{error_message}: {error_details['message']}" - logger.error(f"Request failed: {error_details}") - except (json.JSONDecodeError, UnicodeDecodeError): - # Log raw response if we can't parse JSON - content = response.data.decode("utf-8", errors="replace") if response.data else "" - logger.error(f"Request failed with non-JSON response: {content}") - - raise RequestError(error_message, None) - - except MaxRetryError as e: - # Extract the most recent error from the retry history - error_message = f"SEA request failed after retries: {str(e)}" - - if hasattr(e, "reason") and e.reason: - if hasattr(e.reason, "response"): - # Extract status code and body from the final failed response - response = e.reason.response - error_message = f"SEA request failed after retries (status {response.status})" - try: - if response.data: - error_details = json.loads(response.data.decode("utf-8")) - if isinstance(error_details, dict) and "message" in error_details: - error_message = f"{error_message}: {error_details['message']}" - except (json.JSONDecodeError, UnicodeDecodeError): - pass - else: - error_message = f"SEA request failed after retries: {str(e.reason)}" - - logger.error(error_message) - raise RequestError(error_message, e) - - except HTTPError as e: - error_message = f"SEA HTTP error: {str(e)}" - logger.error(error_message) - raise RequestError(error_message, e) - - except Exception as e: - error_message = f"Unexpected error in SEA request: {str(e)}" - logger.error(error_message) - raise RequestError(error_message, e) + logger.debug(f"Response status: {response.status}") + + # Handle successful responses + if 200 <= response.status < 300: + if response.data: + try: + result = json.loads(response.data.decode("utf-8")) + logger.debug("Successfully parsed JSON response") + return result + except (json.JSONDecodeError, UnicodeDecodeError) as e: + logger.error(f"Failed to parse JSON response: {e}") + raise RequestError(f"Invalid JSON response: {e}", e) + return {} + + # Handle error responses + error_message = f"SEA HTTP request failed with status {response.status}" + + try: + if response.data: + error_details = json.loads(response.data.decode("utf-8")) + if isinstance(error_details, dict) and "message" in error_details: + error_message = f"{error_message}: {error_details['message']}" + logger.error(f"Request failed: {error_details}") + except (json.JSONDecodeError, UnicodeDecodeError): + # Log raw response if we can't parse JSON + content = response.data.decode("utf-8", errors="replace") if response.data else "" + logger.error(f"Request failed with non-JSON response: {content}") + + raise RequestError(error_message, None) def _get_command_type_from_path(self, path: str, method: str) -> CommandType: """ From cc48cafe4a834a4aaff13a1bb8c21a7d732e981f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 04:00:35 +0000 Subject: [PATCH 03/15] formatting (black) Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 054d57c00..096977eb3 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -56,7 +56,7 @@ def __init__( # Build base URL self.base_url = f"https://{server_hostname}:{self.port}" - + # Parse URL for proxy handling parsed_url = urllib.parse.urlparse(self.base_url) self.scheme = parsed_url.scheme @@ -84,7 +84,7 @@ def __init__( # Setup retry policy self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) - + if self.enable_v3_retries: self.retry_policy = DatabricksRetryPolicy( delay_min=self._retry_delay_min, @@ -138,13 +138,17 @@ def _open(self): pool_class = HTTPConnectionPool else: # https pool_class = HTTPSConnectionPool - pool_kwargs.update({ - "cert_reqs": ssl.CERT_REQUIRED if self.ssl_options.tls_verify else ssl.CERT_NONE, - "ca_certs": self.ssl_options.tls_trusted_ca_file, - "cert_file": self.ssl_options.tls_client_cert_file, - "key_file": self.ssl_options.tls_client_cert_key_file, - "key_password": self.ssl_options.tls_client_cert_key_password, - }) + pool_kwargs.update( + { + "cert_reqs": ssl.CERT_REQUIRED + if self.ssl_options.tls_verify + else ssl.CERT_NONE, + "ca_certs": self.ssl_options.tls_trusted_ca_file, + "cert_file": self.ssl_options.tls_client_cert_file, + "key_file": self.ssl_options.tls_client_cert_key_file, + "key_password": self.ssl_options.tls_client_cert_key_password, + } + ) if self.proxy_uri: proxy_manager = ProxyManager( @@ -253,7 +257,7 @@ def _make_request( # Handle error responses error_message = f"SEA HTTP request failed with status {response.status}" - + try: if response.data: error_details = json.loads(response.data.decode("utf-8")) @@ -262,7 +266,9 @@ def _make_request( logger.error(f"Request failed: {error_details}") except (json.JSONDecodeError, UnicodeDecodeError): # Log raw response if we can't parse JSON - content = response.data.decode("utf-8", errors="replace") if response.data else "" + content = ( + response.data.decode("utf-8", errors="replace") if response.data else "" + ) logger.error(f"Request failed with non-JSON response: {content}") raise RequestError(error_message, None) @@ -270,13 +276,13 @@ def _make_request( def _get_command_type_from_path(self, path: str, method: str) -> CommandType: """ Determine the command type based on the API path and method. - + This helps the retry policy make appropriate decisions for different types of SEA operations. """ path = path.lower() method = method.upper() - + if "/statements" in path: if method == "POST" and path.endswith("/statements"): return CommandType.EXECUTE_STATEMENT @@ -289,5 +295,5 @@ def _get_command_type_from_path(self, path: str, method: str) -> CommandType: elif "/sessions" in path: if method == "DELETE": return CommandType.CLOSE_SESSION - + return CommandType.OTHER From 6a1274ffe5025beb73248b9168786aa422c06873 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 27 Jun 2025 04:19:37 +0000 Subject: [PATCH 04/15] fix type annotations Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 096977eb3..5a5755762 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -25,6 +25,13 @@ class SeaHttpClient: and connection pooling, similar to the Thrift HTTP client but simplified. """ + retry_policy: Union[DatabricksRetryPolicy, int] + _pool: Optional[Union[HTTPConnectionPool, HTTPSConnectionPool]] + proxy_uri: Optional[str] + realhost: Optional[str] + realport: Optional[int] + proxy_auth: Optional[Dict[str, str]] + def __init__( self, server_hostname: str, @@ -105,7 +112,7 @@ def __init__( except (KeyError, AttributeError): proxy = None else: - if urllib.request.proxy_bypass(self.host): + if self.host and urllib.request.proxy_bypass(self.host): proxy = None if proxy: @@ -114,10 +121,13 @@ def __init__( self.realport = self.port self.proxy_uri = proxy self.host = parsed_proxy.hostname - self.port = parsed_proxy.port + self.port = parsed_proxy.port or (443 if self.scheme == "https" else 80) self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy) else: - self.realhost = self.realport = self.proxy_auth = self.proxy_uri = None + self.realhost = None + self.realport = None + self.proxy_auth = None + self.proxy_uri = None # Initialize connection pool self._pool = None @@ -150,7 +160,7 @@ def _open(self): } ) - if self.proxy_uri: + if self.using_proxy(): proxy_manager = ProxyManager( self.proxy_uri, num_pools=1, @@ -170,6 +180,10 @@ def close(self): if self._pool: self._pool.clear() + def using_proxy(self) -> bool: + """Check if proxy is being used (for compatibility with Thrift client).""" + return self.realhost is not None + def set_retry_command_type(self, command_type: CommandType): """Set the command type for retry policy decision making.""" if isinstance(self.retry_policy, DatabricksRetryPolicy): @@ -232,6 +246,9 @@ def _make_request( # When v3 retries are enabled, urllib3 handles retries internally via DatabricksRetryPolicy # When disabled, we let exceptions bubble up (similar to Thrift backend approach) + if self._pool is None: + raise RequestError("Connection pool not initialized", None) + response = self._pool.request( method=method.upper(), url=url, From 4651cd608f875bd3f530b375cbe3ec358b657858 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 30 Jun 2025 02:20:53 +0000 Subject: [PATCH 05/15] pass test_retry_exponential_backoff Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 5a5755762..382997e3a 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -12,7 +12,7 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy from databricks.sql.types import SSLOptions -from databricks.sql.exc import RequestError +from databricks.sql.exc import RequestError, MaxRetryDurationError logger = logging.getLogger(__name__) @@ -249,14 +249,21 @@ def _make_request( if self._pool is None: raise RequestError("Connection pool not initialized", None) - response = self._pool.request( - method=method.upper(), - url=url, - body=body, - headers=headers, - preload_content=True, - retries=self.retry_policy, - ) + try: + response = self._pool.request( + method=method.upper(), + url=url, + body=body, + headers=headers, + preload_content=True, + retries=self.retry_policy, + ) + except MaxRetryDurationError as e: + # MaxRetryDurationError is raised directly by DatabricksRetryPolicy + # when duration limits are exceeded (like in test_retry_exponential_backoff) + error_message = f"Request failed due to retry duration limit: {e}" + # Construct RequestError with message, context, and specific error (like Thrift backend) + raise RequestError(error_message, None, e) logger.debug(f"Response status: {response.status}") From d67eb7bc418a53c332e739bec753527f9c7734b1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 30 Jun 2025 02:25:43 +0000 Subject: [PATCH 06/15] prevent parsing empty response data (get test_retry_abort_non_recoverable_error to pass) Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 382997e3a..b15527891 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -284,16 +284,30 @@ def _make_request( try: if response.data: - error_details = json.loads(response.data.decode("utf-8")) - if isinstance(error_details, dict) and "message" in error_details: - error_message = f"{error_message}: {error_details['message']}" - logger.error(f"Request failed: {error_details}") - except (json.JSONDecodeError, UnicodeDecodeError): - # Log raw response if we can't parse JSON - content = ( - response.data.decode("utf-8", errors="replace") if response.data else "" - ) - logger.error(f"Request failed with non-JSON response: {content}") + decoded_data = response.data.decode("utf-8") + # Ensure we have a string before attempting JSON parsing + if isinstance(decoded_data, str): + error_details = json.loads(decoded_data) + if isinstance(error_details, dict) and "message" in error_details: + error_message = f"{error_message}: {error_details['message']}" + logger.error(f"Request failed: {error_details}") + else: + # Handle case where decode returns non-string (e.g., MagicMock in tests) + logger.error( + f"Request failed with non-string response data: {type(decoded_data)}" + ) + except (json.JSONDecodeError, UnicodeDecodeError, TypeError): + # Log raw response if we can't parse JSON or if we get unexpected types + try: + content = ( + response.data.decode("utf-8", errors="replace") + if response.data + else "" + ) + logger.error(f"Request failed with non-JSON response: {content}") + except (AttributeError, TypeError): + # Handle case where response.data itself might be a mock + logger.error(f"Request failed with unparseable response data") raise RequestError(error_message, None) From 2caf38d8c0de1a82203a57a2843ffa709a984507 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 30 Jun 2025 03:29:40 +0000 Subject: [PATCH 07/15] more defensive parsing, allow more method types in urllib3 Signed-off-by: varun-edachali-dbx --- src/databricks/sql/auth/retry.py | 4 +- .../sql/backend/sea/utils/http_client.py | 76 +++++++++++++------ 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 432ac687d..ed79fbf15 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -130,8 +130,8 @@ def __init__( allowed_methods=["POST"], status_forcelist=[429, 503, *self.force_dangerous_codes], ) - - urllib3_kwargs.update(**_urllib_kwargs_we_care_about) + _urllib_kwargs_we_care_about.update(**urllib3_kwargs) + urllib3_kwargs = _urllib_kwargs_we_care_about super().__init__( **urllib3_kwargs, diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index b15527891..068d6a0d9 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -12,7 +12,12 @@ from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy from databricks.sql.types import SSLOptions -from databricks.sql.exc import RequestError, MaxRetryDurationError +from databricks.sql.exc import ( + RequestError, + MaxRetryDurationError, + SessionAlreadyClosedError, + CursorAlreadyClosedError, +) logger = logging.getLogger(__name__) @@ -264,6 +269,12 @@ def _make_request( error_message = f"Request failed due to retry duration limit: {e}" # Construct RequestError with message, context, and specific error (like Thrift backend) raise RequestError(error_message, None, e) + except (SessionAlreadyClosedError, CursorAlreadyClosedError) as e: + # These exceptions are raised by DatabricksRetryPolicy when detecting + # "already closed" scenarios (404 responses with retry history) + error_message = f"Request failed: {e}" + # Construct RequestError with proper 3-argument format (message, context, error) like Thrift backend + raise RequestError(error_message, None, e) logger.debug(f"Response status: {response.status}") @@ -282,34 +293,51 @@ def _make_request( # Handle error responses error_message = f"SEA HTTP request failed with status {response.status}" + # Try to extract additional error details from response, but don't fail if we can't + error_message = self._try_add_error_details_to_message(response, error_message) + + raise RequestError(error_message, None) + + def _try_add_error_details_to_message(self, response, error_message: str) -> str: + """ + Try to extract error details from response and add to error message. + This method is defensive and will not raise exceptions if parsing fails. + It handles mock objects and malformed responses gracefully. + """ try: - if response.data: + # Check if response.data exists and is accessible + if not hasattr(response, "data") or response.data is None: + return error_message + + # Try to decode the response data + try: decoded_data = response.data.decode("utf-8") - # Ensure we have a string before attempting JSON parsing - if isinstance(decoded_data, str): - error_details = json.loads(decoded_data) - if isinstance(error_details, dict) and "message" in error_details: - error_message = f"{error_message}: {error_details['message']}" - logger.error(f"Request failed: {error_details}") - else: - # Handle case where decode returns non-string (e.g., MagicMock in tests) - logger.error( - f"Request failed with non-string response data: {type(decoded_data)}" - ) - except (json.JSONDecodeError, UnicodeDecodeError, TypeError): - # Log raw response if we can't parse JSON or if we get unexpected types + except (AttributeError, UnicodeDecodeError, TypeError): + # response.data might be a mock object or not bytes + return error_message + + # Ensure we have a string before attempting JSON parsing + if not isinstance(decoded_data, str): + return error_message + + # Try to parse as JSON try: - content = ( - response.data.decode("utf-8", errors="replace") - if response.data - else "" + error_details = json.loads(decoded_data) + if isinstance(error_details, dict) and "message" in error_details: + enhanced_message = f"{error_message}: {error_details['message']}" + logger.error(f"Request failed: {error_details}") + return enhanced_message + except json.JSONDecodeError: + # Not valid JSON, log what we can + logger.debug( + f"Request failed with non-JSON response: {decoded_data[:200]}" ) - logger.error(f"Request failed with non-JSON response: {content}") - except (AttributeError, TypeError): - # Handle case where response.data itself might be a mock - logger.error(f"Request failed with unparseable response data") - raise RequestError(error_message, None) + except Exception: + # Catch-all for any unexpected issues (e.g., mock objects with unexpected behavior) + logger.debug("Could not parse error response data") + + return error_message def _get_command_type_from_path(self, path: str, method: str) -> CommandType: """ From 3e55ddd349779a281239f2eaf9992d693b7e9d36 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 30 Jun 2025 03:32:44 +0000 Subject: [PATCH 08/15] allow Any values in session_conf, cast to String as done in Thrift backend Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 78b05c065..3294cf406 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -3,7 +3,7 @@ import logging import time import re -from typing import Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set +from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest from databricks.sql.backend.sea.utils.constants import ( @@ -49,7 +49,7 @@ def _filter_session_configuration( - session_configuration: Optional[Dict[str, str]] + session_configuration: Optional[Dict[str, Any]] ) -> Optional[Dict[str, str]]: if not session_configuration: return None @@ -59,7 +59,7 @@ def _filter_session_configuration( for key, value in session_configuration.items(): if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: - filtered_session_configuration[key.lower()] = value + filtered_session_configuration[key.lower()] = str(value) else: ignored_configs.add(key) @@ -183,7 +183,7 @@ def max_download_threads(self) -> int: def open_session( self, - session_configuration: Optional[Dict[str, str]], + session_configuration: Optional[Dict[str, Any]], catalog: Optional[str], schema: Optional[str], ) -> SessionId: From 4afff395b6473a2385736b1a5e85bade74e7b982 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 30 Jun 2025 04:38:27 +0000 Subject: [PATCH 09/15] account for max_redirects in SEA backend Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 068d6a0d9..225ed3456 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -8,6 +8,7 @@ from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager from urllib3.util import make_headers +from urllib3.exceptions import MaxRetryError from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy @@ -98,6 +99,15 @@ def __init__( self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) if self.enable_v3_retries: + urllib3_kwargs = {"allowed_methods": ["GET", "POST", "DELETE"]} + _max_redirects = kwargs.get("_retry_max_redirects") + if _max_redirects: + if _max_redirects > self._retry_stop_after_attempts_count: + logger.warning( + "_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!" + ) + urllib3_kwargs["redirect"] = _max_redirects + self.retry_policy = DatabricksRetryPolicy( delay_min=self._retry_delay_min, delay_max=self._retry_delay_max, @@ -105,7 +115,7 @@ def __init__( stop_after_attempts_duration=self._retry_stop_after_attempts_duration, delay_default=self._retry_delay_default, force_dangerous_codes=self.force_dangerous_codes, - urllib3_kwargs={"allowed_methods": ["GET", "POST", "DELETE"]}, + urllib3_kwargs=urllib3_kwargs, ) else: # Legacy behavior - no automatic retries @@ -275,6 +285,18 @@ def _make_request( error_message = f"Request failed: {e}" # Construct RequestError with proper 3-argument format (message, context, error) like Thrift backend raise RequestError(error_message, None, e) + except MaxRetryError as e: + # urllib3 MaxRetryError should bubble up for redirect tests to catch + # Don't convert to RequestError, let the test framework handle it + logger.error(f"SEA HTTP request failed with MaxRetryError: {e}") + raise + except Exception as e: + # Broad exception handler like Thrift backend to catch any unexpected errors + # (including test mocking issues like StopIteration) + logger.error(f"SEA HTTP request failed with exception: {e}") + error_message = f"Error during request to server. {e}" + # Construct RequestError with proper 3-argument format (message, context, error) like Thrift backend + raise RequestError(error_message, None, e) logger.debug(f"Response status: {response.status}") From 01d49cd63aa229a6f855de8eb16cfa600aac17e1 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 1 Jul 2025 01:31:38 +0000 Subject: [PATCH 10/15] return empty JsonQueue if no data Signed-off-by: varun-edachali-dbx --- src/databricks/sql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 5e8a807e6..2cdfc8fe3 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -173,7 +173,7 @@ def build_queue( lz4_compressed=lz4_compressed, description=description, ) - raise ProgrammingError("No result data or external links found") + return JsonQueue([]) class JsonQueue(ResultSetQueue): From 3d8aa7f7ac556064dc5127548979ac45e88af3e7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 2 Jul 2025 08:15:03 +0000 Subject: [PATCH 11/15] do not preload content? Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/utils/http_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 225ed3456..4af6930b8 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -270,7 +270,7 @@ def _make_request( url=url, body=body, headers=headers, - preload_content=True, + preload_content=False, retries=self.retry_policy, ) except MaxRetryDurationError as e: From 51aa9bebfadeb4c36c1ffdd15e896da3ea0d67f8 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 7 Jul 2025 14:45:41 +0530 Subject: [PATCH 12/15] add sea tag on large queries Signed-off-by: varun-edachali-dbx --- tests/e2e/common/large_queries_mixin.py | 41 +++++++++++++++++++++---- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index 1181ef154..87d679946 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -2,6 +2,8 @@ import math import time +import pytest + log = logging.getLogger(__name__) @@ -42,7 +44,16 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): + "assuming 10K fetch size." ) - def test_query_with_large_wide_result_set(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_query_with_large_wide_result_set(self, extra_params): resultSize = 300 * 1000 * 1000 # 300 MB width = 8192 # B rows = resultSize // width @@ -52,7 +63,7 @@ def test_query_with_large_wide_result_set(self): fetchmany_size = 10 * 1024 * 1024 // width # This is used by PyHive tests to determine the buffer size self.arraysize = 1000 - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: for lz4_compression in [False, True]: cursor.connection.lz4_compression = lz4_compression uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) @@ -68,7 +79,16 @@ def test_query_with_large_wide_result_set(self): assert row[0] == row_id # Verify no rows are dropped in the middle. assert len(row[1]) == 36 - def test_query_with_large_narrow_result_set(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_query_with_large_narrow_result_set(self, extra_params): resultSize = 300 * 1000 * 1000 # 300 MB width = 8 # sizeof(long) rows = resultSize / width @@ -77,12 +97,21 @@ def test_query_with_large_narrow_result_set(self): fetchmany_size = 10 * 1024 * 1024 // width # This is used by PyHive tests to determine the buffer size self.arraysize = 10000000 - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: cursor.execute("SELECT * FROM RANGE({rows})".format(rows=rows)) for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): assert row[0] == row_id - def test_long_running_query(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + }, + ], + ) + def test_long_running_query(self, extra_params): """Incrementally increase query size until it takes at least 3 minutes, and asserts that the query completes successfully. """ @@ -92,7 +121,7 @@ def test_long_running_query(self): duration = -1 scale0 = 10000 scale_factor = 1 - with self.cursor() as cursor: + with self.cursor(extra_params) as cursor: while duration < min_duration: assert scale_factor < 1024, "Detected infinite loop" start = time.time() From 461e762aa7fb7d32babdba6677d34e6da1acc64a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Mon, 7 Jul 2025 18:18:15 +0530 Subject: [PATCH 13/15] simplify error handling Signed-off-by: varun-edachali-dbx --- .../sql/backend/sea/utils/http_client.py | 78 ++----------------- 1 file changed, 7 insertions(+), 71 deletions(-) diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index 4af6930b8..00edf4b36 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -220,7 +220,6 @@ def _make_request( method: str, path: str, data: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Make an HTTP request to the SEA endpoint. @@ -229,7 +228,6 @@ def _make_request( method: HTTP method (GET, POST, DELETE) path: API endpoint path data: Request payload data - params: Query parameters Returns: Dict[str, Any]: Response data parsed from JSON @@ -238,12 +236,6 @@ def _make_request( RequestError: If the request fails after retries """ - # Build full URL - if path.startswith("/"): - url = path - else: - url = f"/{path.lstrip('/')}" - # Prepare headers headers = {**self.headers, **self._get_auth_headers()} @@ -257,7 +249,7 @@ def _make_request( self.set_retry_command_type(command_type) self.start_retry_timer() - logger.debug(f"Making {method} request to {url}") + logger.debug(f"Making {method} request to {path}") # When v3 retries are enabled, urllib3 handles retries internally via DatabricksRetryPolicy # When disabled, we let exceptions bubble up (similar to Thrift backend approach) @@ -267,7 +259,7 @@ def _make_request( try: response = self._pool.request( method=method.upper(), - url=url, + url=path, body=body, headers=headers, preload_content=False, @@ -277,90 +269,34 @@ def _make_request( # MaxRetryDurationError is raised directly by DatabricksRetryPolicy # when duration limits are exceeded (like in test_retry_exponential_backoff) error_message = f"Request failed due to retry duration limit: {e}" - # Construct RequestError with message, context, and specific error (like Thrift backend) + # Construct RequestError with message, context, and specific error raise RequestError(error_message, None, e) except (SessionAlreadyClosedError, CursorAlreadyClosedError) as e: # These exceptions are raised by DatabricksRetryPolicy when detecting # "already closed" scenarios (404 responses with retry history) error_message = f"Request failed: {e}" - # Construct RequestError with proper 3-argument format (message, context, error) like Thrift backend + # Construct RequestError with proper 3-argument format (message, context, error) raise RequestError(error_message, None, e) except MaxRetryError as e: # urllib3 MaxRetryError should bubble up for redirect tests to catch - # Don't convert to RequestError, let the test framework handle it logger.error(f"SEA HTTP request failed with MaxRetryError: {e}") raise except Exception as e: - # Broad exception handler like Thrift backend to catch any unexpected errors - # (including test mocking issues like StopIteration) logger.error(f"SEA HTTP request failed with exception: {e}") error_message = f"Error during request to server. {e}" - # Construct RequestError with proper 3-argument format (message, context, error) like Thrift backend + # Construct RequestError with proper 3-argument format (message, context, error) raise RequestError(error_message, None, e) logger.debug(f"Response status: {response.status}") # Handle successful responses if 200 <= response.status < 300: - if response.data: - try: - result = json.loads(response.data.decode("utf-8")) - logger.debug("Successfully parsed JSON response") - return result - except (json.JSONDecodeError, UnicodeDecodeError) as e: - logger.error(f"Failed to parse JSON response: {e}") - raise RequestError(f"Invalid JSON response: {e}", e) - return {} - - # Handle error responses - error_message = f"SEA HTTP request failed with status {response.status}" + return response.json() - # Try to extract additional error details from response, but don't fail if we can't - error_message = self._try_add_error_details_to_message(response, error_message) + error_message = f"SEA HTTP request failed with status {response.status}" raise RequestError(error_message, None) - def _try_add_error_details_to_message(self, response, error_message: str) -> str: - """ - Try to extract error details from response and add to error message. - This method is defensive and will not raise exceptions if parsing fails. - It handles mock objects and malformed responses gracefully. - """ - try: - # Check if response.data exists and is accessible - if not hasattr(response, "data") or response.data is None: - return error_message - - # Try to decode the response data - try: - decoded_data = response.data.decode("utf-8") - except (AttributeError, UnicodeDecodeError, TypeError): - # response.data might be a mock object or not bytes - return error_message - - # Ensure we have a string before attempting JSON parsing - if not isinstance(decoded_data, str): - return error_message - - # Try to parse as JSON - try: - error_details = json.loads(decoded_data) - if isinstance(error_details, dict) and "message" in error_details: - enhanced_message = f"{error_message}: {error_details['message']}" - logger.error(f"Request failed: {error_details}") - return enhanced_message - except json.JSONDecodeError: - # Not valid JSON, log what we can - logger.debug( - f"Request failed with non-JSON response: {decoded_data[:200]}" - ) - - except Exception: - # Catch-all for any unexpected issues (e.g., mock objects with unexpected behavior) - logger.debug("Could not parse error response data") - - return error_message - def _get_command_type_from_path(self, path: str, method: str) -> CommandType: """ Determine the command type based on the API path and method. From fd1e6cf0d4e49a189dbc1b84bdad70227f84523a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 8 Jul 2025 15:45:55 +0530 Subject: [PATCH 14/15] stop mypy complaints Signed-off-by: varun-edachali-dbx --- src/databricks/sql/auth/retry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index ed79fbf15..835f36095 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -4,7 +4,7 @@ import typing from importlib.metadata import version from enum import Enum -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import urllib3 @@ -123,7 +123,7 @@ def __init__( _total: int = urllib3_kwargs.pop("total") _attempts_remaining = _total - _urllib_kwargs_we_care_about = dict( + _urllib_kwargs_we_care_about: dict[str, Any] = dict( total=_attempts_remaining, respect_retry_after_header=True, backoff_factor=self.delay_min, From 15378dead4bdaf5c3f7bf9de915f777399f7a64f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 8 Jul 2025 15:56:04 +0530 Subject: [PATCH 15/15] run retry tests Signed-off-by: varun-edachali-dbx --- tests/e2e/common/retry_test_mixins.py | 273 ++++++++++++++++++++++---- 1 file changed, 232 insertions(+), 41 deletions(-) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index dd509c062..17a89d18f 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -17,17 +17,32 @@ class Client429ResponseMixin: - def test_client_should_retry_automatically_when_getting_429(self): - with self.cursor() as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_client_should_retry_automatically_when_getting_429(self, extra_params): + with self.cursor(extra_params) as cursor: for _ in range(10): cursor.execute("SELECT 1") rows = cursor.fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0][0], 1) - def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self, extra_params): with pytest.raises(self.error_type) as cm: - with self.cursor(self.conf_to_disable_rate_limit_retries) as cursor: + extra_params = {**extra_params, **self.conf_to_disable_rate_limit_retries} + with self.cursor(extra_params) as cursor: for _ in range(10): cursor.execute("SELECT 1") rows = cursor.fetchall() @@ -46,14 +61,32 @@ def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self): class Client503ResponseMixin: - def test_wait_cluster_startup(self): - with self.cursor() as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_wait_cluster_startup(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.execute("SELECT 1") cursor.fetchall() - def _test_retry_disabled_with_message(self, error_msg_substring, exception_type): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def _test_retry_disabled_with_message( + self, error_msg_substring, exception_type, extra_params + ): with pytest.raises(exception_type) as cm: - with self.connection(self.conf_to_disable_temporarily_unavailable_retries): + with self.connection( + self.conf_to_disable_temporarily_unavailable_retries, extra_params + ): pass assert error_msg_substring in str(cm.exception) @@ -127,7 +160,14 @@ class PySQLRetryTestsMixin: "_retry_delay_default": 0.5, } - def test_retry_urllib3_settings_are_honored(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_urllib3_settings_are_honored(self, extra_params): """Databricks overrides some of urllib3's configuration. This tests confirms that what configuration we DON'T override is preserved in urllib3's internals """ @@ -147,19 +187,34 @@ def test_retry_urllib3_settings_are_honored(self): assert rp.read == 11 assert rp.redirect == 12 - def test_oserror_retries(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_oserror_retries(self, extra_params): """If a network error occurs during make_request, the request is retried according to policy""" with patch( "urllib3.connectionpool.HTTPSConnectionPool._validate_conn", ) as mock_validate_conn: mock_validate_conn.side_effect = OSError("Some arbitrary network error") with pytest.raises(MaxRetryError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert mock_validate_conn.call_count == 6 - def test_retry_max_count_not_exceeded(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_count_not_exceeded(self, extra_params): """GIVEN the max_attempts_count is 5 WHEN the server sends nothing but 429 responses THEN the connector issues six request (original plus five retries) @@ -167,11 +222,19 @@ def test_retry_max_count_not_exceeded(self): """ with mocked_server_response(status=404) as mock_obj: with pytest.raises(MaxRetryError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert mock_obj.return_value.getresponse.call_count == 6 - def test_retry_exponential_backoff(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_exponential_backoff(self, extra_params): """GIVEN the retry policy is configured for reasonable exponential backoff WHEN the server sends nothing but 429 responses with retry-afters THEN the connector will use those retry-afters values as floor @@ -184,7 +247,8 @@ def test_retry_exponential_backoff(self): status=429, headers={"Retry-After": "8"} ) as mock_obj: with pytest.raises(RequestError) as cm: - with self.connection(extra_params=retry_policy) as conn: + extra_params = {**extra_params, **retry_policy} + with self.connection(extra_params=extra_params) as conn: pass duration = time.time() - time_start @@ -200,18 +264,33 @@ def test_retry_exponential_backoff(self): # Should be less than 26, but this is a safe margin for CI/CD slowness assert duration < 30 - def test_retry_max_duration_not_exceeded(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_duration_not_exceeded(self, extra_params): """GIVEN the max attempt duration of 10 seconds WHEN the server sends a Retry-After header of 60 seconds THEN the connector raises a MaxRetryDurationError """ with mocked_server_response(status=429, headers={"Retry-After": "60"}): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], MaxRetryDurationError) - def test_retry_abort_non_recoverable_error(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_non_recoverable_error(self, extra_params): """GIVEN the server returns a code 501 WHEN the connector receives this response THEN nothing is retried and an exception is raised @@ -220,16 +299,25 @@ def test_retry_abort_non_recoverable_error(self): # Code 501 is a Not Implemented error with mocked_server_response(status=501): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], NonRecoverableNetworkError) - def test_retry_abort_unsafe_execute_statement_retry_condition(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_unsafe_execute_statement_retry_condition(self, extra_params): """GIVEN the server sends a code other than 429 or 503 WHEN the connector sent an ExecuteStatement command THEN nothing is retried because it's idempotent """ - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load with mocked_server_response(status=502): @@ -237,7 +325,14 @@ def test_retry_abort_unsafe_execute_statement_retry_condition(self): cursor.execute("Not a real query") assert isinstance(cm.value.args[1], UnsafeToRetryError) - def test_retry_dangerous_codes(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_dangerous_codes(self, extra_params): """GIVEN the server sends a dangerous code and the user forced this to be retryable WHEN the connector sent an ExecuteStatement command THEN the command is retried @@ -253,7 +348,8 @@ def test_retry_dangerous_codes(self): } # Prove that these codes are not retried by default - with self.connection(extra_params={**self._retry_policy}) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): @@ -263,7 +359,7 @@ def test_retry_dangerous_codes(self): # Prove that these codes are retried if forced by the user with self.connection( - extra_params={**self._retry_policy, **additional_settings} + extra_params={**extra_params, **self._retry_policy, **additional_settings} ) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: @@ -271,7 +367,14 @@ def test_retry_dangerous_codes(self): with pytest.raises(MaxRetryError) as cm: cursor.execute("Not a real query") - def test_retry_safe_execute_statement_retry_condition(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_safe_execute_statement_retry_condition(self, extra_params): """GIVEN the server sends either code 429 or 503 WHEN the connector sent an ExecuteStatement command THEN the request is retried because these are idempotent @@ -283,7 +386,11 @@ def test_retry_safe_execute_statement_retry_condition(self): ] with self.connection( - extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1} + extra_params={ + **extra_params, + **self._retry_policy, + "_retry_stop_after_attempts_count": 1, + } ) as conn: with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load @@ -292,7 +399,14 @@ def test_retry_safe_execute_statement_retry_condition(self): cursor.execute("This query never reaches the server") assert mock_obj.return_value.getresponse.call_count == 2 - def test_retry_abort_close_session_on_404(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_close_session_on_404(self, extra_params, caplog): """GIVEN the connector sends a CloseSession command WHEN server sends a 404 (which is normally retried) THEN nothing is retried because 404 means the session already closed @@ -305,12 +419,20 @@ def test_retry_abort_close_session_on_404(self, caplog): {"status": 404, "headers": {}, "redirect_location": None}, ] - with self.connection(extra_params={**self._retry_policy}) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with mock_sequential_server_responses(responses): conn.close() assert "Session was closed by a prior request" in caplog.text - def test_retry_abort_close_operation_on_404(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_close_operation_on_404(self, extra_params, caplog): """GIVEN the connector sends a CancelOperation command WHEN server sends a 404 (which is normally retried) THEN nothing is retried because 404 means the operation was already canceled @@ -323,7 +445,8 @@ def test_retry_abort_close_operation_on_404(self, caplog): {"status": 404, "headers": {}, "redirect_location": None}, ] - with self.connection(extra_params={**self._retry_policy}) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with conn.cursor() as curs: with patch( "databricks.sql.backend.types.ExecuteResponse.has_been_closed_server_side", @@ -338,7 +461,16 @@ def test_retry_abort_close_operation_on_404(self, caplog): "Operation was canceled by a prior request" in caplog.text ) - def test_retry_max_redirects_raises_too_many_redirects_exception(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_raises_too_many_redirects_exception( + self, extra_params + ): """GIVEN the connector is configured with a custom max_redirects WHEN the DatabricksRetryPolicy is created THEN the connector raises a MaxRedirectsError if that number is exceeded @@ -353,6 +485,7 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ + **extra_params, **self._retry_policy, "_retry_max_redirects": max_redirects, } @@ -362,7 +495,14 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): # Total call count should be 2 (original + 1 retry) assert mock_obj.return_value.getresponse.call_count == expected_call_count - def test_retry_max_redirects_unset_doesnt_redirect_forever(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_unset_doesnt_redirect_forever(self, extra_params): """GIVEN the connector is configured without a custom max_redirects WHEN the DatabricksRetryPolicy is used THEN the connector raises a MaxRedirectsError if that number is exceeded @@ -377,6 +517,7 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ + **extra_params, **self._retry_policy, } ): @@ -385,7 +526,16 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): # Total call count should be 6 (original + _retry_stop_after_attempts_count) assert mock_obj.return_value.getresponse.call_count == 6 - def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count( + self, extra_params + ): # If I add another 503 or 302 here the test will fail with a MaxRetryError responses = [ {"status": 302, "headers": {}, "redirect_location": "/foo.bar"}, @@ -400,7 +550,11 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): with pytest.raises(RequestError) as cm: with mock_sequential_server_responses(responses): with self.connection( - extra_params={**self._retry_policy, **additional_settings} + extra_params={ + **extra_params, + **self._retry_policy, + **additional_settings, + } ): pass @@ -408,9 +562,19 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): assert "too many redirects" not in str(cm.value.message) assert "Error during request to server" in str(cm.value.message) - def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_exceeds_max_attempts_count_warns_user( + self, extra_params, caplog + ): with self.connection( extra_params={ + **extra_params, **self._retry_policy, **{ "_retry_max_redirects": 100, @@ -420,15 +584,33 @@ def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog) ): assert "it will have no affect!" in caplog.text - def test_retry_legacy_behavior_warns_user(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_legacy_behavior_warns_user(self, extra_params, caplog): with self.connection( - extra_params={**self._retry_policy, "_enable_v3_retries": False} + extra_params={ + **extra_params, + **self._retry_policy, + "_enable_v3_retries": False, + } ): assert ( "Legacy retry behavior is enabled for this connection." in caplog.text ) - def test_403_not_retried(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_403_not_retried(self, extra_params): """GIVEN the server returns a code 403 WHEN the connector receives this response THEN nothing is retried and an exception is raised @@ -437,11 +619,19 @@ def test_403_not_retried(self): # Code 403 is a Forbidden error with mocked_server_response(status=403): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], NonRecoverableNetworkError) - def test_401_not_retried(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_401_not_retried(self, extra_params): """GIVEN the server returns a code 401 WHEN the connector receives this response THEN nothing is retried and an exception is raised @@ -450,6 +640,7 @@ def test_401_not_retried(self): # Code 401 is an Unauthorized error with mocked_server_response(status=401): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy): + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params): pass assert isinstance(cm.value.args[1], NonRecoverableNetworkError)