From e20b33c6d53ecf4662b81222bdb8c63cb47654a7 Mon Sep 17 00:00:00 2001 From: Mashed Potato <38517644+potatomashed@users.noreply.github.com> Date: Wed, 16 Jul 2025 00:39:10 -0700 Subject: [PATCH] feat(dataclasses): Remove PyClass --- README.md | 15 +- cpp/registry.h | 4 +- cpp/structure.cc | 10 +- python/mlc/__init__.py | 8 +- python/mlc/core/__init__.py | 4 +- python/mlc/core/func.py | 6 +- python/mlc/core/object.py | 152 +++++++++++++----- python/mlc/dataclasses/__init__.py | 4 +- python/mlc/dataclasses/py_class.py | 85 ++++++---- python/mlc/dataclasses/utils.py | 6 +- python/mlc/printer/ast.py | 2 +- python/mlc/printer/ir_printer.py | 4 +- python/mlc/sym/expr.py | 4 +- python/mlc/testing/dataclasses.py | 6 +- python/mlc/testing/toy_ir/ir.py | 2 +- tests/python/test_core_dep_graph.py | 4 +- tests/python/test_core_dict.py | 7 +- tests/python/test_core_json.py | 8 +- tests/python/test_core_list.py | 7 +- tests/python/test_core_opaque.py | 14 +- tests/python/test_core_tensor.py | 6 +- tests/python/test_dataclasses_copy.py | 78 ++++----- tests/python/test_dataclasses_fields.py | 2 +- tests/python/test_dataclasses_py_class.py | 38 ++--- tests/python/test_dataclasses_serialize.py | 20 +-- tests/python/test_dataclasses_structure.py | 124 +++++++------- tests/python/test_parser_toy_ir_parser.py | 7 +- tests/python/test_printer_ir_printer.py | 9 +- .../test_sym_analyzer_canonical_simplify.py | 5 +- .../test_sym_analyzer_rewrite_simplify.py | 5 +- tests/python/test_sym_analyzer_simplify.py | 3 +- 31 files changed, 376 insertions(+), 273 deletions(-) diff --git a/README.md b/README.md index 53323b38..519fe3c1 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ MLC provides Pythonic dataclasses: import mlc.dataclasses as mlcd @mlcd.py_class("demo.MyClass") -class MyClass(mlcd.PyClass): +class MyClass: a: int b: str c: float | None @@ -60,10 +60,10 @@ AttributeError: 'MyClass' object has no attribute 'non_exist' and no __dict__ fo **Serialization**. MLC dataclasses are picklable and JSON-serializable. ```python ->>> MyClass.from_json(instance.json()) +>>> mlc.json_loads(mlc.json_dumps(instance)) demo.MyClass(a=12, b='test', c=None) ->>> import pickle; pickle.loads(pickle.dumps(instance)) +>>> pickle.loads(pickle.dumps(instance)) demo.MyClass(a=12, b='test', c=None) ``` @@ -114,10 +114,11 @@ By annotating IR definitions with `structure`, MLC supports structural equality
Define a toy IR with `structure`. ```python +import mlc import mlc.dataclasses as mlcd @mlcd.py_class -class Expr(mlcd.PyClass): +class Expr: def __add__(self, other): return Add(a=self, b=other) @@ -146,16 +147,16 @@ class Let(Expr): >>> L1 = Let(rhs=x + y, lhs=z, body=z) # let z = x + y; z >>> L2 = Let(rhs=y + z, lhs=x, body=x) # let x = y + z; x >>> L3 = Let(rhs=x + x, lhs=z, body=z) # let z = x + x; z ->>> L1.eq_s(L2) +>>> mlc.eq_s(L1, L2) True ->>> L1.eq_s(L3, assert_mode=True) +>>> mlc.eq_s(L1, L3, assert_mode=True) ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound ``` **Structural hashing**. The structure of MLC dataclasses can be hashed via `hash_s`, which guarantees if two dataclasses are alpha-equivalent, they will share the same structural hash: ```python ->>> L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s() +>>> L1_hash, L2_hash, L3_hash = mlc.hash_s(L1), mlc.hash_s(L2), mlc.hash_s(L3) >>> assert L1_hash == L2_hash >>> assert L1_hash != L3_hash ``` diff --git a/cpp/registry.h b/cpp/registry.h index 6779ee82..95c98d0b 100644 --- a/cpp/registry.h +++ b/cpp/registry.h @@ -18,7 +18,7 @@ namespace mlc { namespace registry { -Any JSONLoads(AnyView json_str); +Any JSONParse(AnyView json_str); Any JSONDeserialize(AnyView json_str, FuncObj *fn_opaque_deserialize); Str JSONSerialize(AnyView source, FuncObj *fn_opaque_serialize); bool StructuralEqual(AnyView lhs, AnyView rhs, bool bind_free_vars, bool assert_mode); @@ -646,7 +646,7 @@ inline TypeTable *TypeTable::New() { self->SetFunc("mlc.base.DeviceTypeRegister", Func([self](const char *name) { return self->DeviceTypeRegister(name); }).get()); self->SetFunc("mlc.core.Stringify", Func(::mlc::core::StringifyWithFields).get()); - self->SetFunc("mlc.core.JSONLoads", Func(::mlc::registry::JSONLoads).get()); + self->SetFunc("mlc.core.JSONParse", Func(::mlc::registry::JSONParse).get()); self->SetFunc("mlc.core.JSONSerialize", Func(::mlc::registry::JSONSerialize).get()); self->SetFunc("mlc.core.JSONDeserialize", Func(::mlc::registry::JSONDeserialize).get()); self->SetFunc("mlc.core.StructuralEqual", Func(::mlc::registry::StructuralEqual).get()); diff --git a/cpp/structure.cc b/cpp/structure.cc index 7d2cde70..4c89b412 100644 --- a/cpp/structure.cc +++ b/cpp/structure.cc @@ -23,7 +23,7 @@ using mlc::core::VisitStructure; /****************** JSON ******************/ -inline Any JSONLoads(const char *json_str, int64_t json_str_len) { +inline Any JSONParse(const char *json_str, int64_t json_str_len) { struct JSONParser { Any Parse() { SkipWhitespace(); @@ -1552,7 +1552,7 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len, FuncObj *fn_o int32_t json_type_index_tensor = -1; int32_t json_type_index_opaque = -1; // Step 0. Parse JSON string - UDict json_obj = JSONLoads(json_str, json_str_len); + UDict json_obj = JSONParse(json_str, json_str_len); // Step 1. type_key => constructors UList type_keys = json_obj->at("type_keys"); std::vector constructors; @@ -1700,12 +1700,12 @@ Any CopyShallow(AnyView source) { return CopyShallowImpl(source); } Any CopyDeep(AnyView source) { return CopyDeepImpl(source); } void CopyReplace(int32_t num_args, const AnyView *args, Any *ret) { CopyReplaceImpl(num_args, args, ret); } -Any JSONLoads(AnyView json_str) { +Any JSONParse(AnyView json_str) { if (json_str.type_index == kMLCRawStr) { - return ::mlc::JSONLoads(json_str.operator const char *(), -1); + return ::mlc::JSONParse(json_str.operator const char *(), -1); } else { StrObj *js = json_str.operator StrObj *(); - return ::mlc::JSONLoads(js->data(), js->size()); + return ::mlc::JSONParse(js->data(), js->size()); } } diff --git a/python/mlc/__init__.py b/python/mlc/__init__.py index 41016c5f..c4c0e2e6 100644 --- a/python/mlc/__init__.py +++ b/python/mlc/__init__.py @@ -13,11 +13,17 @@ Tensor, build_info, dep_graph, + eq_ptr, + eq_s, + eq_s_fail_reason, + hash_s, + json_dumps, json_loads, + json_parse, typing, ) from .core.dep_graph import DepGraph, DepNode -from .dataclasses import PyClass, c_class, py_class +from .dataclasses import c_class, py_class try: from ._version import __version__, __version_tuple__ # type: ignore[import-not-found] diff --git a/python/mlc/core/__init__.py b/python/mlc/core/__init__.py index bc249abd..3f9fab73 100644 --- a/python/mlc/core/__init__.py +++ b/python/mlc/core/__init__.py @@ -3,9 +3,9 @@ from .dict import Dict from .dtype import DataType from .error import Error -from .func import Func, build_info, json_loads +from .func import Func, build_info, json_parse from .list import List -from .object import Object +from .object import Object, eq_ptr, eq_s, eq_s_fail_reason, hash_s, json_dumps, json_loads from .object_path import ObjectPath from .opaque import Opaque from .tensor import Tensor diff --git a/python/mlc/core/func.py b/python/mlc/core/func.py index bc58d2b8..700ca5a6 100644 --- a/python/mlc/core/func.py +++ b/python/mlc/core/func.py @@ -51,13 +51,13 @@ def decorator(func: _CallableType) -> _CallableType: return decorator -def json_loads(s: str) -> Any: - return _json_loads(s) +def json_parse(s: str) -> Any: + return _json_parse(s) def build_info() -> dict[str, Any]: return _build_info() -_json_loads = Func.get("mlc.core.JSONLoads") +_json_parse = Func.get("mlc.core.JSONParse") _build_info = Func.get("mlc.core.BuildInfo") diff --git a/python/mlc/core/object.py b/python/mlc/core/object.py index 28176e84..90731656 100644 --- a/python/mlc/core/object.py +++ b/python/mlc/core/object.py @@ -5,6 +5,11 @@ from mlc._cython import PyAny, TypeInfo, c_class_core +try: + from warnings import deprecated # type: ignore[attr-defined] +except ImportError: + from typing_extensions import deprecated + @c_class_core("object.Object") class Object(PyAny): @@ -21,42 +26,6 @@ def id_(self) -> int: def is_(self, other: Object) -> bool: return isinstance(other, Object) and self._mlc_address == other._mlc_address - def json( - self, - fn_opaque_serialize: Callable[[list[typing.Any]], str] | None = None, - ) -> str: - return super()._mlc_json(fn_opaque_serialize) - - @staticmethod - def from_json( - json_str: str, - fn_opaque_deserialize: Callable[[str], list[typing.Any]] | None = None, - ) -> Object: - return PyAny._mlc_from_json(json_str, fn_opaque_deserialize) # type: ignore[attr-defined] - - def eq_s( - self, - other: Object, - *, - bind_free_vars: bool = True, - assert_mode: bool = False, - ) -> bool: - return PyAny._mlc_eq_s(self, other, bind_free_vars, assert_mode) # type: ignore[attr-defined] - - def eq_s_fail_reason( - self, - other: Object, - *, - bind_free_vars: bool = True, - ) -> tuple[bool, str]: - return PyAny._mlc_eq_s_fail_reason(self, other, bind_free_vars) - - def hash_s(self) -> int: - return PyAny._mlc_hash_s(self) # type: ignore[attr-defined] - - def eq_ptr(self, other: typing.Any) -> bool: - return isinstance(other, Object) and self._mlc_address == other._mlc_address - def __copy__(self: Object) -> Object: return PyAny._mlc_copy_shallow(self) # type: ignore[attr-defined] @@ -74,7 +43,7 @@ def __hash__(self) -> int: return hash((type(self), self._mlc_address)) def __eq__(self, other: typing.Any) -> bool: - return self.eq_ptr(other) + return eq_ptr(self, other) def __ne__(self, other: typing.Any) -> bool: return not self == other @@ -103,3 +72,112 @@ def swap(self, other: typing.Any) -> None: self._mlc_swap(other) else: raise TypeError(f"Cannot different types: `{type(self)}` and `{type(other)}`") + + @deprecated( + "Method `.json` is deprecated. Use `mlc.json_dumps` instead.", + stacklevel=2, + ) + def json( + self, + fn_opaque_serialize: Callable[[list[typing.Any]], str] | None = None, + ) -> str: + return json_dumps(self, fn_opaque_serialize) + + @staticmethod + @deprecated( + "Method `.from_json` is deprecated. Use `mlc.json_loads` instead.", + stacklevel=2, + ) + def from_json( + json_str: str, + fn_opaque_deserialize: Callable[[str], list[typing.Any]] | None = None, + ) -> Object: + return json_loads(json_str, fn_opaque_deserialize) + + @deprecated( + "Method `.eq_s` is deprecated. Use `mlc.eq_s` instead.", + stacklevel=2, + ) + def eq_s( + self, + other: Object, + *, + bind_free_vars: bool = True, + assert_mode: bool = False, + ) -> bool: + return eq_s(self, other, bind_free_vars=bind_free_vars, assert_mode=assert_mode) + + @deprecated( + "Method `.eq_s_fail_reason` is deprecated. Use `mlc.eq_s_fail_reason` instead.", + stacklevel=2, + ) + def eq_s_fail_reason( + self, + other: Object, + *, + bind_free_vars: bool = True, + ) -> tuple[bool, str]: + return eq_s_fail_reason(self, other, bind_free_vars=bind_free_vars) + + @deprecated( + "Method `.hash_s` is deprecated. Use `mlc.hash_s` instead.", + stacklevel=2, + ) + def hash_s(self) -> int: + return hash_s(self) + + @deprecated( + "Method `.eq_ptr` is deprecated. Use `mlc.eq_ptr` instead.", + stacklevel=2, + ) + def eq_ptr(self, other: typing.Any) -> bool: + return eq_ptr(self, other) + + +def json_dumps( + object: typing.Any, + fn_opaque_serialize: Callable[[list[typing.Any]], str] | None = None, +) -> str: + assert isinstance(object, Object), f"Expected `mlc.Object`, got `{type(object)}`" + return object._mlc_json(fn_opaque_serialize) # type: ignore[attr-defined] + + +def json_loads( + json_str: str, + fn_opaque_deserialize: Callable[[str], list[typing.Any]] | None = None, +) -> Object: + return PyAny._mlc_from_json(json_str, fn_opaque_deserialize) # type: ignore[attr-defined] + + +def eq_s( + lhs: typing.Any, + rhs: typing.Any, + *, + bind_free_vars: bool = True, + assert_mode: bool = False, +) -> bool: + assert isinstance(lhs, Object), f"Expected `mlc.Object`, got `{type(lhs)}`" + assert isinstance(rhs, Object), f"Expected `mlc.Object`, got `{type(rhs)}`" + return PyAny._mlc_eq_s(lhs, rhs, bind_free_vars, assert_mode) # type: ignore[attr-defined] + + +def eq_s_fail_reason( + lhs: typing.Any, + rhs: typing.Any, + *, + bind_free_vars: bool = True, +) -> tuple[bool, str]: + assert isinstance(lhs, Object), f"Expected `mlc.Object`, got `{type(lhs)}`" + assert isinstance(rhs, Object), f"Expected `mlc.Object`, got `{type(rhs)}`" + return PyAny._mlc_eq_s_fail_reason(lhs, rhs, bind_free_vars) + + +def hash_s(obj: typing.Any) -> int: + assert isinstance(obj, Object), f"Expected `mlc.Object`, got `{type(obj)}`" + return PyAny._mlc_hash_s(obj) # type: ignore[attr-defined] + + +def eq_ptr(lhs: typing.Any, rhs: typing.Any) -> bool: + assert isinstance(lhs, Object), f"Expected `mlc.Object`, got `{type(lhs)}`" + assert isinstance(rhs, Object), f"Expected `mlc.Object`, got `{type(rhs)}`" + return lhs._mlc_address == rhs._mlc_address diff --git a/python/mlc/dataclasses/__init__.py b/python/mlc/dataclasses/__init__.py index 0bd81e32..54e8e7ad 100644 --- a/python/mlc/dataclasses/__init__.py +++ b/python/mlc/dataclasses/__init__.py @@ -1,5 +1,7 @@ +from mlc.core.object import Object as PyClass # for backward compatibility + from .c_class import c_class -from .py_class import PyClass, py_class +from .py_class import py_class from .utils import ( Structure, add_vtable_method, diff --git a/python/mlc/dataclasses/py_class.py b/python/mlc/dataclasses/py_class.py index 4c0e5a41..1d0c3c87 100644 --- a/python/mlc/dataclasses/py_class.py +++ b/python/mlc/dataclasses/py_class.py @@ -1,5 +1,10 @@ from __future__ import annotations +try: + from typing import dataclass_transform +except ImportError: + from typing_extensions import dataclass_transform + import ctypes import functools import typing @@ -21,6 +26,7 @@ ) from mlc.core import Object +from .utils import Field as _Field from .utils import ( Structure, add_vtable_methods_for_type_cls, @@ -30,28 +36,20 @@ structure_parse, structure_to_c, ) +from .utils import field as _field -ClsType = typing.TypeVar("ClsType") - - -class PyClass(Object): - _mlc_type_info = Object._mlc_type_info +InputClsType = typing.TypeVar("InputClsType") - def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: - raise NotImplementedError - def __str__(self) -> str: - return self.__repr__() - - -def py_class( # noqa: PLR0915 +@dataclass_transform(field_specifiers=(_field, _Field)) +def py_class( type_key: str | type | None = None, *, init: bool = True, repr: bool = True, frozen: bool = False, structure: typing.Literal["bind", "nobind", "var"] | None = None, -) -> Callable[[type[ClsType]], type[ClsType]]: +) -> Callable[[type[InputClsType]], type[InputClsType]]: if isinstance(type_key, type): return py_class( type_key=None, @@ -65,18 +63,12 @@ def py_class( # noqa: PLR0915 f"but got: {structure}" ) - def decorator(super_type_cls: type[ClsType]) -> type[ClsType]: + def decorator(super_type_cls: type[InputClsType]) -> type[InputClsType]: nonlocal type_key if type_key is None: type_key = f"{super_type_cls.__module__}.{super_type_cls.__qualname__}" assert isinstance(type_key, str) - if not issubclass(super_type_cls, PyClass): - raise TypeError( - "Not a subclass of `mlc.PyClass`: " - f"`{super_type_cls.__module__}.{super_type_cls.__qualname__}`" - ) - # Step 1. Create the type according to its parent type parent_type_info: TypeInfo = get_parent_type(super_type_cls)._mlc_type_info # type: ignore[attr-defined] type_info: TypeInfo = type_create(parent_type_info.type_index, type_key) @@ -96,15 +88,11 @@ def decorator(super_type_cls: type[ClsType]) -> type[ClsType]: mlc_init = make_mlc_init(fields) # Step 3. Create the proxy class with the fields as properties - @functools.wraps(super_type_cls, updated=()) - class type_cls(super_type_cls): # type: ignore[valid-type,misc] - __slots__ = () - - def _mlc_init(self, *args: typing.Any) -> None: - mlc_init(self, *args) - - def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> type[ClsType]: - return type_create_instance(cls, type_index, num_bytes) + type_cls: type[InputClsType] = _create_cls( + cls=super_type_cls, + mlc_init=mlc_init, + mlc_new=lambda cls, *args, **kwargs: type_create_instance(cls, type_index, num_bytes), + ) type_info.type_cls = type_cls setattr(type_cls, "_mlc_type_info", type_info) @@ -183,22 +171,49 @@ class CType(ctypes.Structure): def _method_repr( type_key: str, fields: list[TypeField], -) -> Callable[[ClsType], str]: +) -> Callable[[InputClsType], str]: field_names = tuple(field.name for field in fields) - def method(self: ClsType) -> str: + def method(self: InputClsType) -> str: fields = (f"{name}={getattr(self, name)!r}" for name in field_names) return f"{type_key}({', '.join(fields)})" return method -def _method_new( - type_cls: type[ClsType], -) -> Callable[..., ClsType]: - def method(*args: typing.Any) -> ClsType: +def _method_new(type_cls: type[InputClsType]) -> Callable[..., InputClsType]: + def method(*args: typing.Any) -> InputClsType: obj = type_cls.__new__(type_cls) obj._mlc_init(*args) # type: ignore[attr-defined] return obj return method + + +def _create_cls( + cls: type, + mlc_init: Callable[..., None], + mlc_new: Callable[..., None], +) -> type[InputClsType]: + cls_name = cls.__name__ + cls_bases = cls.__bases__ + attrs = dict(cls.__dict__) + if cls_bases == (object,): + cls_bases = (Object,) + + def _add_method(fn: Callable, fn_name: str) -> None: + attrs[fn_name] = fn + fn.__module__ = cls.__module__ + fn.__name__ = fn_name + fn.__qualname__ = f"{cls_name}.{fn_name}" + + attrs["__slots__"] = () + attrs.pop("__dict__", None) + attrs.pop("__weakref__", None) + _add_method(mlc_init, "_mlc_init") + _add_method(mlc_new, "__new__") + + new_cls = type(cls_name, cls_bases, attrs) + new_cls.__module__ = cls.__module__ + new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore + return new_cls diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index 10c21fb5..ceda3aa6 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -20,6 +20,7 @@ type_index2type_methods, type_table, ) +from mlc.core import Object from mlc.core import typing as mlc_typing KIND_MAP = {None: 0, "nobind": 1, "bind": 2, "var": 3} @@ -284,10 +285,7 @@ def get_parent_type(type_cls: type) -> type: for base in type_cls.__bases__: if hasattr(base, "_mlc_type_info"): return base - raise ValueError( - f"No parent type found for `{type_cls.__module__}.{type_cls.__qualname__}`. " - f"The type must inherit from `mlc.Object`." - ) + return Object def add_vtable_method( diff --git a/python/mlc/printer/ast.py b/python/mlc/printer/ast.py index 821bad11..aaab8e3a 100644 --- a/python/mlc/printer/ast.py +++ b/python/mlc/printer/ast.py @@ -18,7 +18,7 @@ class PrinterConfig(Object): @mlcd.c_class("mlc.printer.ast.Node") -class Node(mlcd.PyClass): +class Node(Object): source_paths: list[ObjectPath] = mlcd.field(default_factory=list) def to_python(self, config: Optional[PrinterConfig] = None) -> str: diff --git a/python/mlc/printer/ir_printer.py b/python/mlc/printer/ir_printer.py index 44804b92..81927611 100644 --- a/python/mlc/printer/ir_printer.py +++ b/python/mlc/printer/ir_printer.py @@ -1,6 +1,6 @@ import contextlib from collections.abc import Generator -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, TypeVar import mlc.dataclasses as mlcd from mlc.core import Func, Object, ObjectPath @@ -102,7 +102,7 @@ def frame_push(self, frame: Any) -> None: def frame_pop(self) -> None: IRPrinter._C(b"frame_pop", self) - def __call__(self, obj: Union[Node, int, str, bool, float, None], path: ObjectPath) -> Node: + def __call__(self, obj: Any, path: ObjectPath) -> Node: return IRPrinter._C(b"__call__", self, obj, path) @contextlib.contextmanager diff --git a/python/mlc/sym/expr.py b/python/mlc/sym/expr.py index 9896608b..91548276 100644 --- a/python/mlc/sym/expr.py +++ b/python/mlc/sym/expr.py @@ -337,7 +337,7 @@ class Call(Expr): @mlcd.c_class("mlc.sym.Op") -class Op(mlcd.PyClass): +class Op(Object): name: str @staticmethod @@ -346,7 +346,7 @@ def get(name: str) -> Op: @mlcd.c_class("mlc.sym.Range", init=False) -class Range(mlcd.PyClass): +class Range(Object): min: Expr extent: Expr diff --git a/python/mlc/testing/dataclasses.py b/python/mlc/testing/dataclasses.py index 15e24fb7..d7e53284 100644 --- a/python/mlc/testing/dataclasses.py +++ b/python/mlc/testing/dataclasses.py @@ -52,7 +52,7 @@ def i64_plus_one(self) -> int: @mlc.py_class("mlc.testing.py_class") -class PyClassForTest(mlc.PyClass): +class PyClassForTest: bool_: bool i8: int # `py_class` doesn't support `int8`, it will effectively be `int64_t` i16: int # `py_class` doesn't support `int16`, it will effectively be `int64_t` @@ -104,11 +104,11 @@ def visit_fields(obj: mlc.Object) -> list[tuple[str, str, Any]]: return list(zip(types, names, values)) -def field_get(obj: mlc.Object, name: str) -> Any: +def field_get(obj: Any, name: str) -> Any: return _C_FieldGet(obj, name) -def field_set(obj: mlc.Object, name: str, value: Any) -> None: +def field_set(obj: Any, name: str, value: Any) -> None: _C_FieldSet(obj, name, value) diff --git a/python/mlc/testing/toy_ir/ir.py b/python/mlc/testing/toy_ir/ir.py index e02a1c78..1ebd364d 100644 --- a/python/mlc/testing/toy_ir/ir.py +++ b/python/mlc/testing/toy_ir/ir.py @@ -6,7 +6,7 @@ @mlcd.py_class -class Node(mlcd.PyClass): ... +class Node: ... @mlcd.py_class diff --git a/tests/python/test_core_dep_graph.py b/tests/python/test_core_dep_graph.py index 67b31ac5..4e599c0d 100644 --- a/tests/python/test_core_dep_graph.py +++ b/tests/python/test_core_dep_graph.py @@ -4,7 +4,7 @@ @mlcd.py_class(repr=False) -class Var(mlcd.PyClass): +class Var: name: str def __str__(self) -> str: @@ -12,7 +12,7 @@ def __str__(self) -> str: @mlcd.py_class(repr=False) -class Stmt(mlcd.PyClass): +class Stmt: args: list[Var] outs: list[Var] diff --git a/tests/python/test_core_dict.py b/tests/python/test_core_dict.py index 276bc1d6..a3dda549 100644 --- a/tests/python/test_core_dict.py +++ b/tests/python/test_core_dict.py @@ -1,5 +1,6 @@ from collections.abc import Callable +import mlc import pytest from mlc import Dict @@ -88,10 +89,10 @@ def test_dict_eq() -> None: assert b == a assert a == Dict(b) assert Dict(b) == a - assert not a.eq_ptr(Dict(b)) - assert not Dict(b).eq_ptr(a) + assert not mlc.eq_ptr(a, Dict(b)) + assert not mlc.eq_ptr(Dict(b), a) assert a == a # noqa: PLR0124 - assert a.eq_ptr(a) + assert mlc.eq_ptr(a, a) def test_dict_ne_0() -> None: diff --git a/tests/python/test_core_json.py b/tests/python/test_core_json.py index 35abdeaa..2a7a3aa5 100644 --- a/tests/python/test_core_json.py +++ b/tests/python/test_core_json.py @@ -3,16 +3,16 @@ import mlc -def test_json_loads_bool() -> None: +def test_json_parse_bool() -> None: src = json.dumps([True, False]) - result = mlc.json_loads(src) + result = mlc.json_parse(src) assert isinstance(result, mlc.List) and len(result) == 2 assert result[0] == True assert result[1] == False -def test_json_loads_none() -> None: +def test_json_parse_none() -> None: src = json.dumps([None]) - result = mlc.json_loads(src) + result = mlc.json_parse(src) assert isinstance(result, mlc.List) and len(result) == 1 assert result[0] is None diff --git a/tests/python/test_core_list.py b/tests/python/test_core_list.py index fdf8db4f..220c8414 100644 --- a/tests/python/test_core_list.py +++ b/tests/python/test_core_list.py @@ -1,6 +1,7 @@ from collections.abc import Callable, Sequence from typing import Any +import mlc import pytest from mlc import DataType, Device, List @@ -146,9 +147,9 @@ def test_list_eq() -> None: assert tuple(a) == a assert a == b assert b == a - assert not a.eq_ptr(b) - assert not b.eq_ptr(a) - assert a.eq_ptr(a) + assert not mlc.eq_ptr(a, b) + assert not mlc.eq_ptr(b, a) + assert mlc.eq_ptr(a, a) assert a == a # noqa: PLR0124 diff --git a/tests/python/test_core_opaque.py b/tests/python/test_core_opaque.py index 08a52acf..1e1d9799 100644 --- a/tests/python/test_core_opaque.py +++ b/tests/python/test_core_opaque.py @@ -30,7 +30,7 @@ def __hash__(self) -> int: @mlc.dataclasses.py_class(structure="bind") -class Wrapper(mlc.dataclasses.PyClass): +class Wrapper: field: Any = mlc.dataclasses.field(structure="nobind") @@ -78,30 +78,30 @@ def test_opaque_dataclass() -> None: def test_opaque_dataclass_eq_s() -> None: a1 = Wrapper(field=MyType(a=10)) a2 = Wrapper(field=MyType(a=10)) - a1.eq_s(a2, assert_mode=True) + mlc.eq_s(a1, a2, assert_mode=True) def test_opaque_dataclass_eq_s_fail() -> None: a1 = Wrapper(field=MyType(a=10)) a2 = Wrapper(field=MyType(a=20)) with pytest.raises(ValueError) as exc_info: - a1.eq_s(a2, assert_mode=True) + mlc.eq_s(a1, a2, assert_mode=True) assert str(exc_info.value).startswith("Structural equality check failed at {root}.field") def test_opaque_dataclass_hash_s() -> None: a1 = Wrapper(field=MyType(a=10)) - assert isinstance(a1.hash_s(), int) + assert isinstance(mlc.hash_s(a1), int) def test_opaque_serialize() -> None: obj_1 = Wrapper(field=MyType(a=10)) - json_str = obj_1.json() + json_str = mlc.json_dumps(obj_1) js = json.loads(json_str) assert js["opaques"] == '[{"py/object": "test_core_opaque.MyType", "a": 10}]' assert js["values"] == [[0, 0], [1, 0]] assert js["type_keys"] == ["mlc.core.Opaque", "test_core_opaque.Wrapper"] - obj_2 = Wrapper.from_json(json_str) + obj_2 = mlc.json_loads(json_str) assert isinstance(obj_2.field, MyType) assert obj_2.field.a == 10 @@ -111,7 +111,7 @@ def test_opaque_serialize_with_alias() -> None: a2 = MyType(a=20) a3 = MyType(a=30) obj_1 = Wrapper(field=[a1, a2, a3, a3, a2, a1]) - obj_2 = Wrapper.from_json(obj_1.json()) + obj_2 = mlc.json_loads(mlc.json_dumps(obj_1)) assert obj_2.field[0] is obj_2.field[5] assert obj_2.field[1] is obj_2.field[4] assert obj_2.field[2] is obj_2.field[3] diff --git a/tests/python/test_core_tensor.py b/tests/python/test_core_tensor.py index 5de0d05e..c77b30f9 100644 --- a/tests/python/test_core_tensor.py +++ b/tests/python/test_core_tensor.py @@ -120,11 +120,11 @@ def test_torch_strides() -> None: def test_tensor_serialize() -> None: a = mlc.Tensor(np.arange(24, dtype=np.int16).reshape(2, 3, 4)) - a_json = mlc.List([a, a]).json() - b = mlc.List.from_json(a_json) + a_json = mlc.json_dumps(mlc.List([a, a])) + b = mlc.json_loads(a_json) assert isinstance(b, mlc.List) assert len(b) == 2 assert isinstance(b[0], mlc.Tensor) assert isinstance(b[1], mlc.Tensor) - assert b[0].eq_ptr(b[1]) + assert mlc.eq_ptr(b[0], b[1]) assert np.array_equal(a.numpy(), b[0].numpy()) diff --git a/tests/python/test_dataclasses_copy.py b/tests/python/test_dataclasses_copy.py index 5fb23c36..d84dd6aa 100644 --- a/tests/python/test_dataclasses_copy.py +++ b/tests/python/test_dataclasses_copy.py @@ -6,7 +6,7 @@ @mlc.py_class(init=False) -class CustomInit(mlc.PyClass): +class CustomInit: a: int b: str @@ -31,10 +31,10 @@ def mlc_class_for_test() -> PyClassForTest: f32=2, f64=2.5, raw_ptr=mlc.Ptr(0xDEADBEEF), - dtype="float8", - device="cuda:0", + dtype="float8", # type: ignore[arg-type] + device="cuda:0", # type: ignore[arg-type] any="hello", - func=lambda x: x + 1, + func=lambda x: x + 1, # type: ignore[arg-type] ulist=[1, 2.0, "three", lambda: 4], udict={"1": 1, "2": 2.0, "3": "three", "4": lambda: 4}, str_="world", @@ -51,9 +51,9 @@ def mlc_class_for_test() -> PyClassForTest: opt_i64=-64, opt_f64=-2.5, opt_raw_ptr=mlc.Ptr(0xBEEFDEAD), - opt_dtype="float16", - opt_device="cuda:0", - opt_func=lambda x: x - 1, + opt_dtype="float16", # type: ignore[arg-type] + opt_device="cuda:0", # type: ignore[arg-type] + opt_func=lambda x: x - 1, # type: ignore[arg-type] opt_ulist=[1, 2.0, "three", lambda: 4], opt_udict={"1": 1, "2": 2.0, "3": "three", "4": lambda: 4}, opt_str="world", @@ -83,15 +83,15 @@ def test_copy_shallow(mlc_class_for_test: PyClassForTest) -> None: assert src.device == dst.device assert src.any == dst.any assert src.func(1) == dst.func(1) - assert src.ulist.eq_ptr(dst.ulist) # type: ignore - assert src.udict.eq_ptr(dst.udict) # type: ignore + assert mlc.eq_ptr(src.ulist, dst.ulist) # type: ignore + assert mlc.eq_ptr(src.udict, dst.udict) # type: ignore assert src.str_ == dst.str_ - assert src.list_any.eq_ptr(dst.list_any) # type: ignore - assert src.list_list_int.eq_ptr(dst.list_list_int) # type: ignore - assert src.dict_any_any.eq_ptr(dst.dict_any_any) # type: ignore - assert src.dict_str_any.eq_ptr(dst.dict_str_any) # type: ignore - assert src.dict_any_str.eq_ptr(dst.dict_any_str) # type: ignore - assert src.dict_str_list_int.eq_ptr(dst.dict_str_list_int) # type: ignore + assert mlc.eq_ptr(src.list_any, dst.list_any) # type: ignore + assert mlc.eq_ptr(src.list_list_int, dst.list_list_int) # type: ignore + assert mlc.eq_ptr(src.dict_any_any, dst.dict_any_any) # type: ignore + assert mlc.eq_ptr(src.dict_str_any, dst.dict_str_any) # type: ignore + assert mlc.eq_ptr(src.dict_any_str, dst.dict_any_str) # type: ignore + assert mlc.eq_ptr(src.dict_str_list_int, dst.dict_str_list_int) # type: ignore assert src.opt_bool == dst.opt_bool assert src.opt_i64 == dst.opt_i64 assert src.opt_f64 == dst.opt_f64 @@ -99,15 +99,15 @@ def test_copy_shallow(mlc_class_for_test: PyClassForTest) -> None: assert src.opt_dtype == dst.opt_dtype assert src.opt_device == dst.opt_device assert src.opt_func(2) == dst.opt_func(2) # type: ignore[misc] - assert src.opt_ulist.eq_ptr(dst.opt_ulist) # type: ignore - assert src.opt_udict.eq_ptr(dst.opt_udict) # type: ignore + assert mlc.eq_ptr(src.opt_ulist, dst.opt_ulist) # type: ignore + assert mlc.eq_ptr(src.opt_udict, dst.opt_udict) # type: ignore assert src.opt_str == dst.opt_str - assert src.opt_list_any.eq_ptr(dst.opt_list_any) # type: ignore - assert src.opt_list_list_int.eq_ptr(dst.opt_list_list_int) # type: ignore - assert src.opt_dict_any_any.eq_ptr(dst.opt_dict_any_any) # type: ignore - assert src.opt_dict_str_any.eq_ptr(dst.opt_dict_str_any) # type: ignore - assert src.opt_dict_any_str.eq_ptr(dst.opt_dict_any_str) # type: ignore - assert src.opt_dict_str_list_int.eq_ptr(dst.opt_dict_str_list_int) # type: ignore + assert mlc.eq_ptr(src.opt_list_any, dst.opt_list_any) # type: ignore + assert mlc.eq_ptr(src.opt_list_list_int, dst.opt_list_list_int) # type: ignore + assert mlc.eq_ptr(src.opt_dict_any_any, dst.opt_dict_any_any) # type: ignore + assert mlc.eq_ptr(src.opt_dict_str_any, dst.opt_dict_str_any) # type: ignore + assert mlc.eq_ptr(src.opt_dict_any_str, dst.opt_dict_any_str) # type: ignore + assert mlc.eq_ptr(src.opt_dict_str_list_int, dst.opt_dict_str_list_int) # type: ignore def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: @@ -127,7 +127,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: assert src.any == dst.any assert src.func(1) == dst.func(1) assert ( - not src.ulist.eq_ptr(dst.ulist) # type: ignore + not mlc.eq_ptr(src.ulist, dst.ulist) # type: ignore and len(src.ulist) == len(dst.ulist) and src.ulist[0] == dst.ulist[0] and src.ulist[1] == dst.ulist[1] @@ -135,7 +135,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.ulist[3]() == dst.ulist[3]() ) assert ( - not src.udict.eq_ptr(dst.udict) # type: ignore + not mlc.eq_ptr(src.udict, dst.udict) # type: ignore and len(src.udict) == len(dst.udict) and src.udict["1"] == dst.udict["1"] and src.udict["2"] == dst.udict["2"] @@ -144,7 +144,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: ) assert src.str_ == dst.str_ assert ( - not src.list_any.eq_ptr(dst.list_any) # type: ignore + not mlc.eq_ptr(src.list_any, dst.list_any) # type: ignore and len(src.list_any) == len(dst.list_any) and src.list_any[0] == dst.list_any[0] and src.list_any[1] == dst.list_any[1] @@ -152,13 +152,13 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.list_any[3]() == dst.list_any[3]() ) assert ( - not src.list_list_int.eq_ptr(dst.list_list_int) # type: ignore + not mlc.eq_ptr(src.list_list_int, dst.list_list_int) # type: ignore and len(src.list_list_int) == len(dst.list_list_int) and tuple(src.list_list_int[0]) == tuple(dst.list_list_int[0]) and tuple(src.list_list_int[1]) == tuple(dst.list_list_int[1]) ) assert ( - not src.dict_any_any.eq_ptr(dst.dict_any_any) # type: ignore + not mlc.eq_ptr(src.dict_any_any, dst.dict_any_any) # type: ignore and len(src.dict_any_any) == len(dst.dict_any_any) and src.dict_any_any[1] == dst.dict_any_any[1] and src.dict_any_any[2.0] == dst.dict_any_any[2.0] @@ -166,7 +166,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.dict_any_any[4]() == dst.dict_any_any[4]() ) assert ( - not src.dict_str_any.eq_ptr(dst.dict_str_any) # type: ignore + not mlc.eq_ptr(src.dict_str_any, dst.dict_str_any) # type: ignore and len(src.dict_str_any) == len(dst.dict_str_any) and src.dict_str_any["1"] == dst.dict_str_any["1"] and src.dict_str_any["2.0"] == dst.dict_str_any["2.0"] @@ -174,7 +174,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.dict_str_any["4"]() == dst.dict_str_any["4"]() ) assert ( - not src.dict_any_str.eq_ptr(dst.dict_any_str) # type: ignore + not mlc.eq_ptr(src.dict_any_str, dst.dict_any_str) # type: ignore and len(src.dict_any_str) == len(dst.dict_any_str) and src.dict_any_str[1] == dst.dict_any_str[1] and src.dict_any_str[2.0] == dst.dict_any_str[2.0] @@ -182,7 +182,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.dict_any_str[4] == dst.dict_any_str[4] ) assert ( - not src.dict_str_list_int.eq_ptr(dst.dict_str_list_int) # type: ignore + not mlc.eq_ptr(src.dict_str_list_int, dst.dict_str_list_int) # type: ignore and len(src.dict_str_list_int) == len(dst.dict_str_list_int) and tuple(src.dict_str_list_int["1"]) == tuple(dst.dict_str_list_int["1"]) and tuple(src.dict_str_list_int["2"]) == tuple(dst.dict_str_list_int["2"]) @@ -194,7 +194,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: assert src.opt_device == dst.opt_device assert src.opt_func(2) == dst.opt_func(2) # type: ignore[misc] assert ( - not src.opt_ulist.eq_ptr(dst.opt_ulist) # type: ignore + not mlc.eq_ptr(src.opt_ulist, dst.opt_ulist) # type: ignore and len(src.opt_ulist) == len(dst.opt_ulist) # type: ignore[arg-type] and src.opt_ulist[0] == dst.opt_ulist[0] # type: ignore[index] and src.opt_ulist[1] == dst.opt_ulist[1] # type: ignore[index] @@ -202,7 +202,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.opt_ulist[3]() == dst.opt_ulist[3]() # type: ignore[index] ) assert ( - not src.opt_udict.eq_ptr(dst.opt_udict) # type: ignore + not mlc.eq_ptr(src.opt_udict, dst.opt_udict) # type: ignore and len(src.opt_udict) == len(dst.opt_udict) # type: ignore[arg-type] and src.opt_udict["1"] == dst.opt_udict["1"] # type: ignore[index] and src.opt_udict["2"] == dst.opt_udict["2"] # type: ignore[index] @@ -211,7 +211,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: ) assert src.opt_str == dst.opt_str assert ( - not src.opt_list_any.eq_ptr(dst.opt_list_any) # type: ignore + not mlc.eq_ptr(src.opt_list_any, dst.opt_list_any) # type: ignore and len(src.opt_list_any) == len(dst.opt_list_any) # type: ignore[arg-type] and src.opt_list_any[0] == dst.opt_list_any[0] # type: ignore[index] and src.opt_list_any[1] == dst.opt_list_any[1] # type: ignore[index] @@ -219,13 +219,13 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.opt_list_any[3]() == dst.opt_list_any[3]() # type: ignore[index] ) assert ( - not src.opt_list_list_int.eq_ptr(dst.opt_list_list_int) # type: ignore + not mlc.eq_ptr(src.opt_list_list_int, dst.opt_list_list_int) # type: ignore and len(src.opt_list_list_int) == len(dst.opt_list_list_int) # type: ignore[arg-type] and tuple(src.opt_list_list_int[0]) == tuple(dst.opt_list_list_int[0]) # type: ignore[index] and tuple(src.opt_list_list_int[1]) == tuple(dst.opt_list_list_int[1]) # type: ignore[index] ) assert ( - not src.opt_dict_any_any.eq_ptr(dst.opt_dict_any_any) # type: ignore + not mlc.eq_ptr(src.opt_dict_any_any, dst.opt_dict_any_any) # type: ignore and len(src.opt_dict_any_any) == len(dst.opt_dict_any_any) # type: ignore[arg-type] and src.opt_dict_any_any[1] == dst.opt_dict_any_any[1] # type: ignore[index] and src.opt_dict_any_any[2.0] == dst.opt_dict_any_any[2.0] # type: ignore[index] @@ -233,7 +233,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.opt_dict_any_any[4]() == dst.opt_dict_any_any[4]() # type: ignore[index] ) assert ( - not src.opt_dict_str_any.eq_ptr(dst.opt_dict_str_any) # type: ignore + not mlc.eq_ptr(src.opt_dict_str_any, dst.opt_dict_str_any) # type: ignore and len(src.opt_dict_str_any) == len(dst.opt_dict_str_any) # type: ignore[arg-type] and src.opt_dict_str_any["1"] == dst.opt_dict_str_any["1"] # type: ignore[index] and src.opt_dict_str_any["2.0"] == dst.opt_dict_str_any["2.0"] # type: ignore[index] @@ -241,7 +241,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.opt_dict_str_any["4"]() == dst.opt_dict_str_any["4"]() # type: ignore[index] ) assert ( - not src.opt_dict_any_str.eq_ptr(dst.opt_dict_any_str) # type: ignore + not mlc.eq_ptr(src.opt_dict_any_str, dst.opt_dict_any_str) # type: ignore and len(src.opt_dict_any_str) == len(dst.opt_dict_any_str) # type: ignore[arg-type] and src.opt_dict_any_str[1] == dst.opt_dict_any_str[1] # type: ignore[index] and src.opt_dict_any_str[2.0] == dst.opt_dict_any_str[2.0] # type: ignore[index] @@ -249,7 +249,7 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None: and src.opt_dict_any_str[4] == dst.opt_dict_any_str[4] # type: ignore[index] ) assert ( - not src.opt_dict_str_list_int.eq_ptr(dst.opt_dict_str_list_int) # type: ignore + not mlc.eq_ptr(src.opt_dict_str_list_int, dst.opt_dict_str_list_int) # type: ignore and len(src.opt_dict_str_list_int) == len(dst.opt_dict_str_list_int) # type: ignore[arg-type] and tuple(src.opt_dict_str_list_int["1"]) == tuple(dst.opt_dict_str_list_int["1"]) # type: ignore[index] and tuple(src.opt_dict_str_list_int["2"]) == tuple(dst.opt_dict_str_list_int["2"]) # type: ignore[index] diff --git a/tests/python/test_dataclasses_fields.py b/tests/python/test_dataclasses_fields.py index 17b816b8..0ac0ec0a 100644 --- a/tests/python/test_dataclasses_fields.py +++ b/tests/python/test_dataclasses_fields.py @@ -646,7 +646,7 @@ def test_mlc_class_mem_fn(mlc_class_for_test: MLCClassForTest) -> None: def test_stringify(mlc_class_for_test: MLCClassForTest) -> None: obj = mlc_class_for_test - type_key = type(mlc_class_for_test)._mlc_type_info.type_key + type_key = type(mlc_class_for_test)._mlc_type_info.type_key # type: ignore[union-attr] expected = ( type_key + """@0x(bool_=False, i8=8, i16=16, i32=32, i64=64, f32=1.500000, f64=2.500000, raw_ptr=0x0000deadbeef, dtype=float8, device=cuda:0, any="hello", func=object.Func@0x, ulist=[1, 2.000000, "three", object.Func@0x], udict={"2": 2.000000, "4": object.Func@0x, "1": 1, "3": "three"}, str_="world", str_readonly="world", list_any=[1, 2.000000, "three", object.Func@0x], list_list_int=[[1, 2, 3], [4, 5, 6]], dict_any_any={2.000000: 2, 4: object.Func@0x, 1: 1.000000, "three": "four"}, dict_str_any={"4": object.Func@0x, "1": 1.000000, "2.0": 2, "three": "four"}, dict_any_str={2.000000: "2", 4: "5", 1: "1.0", "three": "four"}, dict_str_list_int={"2": [4, 5, 6], "1": [1, 2, 3]}, opt_bool=True, opt_i64=-64, opt_f64=None, opt_raw_ptr=None, opt_dtype=None, opt_device=cuda:0, opt_func=None, opt_ulist=None, opt_udict=None, opt_str=None, opt_list_any=[1, 2.000000, "three", object.Func@0x], opt_list_list_int=[[1, 2, 3], [4, 5, 6]], opt_dict_any_any=None, opt_dict_str_any={"4": object.Func@0x, "1": 1.000000, "2.0": 2, "three": "four"}, opt_dict_any_str={2.000000: "2", 4: "5", 1: "1.0", "three": "four"}, opt_dict_str_list_int={"2": [4, 5, 6], "1": [1, 2, 3]})""" diff --git a/tests/python/test_dataclasses_py_class.py b/tests/python/test_dataclasses_py_class.py index 5b4aa5b6..6027bc2e 100644 --- a/tests/python/test_dataclasses_py_class.py +++ b/tests/python/test_dataclasses_py_class.py @@ -6,7 +6,7 @@ @mlcd.py_class("mlc.testing.py_class_base") -class Base(mlcd.PyClass): +class Base: base_a: int base_b: str @@ -27,7 +27,7 @@ class Derived(Base): @mlcd.py_class("mlc.testing.py_class_base_with_default") -class BaseWithDefault(mlcd.PyClass): +class BaseWithDefault: base_a: int base_b: list[int] = mlcd.field(default_factory=list) @@ -40,17 +40,17 @@ class DerivedWithDefault(BaseWithDefault): @mlcd.py_class("mlc.testing.DerivedDerived") class DerivedDerived(DerivedWithDefault): - derived_derived_a: str + derived_derived_a: str # type: ignore[misc] @mlcd.py_class("mlc.testing.py_class_derived_with_default_interleaved") class DerivedWithDefaultInterleaved(BaseWithDefault): - derived_a: int + derived_a: int # type: ignore[misc] derived_b: Optional[str] = "1234" @mlcd.py_class("mlc.testing.py_class_post_init") -class PostInit(mlcd.PyClass): +class PostInit: a: int b: str @@ -59,19 +59,19 @@ def __post_init__(self) -> None: @mlcd.py_class("mlc.testing.py_class_frozen", frozen=True) -class Frozen(mlcd.PyClass): +class Frozen: a: int b: str @mlcd.py_class -class ContainerFields(mlcd.PyClass): +class ContainerFields: a: list[int] b: dict[int, int] @mlcd.py_class(frozen=True) -class FrozenContainerFields(mlcd.PyClass): +class FrozenContainerFields: a: list[int] b: dict[int, int] @@ -86,7 +86,7 @@ def test_base() -> None: def test_derived() -> None: - derived = Derived(1.0, "b", 2, "c") + derived = Derived(1, "b", 2, "c") derived_str = "mlc.testing.py_class_derived(base_a=1, base_b='b', derived_a=2.0, derived_b='c')" assert derived.base_a == 1 assert derived.base_b == "b" @@ -100,7 +100,7 @@ def test_repr_in_list() -> None: target = mlc.List[Base]( [ Base(1, "a"), - Derived(1.0, "b", 2, "c"), + Derived(1, "b", 2, "c"), ], ) target_str_0 = "mlc.testing.py_class_base(base_a=1, base_b='a')" @@ -124,7 +124,7 @@ def test_default_in_derived() -> None: def test_default_in_derived_interleaved() -> None: - derived = DerivedWithDefaultInterleaved(12, 34) + derived = DerivedWithDefaultInterleaved(12, 34) # type: ignore[call-arg,arg-type] assert derived.base_a == 12 assert isinstance(derived.base_b, mlc.List) and len(derived.base_b) == 0 assert derived.derived_a == 34 @@ -142,7 +142,7 @@ def test_post_init() -> None: def test_frozen_set_fail() -> None: frozen = Frozen(1, "a") with pytest.raises(AttributeError) as e: - frozen.a = 2 + frozen.a = 2 # type: ignore[misc] # depends on Python version, there are a few possible error messages assert str(e.value) in [ "property 'a' of 'Frozen' object has no setter", @@ -154,18 +154,18 @@ def test_frozen_set_fail() -> None: def test_frozen_force_set() -> None: frozen = Frozen(1, "a") - frozen._mlc_setattr("a", 2) + frozen._mlc_setattr("a", 2) # type: ignore[attr-defined] assert frozen.a == 2 assert frozen.b == "a" - frozen._mlc_setattr("b", "b") + frozen._mlc_setattr("b", "b") # type: ignore[attr-defined] assert frozen.a == 2 assert frozen.b == "b" def test_derived_derived() -> None: # __init__(base_a, derived_derived_a, base_b, derived_a, derived_b) - obj = DerivedDerived(1, "a", [1, 2], 2, "b") + obj = DerivedDerived(1, "a", [1, 2], 2, "b") # type: ignore[arg-type] assert obj.base_a == 1 assert obj.derived_derived_a == "a" assert isinstance(obj.base_b, mlc.List) and len(obj.base_b) == 2 @@ -193,19 +193,19 @@ def test_frozen_container_fields() -> None: assert obj.a == [1, 2] assert obj.b == {1: 2} - assert obj.a.frozen - assert obj.b.frozen + assert obj.a.frozen # type: ignore[attr-defined] + assert obj.b.frozen # type: ignore[attr-defined] e: pytest.ExceptionInfo with pytest.raises(AttributeError) as e: - obj.a = [2, 3] + obj.a = [2, 3] # type: ignore[misc] assert str(e.value) in [ "property 'a' of 'FrozenContainerFields' object has no setter", "can't set attribute", ] with pytest.raises(AttributeError) as e: - obj.b = {2: 3} + obj.b = {2: 3} # type: ignore[misc] assert str(e.value) in [ "property 'b' of 'FrozenContainerFields' object has no setter", "can't set attribute", diff --git a/tests/python/test_dataclasses_serialize.py b/tests/python/test_dataclasses_serialize.py index 1e491575..cca4170b 100644 --- a/tests/python/test_dataclasses_serialize.py +++ b/tests/python/test_dataclasses_serialize.py @@ -6,7 +6,7 @@ @mlc.dataclasses.py_class("mlc.testing.serialize", init=False) -class ObjTest(mlc.PyClass): +class ObjTest: a: int b: float c: str @@ -20,7 +20,7 @@ def __init__(self, b: float, c: str, a: int, d: bool) -> None: @mlc.dataclasses.py_class("mlc.testing.serialize_opt") -class ObjTestOpt(mlc.PyClass): +class ObjTestOpt: a: Optional[int] b: Optional[float] c: Optional[str] @@ -28,17 +28,17 @@ class ObjTestOpt(mlc.PyClass): @mlc.dataclasses.py_class("mlc.testing.AnyContainer") -class AnyContainer(mlc.PyClass): +class AnyContainer: field: Any def test_json() -> None: obj = ObjTest(a=1, b=2.0, c="3", d=True) - obj_json = obj.json() + obj_json = mlc.json_dumps(obj) obj_json_dict = json.loads(obj_json) assert obj_json_dict["type_keys"] == ["mlc.testing.serialize", "int"] assert obj_json_dict["values"] == ["3", [0, [1, 1], 2.0, 0, True]] - obj_from_json: ObjTest = ObjTest.from_json(obj_json) + obj_from_json: ObjTest = mlc.json_loads(obj_json) assert obj.a == obj_from_json.a assert obj.b == obj_from_json.b assert obj.c == obj_from_json.c @@ -57,7 +57,7 @@ def test_pickle() -> None: def test_json_opt_0() -> None: obj = ObjTestOpt(1, 2.0, "3", True) - obj_json = obj.json() + obj_json = mlc.json_dumps(obj) obj_json_dict = json.loads(obj_json) assert obj_json_dict["type_keys"] == ["mlc.testing.serialize_opt", "int"] assert obj_json_dict["values"] == [ @@ -78,7 +78,7 @@ def test_json_opt_0() -> None: True, ], ] - obj_from_json: ObjTestOpt = ObjTestOpt.from_json(obj_json) + obj_from_json: ObjTestOpt = mlc.json_loads(obj_json) assert obj.a == obj_from_json.a assert obj.b == obj_from_json.b assert obj.c == obj_from_json.c @@ -87,11 +87,11 @@ def test_json_opt_0() -> None: def test_json_opt_1() -> None: obj = ObjTestOpt(None, None, None, None) - obj_json = obj.json() + obj_json = mlc.json_dumps(obj) obj_json_dict = json.loads(obj_json) assert obj_json_dict["type_keys"] == ["mlc.testing.serialize_opt"] assert obj_json_dict["values"] == [[0, None, None, None, None]] - obj_from_json: ObjTestOpt = ObjTestOpt.from_json(obj_json) + obj_from_json: ObjTestOpt = mlc.json_loads(obj_json) assert obj.a == obj_from_json.a assert obj.b == obj_from_json.b assert obj.c == obj_from_json.c @@ -103,7 +103,7 @@ def test_json_dag() -> None: dct = mlc.Dict({"a": 1, "b": 2.0, "c": "3", "d": True, "v": lst}) big_lst = mlc.List([lst, dct, lst, dct]) obj_1 = AnyContainer([big_lst, big_lst]) - obj_2: AnyContainer = AnyContainer.from_json(obj_1.json()) + obj_2: AnyContainer = mlc.json_loads(mlc.json_dumps(obj_1)) assert obj_2.field[0].is_(obj_2.field[1]) assert obj_2.field[0] == big_lst big_lst = obj_2.field[0] diff --git a/tests/python/test_dataclasses_structure.py b/tests/python/test_dataclasses_structure.py index f0323f4a..dd88a4b0 100644 --- a/tests/python/test_dataclasses_structure.py +++ b/tests/python/test_dataclasses_structure.py @@ -6,7 +6,7 @@ @mlcd.py_class -class Expr(mlcd.PyClass): +class Expr: def __add__(self, other: Expr) -> Expr: return Add(a=self, b=other) @@ -37,7 +37,7 @@ class Let(Expr): @mlcd.py_class(structure="bind") -class Func(mlcd.PyClass): +class Func: name: str = mlcd.field(structure=None) args: list[Var] = mlcd.field(structure="bind") body: Expr @@ -49,17 +49,17 @@ class Constant(Expr): @mlcd.py_class(structure="bind") -class TensorType(mlcd.PyClass): +class TensorType: shape: list[int] dtype: mlc.DataType @mlcd.py_class(structure="bind") -class Stmt(mlcd.PyClass): ... +class Stmt: ... @mlcd.py_class(structure="bind") -class FuncStmts(mlcd.PyClass): +class FuncStmts: name: str = mlcd.field(structure=None) args: list[Var] = mlcd.field(structure="bind") stmts: list[Stmt] @@ -75,24 +75,24 @@ def test_free_var_1() -> None: x = Var("x") lhs = x rhs = x - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) - assert lhs.hash_s() == rhs.hash_s() + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) + assert mlc.hash_s(lhs) == mlc.hash_s(rhs) with pytest.raises(ValueError) as e: - lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) + mlc.eq_s(lhs, rhs, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}: Unbound variable" - assert str(e.value) == lhs.eq_s_fail_reason(rhs, bind_free_vars=False) + assert str(e.value) == mlc.eq_s_fail_reason(lhs, rhs, bind_free_vars=False) def test_free_var_2() -> None: x = Var("x") lhs = Add(x, x) rhs = Add(x, x) - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) - assert lhs.hash_s() == rhs.hash_s() + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) + assert mlc.hash_s(lhs) == mlc.hash_s(rhs) with pytest.raises(ValueError) as e: - lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) + mlc.eq_s(lhs, rhs, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.a: Unbound variable" - assert str(e.value) == lhs.eq_s_fail_reason(rhs, bind_free_vars=False) + assert str(e.value) == mlc.eq_s_fail_reason(lhs, rhs, bind_free_vars=False) def test_cyclic() -> None: @@ -101,46 +101,46 @@ def test_cyclic() -> None: z = Var("z") lhs = x + y + z rhs = y + z + x - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) - assert lhs.hash_s() == rhs.hash_s() + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) + assert mlc.hash_s(lhs) == mlc.hash_s(rhs) def test_tensor_type() -> None: - t1 = TensorType(shape=(1, 2, 3), dtype="float32") - t2 = TensorType(shape=(1, 2, 3), dtype="float32") - t3 = TensorType(shape=(1, 2, 3), dtype="int32") - t4 = TensorType(shape=(1, 2), dtype="float32") + t1 = TensorType(shape=(1, 2, 3), dtype="float32") # type: ignore[arg-type] + t2 = TensorType(shape=(1, 2, 3), dtype="float32") # type: ignore[arg-type] + t3 = TensorType(shape=(1, 2, 3), dtype="int32") # type: ignore[arg-type] + t4 = TensorType(shape=(1, 2), dtype="float32") # type: ignore[arg-type] # t1 == t2 - t1.eq_s(t2, bind_free_vars=False, assert_mode=True) - assert t1.hash_s() == t2.hash_s() + mlc.eq_s(t1, t2, bind_free_vars=False, assert_mode=True) + assert mlc.hash_s(t1) == mlc.hash_s(t2) # t1 != t3, dtype mismatch with pytest.raises(ValueError) as e: - t1.eq_s(t3, bind_free_vars=False, assert_mode=True) + mlc.eq_s(t1, t3, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.dtype: float32 vs int32" - assert str(e.value) == t1.eq_s_fail_reason(t3, bind_free_vars=False) - assert t1.hash_s() != t3.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(t1, t3, bind_free_vars=False) + assert mlc.hash_s(t1) != mlc.hash_s(t3) # t1 != t4, shape mismatch with pytest.raises(ValueError) as e: - t1.eq_s(t4, bind_free_vars=False, assert_mode=True) + mlc.eq_s(t1, t4, bind_free_vars=False, assert_mode=True) assert ( str(e.value) == "Structural equality check failed at {root}.shape: List length mismatch: 3 vs 2" ) - assert str(e.value) == t1.eq_s_fail_reason(t4, bind_free_vars=False) - assert t1.hash_s() != t4.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(t1, t4, bind_free_vars=False) + assert mlc.hash_s(t1) != mlc.hash_s(t4) def test_constant() -> None: c1 = Constant(1) c2 = Constant(1) c3 = Constant(2) - c1.eq_s(c2, bind_free_vars=False, assert_mode=True) + mlc.eq_s(c1, c2, bind_free_vars=False, assert_mode=True) with pytest.raises(ValueError) as e: - c1.eq_s(c3, bind_free_vars=False, assert_mode=True) + mlc.eq_s(c1, c3, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.value: 1 vs 2" - assert str(e.value) == c1.eq_s_fail_reason(c3, bind_free_vars=False) - assert c1.hash_s() == c2.hash_s() - assert c1.hash_s() != c3.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(c1, c3, bind_free_vars=False) + assert mlc.hash_s(c1) == mlc.hash_s(c2) + assert mlc.hash_s(c1) != mlc.hash_s(c3) def test_let_1() -> None: @@ -157,12 +157,12 @@ def test_let_1() -> None: """ lhs = Let(rhs=x + x, lhs=y, body=y) rhs = Let(rhs=y + y, lhs=x, body=x) - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) with pytest.raises(ValueError) as e: - lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) + mlc.eq_s(lhs, rhs, bind_free_vars=False, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.rhs.a: Unbound variable" - assert str(e.value) == lhs.eq_s_fail_reason(rhs, bind_free_vars=False) - assert lhs.hash_s() == rhs.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(lhs, rhs, bind_free_vars=False) + assert mlc.hash_s(lhs) == mlc.hash_s(rhs) def test_let_2() -> None: @@ -176,13 +176,13 @@ def test_let_2() -> None: l1 = Let(rhs=Constant(1), lhs=v1, body=v1) l2 = Let(rhs=Constant(1), lhs=v2, body=v2) l3 = Let(rhs=Constant(2), lhs=v1, body=v1) - l1.eq_s(l2, bind_free_vars=True, assert_mode=True) + mlc.eq_s(l1, l2, bind_free_vars=True, assert_mode=True) with pytest.raises(ValueError) as e: - l1.eq_s(l3, bind_free_vars=True, assert_mode=True) + mlc.eq_s(l1, l3, bind_free_vars=True, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.rhs.value: 1 vs 2" - assert str(e.value) == l1.eq_s_fail_reason(l3, bind_free_vars=True) - assert l1.hash_s() == l2.hash_s() - assert l1.hash_s() != l3.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(l1, l3, bind_free_vars=True) + assert mlc.hash_s(l1) == mlc.hash_s(l2) + assert mlc.hash_s(l1) != mlc.hash_s(l3) def test_non_scoped_compute_1() -> None: @@ -201,8 +201,8 @@ def test_non_scoped_compute_1() -> None: z = x + x lhs = y + y rhs = y + z - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) - assert lhs.hash_s() == rhs.hash_s() + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) + assert mlc.hash_s(lhs) == mlc.hash_s(rhs) def test_non_scoped_compute_2() -> None: @@ -222,13 +222,13 @@ def test_non_scoped_compute_2() -> None: lhs = AddBind(y, y) rhs = AddBind(y, z) with pytest.raises(ValueError) as e: - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) assert ( str(e.value) == "Structural equality check failed at {root}.b: Inconsistent binding. " "LHS has been bound to a different node while RHS is not bound" ) - assert str(e.value) == lhs.eq_s_fail_reason(rhs, bind_free_vars=True) - assert lhs.hash_s() != rhs.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(lhs, rhs, bind_free_vars=True) + assert mlc.hash_s(lhs) != mlc.hash_s(rhs) def test_func_1() -> None: @@ -250,8 +250,8 @@ def test_func_1() -> None: z = Var("z") lhs = Func("lhs", args=[x, y], body=Let(rhs=x + y, lhs=z, body=z + z)) rhs = Func("rhs", args=[y, x], body=Let(rhs=y + x, lhs=z, body=z + z)) - lhs.eq_s(rhs, bind_free_vars=False, assert_mode=True) - assert lhs.hash_s() == rhs.hash_s() + mlc.eq_s(lhs, rhs, bind_free_vars=False, assert_mode=True) + assert mlc.hash_s(lhs) == mlc.hash_s(rhs) def test_func_2() -> None: @@ -276,17 +276,17 @@ def test_func_2() -> None: l2 = Func("l2", args=[z, x, y, d], body=z + x + y) l3 = Func("l3", args=[x, y, z], body=x + y + z) # l1 == l2 - l1.eq_s(l2, bind_free_vars=False, assert_mode=True) + mlc.eq_s(l1, l2, bind_free_vars=False, assert_mode=True) # l1 != l3, arg length mismatch with pytest.raises(ValueError) as e: - l1.eq_s(l3, bind_free_vars=False, assert_mode=True) + mlc.eq_s(l1, l3, bind_free_vars=False, assert_mode=True) assert ( str(e.value) == "Structural equality check failed at {root}.args: List length mismatch: 4 vs 3" ) - assert str(e.value) == l1.eq_s_fail_reason(l3, bind_free_vars=False) - assert l1.hash_s() == l2.hash_s() - assert l1.hash_s() != l3.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(l1, l3, bind_free_vars=False) + assert mlc.hash_s(l1) == mlc.hash_s(l2) + assert mlc.hash_s(l1) != mlc.hash_s(l3) def test_func_stmts() -> None: @@ -321,16 +321,16 @@ def test_func_stmts() -> None: AssignStmt(rhs=a + b, lhs=c), ], ) - func_f.eq_s(func_f, bind_free_vars=False, assert_mode=True) + mlc.eq_s(func_f, func_f, bind_free_vars=False, assert_mode=True) with pytest.raises(ValueError) as e: - func_f.eq_s(func_g, bind_free_vars=False, assert_mode=True) + mlc.eq_s(func_f, func_g, bind_free_vars=False, assert_mode=True) assert ( str(e.value) == "Structural equality check failed at {root}.stmts[0].rhs.b: Inconsistent binding. " "LHS has been bound to a different node while RHS is not bound" ) - assert str(e.value) == func_f.eq_s_fail_reason(func_g, bind_free_vars=False) - assert func_f.hash_s() != func_g.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(func_f, func_g, bind_free_vars=False) + assert mlc.hash_s(func_f) != mlc.hash_s(func_g) def test_global_var() -> None: @@ -340,12 +340,12 @@ def test_global_var() -> None: lhs = x + y + z rhs = z + y + x with pytest.raises(ValueError) as e: - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) assert str(e.value) == "Structural equality check failed at {root}.a.a.is_global: False vs True" - assert str(e.value) == lhs.eq_s_fail_reason(rhs, bind_free_vars=True) - assert lhs.hash_s() != rhs.hash_s() + assert str(e.value) == mlc.eq_s_fail_reason(lhs, rhs, bind_free_vars=True) + assert mlc.hash_s(lhs) != mlc.hash_s(rhs) lhs = x + y + z rhs = x + z + y - lhs.eq_s(rhs, bind_free_vars=True, assert_mode=True) - assert lhs.hash_s() == rhs.hash_s() + mlc.eq_s(lhs, rhs, bind_free_vars=True, assert_mode=True) + assert mlc.hash_s(lhs) == mlc.hash_s(rhs) diff --git a/tests/python/test_parser_toy_ir_parser.py b/tests/python/test_parser_toy_ir_parser.py index 3c502a15..8be55de6 100644 --- a/tests/python/test_parser_toy_ir_parser.py +++ b/tests/python/test_parser_toy_ir_parser.py @@ -1,7 +1,8 @@ from __future__ import annotations +import mlc from mlc.testing import toy_ir -from mlc.testing.toy_ir import Add, Assign, Func, Var +from mlc.testing.toy_ir import Add, Assign, Func, Stmt, Var def test_parse_func() -> None: @@ -16,7 +17,7 @@ def _expected() -> Func: c = Var(name="_c") d = Var(name="_d") e = Var(name="_e") - stmts = [ + stmts: list[Stmt] = [ Assign(lhs=d, rhs=Add(a, b)), Assign(lhs=e, rhs=Add(d, c)), ] @@ -25,4 +26,4 @@ def _expected() -> Func: result = toy_ir.parse_func(source_code) expected = _expected() - result.eq_s(expected, assert_mode=True) + mlc.eq_s(result, expected, assert_mode=True) diff --git a/tests/python/test_printer_ir_printer.py b/tests/python/test_printer_ir_printer.py index fe16ecde..11f1fb13 100644 --- a/tests/python/test_printer_ir_printer.py +++ b/tests/python/test_printer_ir_printer.py @@ -3,7 +3,7 @@ import mlc.printer as mlcp import pytest from mlc.printer import ObjectPath -from mlc.testing.toy_ir import Add, Assign, Func, Var +from mlc.testing.toy_ir import Add, Assign, Func, Stmt, Var def test_var_print() -> None: @@ -37,7 +37,7 @@ def test_func_print() -> None: c = Var(name="c") d = Var(name="d") e = Var(name="e") - stmts = [ + stmts: list[Stmt] = [ Assign(lhs=d, rhs=Add(a, b)), Assign(lhs=e, rhs=Add(d, c)), ] @@ -84,13 +84,10 @@ def test_print_bool() -> None: def test_duplicated_vars() -> None: a = Var(name="a") b = Var(name="a") - stmts = [ - Assign(lhs=b, rhs=Add(a, a)), - ] f = Func( name="f", args=[a], - stmts=stmts, + stmts=[Assign(lhs=b, rhs=Add(a, a))], ret=b, ) assert ( diff --git a/tests/python/test_sym_analyzer_canonical_simplify.py b/tests/python/test_sym_analyzer_canonical_simplify.py index eea15d7a..c8c69555 100644 --- a/tests/python/test_sym_analyzer_canonical_simplify.py +++ b/tests/python/test_sym_analyzer_canonical_simplify.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from types import MappingProxyType +import mlc import pytest from mlc import sym as S from mlc.sym import floordiv as fld @@ -52,13 +53,13 @@ def test_body(self, param: Param) -> None: analyzer = S.Analyzer() self._add_bound(analyzer, param) after = canonical_simplify(analyzer, param.before) - if not param.after.eq_s(after): + if not mlc.eq_s(param.after, after): raise AssertionError( "CanonicalSimplify did not produce the expected result.\n" f"Before: {param.before}\n" f"Expected: {param.after}\n" f"Actual: {after}\n" - f"Reason: {param.after.eq_s_fail_reason(after)}" + f"Reason: {mlc.eq_s_fail_reason(param.after, after)}" ) def _add_bound(self, analyzer: S.Analyzer, param: Param) -> None: diff --git a/tests/python/test_sym_analyzer_rewrite_simplify.py b/tests/python/test_sym_analyzer_rewrite_simplify.py index ce984593..02293b09 100644 --- a/tests/python/test_sym_analyzer_rewrite_simplify.py +++ b/tests/python/test_sym_analyzer_rewrite_simplify.py @@ -3,6 +3,7 @@ import dataclasses import typing +import mlc import pytest from mlc import sym as S from mlc.sym import floordiv as fld @@ -57,13 +58,13 @@ def test_body(self, param: Param) -> None: analyzer = S.Analyzer() with enter_constraint(analyzer, param.constraints): after = rewrite_simplify(analyzer, param.before) - if not param.after.eq_s(after): + if not mlc.eq_s(param.after, after): raise AssertionError( "RewriteSimplify did not produce the expected result.\n" f"Before: {param.before}\n" f"Expected: {param.after}\n" f"Actual: {after}\n" - f"Reason: {param.after.eq_s_fail_reason(after)}" + f"Reason: {mlc.eq_s_fail_reason(param.after, after)}" ) diff --git a/tests/python/test_sym_analyzer_simplify.py b/tests/python/test_sym_analyzer_simplify.py index 829fa785..99e4bc1d 100644 --- a/tests/python/test_sym_analyzer_simplify.py +++ b/tests/python/test_sym_analyzer_simplify.py @@ -3,6 +3,7 @@ from types import MappingProxyType from typing import Literal +import mlc import pytest from mlc import sym as S @@ -28,7 +29,7 @@ def test_index_flatten(analyzer: S.Analyzer) -> None: before = (i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4 expected_after = i_flattened after = analyzer.simplify(before) - expected_after.eq_s(after, assert_mode=True) + mlc.eq_s(expected_after, after, assert_mode=True) @pytest.mark.parametrize(