Skip to content

feat: Add types for snowflake.connector.connect #2271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/snowflake/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

import logging
from logging import NullHandler
from typing import TYPE_CHECKING

from typing_extensions import Unpack

from snowflake.connector.externals_utils.externals_setup import setup_external_libraries

Expand Down Expand Up @@ -45,13 +48,26 @@
from .log_configuration import EasyLoggingConfigPython
from .version import VERSION

if TYPE_CHECKING:
from os import PathLike

from .connection import SnowflakeConnectionConfig

logging.getLogger(__name__).addHandler(NullHandler())
setup_external_libraries()


@wraps(SnowflakeConnection.__init__)
def Connect(**kwargs) -> SnowflakeConnection:
return SnowflakeConnection(**kwargs)
def Connect(
connection_name: str | None = None,
connections_file_path: PathLike[str] | None = None,
**kwargs: Unpack[SnowflakeConnectionConfig],
) -> SnowflakeConnection:
return SnowflakeConnection(
connection_name=connection_name,
connections_file_path=connections_file_path,
**kwargs,
)


connect = Connect
Expand Down
14 changes: 7 additions & 7 deletions src/snowflake/connector/backoff_policies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import random
from typing import Callable, Iterator
from typing import Callable, Generator

"""This module provides common implementations of backoff policies

Expand Down Expand Up @@ -38,7 +38,7 @@ def mixed_backoff(
base: int = DEFAULT_BACKOFF_BASE,
cap: int = DEFAULT_BACKOFF_CAP,
enable_jitter: bool = DEFAULT_ENABLE_JITTER,
) -> Callable[..., Iterator[int]]:
) -> Callable[[], Generator[int]]:
Copy link
Author

@max-muoto max-muoto Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding more accurate types in this module to make sure these functions will properly fulfill the annotation for backoff_policy

"""Randomly chooses between exponential and constant backoff. Uses equal jitter.

Args:
Expand All @@ -52,7 +52,7 @@ def mixed_backoff(
Callable: generator function implementing the mixed backoff policy
"""

def generator():
def generator() -> Generator[int]:
cnt = 0
sleep = base

Expand Down Expand Up @@ -80,7 +80,7 @@ def linear_backoff(
base: int = DEFAULT_BACKOFF_BASE,
cap: int = DEFAULT_BACKOFF_CAP,
enable_jitter: bool = DEFAULT_ENABLE_JITTER,
) -> Callable[..., Iterator[int]]:
) -> Callable[[], Generator[int]]:
"""Standard linear backoff. Uses full jitter.

Args:
Expand All @@ -94,7 +94,7 @@ def linear_backoff(
Callable: generator function implementing the linear backoff policy
"""

def generator():
def generator() -> Generator[int]:
sleep = base

yield sleep
Expand All @@ -113,7 +113,7 @@ def exponential_backoff(
base: int = DEFAULT_BACKOFF_BASE,
cap: int = DEFAULT_BACKOFF_CAP,
enable_jitter: bool = DEFAULT_ENABLE_JITTER,
) -> Callable[..., Iterator[int]]:
) -> Callable[[], Generator[int]]:
"""Standard exponential backoff. Uses full jitter.

Args:
Expand All @@ -127,7 +127,7 @@ def exponential_backoff(
Callable: generator function implementing the exponential backoff policy
"""

def generator():
def generator() -> Generator[int]:
sleep = base

yield sleep
Expand Down
74 changes: 68 additions & 6 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@
from logging import getLogger
from threading import Lock
from types import TracebackType
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Iterable,
Iterator,
NamedTuple,
Sequence,
TypedDict,
)
from uuid import UUID

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from typing_extensions import Unpack

from . import errors, proxy
from ._query_context_cache import QueryContextCache
Expand Down Expand Up @@ -123,6 +134,10 @@
from .util_text import construct_hostname, parse_account, split_statements
from .wif_util import AttestationProvider

if TYPE_CHECKING:
from os import PathLike


DEFAULT_CLIENT_PREFETCH_THREADS = 4
MAX_CLIENT_PREFETCH_THREADS = 10
MAX_CLIENT_FETCH_THREADS = 1024
Expand Down Expand Up @@ -378,6 +393,53 @@ class TypeAndBinding(NamedTuple):
binding: str | None


class SnowflakeConnectionConfig(TypedDict):
"""Configuration type for the SnowflakeConnection."""

insecure_mode: bool
disable_ocsp_checks: bool
ocsp_fail_open: bool
session_id: int
user: str
host: str
port: int
region: str
proxy_host: str
proxy_port: str
proxy_user: str
proxy_password: str
account: str
database: str
schema: str
warehouse: str
role: str
login_timeout: int
network_timeout: int
socket_timeout: int
backoff_policy: Callable[[], Generator[int]]
client_session_keep_alive_heartbeat_frequency: int
client_prefetch_threads: int
client_fetch_threads: int
rest: SnowflakeRestful
application: str
errorhandler: Callable
converter_class: type[SnowflakeConverter]
validate_default_parameters: bool
is_pyformat: bool
consent_cache_id_token: str
enable_stage_s3_privatelink_for_us_east_1: bool
enable_connection_diag: bool
connection_diag_log_path: PathLike[str] | str
connection_diag_whitelist_path: PathLike[str] | str
connection_diag_allowlist_path: PathLike[str] | str
json_result_force_utf8_decoding: bool
server_session_keep_alive: bool
token_file_path: PathLike[str] | str
unsafe_file_write: bool
gcs_use_virtual_endpoints: bool
check_arrow_conversion_error_on_every_column: bool


class SnowflakeConnection:
"""Implementation of the connection object for the Snowflake Database.

Expand Down Expand Up @@ -448,8 +510,8 @@ class SnowflakeConnection:
def __init__(
self,
connection_name: str | None = None,
connections_file_path: pathlib.Path | None = None,
**kwargs,
connections_file_path: PathLike[str] | None = None,
**kwargs: Unpack[SnowflakeConnectionConfig],
) -> None:
"""Create a new SnowflakeConnection.

Expand Down Expand Up @@ -651,7 +713,7 @@ def socket_timeout(self) -> int | None:
return int(self._socket_timeout) if self._socket_timeout is not None else None

@property
def _backoff_generator(self) -> Iterator:
def _backoff_generator(self) -> Generator[int]:
return self._backoff_policy()

@property
Expand Down Expand Up @@ -983,7 +1045,7 @@ def autocommit(self, mode) -> None:
except Error as e:
if e.sqlstate == SQLSTATE_FEATURE_NOT_SUPPORTED:
logger.debug(
"Autocommit feature is not enabled for this " "connection. Ignored"
"Autocommit feature is not enabled for this connection. Ignored"
)

def commit(self) -> None:
Expand Down Expand Up @@ -1166,7 +1228,7 @@ def __open_connection(self):
elif self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR:
self._session_parameters[
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL
] = (self._client_store_temporary_credential if IS_LINUX else True)
] = self._client_store_temporary_credential if IS_LINUX else True
auth.read_temporary_credentials(
self.host,
self.user,
Expand Down
Loading