Skip to content
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ path = "src/databricks/labs/blueprint/__about__.py"
dependencies = [
"databricks-labs-blueprint[yaml]",
"coverage[toml]~=7.4.4",
"mypy~=1.9.0",
"pylint~=3.1.0",
"pylint-pytest==2.0.0a0",
"mypy~=1.18.0",
"pylint~=4.0.0",
# "pylint-pytest==2.0.0a0" # Incorrect dependency constraint (pylint<4), installed separately below.
"databricks-labs-pylint~=0.3.0",
"pytest~=8.1.0",
"pytest-cov~=4.1.0",
Expand All @@ -55,6 +55,10 @@ dependencies = [
"types-requests~=2.31.0",
]

post-install-commands = [
"pip install --no-deps pylint-pytest==2.0.0a0", # See above; installed here to avoid dependency conflict.
]

# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better
path = ".venv"

Expand Down Expand Up @@ -207,10 +211,6 @@ py-version = "3.10"
# source root.
# source-roots =

# When enabled, pylint would attempt to guess common misconfiguration and emit
# user-friendly hints instead of false-positive error messages.
suggestion-mode = true

Comment on lines -210 to -213
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was dropped because the option is no longer supported by pylint. (Version 4+ always behaves as if this option is on.)

# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
# unsafe-load-any-extension =
Expand Down
53 changes: 32 additions & 21 deletions src/databricks/labs/blueprint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def get_argument_type(self, argument_name: str) -> str | None:
return annotation.__name__


# The types of arguments that can be passed to commands.
_CommandArg = int | str | bool | float | WorkspaceClient | AccountClient | Prompts


class App:
def __init__(self, __file: str):
self._mapping: dict[str, Command] = {}
Expand Down Expand Up @@ -94,27 +98,9 @@ def _route(self, raw):
log_level = "info"
databricks_logger = logging.getLogger("databricks")
databricks_logger.setLevel(log_level.upper())
kwargs = {k.replace("-", "_"): v for k, v in flags.items() if v != ""}
cmd = self._mapping[command]
# modify kwargs to match the type of the argument
for kwarg in list(kwargs.keys()):
match cmd.get_argument_type(kwarg):
case "int":
kwargs[kwarg] = int(kwargs[kwarg])
case "bool":
kwargs[kwarg] = kwargs[kwarg].lower() == "true"
case "float":
kwargs[kwarg] = float(kwargs[kwarg])
kwargs = self._build_args(cmd, flags)
try:
if cmd.needs_workspace_client():
self._patch_databricks_host()
kwargs["w"] = self._workspace_client()
elif cmd.is_account:
self._patch_databricks_host()
kwargs["a"] = self._account_client()
prompts_argument = cmd.prompts_argument_name()
if prompts_argument:
kwargs[prompts_argument] = Prompts()
cmd.fn(**kwargs)
except Exception as err: # pylint: disable=broad-exception-caught
logger = self._logger.getChild(command)
Expand All @@ -123,6 +109,31 @@ def _route(self, raw):
else:
logger.error(f"{err.__class__.__name__}: {err}")

def _build_args(self, cmd: Command, flags: dict[str, str]) -> dict[str, _CommandArg]:
kwargs: dict[str, _CommandArg] = {k.replace("-", "_"): v for k, v in flags.items() if v != ""}
# modify kwargs to match the type of the argument
for kwarg in list(kwargs.keys()):
value = kwargs[kwarg]
if not isinstance(value, str):
continue
match cmd.get_argument_type(kwarg):
case "int":
kwargs[kwarg] = int(value)
case "bool":
kwargs[kwarg] = value.lower() == "true"
case "float":
kwargs[kwarg] = float(value)
if cmd.needs_workspace_client():
self._patch_databricks_host()
kwargs["w"] = self._workspace_client()
elif cmd.is_account:
self._patch_databricks_host()
kwargs["a"] = self._account_client()
prompts_argument = cmd.prompts_argument_name()
if prompts_argument:
kwargs[prompts_argument] = Prompts()
return kwargs

@classmethod
def fix_databricks_host(cls, host: str) -> str:
"""Emulate the way the Go SDK fixes the Databricks host before using it.
Expand Down Expand Up @@ -171,13 +182,13 @@ def _patch_databricks_host(self) -> None:
self._logger.warning(f"Working around DATABRICKS_HOST normalization issue: {host} -> {fixed_host}")
os.environ["DATABRICKS_HOST"] = fixed_host

def _account_client(self):
def _account_client(self) -> AccountClient:
return AccountClient(
product=self._product_info.product_name(),
product_version=self._product_info.version(),
)

def _workspace_client(self):
def _workspace_client(self) -> WorkspaceClient:
return WorkspaceClient(
product=self._product_info.product_name(),
product_version=self._product_info.version(),
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/labs/blueprint/installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def _unmarshal_union(cls, inst, path, type_ref):
"""The `_unmarshal_union` method is a private method that is used to deserialize a dictionary to an object
of type `type_ref`. This method is called by the `load` method."""
for variant in get_args(type_ref):
if variant == type(None) and inst is None:
if variant == types.NoneType and inst is None:
return None
try:
value = cls._unmarshal(inst, path, variant)
Expand Down
138 changes: 80 additions & 58 deletions src/databricks/labs/blueprint/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from collections.abc import Generator, Iterable, Sequence
from io import BytesIO, StringIO
from pathlib import Path, PurePath
from typing import BinaryIO, Literal, NoReturn, TextIO, TypeVar
from typing import BinaryIO, ClassVar, Literal, NoReturn, TextIO, TypeVar
from urllib.parse import quote_from_bytes as urlquote_from_bytes

from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -127,7 +127,7 @@ class _DatabricksPath(Path, abc.ABC): # pylint: disable=too-many-public-methods
_str: str
_hash: int

parser = _posixpath
parser: ClassVar = _posixpath

# Compatibility attribute, for when superclass implementations get invoked on python <= 3.11.
_flavour = object()
Expand Down Expand Up @@ -236,7 +236,7 @@ def rmdir(self, recursive: bool = False) -> None: ...
def unlink(self, missing_ok: bool = False) -> None: ...

@abstractmethod
def open(
def open( # pylint: disable=too-many-positional-arguments
self,
mode: str = "r",
buffering: int = -1,
Expand All @@ -246,10 +246,10 @@ def open(
): ...

@abstractmethod
def is_dir(self) -> bool: ...
def is_dir(self, *, follow_symlinks: bool = True) -> bool: ...

@abstractmethod
def is_file(self) -> bool: ...
def is_file(self, *, follow_symlinks: bool = True) -> bool: ...

@abstractmethod
def rename(self: P, target: str | bytes | os.PathLike) -> P: ...
Expand All @@ -258,7 +258,7 @@ def rename(self: P, target: str | bytes | os.PathLike) -> P: ...
def replace(self: P, target: str | bytes | os.PathLike) -> P: ...

@abstractmethod
def iterdir(self: P) -> Generator[P, None, None]: ...
def iterdir(self: P) -> Generator[P]: ...

def __reduce__(self) -> NoReturn:
# Cannot support pickling because we can't pickle the workspace client.
Expand Down Expand Up @@ -587,7 +587,10 @@ def glob(
pattern: str | bytes | os.PathLike,
*,
case_sensitive: bool | None = None,
) -> Generator[P, None, None]:
recurse_symlinks: bool = False,
) -> Generator[P]:
if recurse_symlinks:
raise NotImplementedError("recurse_symlinks is not supported for Databricks paths")
pattern_parts = self._prepare_pattern(pattern)
if case_sensitive is None:
case_sensitive = True
Expand All @@ -599,7 +602,10 @@ def rglob(
pattern: str | bytes | os.PathLike,
*,
case_sensitive: bool | None = None,
) -> Generator[P, None, None]:
recurse_symlinks: bool = False,
) -> Generator[P]:
if recurse_symlinks:
raise NotImplementedError("recurse_symlinks is not supported for Databricks paths")
pattern_parts = ("**", *self._prepare_pattern(pattern))
if case_sensitive is None:
case_sensitive = True
Expand Down Expand Up @@ -675,7 +681,7 @@ def unlink(self, missing_ok: bool = False) -> None:
raise FileNotFoundError(f"{self.as_posix()} does not exist")
self._ws.dbfs.delete(self.as_posix())

def open(
def open( # pylint: disable=too-many-positional-arguments
self,
mode: str = "r",
buffering: int = -1,
Expand Down Expand Up @@ -731,18 +737,22 @@ def stat(self, *, follow_symlinks=True) -> os.stat_result:
) # 8
return os.stat_result(seq)

def is_dir(self) -> bool:
def is_dir(self, *, follow_symlinks: bool = True) -> bool:
"""Return True if the path points to a DBFS directory."""
if not follow_symlinks:
raise NotImplementedError("follow_symlinks is not supported for DBFS paths")
try:
return bool(self._file_info.is_dir)
except DatabricksError:
return False

def is_file(self) -> bool:
"""Return True if the path points to a file in Databricks Workspace."""
def is_file(self, *, follow_symlinks: bool = True) -> bool:
"""Return True if the path points to a DBFS file."""
if not follow_symlinks:
raise NotImplementedError("follow_symlinks is not supported for DBFS paths")
return not self.is_dir()

def iterdir(self) -> Generator[DBFSPath, None, None]:
def iterdir(self) -> Generator[DBFSPath]:
for child in self._ws.dbfs.list(self.as_posix()):
yield self._from_file_info(self._ws, child)

Expand Down Expand Up @@ -821,7 +831,7 @@ def unlink(self, missing_ok: bool = False) -> None:
if not missing_ok:
raise FileNotFoundError(f"{self.as_posix()} does not exist") from e

def open(
def open( # pylint: disable=too-many-positional-arguments
self,
mode: str = "r",
buffering: int = -1,
Expand All @@ -842,8 +852,8 @@ def open(
return _TextUploadIO(self._ws, self.as_posix())
raise ValueError(f"invalid mode: {mode}")

def read_text(self, encoding=None, errors=None):
with self.open(mode="r", encoding=encoding, errors=errors) as f:
def read_text(self, encoding=None, errors=None, newline=None) -> str:
with self.open(mode="r", encoding=encoding, errors=errors, newline=newline) as f:
return f.read()

@property
Expand Down Expand Up @@ -881,15 +891,19 @@ def stat(self, *, follow_symlinks=True) -> os.stat_result:
seq[stat.ST_CTIME] = float(self._object_info.created_at) / 1000.0 if self._object_info.created_at else -1.0 # 9
return os.stat_result(seq)

def is_dir(self) -> bool:
def is_dir(self, *, follow_symlinks: bool = True) -> bool:
"""Return True if the path points to a directory in Databricks Workspace."""
if not follow_symlinks:
raise NotImplementedError("follow_symlinks is not supported for Workspace paths")
try:
return self._object_info.object_type == ObjectType.DIRECTORY
except DatabricksError:
return False

def is_file(self) -> bool:
def is_file(self, *, follow_symlinks: bool = True) -> bool:
"""Return True if the path points to a file in Databricks Workspace."""
if not follow_symlinks:
raise NotImplementedError("follow_symlinks is not supported for Workspace paths")
try:
return self._object_info.object_type == ObjectType.FILE
except DatabricksError:
Expand All @@ -902,7 +916,7 @@ def is_notebook(self) -> bool:
except DatabricksError:
return False

def iterdir(self) -> Generator[WorkspacePath, None, None]:
def iterdir(self) -> Generator[WorkspacePath]:
for child in self._ws.workspace.list(self.as_posix()):
yield self._from_object_info(self._ws, child)

Expand Down Expand Up @@ -1068,47 +1082,56 @@ def _detect_encoding_bom(
)


def _read_xml_encoding(binary_io: BinaryIO) -> tuple[bytes, str] | None:
"""Read the XML encoding from the start of a binary file, if present."""
maybe_xml: bytes = binary_io.read(4)
# Useful to know here, an XML declaration must start with '<?xml', and:
# - '<' is 0x3C.
# - '?' is 0x3F.
# References:
# - https://www.w3.org/TR/xml/#sec-guessing-no-ext-info
# - https://www.w3.org/TR/2006/REC-xml11-20060816/#sec-guessing-no-ext-info
match maybe_xml:
case b"\0\0\0\x3c":
# Potentially 32-bit BE (1234) encoding.
sniff_with = "utf-32-be"
case b"\x3c\0\0\0":
# Potentially 32-bit LE (4321) encoding.
sniff_with = "utf-32-le"
case b"\0\0\x3c\0" | b"\0\x3c\0\0":
# Potentially non-standard 32-bit encodings.
logger.warning("XML declaration with non-standard 32-bit encoding detected; not supported.")
return None
case b"\x00\x3c\x00\x3f":
# Potentially 16-bit BE (12) encoding.
sniff_with = "utf-16-be"
case b"\x3c\x00\x3f\x00":
# Potentially 16-bit LE (21) encoding.
sniff_with = "utf-16-le"
case b"\x3c\x3f\x78\x6d":
# Potentially 8-bit, UTF-8 or other ASCII-compatible encoding.
sniff_with = "us-ascii"
case b"\x4c\x6f\xa7\x94":
# Something EBCDIC-ish, oh-my.
sniff_with = "cp037"
case _:
logger.debug(f"No XML declaration detected in the first 4 bytes: {maybe_xml!r}")
return None
logger.debug(f"XML declaration detected, sniffing further with encoding: {sniff_with}")
maybe_xml += binary_io.read(_XML_ENCODING_SNIFF_LIMIT - 4)
return maybe_xml, sniff_with


def _detect_encoding_xml(binary_io: BinaryIO, *, preserve_position: bool) -> str | None:
position = binary_io.tell() if preserve_position else None
try:
maybe_xml: bytes = binary_io.read(4)
# Useful to know here, an XML declaration must start with '<?xml', and:
# - '<' is 0x3C.
# - '?' is 0x3F.
# References:
# - https://www.w3.org/TR/xml/#sec-guessing-no-ext-info
# - https://www.w3.org/TR/2006/REC-xml11-20060816/#sec-guessing-no-ext-info
match maybe_xml:
case b"\0\0\0\x3c":
# Potentially 32-bit BE (1234) encoding.
sniff_with = "utf-32-be"
case b"\x3c\0\0\0":
# Potentially 32-bit LE (4321) encoding.
sniff_with = "utf-32-le"
case b"\0\0\x3c\0" | b"\0\x3c\0\0":
# Potentially non-standard 32-bit encodings.
logger.warning("XML declaration with non-standard 32-bit encoding detected; not supported.")
return None
case b"\x00\x3c\x00\x3f":
# Potentially 16-bit BE (12) encoding.
sniff_with = "utf-16-be"
case b"\x3c\x00\x3f\x00":
# Potentially 16-bit LE (21) encoding.
sniff_with = "utf-16-le"
case b"\x3c\x3f\x78\x6d":
# Potentially 8-bit, UTF-8 or other ASCII-compatible encoding.
sniff_with = "us-ascii"
case b"\x4c\x6f\xa7\x94":
# Something EBCDIC-ish, oh-my.
sniff_with = "cp037"
case _:
logger.debug(f"No XML declaration detected in the first 4 bytes: {maybe_xml!r}")
return None
logger.debug(f"XML declaration detected, sniffing further with encoding: {sniff_with}")
maybe_xml += binary_io.read(_XML_ENCODING_SNIFF_LIMIT - 4)
read_sample = _read_xml_encoding(binary_io)
finally:
if position is not None:
binary_io.seek(position)
if read_sample is None:
return None
maybe_xml, sniff_with = read_sample

# Try to decode the XML declaration with the sniffed encoding; XML is designed so that the declaration can be
# read with the common subset of related encodings, up to where the actual encoding is specified.
Expand All @@ -1118,10 +1141,9 @@ def _detect_encoding_xml(binary_io: BinaryIO, *, preserve_position: bool) -> str
if encoding:
logger.debug(f"XML declaration encoding detected: {encoding}")
# TODO: XML encodings come from the IATA list, maybe they need to mapped/checked against Python's names.
else:
logger.debug("XML declaration without encoding detected, must be utf-8")
encoding = "utf-8"
return encoding
return encoding
logger.debug("XML declaration without encoding detected, must be utf-8")
return "utf-8"
return None


Expand Down
Loading