diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 71b012b7..eabb82b9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,30 +14,30 @@ repos: - id: check-toml - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.0 + rev: v0.12.3 hooks: - - id: ruff + - id: ruff-check types_or: [python, pyi, jupyter] args: [--fix] - id: ruff-format types_or: [python, pyi, jupyter] - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.14.1" + rev: "v1.17.0" hooks: - id: mypy additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "torch", "jsonpickle"] args: [--show-error-codes] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: "v19.1.6" + rev: "v20.1.8" hooks: - id: clang-format - repo: https://github.com/MarcoGorelli/cython-lint - rev: v0.16.6 + rev: v0.16.7 hooks: - id: cython-lint - id: double-quote-cython-strings - repo: https://github.com/scop/pre-commit-shfmt - rev: v3.10.0-2 + rev: v3.12.0-2 hooks: - id: shfmt - repo: https://github.com/shellcheck-py/shellcheck-py @@ -50,7 +50,7 @@ repos: # - id: cmake-format # - id: cmake-lint - repo: https://github.com/compilerla/conventional-pre-commit - rev: v4.0.0 + rev: v4.2.0 hooks: - id: conventional-pre-commit stages: [commit-msg] diff --git a/pyproject.toml b/pyproject.toml index 15ab33fc..cbd83e1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ line-length = 100 indent-width = 4 target-version = "py39" include = ["pyproject.toml", "python/**/*.py", "tests/python/**/*.py"] -select = [ +lint.select = [ "UP", # pyupgrade, https://docs.astral.sh/ruff/rules/#pyupgrade-up "PL", # pylint, https://docs.astral.sh/ruff/rules/#pylint-pl "I", # isort, https://docs.astral.sh/ruff/rules/#isort-i @@ -88,12 +88,12 @@ select = [ "PTH", # flake8-use-pathlib, https://docs.astral.sh/ruff/rules/#flake8-use-pathlib-pth # "D", # pydocstyle, https://docs.astral.sh/ruff/rules/#pydocstyle-d ] -ignore = [ +lint.ignore = [ "PLR2004", # pylint: magic-value-comparison "ANN401", # flake8-annotations: any-type ] -fixable = ["ALL"] -unfixable = [] +lint.fixable = ["ALL"] +lint.unfixable = [] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] diff --git a/python/mlc/_cython/base.py b/python/mlc/_cython/base.py index 362f4174..874579c6 100644 --- a/python/mlc/_cython/base.py +++ b/python/mlc/_cython/base.py @@ -183,7 +183,7 @@ class TypeMethod: kind: int # 0: member method, 1: static method def as_callable(self) -> Callable[..., typing.Any]: - from .core import func_call # type: ignore[import-not-found] + from .core import func_call # type: ignore[import-not-found] # noqa: PLC0415 func = self.func if self.kind == 0: # member method @@ -219,14 +219,16 @@ class TypeInfo: d_fields: tuple[Field, ...] = () def get_parent(self) -> TypeInfo: - from .core import type_index2cached_py_type_info # type: ignore[import-not-found] + from .core import ( # type: ignore[import-not-found] # noqa: PLC0415 + type_index2cached_py_type_info, + ) type_index = self.type_ancestors[-1] return type_index2cached_py_type_info(type_index) def translate_exception_to_c(exception: Exception) -> tuple[bytes, int, bytes]: - from .core import str_py2c # type: ignore[import-not-found] + from .core import str_py2c # type: ignore[import-not-found] # noqa: PLC0415 def _kind() -> bytes: kind: str = exception.__class__.__name__ @@ -256,7 +258,7 @@ def _bytes_info() -> bytes: def translate_exception_from_c(err: Error) -> Exception: - from .core import error_pycode_fake # type: ignore[import-not-found] + from .core import error_pycode_fake # type: ignore[import-not-found] # noqa: PLC0415 kind, info = err.kind, err._info if info: @@ -306,7 +308,7 @@ def dtype_normalize(dtype: typing.Any) -> str | DataType: def _torch_dtype_to_str() -> str | None: if "torch" not in sys.modules: return None - import torch + import torch # noqa: PLC0415 if not isinstance(dtype, torch.dtype): return None @@ -324,7 +326,7 @@ def _numpy_dtype_to_str() -> str | None: return np_dtype if (torch_dtype := _torch_dtype_to_str()) is not None: return torch_dtype - from mlc.core.dtype import DataType + from mlc.core.dtype import DataType # noqa: PLC0415 return dtype if isinstance(dtype, DataType) else str(dtype) @@ -333,7 +335,7 @@ def device_normalize(device: typing.Any) -> str | Device: def _torch_device_to_str() -> str | None: if "torch" not in sys.modules: return None - import torch + import torch # noqa: PLC0415 if not isinstance(device, torch.device): return None @@ -416,7 +418,7 @@ def attach_method( def c_class_core(type_key: str) -> Callable[[type[ClsType]], type[ClsType]]: def decorator(super_type_cls: type[ClsType]) -> type[ClsType]: - from .core import ( # type: ignore[import-not-found] + from .core import ( # type: ignore[import-not-found] # noqa: PLC0415 type_index2type_methods, type_key2py_type_info, ) diff --git a/python/mlc/config.py b/python/mlc/config.py index e99127a4..50a31f57 100644 --- a/python/mlc/config.py +++ b/python/mlc/config.py @@ -34,7 +34,7 @@ def probe_vcvarsall() -> Path: def probe_msvc() -> tuple[Path, ...]: - import setuptools # type: ignore[import-not-found,import-untyped] + import setuptools # type: ignore[import-not-found,import-untyped] # noqa: PLC0415 results = [] if (path := shutil.which("cl.exe", mode=os.X_OK)) is not None: @@ -67,7 +67,7 @@ def probe_compiler() -> tuple[Path, ...]: def display_build_info() -> None: - from mlc.core import Func + from mlc.core import Func # noqa: PLC0415 info = Func.get("mlc.core.BuildInfo")() for k in sorted(info.keys()): diff --git a/python/mlc/core/device.py b/python/mlc/core/device.py index 452f8bb5..10b55487 100644 --- a/python/mlc/core/device.py +++ b/python/mlc/core/device.py @@ -39,12 +39,12 @@ def __hash__(self) -> int: return hash((Device, *self._device_pair)) def torch(self) -> torch.device: - import torch + import torch # noqa: PLC0415 return torch.device(str(self)) @staticmethod def register(name: str) -> int: - from .func import Func + from .func import Func # noqa: PLC0415 return Func.get("mlc.base.DeviceTypeRegister")(name) diff --git a/python/mlc/core/dict.py b/python/mlc/core/dict.py index 8fdd9410..285f28e6 100644 --- a/python/mlc/core/dict.py +++ b/python/mlc/core/dict.py @@ -159,6 +159,10 @@ def __ne__(self, other: Any) -> bool: def py(self) -> dict[K, V]: return container_to_py(self) + def __hash__(self) -> int: + # TODO: hash by elements + return hash((type(self), self._mlc_address)) + class _DictKeysView(KeysView[K]): def __init__(self, mapping: Dict[K, V]) -> None: diff --git a/python/mlc/core/dtype.py b/python/mlc/core/dtype.py index 631d5d4f..966ff96f 100644 --- a/python/mlc/core/dtype.py +++ b/python/mlc/core/dtype.py @@ -62,7 +62,7 @@ def __hash__(self) -> int: return hash((DataType, *self._dtype_triple)) def torch(self) -> torch.dtype: - import torch + import torch # noqa: PLC0415 if (ret := getattr(torch, str(self), None)) is not None: if isinstance(ret, torch.dtype): @@ -74,6 +74,6 @@ def numpy(self) -> np.dtype: @staticmethod def register(name: str, bits: int) -> int: - from .func import Func + from .func import Func # noqa: PLC0415 return Func.get("mlc.base.DataTypeRegister")(name, bits) diff --git a/python/mlc/core/list.py b/python/mlc/core/list.py index 56424cbe..90a0a652 100644 --- a/python/mlc/core/list.py +++ b/python/mlc/core/list.py @@ -126,6 +126,10 @@ def __delitem__(self, i: int) -> None: def py(self) -> list[T]: return container_to_py(self) + def __hash__(self) -> int: + # TODO: hash by elements + return hash((type(self), self._mlc_address)) + def _normalize_index(i: int, length: int) -> int: if not -length <= i < length: diff --git a/python/mlc/core/opaque.py b/python/mlc/core/opaque.py index f0fab0ac..b775eb8d 100644 --- a/python/mlc/core/opaque.py +++ b/python/mlc/core/opaque.py @@ -49,13 +49,13 @@ def register( def _default_serialize(opaques: list[Any]) -> str: - import jsonpickle # type: ignore[import-untyped] + import jsonpickle # type: ignore[import-untyped] # noqa: PLC0415 return jsonpickle.dumps(list(opaques)) def _default_deserialize(json_str: str) -> list[Any]: - import jsonpickle # type: ignore[import-untyped] + import jsonpickle # type: ignore[import-untyped] # noqa: PLC0415 return jsonpickle.loads(json_str) diff --git a/python/mlc/core/tensor.py b/python/mlc/core/tensor.py index cb0c1b67..56483a77 100644 --- a/python/mlc/core/tensor.py +++ b/python/mlc/core/tensor.py @@ -77,7 +77,7 @@ def numpy(self) -> np.ndarray: return np.from_dlpack(self) def torch(self) -> torch.Tensor: - import torch + import torch # noqa: PLC0415 return torch.from_dlpack(self) diff --git a/python/mlc/core/typing.py b/python/mlc/core/typing.py index 0c0b58b7..0984195c 100644 --- a/python/mlc/core/typing.py +++ b/python/mlc/core/typing.py @@ -124,8 +124,8 @@ def from_py(ann: type) -> Type: elif (type_info := getattr(ann, "_mlc_type_info", None)) is not None: return AtomicType(type_info.type_index) elif (origin := typing.get_origin(ann)) is not None: - from mlc.core import Dict as MLCDict - from mlc.core import List as MLCList + from mlc.core import Dict as MLCDict # noqa: PLC0415 + from mlc.core import List as MLCList # noqa: PLC0415 args = typing.get_args(ann) if (origin is list) or (origin is MLCList): diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index ceda3aa6..cf9bc450 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -457,6 +457,6 @@ def replace(obj: Any, /, **changes: Any) -> Any: def stringify(obj: Any) -> str: - from mlc.core.func import Func + from mlc.core.func import Func # noqa: PLC0415 return Func.get("mlc.core.Stringify")(obj) diff --git a/python/mlc/parser/env.py b/python/mlc/parser/env.py index f273f182..994e0205 100644 --- a/python/mlc/parser/env.py +++ b/python/mlc/parser/env.py @@ -110,7 +110,7 @@ def _getfile(obj: Any) -> str: return PY_GETFILE(obj) mod = getattr(obj, "__module__", None) if mod is not None: - import sys + import sys # noqa: PLC0415 file = getattr(sys.modules[mod], "__file__", None) if file is not None: @@ -132,7 +132,7 @@ def _findsource(obj: Any) -> tuple[list[str], int]: # noqa: PLR0912 if not inspect.isclass(obj): return PY_FINDSOURCE(obj) - import linecache + import linecache # noqa: PLC0415 file = inspect.getsourcefile(obj) if file: diff --git a/python/mlc/printer/cprint.py b/python/mlc/printer/cprint.py index fe52bcfc..feb6a707 100644 --- a/python/mlc/printer/cprint.py +++ b/python/mlc/printer/cprint.py @@ -36,15 +36,15 @@ def cprint(printable: str, style: str | None = None) -> None: return # pylint: disable=import-outside-toplevel - from pygments import highlight # type: ignore[import-untyped] - from pygments.formatters import ( # type: ignore[import-untyped] + from pygments import highlight # type: ignore[import-untyped] # noqa: PLC0415 + from pygments.formatters import ( # type: ignore[import-untyped] # noqa: PLC0415 HtmlFormatter, Terminal256Formatter, ) - from pygments.lexers.python import Python3Lexer # type: ignore[import-untyped] + from pygments.lexers.python import Python3Lexer # type: ignore[import-untyped] # noqa: PLC0415 if is_in_notebook: - from IPython import display # type: ignore[import-not-found] + from IPython import display # type: ignore[import-not-found] # noqa: PLC0415 formatter = HtmlFormatter(style=pygment_style) formatter.noclasses = True # inline styles @@ -58,8 +58,8 @@ def _get_pygments_style( style: str | None, is_in_notebook: bool, ) -> pygments.style.Style | str | None: - from pygments.style import Style # type: ignore[import-untyped] - from pygments.token import ( # type: ignore[import-untyped] + from pygments.style import Style # type: ignore[import-untyped] # noqa: PLC0415 + from pygments.token import ( # type: ignore[import-untyped] # noqa: PLC0415 Comment, Keyword, Name, diff --git a/python/mlc/sym/analyzer.py b/python/mlc/sym/analyzer.py index 475552e0..1e1fe122 100644 --- a/python/mlc/sym/analyzer.py +++ b/python/mlc/sym/analyzer.py @@ -20,7 +20,7 @@ def bind( bound: Range | Expr | int | float, allow_override: bool = False, ) -> None: - from .expr import Expr, Range, const + from .expr import Expr, Range, const # noqa: PLC0415 if isinstance(bound, Range): Analyzer._C(b"_bind_range", self, v, bound, allow_override) @@ -40,7 +40,7 @@ def can_prove_less(self, a: Expr, b: int) -> bool: return Analyzer._C(b"can_prove_less", self, a, b) def can_prove_equal(self, a: Expr, b: Expr | int) -> bool: - from .expr import Expr, const + from .expr import Expr, const # noqa: PLC0415 assert isinstance(a, Expr) if isinstance(b, int): diff --git a/python/mlc/sym/expr.py b/python/mlc/sym/expr.py index 91548276..c47f4063 100644 --- a/python/mlc/sym/expr.py +++ b/python/mlc/sym/expr.py @@ -370,7 +370,7 @@ def from_const(dtype: DataType | str, min: int, extent: int) -> Range: def const(dtype: DataType | str, a: Expr | int | float) -> Expr: - from .expr import BoolImm, FloatImm, IntImm + from .expr import BoolImm, FloatImm, IntImm # noqa: PLC0415 if isinstance(dtype, str): dtype = DataType(dtype) diff --git a/python/mlc/sym/op.py b/python/mlc/sym/op.py index 84dec289..b07f83c8 100644 --- a/python/mlc/sym/op.py +++ b/python/mlc/sym/op.py @@ -9,7 +9,7 @@ def _binary_op_args(a: Expr | int | float, b: Expr | int | float) -> tuple[Expr, Expr]: - from .expr import Expr, const + from .expr import Expr, const # noqa: PLC0415 a_is_expr = isinstance(a, Expr) b_is_expr = isinstance(b, Expr) @@ -107,7 +107,7 @@ def select( def let(var: Var, value: Expr | int | float, body: Expr) -> Let: - from .expr import Let, const + from .expr import Let, const # noqa: PLC0415 if isinstance(value, (int, float)): value = const(var.dtype, value) @@ -121,7 +121,7 @@ def ramp( *, dtype: DataType | str | None = None, ) -> Ramp: - from .expr import IntImm, Ramp + from .expr import IntImm, Ramp # noqa: PLC0415 if isinstance(stride, int): stride = IntImm(stride, base.dtype) @@ -136,7 +136,7 @@ def broadcast( *, dtype: DataType | str | None = None, ) -> Broadcast: - from .expr import Broadcast + from .expr import Broadcast # noqa: PLC0415 if dtype is None: dtype = f"{value.dtype}x{lanes}" diff --git a/tests/python/test_core_func.py b/tests/python/test_core_func.py index 1583886d..7ccfc686 100644 --- a/tests/python/test_core_func.py +++ b/tests/python/test_core_func.py @@ -36,7 +36,7 @@ def test_cxx_float(x: float) -> None: @pytest.mark.parametrize("x", [0x0, 0xDEADBEEF]) def test_cxx_ptr(x: int) -> None: - import ctypes + import ctypes # noqa: PLC0415 func = mlc.Func.get("mlc.testing.cxx_ptr") y = func(ctypes.c_void_p(x))