Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ MLC provides Pythonic dataclasses:
import mlc.dataclasses as mlcd

@mlcd.py_class("demo.MyClass")
class MyClass(mlcd.PyClass):
class MyClass:
a: int
b: str
c: float | None
Expand All @@ -60,10 +60,10 @@ AttributeError: 'MyClass' object has no attribute 'non_exist' and no __dict__ fo
**Serialization**. MLC dataclasses are picklable and JSON-serializable.

```python
>>> MyClass.from_json(instance.json())
>>> mlc.json_loads(mlc.json_dumps(instance))
demo.MyClass(a=12, b='test', c=None)

>>> import pickle; pickle.loads(pickle.dumps(instance))
>>> pickle.loads(pickle.dumps(instance))
demo.MyClass(a=12, b='test', c=None)
```

Expand Down Expand Up @@ -114,10 +114,11 @@ By annotating IR definitions with `structure`, MLC supports structural equality
<details><summary> Define a toy IR with `structure`. </summary>

```python
import mlc
import mlc.dataclasses as mlcd

@mlcd.py_class
class Expr(mlcd.PyClass):
class Expr:
def __add__(self, other):
return Add(a=self, b=other)

Expand Down Expand Up @@ -146,16 +147,16 @@ class Let(Expr):
>>> L1 = Let(rhs=x + y, lhs=z, body=z) # let z = x + y; z
>>> L2 = Let(rhs=y + z, lhs=x, body=x) # let x = y + z; x
>>> L3 = Let(rhs=x + x, lhs=z, body=z) # let z = x + x; z
>>> L1.eq_s(L2)
>>> mlc.eq_s(L1, L2)
True
>>> L1.eq_s(L3, assert_mode=True)
>>> mlc.eq_s(L1, L3, assert_mode=True)
ValueError: Structural equality check failed at {root}.rhs.b: Inconsistent binding. RHS has been bound to a different node while LHS is not bound
```

**Structural hashing**. The structure of MLC dataclasses can be hashed via `hash_s`, which guarantees if two dataclasses are alpha-equivalent, they will share the same structural hash:

