diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 432ac687d..835f36095 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -4,7 +4,7 @@ import typing from importlib.metadata import version from enum import Enum -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import urllib3 @@ -123,15 +123,15 @@ def __init__( _total: int = urllib3_kwargs.pop("total") _attempts_remaining = _total - _urllib_kwargs_we_care_about = dict( + _urllib_kwargs_we_care_about: dict[str, Any] = dict( total=_attempts_remaining, respect_retry_after_header=True, backoff_factor=self.delay_min, allowed_methods=["POST"], status_forcelist=[429, 503, *self.force_dangerous_codes], ) - - urllib3_kwargs.update(**_urllib_kwargs_we_care_about) + _urllib_kwargs_we_care_about.update(**urllib3_kwargs) + urllib3_kwargs = _urllib_kwargs_we_care_about super().__init__( **urllib3_kwargs, diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 25f706a79..ed540eef1 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -49,7 +49,7 @@ def _filter_session_configuration( - session_configuration: Optional[Dict[str, str]] + session_configuration: Optional[Dict[str, Any]] ) -> Optional[Dict[str, str]]: if not session_configuration: return None @@ -59,7 +59,7 @@ def _filter_session_configuration( for key, value in session_configuration.items(): if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: - filtered_session_configuration[key.lower()] = value + filtered_session_configuration[key.lower()] = str(value) else: ignored_configs.add(key) @@ -183,7 +183,7 @@ def max_download_threads(self) -> int: def open_session( self, - session_configuration: Optional[Dict[str, str]], + session_configuration: Optional[Dict[str, Any]], catalog: Optional[str], schema: Optional[str], ) -> SessionId: diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index fe292919c..00edf4b36 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -1,11 +1,24 @@ import json import logging -import requests -from typing import Callable, Dict, Any, Optional, List, Tuple +import ssl +import urllib.parse +import urllib.request +from typing import Dict, Any, Optional, List, Tuple, Union from urllib.parse import urljoin +from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager +from urllib3.util import make_headers +from urllib3.exceptions import MaxRetryError + from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy from databricks.sql.types import SSLOptions +from databricks.sql.exc import ( + RequestError, + MaxRetryDurationError, + SessionAlreadyClosedError, + CursorAlreadyClosedError, +) logger = logging.getLogger(__name__) @@ -14,10 +27,17 @@ class SeaHttpClient: """ HTTP client for Statement Execution API (SEA). - This client handles the HTTP communication with the SEA endpoints, - including authentication, request formatting, and response parsing. + This client uses urllib3 for robust HTTP communication with retry policies + and connection pooling, similar to the Thrift HTTP client but simplified. """ + retry_policy: Union[DatabricksRetryPolicy, int] + _pool: Optional[Union[HTTPConnectionPool, HTTPSConnectionPool]] + proxy_uri: Optional[str] + realhost: Optional[str] + realport: Optional[int] + proxy_auth: Optional[Dict[str, str]] + def __init__( self, server_hostname: str, @@ -38,48 +58,156 @@ def __init__( http_headers: List of HTTP headers to include in requests auth_provider: Authentication provider ssl_options: SSL configuration options - **kwargs: Additional keyword arguments + **kwargs: Additional keyword arguments including retry policy settings """ self.server_hostname = server_hostname - self.port = port + self.port = port or 443 self.http_path = http_path self.auth_provider = auth_provider self.ssl_options = ssl_options - self.base_url = f"https://{server_hostname}:{port}" + # Build base URL + self.base_url = f"https://{server_hostname}:{self.port}" + # Parse URL for proxy handling + parsed_url = urllib.parse.urlparse(self.base_url) + self.scheme = parsed_url.scheme + self.host = parsed_url.hostname + self.port = parsed_url.port or (443 if self.scheme == "https" else 80) + + # Setup headers self.headers: Dict[str, str] = dict(http_headers) self.headers.update({"Content-Type": "application/json"}) - self.max_retries = kwargs.get("_retry_stop_after_attempts_count", 30) + # Extract retry policy settings + self._retry_delay_min = kwargs.get("_retry_delay_min", 1.0) + self._retry_delay_max = kwargs.get("_retry_delay_max", 60.0) + self._retry_stop_after_attempts_count = kwargs.get( + "_retry_stop_after_attempts_count", 30 + ) + self._retry_stop_after_attempts_duration = kwargs.get( + "_retry_stop_after_attempts_duration", 900.0 + ) + self._retry_delay_default = kwargs.get("_retry_delay_default", 5.0) + self.force_dangerous_codes = kwargs.get("_retry_dangerous_codes", []) + + # Connection pooling settings + self.max_connections = kwargs.get("max_connections", 10) + + # Setup retry policy + self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) + + if self.enable_v3_retries: + urllib3_kwargs = {"allowed_methods": ["GET", "POST", "DELETE"]} + _max_redirects = kwargs.get("_retry_max_redirects") + if _max_redirects: + if _max_redirects > self._retry_stop_after_attempts_count: + logger.warning( + "_retry_max_redirects > _retry_stop_after_attempts_count so it will have no affect!" + ) + urllib3_kwargs["redirect"] = _max_redirects + + self.retry_policy = DatabricksRetryPolicy( + delay_min=self._retry_delay_min, + delay_max=self._retry_delay_max, + stop_after_attempts_count=self._retry_stop_after_attempts_count, + stop_after_attempts_duration=self._retry_stop_after_attempts_duration, + delay_default=self._retry_delay_default, + force_dangerous_codes=self.force_dangerous_codes, + urllib3_kwargs=urllib3_kwargs, + ) + else: + # Legacy behavior - no automatic retries + self.retry_policy = 0 - # Create a session for connection pooling - self.session = requests.Session() + # Handle proxy settings + try: + proxy = urllib.request.getproxies().get(self.scheme) + except (KeyError, AttributeError): + proxy = None + else: + if self.host and urllib.request.proxy_bypass(self.host): + proxy = None + + if proxy: + parsed_proxy = urllib.parse.urlparse(proxy) + self.realhost = self.host + self.realport = self.port + self.proxy_uri = proxy + self.host = parsed_proxy.hostname + self.port = parsed_proxy.port or (443 if self.scheme == "https" else 80) + self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy) + else: + self.realhost = None + self.realport = None + self.proxy_auth = None + self.proxy_uri = None + + # Initialize connection pool + self._pool = None + self._open() + + def _basic_proxy_auth_headers(self, proxy_parsed) -> Optional[Dict[str, str]]: + """Create basic auth headers for proxy if credentials are provided.""" + if proxy_parsed is None or not proxy_parsed.username: + return None + ap = f"{urllib.parse.unquote(proxy_parsed.username)}:{urllib.parse.unquote(proxy_parsed.password)}" + return make_headers(proxy_basic_auth=ap) + + def _open(self): + """Initialize the connection pool.""" + pool_kwargs = {"maxsize": self.max_connections} + + if self.scheme == "http": + pool_class = HTTPConnectionPool + else: # https + pool_class = HTTPSConnectionPool + pool_kwargs.update( + { + "cert_reqs": ssl.CERT_REQUIRED + if self.ssl_options.tls_verify + else ssl.CERT_NONE, + "ca_certs": self.ssl_options.tls_trusted_ca_file, + "cert_file": self.ssl_options.tls_client_cert_file, + "key_file": self.ssl_options.tls_client_cert_key_file, + "key_password": self.ssl_options.tls_client_cert_key_password, + } + ) - # Configure SSL verification - if ssl_options.tls_verify: - self.session.verify = ssl_options.tls_trusted_ca_file or True + if self.using_proxy(): + proxy_manager = ProxyManager( + self.proxy_uri, + num_pools=1, + proxy_headers=self.proxy_auth, + ) + self._pool = proxy_manager.connection_from_host( + host=self.realhost, + port=self.realport, + scheme=self.scheme, + pool_kwargs=pool_kwargs, + ) else: - self.session.verify = False - - # Configure client certificates if provided - if ssl_options.tls_client_cert_file: - client_cert = ssl_options.tls_client_cert_file - client_key = ssl_options.tls_client_cert_key_file - client_key_password = ssl_options.tls_client_cert_key_password - - if client_key: - self.session.cert = (client_cert, client_key) - else: - self.session.cert = client_cert - - if client_key_password: - # Note: requests doesn't directly support key passwords - # This would require more complex handling with libraries like pyOpenSSL - logger.warning( - "Client key password provided but not supported by requests library" - ) + self._pool = pool_class(self.host, self.port, **pool_kwargs) + + def close(self): + """Close the connection pool.""" + if self._pool: + self._pool.clear() + + def using_proxy(self) -> bool: + """Check if proxy is being used (for compatibility with Thrift client).""" + return self.realhost is not None + + def set_retry_command_type(self, command_type: CommandType): + """Set the command type for retry policy decision making.""" + if isinstance(self.retry_policy, DatabricksRetryPolicy): + self.retry_policy.command_type = command_type + + def start_retry_timer(self): + """Start the retry timer for duration-based retry limits.""" + if isinstance(self.retry_policy, DatabricksRetryPolicy): + self.retry_policy.start_retry_timer() def _get_auth_headers(self) -> Dict[str, str]: """Get authentication headers from the auth provider.""" @@ -87,23 +215,11 @@ def _get_auth_headers(self) -> Dict[str, str]: self.auth_provider.add_headers(headers) return headers - def _get_call(self, method: str) -> Callable: - """Get the appropriate HTTP method function.""" - method = method.upper() - if method == "GET": - return self.session.get - if method == "POST": - return self.session.post - if method == "DELETE": - return self.session.delete - raise ValueError(f"Unsupported HTTP method: {method}") - def _make_request( self, method: str, path: str, data: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Make an HTTP request to the SEA endpoint. @@ -112,75 +228,96 @@ def _make_request( method: HTTP method (GET, POST, DELETE) path: API endpoint path data: Request payload data - params: Query parameters Returns: Dict[str, Any]: Response data parsed from JSON Raises: - RequestError: If the request fails + RequestError: If the request fails after retries """ - url = urljoin(self.base_url, path) - headers: Dict[str, str] = {**self.headers, **self._get_auth_headers()} + # Prepare headers + headers = {**self.headers, **self._get_auth_headers()} - logger.debug(f"making {method} request to {url}") + # Prepare request body + body = json.dumps(data).encode("utf-8") if data else b"" + if body: + headers["Content-Length"] = str(len(body)) - try: - call = self._get_call(method) - response = call( - url=url, - headers=headers, - json=data, - params=params, - ) + # Set command type for retry policy + command_type = self._get_command_type_from_path(path, method) + self.set_retry_command_type(command_type) + self.start_retry_timer() - # Check for HTTP errors - response.raise_for_status() + logger.debug(f"Making {method} request to {path}") - # Log response details - logger.debug(f"Response status: {response.status_code}") + # When v3 retries are enabled, urllib3 handles retries internally via DatabricksRetryPolicy + # When disabled, we let exceptions bubble up (similar to Thrift backend approach) + if self._pool is None: + raise RequestError("Connection pool not initialized", None) - # Parse JSON response - if response.content: - result = response.json() - # Log response content (but limit it for large responses) - content_str = json.dumps(result) - if len(content_str) > 1000: - logger.debug( - f"Response content (truncated): {content_str[:1000]}..." - ) - else: - logger.debug(f"Response content: {content_str}") - return result - return {} - - except requests.exceptions.RequestException as e: - # Handle request errors and extract details from response if available - error_message = f"SEA HTTP request failed: {str(e)}" - - if hasattr(e, "response") and e.response is not None: - status_code = e.response.status_code - try: - error_details = e.response.json() - error_message = ( - f"{error_message}: {error_details.get('message', '')}" - ) - logger.error( - f"Request failed (status {status_code}): {error_details}" - ) - except (ValueError, KeyError): - # If we can't parse JSON, log raw content - content = ( - e.response.content.decode("utf-8", errors="replace") - if isinstance(e.response.content, bytes) - else str(e.response.content) - ) - logger.error(f"Request failed (status {status_code}): {content}") - else: - logger.error(error_message) + try: + response = self._pool.request( + method=method.upper(), + url=path, + body=body, + headers=headers, + preload_content=False, + retries=self.retry_policy, + ) + except MaxRetryDurationError as e: + # MaxRetryDurationError is raised directly by DatabricksRetryPolicy + # when duration limits are exceeded (like in test_retry_exponential_backoff) + error_message = f"Request failed due to retry duration limit: {e}" + # Construct RequestError with message, context, and specific error + raise RequestError(error_message, None, e) + except (SessionAlreadyClosedError, CursorAlreadyClosedError) as e: + # These exceptions are raised by DatabricksRetryPolicy when detecting + # "already closed" scenarios (404 responses with retry history) + error_message = f"Request failed: {e}" + # Construct RequestError with proper 3-argument format (message, context, error) + raise RequestError(error_message, None, e) + except MaxRetryError as e: + # urllib3 MaxRetryError should bubble up for redirect tests to catch + logger.error(f"SEA HTTP request failed with MaxRetryError: {e}") + raise + except Exception as e: + logger.error(f"SEA HTTP request failed with exception: {e}") + error_message = f"Error during request to server. {e}" + # Construct RequestError with proper 3-argument format (message, context, error) + raise RequestError(error_message, None, e) + + logger.debug(f"Response status: {response.status}") + + # Handle successful responses + if 200 <= response.status < 300: + return response.json() + + error_message = f"SEA HTTP request failed with status {response.status}" + + raise RequestError(error_message, None) + + def _get_command_type_from_path(self, path: str, method: str) -> CommandType: + """ + Determine the command type based on the API path and method. - # Re-raise as a RequestError - from databricks.sql.exc import RequestError + This helps the retry policy make appropriate decisions for different + types of SEA operations. + """ + path = path.lower() + method = method.upper() - raise RequestError(error_message, e) + if "/statements" in path: + if method == "POST" and path.endswith("/statements"): + return CommandType.EXECUTE_STATEMENT + elif "/cancel" in path: + return CommandType.OTHER # Cancel operation + elif method == "DELETE": + return CommandType.CLOSE_OPERATION + elif method == "GET": + return CommandType.GET_OPERATION_STATUS + elif "/sessions" in path: + if method == "DELETE": + return CommandType.CLOSE_SESSION + + return CommandType.OTHER diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index aeeb67974..87d679946 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -48,7 +48,9 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): "extra_params", [ {}, - {"use_sea": True}, + { + "use_sea": True, + }, ], ) def test_query_with_large_wide_result_set(self, extra_params): @@ -81,7 +83,9 @@ def test_query_with_large_wide_result_set(self, extra_params): "extra_params", [ {}, - {"use_sea": True}, + { + "use_sea": True, + }, ], ) def test_query_with_large_narrow_result_set(self, extra_params): @@ -102,7 +106,9 @@ def test_query_with_large_narrow_result_set(self, extra_params): "extra_params", [ {}, - {"use_sea": True}, + { + "use_sea": True, + }, ], ) def test_long_running_query(self, extra_params): diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index dd509c062..17a89d18f 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -17,17 +17,32 @@ class Client429ResponseMixin: - def test_client_should_retry_automatically_when_getting_429(self): - with self.cursor() as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_client_should_retry_automatically_when_getting_429(self, extra_params): + with self.cursor(extra_params) as cursor: for _ in range(10): cursor.execute("SELECT 1") rows = cursor.fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0][0], 1) - def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self, extra_params): with pytest.raises(self.error_type) as cm: - with self.cursor(self.conf_to_disable_rate_limit_retries) as cursor: + extra_params = {**extra_params, **self.conf_to_disable_rate_limit_retries} + with self.cursor(extra_params) as cursor: for _ in range(10): cursor.execute("SELECT 1") rows = cursor.fetchall() @@ -46,14 +61,32 @@ def test_client_should_not_retry_429_if_RateLimitRetry_is_0(self): class Client503ResponseMixin: - def test_wait_cluster_startup(self): - with self.cursor() as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_wait_cluster_startup(self, extra_params): + with self.cursor(extra_params) as cursor: cursor.execute("SELECT 1") cursor.fetchall() - def _test_retry_disabled_with_message(self, error_msg_substring, exception_type): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def _test_retry_disabled_with_message( + self, error_msg_substring, exception_type, extra_params + ): with pytest.raises(exception_type) as cm: - with self.connection(self.conf_to_disable_temporarily_unavailable_retries): + with self.connection( + self.conf_to_disable_temporarily_unavailable_retries, extra_params + ): pass assert error_msg_substring in str(cm.exception) @@ -127,7 +160,14 @@ class PySQLRetryTestsMixin: "_retry_delay_default": 0.5, } - def test_retry_urllib3_settings_are_honored(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_urllib3_settings_are_honored(self, extra_params): """Databricks overrides some of urllib3's configuration. This tests confirms that what configuration we DON'T override is preserved in urllib3's internals """ @@ -147,19 +187,34 @@ def test_retry_urllib3_settings_are_honored(self): assert rp.read == 11 assert rp.redirect == 12 - def test_oserror_retries(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_oserror_retries(self, extra_params): """If a network error occurs during make_request, the request is retried according to policy""" with patch( "urllib3.connectionpool.HTTPSConnectionPool._validate_conn", ) as mock_validate_conn: mock_validate_conn.side_effect = OSError("Some arbitrary network error") with pytest.raises(MaxRetryError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert mock_validate_conn.call_count == 6 - def test_retry_max_count_not_exceeded(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_count_not_exceeded(self, extra_params): """GIVEN the max_attempts_count is 5 WHEN the server sends nothing but 429 responses THEN the connector issues six request (original plus five retries) @@ -167,11 +222,19 @@ def test_retry_max_count_not_exceeded(self): """ with mocked_server_response(status=404) as mock_obj: with pytest.raises(MaxRetryError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert mock_obj.return_value.getresponse.call_count == 6 - def test_retry_exponential_backoff(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_exponential_backoff(self, extra_params): """GIVEN the retry policy is configured for reasonable exponential backoff WHEN the server sends nothing but 429 responses with retry-afters THEN the connector will use those retry-afters values as floor @@ -184,7 +247,8 @@ def test_retry_exponential_backoff(self): status=429, headers={"Retry-After": "8"} ) as mock_obj: with pytest.raises(RequestError) as cm: - with self.connection(extra_params=retry_policy) as conn: + extra_params = {**extra_params, **retry_policy} + with self.connection(extra_params=extra_params) as conn: pass duration = time.time() - time_start @@ -200,18 +264,33 @@ def test_retry_exponential_backoff(self): # Should be less than 26, but this is a safe margin for CI/CD slowness assert duration < 30 - def test_retry_max_duration_not_exceeded(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_duration_not_exceeded(self, extra_params): """GIVEN the max attempt duration of 10 seconds WHEN the server sends a Retry-After header of 60 seconds THEN the connector raises a MaxRetryDurationError """ with mocked_server_response(status=429, headers={"Retry-After": "60"}): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], MaxRetryDurationError) - def test_retry_abort_non_recoverable_error(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_non_recoverable_error(self, extra_params): """GIVEN the server returns a code 501 WHEN the connector receives this response THEN nothing is retried and an exception is raised @@ -220,16 +299,25 @@ def test_retry_abort_non_recoverable_error(self): # Code 501 is a Not Implemented error with mocked_server_response(status=501): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], NonRecoverableNetworkError) - def test_retry_abort_unsafe_execute_statement_retry_condition(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_unsafe_execute_statement_retry_condition(self, extra_params): """GIVEN the server sends a code other than 429 or 503 WHEN the connector sent an ExecuteStatement command THEN nothing is retried because it's idempotent """ - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load with mocked_server_response(status=502): @@ -237,7 +325,14 @@ def test_retry_abort_unsafe_execute_statement_retry_condition(self): cursor.execute("Not a real query") assert isinstance(cm.value.args[1], UnsafeToRetryError) - def test_retry_dangerous_codes(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_dangerous_codes(self, extra_params): """GIVEN the server sends a dangerous code and the user forced this to be retryable WHEN the connector sent an ExecuteStatement command THEN the command is retried @@ -253,7 +348,8 @@ def test_retry_dangerous_codes(self): } # Prove that these codes are not retried by default - with self.connection(extra_params={**self._retry_policy}) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): @@ -263,7 +359,7 @@ def test_retry_dangerous_codes(self): # Prove that these codes are retried if forced by the user with self.connection( - extra_params={**self._retry_policy, **additional_settings} + extra_params={**extra_params, **self._retry_policy, **additional_settings} ) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: @@ -271,7 +367,14 @@ def test_retry_dangerous_codes(self): with pytest.raises(MaxRetryError) as cm: cursor.execute("Not a real query") - def test_retry_safe_execute_statement_retry_condition(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_safe_execute_statement_retry_condition(self, extra_params): """GIVEN the server sends either code 429 or 503 WHEN the connector sent an ExecuteStatement command THEN the request is retried because these are idempotent @@ -283,7 +386,11 @@ def test_retry_safe_execute_statement_retry_condition(self): ] with self.connection( - extra_params={**self._retry_policy, "_retry_stop_after_attempts_count": 1} + extra_params={ + **extra_params, + **self._retry_policy, + "_retry_stop_after_attempts_count": 1, + } ) as conn: with conn.cursor() as cursor: # Code 502 is a Bad Gateway, which we commonly see in production under heavy load @@ -292,7 +399,14 @@ def test_retry_safe_execute_statement_retry_condition(self): cursor.execute("This query never reaches the server") assert mock_obj.return_value.getresponse.call_count == 2 - def test_retry_abort_close_session_on_404(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_close_session_on_404(self, extra_params, caplog): """GIVEN the connector sends a CloseSession command WHEN server sends a 404 (which is normally retried) THEN nothing is retried because 404 means the session already closed @@ -305,12 +419,20 @@ def test_retry_abort_close_session_on_404(self, caplog): {"status": 404, "headers": {}, "redirect_location": None}, ] - with self.connection(extra_params={**self._retry_policy}) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with mock_sequential_server_responses(responses): conn.close() assert "Session was closed by a prior request" in caplog.text - def test_retry_abort_close_operation_on_404(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_abort_close_operation_on_404(self, extra_params, caplog): """GIVEN the connector sends a CancelOperation command WHEN server sends a 404 (which is normally retried) THEN nothing is retried because 404 means the operation was already canceled @@ -323,7 +445,8 @@ def test_retry_abort_close_operation_on_404(self, caplog): {"status": 404, "headers": {}, "redirect_location": None}, ] - with self.connection(extra_params={**self._retry_policy}) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: with conn.cursor() as curs: with patch( "databricks.sql.backend.types.ExecuteResponse.has_been_closed_server_side", @@ -338,7 +461,16 @@ def test_retry_abort_close_operation_on_404(self, caplog): "Operation was canceled by a prior request" in caplog.text ) - def test_retry_max_redirects_raises_too_many_redirects_exception(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_raises_too_many_redirects_exception( + self, extra_params + ): """GIVEN the connector is configured with a custom max_redirects WHEN the DatabricksRetryPolicy is created THEN the connector raises a MaxRedirectsError if that number is exceeded @@ -353,6 +485,7 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ + **extra_params, **self._retry_policy, "_retry_max_redirects": max_redirects, } @@ -362,7 +495,14 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): # Total call count should be 2 (original + 1 retry) assert mock_obj.return_value.getresponse.call_count == expected_call_count - def test_retry_max_redirects_unset_doesnt_redirect_forever(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_unset_doesnt_redirect_forever(self, extra_params): """GIVEN the connector is configured without a custom max_redirects WHEN the DatabricksRetryPolicy is used THEN the connector raises a MaxRedirectsError if that number is exceeded @@ -377,6 +517,7 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ + **extra_params, **self._retry_policy, } ): @@ -385,7 +526,16 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): # Total call count should be 6 (original + _retry_stop_after_attempts_count) assert mock_obj.return_value.getresponse.call_count == 6 - def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count( + self, extra_params + ): # If I add another 503 or 302 here the test will fail with a MaxRetryError responses = [ {"status": 302, "headers": {}, "redirect_location": "/foo.bar"}, @@ -400,7 +550,11 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): with pytest.raises(RequestError) as cm: with mock_sequential_server_responses(responses): with self.connection( - extra_params={**self._retry_policy, **additional_settings} + extra_params={ + **extra_params, + **self._retry_policy, + **additional_settings, + } ): pass @@ -408,9 +562,19 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): assert "too many redirects" not in str(cm.value.message) assert "Error during request to server" in str(cm.value.message) - def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_max_redirects_exceeds_max_attempts_count_warns_user( + self, extra_params, caplog + ): with self.connection( extra_params={ + **extra_params, **self._retry_policy, **{ "_retry_max_redirects": 100, @@ -420,15 +584,33 @@ def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog) ): assert "it will have no affect!" in caplog.text - def test_retry_legacy_behavior_warns_user(self, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_retry_legacy_behavior_warns_user(self, extra_params, caplog): with self.connection( - extra_params={**self._retry_policy, "_enable_v3_retries": False} + extra_params={ + **extra_params, + **self._retry_policy, + "_enable_v3_retries": False, + } ): assert ( "Legacy retry behavior is enabled for this connection." in caplog.text ) - def test_403_not_retried(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_403_not_retried(self, extra_params): """GIVEN the server returns a code 403 WHEN the connector receives this response THEN nothing is retried and an exception is raised @@ -437,11 +619,19 @@ def test_403_not_retried(self): # Code 403 is a Forbidden error with mocked_server_response(status=403): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy) as conn: + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params) as conn: pass assert isinstance(cm.value.args[1], NonRecoverableNetworkError) - def test_401_not_retried(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + {"use_sea": True}, + ], + ) + def test_401_not_retried(self, extra_params): """GIVEN the server returns a code 401 WHEN the connector receives this response THEN nothing is retried and an exception is raised @@ -450,6 +640,7 @@ def test_401_not_retried(self): # Code 401 is an Unauthorized error with mocked_server_response(status=401): with pytest.raises(RequestError) as cm: - with self.connection(extra_params=self._retry_policy): + extra_params = {**extra_params, **self._retry_policy} + with self.connection(extra_params=extra_params): pass assert isinstance(cm.value.args[1], NonRecoverableNetworkError)