```python
>>> L1_hash, L2_hash, L3_hash = L1.hash_s(), L2.hash_s(), L3.hash_s()
>>> L1_hash, L2_hash, L3_hash = mlc.hash_s(L1), mlc.hash_s(L2), mlc.hash_s(L3)
>>> assert L1_hash == L2_hash
>>> assert L1_hash != L3_hash
```
Expand Down
4 changes: 2 additions & 2 deletions cpp/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
namespace mlc {
namespace registry {

Any JSONLoads(AnyView json_str);
Any JSONParse(AnyView json_str);
Any JSONDeserialize(AnyView json_str, FuncObj *fn_opaque_deserialize);
Str JSONSerialize(AnyView source, FuncObj *fn_opaque_serialize);
bool StructuralEqual(AnyView lhs, AnyView rhs, bool bind_free_vars, bool assert_mode);
Expand Down Expand Up @@ -646,7 +646,7 @@ inline TypeTable *TypeTable::New() {
self->SetFunc("mlc.base.DeviceTypeRegister",
Func([self](const char *name) { return self->DeviceTypeRegister(name); }).get());
self->SetFunc("mlc.core.Stringify", Func(::mlc::core::StringifyWithFields).get());
self->SetFunc("mlc.core.JSONLoads", Func(::mlc::registry::JSONLoads).get());
self->SetFunc("mlc.core.JSONParse", Func(::mlc::registry::JSONParse).get());
self->SetFunc("mlc.core.JSONSerialize", Func(::mlc::registry::JSONSerialize).get());
self->SetFunc("mlc.core.JSONDeserialize", Func(::mlc::registry::JSONDeserialize).get());
self->SetFunc("mlc.core.StructuralEqual", Func(::mlc::registry::StructuralEqual).get());
Expand Down
10 changes: 5 additions & 5 deletions cpp/structure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using mlc::core::VisitStructure;

/****************** JSON ******************/

inline Any JSONLoads(const char *json_str, int64_t json_str_len) {
inline Any JSONParse(const char *json_str, int64_t json_str_len) {
struct JSONParser {
Any Parse() {
SkipWhitespace();
Expand Down Expand Up @@ -1552,7 +1552,7 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len, FuncObj *fn_o
int32_t json_type_index_tensor = -1;
int32_t json_type_index_opaque = -1;
// Step 0. Parse JSON string
UDict json_obj = JSONLoads(json_str, json_str_len);
UDict json_obj = JSONParse(json_str, json_str_len);
// Step 1. type_key => constructors
UList type_keys = json_obj->at("type_keys");
std::vector<FuncObj *> constructors;
Expand Down Expand Up @@ -1700,12 +1700,12 @@ Any CopyShallow(AnyView source) { return CopyShallowImpl(source); }
Any CopyDeep(AnyView source) { return CopyDeepImpl(source); }
void CopyReplace(int32_t num_args, const AnyView *args, Any *ret) { CopyReplaceImpl(num_args, args, ret); }

Any JSONLoads(AnyView json_str) {
Any JSONParse(AnyView json_str) {
if (json_str.type_index == kMLCRawStr) {
return ::mlc::JSONLoads(json_str.operator const char *(), -1);
return ::mlc::JSONParse(json_str.operator const char *(), -1);
} else {
StrObj *js = json_str.operator StrObj *();
return ::mlc::JSONLoads(js->data(), js->size());
return ::mlc::JSONParse(js->data(), js->size());
}
}

Expand Down
8 changes: 7 additions & 1 deletion python/mlc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
Tensor,
build_info,
dep_graph,
eq_ptr,
eq_s,
eq_s_fail_reason,
hash_s,
json_dumps,
json_loads,
json_parse,
typing,
)
from .core.dep_graph import DepGraph, DepNode
from .dataclasses import PyClass, c_class, py_class
from .dataclasses import c_class, py_class

try:
from ._version import __version__, __version_tuple__ # type: ignore[import-not-found]
Expand Down
4 changes: 2 additions & 2 deletions python/mlc/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .dict import Dict
from .dtype import DataType
from .error import Error
from .func import Func, build_info, json_loads
from .func import Func, build_info, json_parse
from .list import List
from .object import Object
from .object import Object, eq_ptr, eq_s, eq_s_fail_reason, hash_s, json_dumps, json_loads
from .object_path import ObjectPath
from .opaque import Opaque
from .tensor import Tensor
6 changes: 3 additions & 3 deletions python/mlc/core/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def decorator(func: _CallableType) -> _CallableType:
return decorator


def json_loads(s: str) -> Any:
return _json_loads(s)
def json_parse(s: str) -> Any:
return _json_parse(s)


def build_info() -> dict[str, Any]:
return _build_info()


_json_loads = Func.get("mlc.core.JSONLoads")
_json_parse = Func.get("mlc.core.JSONParse")
_build_info = Func.get("mlc.core.BuildInfo")
152 changes: 115 additions & 37 deletions python/mlc/core/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

from mlc._cython import PyAny, TypeInfo, c_class_core

try:
from warnings import deprecated # type: ignore[attr-defined]
except ImportError:
from typing_extensions import deprecated


@c_class_core("object.Object")
class Object(PyAny):
Expand All @@ -21,42 +26,6 @@ def id_(self) -> int:
def is_(self, other: Object) -> bool:
return isinstance(other, Object) and self._mlc_address == other._mlc_address

def json(
self,
fn_opaque_serialize: Callable[[list[typing.Any]], str] | None = None,
) -> str:
return super()._mlc_json(fn_opaque_serialize)

@staticmethod
def from_json(
json_str: str,
fn_opaque_deserialize: Callable[[str], list[typing.Any]] | None = None,
) -> Object:
return PyAny._mlc_from_json(json_str, fn_opaque_deserialize) # type: ignore[attr-defined]

def eq_s(
self,
other: Object,
*,
bind_free_vars: bool = True,
assert_mode: bool = False,
) -> bool:
return PyAny._mlc_eq_s(self, other, bind_free_vars, assert_mode) # type: ignore[attr-defined]

def eq_s_fail_reason(
self,
other: Object,
*,
bind_free_vars: bool = True,
) -> tuple[bool, str]:
return PyAny._mlc_eq_s_fail_reason(self, other, bind_free_vars)

def hash_s(self) -> int:
return PyAny._mlc_hash_s(self) # type: ignore[attr-defined]

def eq_ptr(self, other: typing.Any) -> bool:
return isinstance(other, Object) and self._mlc_address == other._mlc_address

def __copy__(self: Object) -> Object:
return PyAny._mlc_copy_shallow(self) # type: ignore[attr-defined]

Expand All @@ -74,7 +43,7 @@ def __hash__(self) -> int:
return hash((type(self), self._mlc_address))

def __eq__(self, other: typing.Any) -> bool:
return self.eq_ptr(other)
return eq_ptr(self, other)

def __ne__(self, other: typing.Any) -> bool:
return not self == other
Expand Down Expand Up @@ -103,3 +72,112 @@ def swap(self, other: typing.Any) -> None:
self._mlc_swap(other)
else:
raise TypeError(f"Cannot different types: `{type(self)}` and `{type(other)}`")

@deprecated(
"Method `.json` is deprecated. Use `mlc.json_dumps` instead.",
stacklevel=2,
)
def json(
self,
fn_opaque_serialize: Callable[[list[typing.Any]], str] | None = None,
) -> str:
return json_dumps(self, fn_opaque_serialize)

@staticmethod
@deprecated(
"Method `.from_json` is deprecated. Use `mlc.json_loads` instead.",
stacklevel=2,
)
def from_json(
json_str: str,
fn_opaque_deserialize: Callable[[str], list[typing.Any]] | None = None,
) -> Object:
return json_loads(json_str, fn_opaque_deserialize)

@deprecated(
"Method `.eq_s` is deprecated. Use `mlc.eq_s` instead.",
stacklevel=2,
)
def eq_s(
self,
other: Object,
*,
bind_free_vars: bool = True,
assert_mode: bool = False,
) -> bool:
return eq_s(self, other, bind_free_vars=bind_free_vars, assert_mode=assert_mode)

@deprecated(
"Method `.eq_s_fail_reason` is deprecated. Use `mlc.eq_s_fail_reason` instead.",
stacklevel=2,
)
def eq_s_fail_reason(
self,
other: Object,
*,
bind_free_vars: bool = True,
) -> tuple[bool, str]:
return eq_s_fail_reason(self, other, bind_free_vars=bind_free_vars)

@deprecated(
"Method `.hash_s` is deprecated. Use `mlc.hash_s` instead.",
stacklevel=2,
)
def hash_s(self) -> int:
return hash_s(self)

@deprecated(
"Method `.eq_ptr` is deprecated. Use `mlc.eq_ptr` instead.",
stacklevel=2,
)
def eq_ptr(self, other: typing.Any) -> bool:
return eq_ptr(self, other)


def json_dumps(
object: typing.Any,
fn_opaque_serialize: Callable[[list[typing.Any]], str] | None = None,
) -> str:
assert isinstance(object, Object), f"Expected `mlc.Object`, got `{type(object)}`"
return object._mlc_json(fn_opaque_serialize) # type: ignore[attr-defined]


def json_loads(
json_str: str,
fn_opaque_deserialize: Callable[[str], list[typing.Any]] | None = None,
) -> Object:
return PyAny._mlc_from_json(json_str, fn_opaque_deserialize) # type: ignore[attr-defined]


def eq_s(
lhs: typing.Any,
rhs: typing.Any,
*,
bind_free_vars: bool = True,
assert_mode: bool = False,
) -> bool:
assert isinstance(lhs, Object), f"Expected `mlc.Object`, got `{type(lhs)}`"
assert isinstance(rhs, Object), f"Expected `mlc.Object`, got `{type(rhs)}`"
return PyAny._mlc_eq_s(lhs, rhs, bind_free_vars, assert_mode) # type: ignore[attr-defined]


def eq_s_fail_reason(
lhs: typing.Any,
rhs: typing.Any,
*,
bind_free_vars: bool = True,
) -> tuple[bool, str]:
assert isinstance(lhs, Object), f"Expected `mlc.Object`, got `{type(lhs)}`"
assert isinstance(rhs, Object), f"Expected `mlc.Object`, got `{type(rhs)}`"
return PyAny._mlc_eq_s_fail_reason(lhs, rhs, bind_free_vars)


def hash_s(obj: typing.Any) -> int:
assert isinstance(obj, Object), f"Expected `mlc.Object`, got `{type(obj)}`"
return PyAny._mlc_hash_s(obj) # type: ignore[attr-defined]


def eq_ptr(lhs: typing.Any, rhs: typing.Any) -> bool:
assert isinstance(lhs, Object), f"Expected `mlc.Object`, got `{type(lhs)}`"
assert isinstance(rhs, Object), f"Expected `mlc.Object`, got `{type(rhs)}`"
return lhs._mlc_address == rhs._mlc_address
4 changes: 3 additions & 1 deletion python/mlc/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from mlc.core.object import Object as PyClass # for backward compatibility

from .c_class import c_class
from .py_class import PyClass, py_class
from .py_class import py_class
from .utils import (
Structure,
add_vtable_method,
Expand Down
Loading
Loading