Skip to content

Commit 2cbb314

Browse files
authored
feat(core): Support __replace__ in general (#26)
This PR adds `mlc.dataclasses.replace` method that mimics the behavior of Python's native `dataclasses.replace`. In principle, we do want something fully compatible with Python's native dataclass, but for now, there are certain limitations: - `__post_init__` will not be called upon replacement; - POD types, such as `mlc.DataType`, `mlc.Device`, do not support `__replace__`. Something we could always add support case by case if needed later
1 parent e5b4e04 commit 2cbb314

File tree

7 files changed

+91
-1
lines changed

7 files changed

+91
-1
lines changed

cpp/registry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ bool StructuralEqual(AnyView lhs, AnyView rhs, bool bind_free_vars, bool assert_
2525
int64_t StructuralHash(AnyView root);
2626
Any CopyShallow(AnyView root);
2727
Any CopyDeep(AnyView root);
28+
void CopyReplace(int32_t num_args, const AnyView *args, Any *ret);
2829
Str DocToPythonScript(mlc::printer::Node node, mlc::printer::PrinterConfig cfg);
2930
UDict BuildInfo();
3031

@@ -650,6 +651,7 @@ inline TypeTable *TypeTable::New() {
650651
self->SetFunc("mlc.core.StructuralHash", Func(::mlc::registry::StructuralHash).get());
651652
self->SetFunc("mlc.core.CopyShallow", Func(::mlc::registry::CopyShallow).get());
652653
self->SetFunc("mlc.core.CopyDeep", Func(::mlc::registry::CopyDeep).get());
654+
self->SetFunc("mlc.core.CopyReplace", Func(::mlc::registry::CopyReplace).get());
653655
self->SetFunc("mlc.core.BuildInfo", Func(::mlc::registry::BuildInfo).get());
654656
self->SetFunc("mlc.core.TensorToBytes", Func(::mlc::registry::TensorToBytes).get());
655657
self->SetFunc("mlc.core.TensorFromBytes", Func(::mlc::registry::TensorFromBytes).get());

cpp/structure.cc

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,9 @@ inline Any CopyShallowImpl(AnyView source) {
952952
return UList(list->begin(), list->end());
953953
} else if (UDictObj *dict = source.TryCast<UDictObj>()) {
954954
return UDict(dict->begin(), dict->end());
955-
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>()) {
955+
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>() ||
956+
source.IsInstance<TensorObj>()) {
957+
// TODO: do we want to shallow copy these types at all?
956958
return source;
957959
}
958960
struct Copier {
@@ -987,6 +989,62 @@ inline Any CopyShallowImpl(AnyView source) {
987989
return ret;
988990
}
989991

992+
inline void CopyReplaceImpl(int32_t num_args, const AnyView *args, Any *ret) {
993+
if (num_args <= 0) {
994+
MLC_THROW(InternalError) << "InternalError: `CopyReplace` requires at least one argument";
995+
}
996+
AnyView source = args[0];
997+
int32_t type_index = source.type_index;
998+
if (::mlc::base::IsTypeIndexPOD(type_index)) {
999+
MLC_THROW(TypeError) << "TypeError: `__replace__` doesn't work on a POD type: " << source;
1000+
} else if (source.IsInstance<StrObj>() || source.IsInstance<ErrorObj>() || source.IsInstance<FuncObj>() ||
1001+
source.IsInstance<UListObj>() || source.IsInstance<UDictObj>() || source.IsInstance<TensorObj>()) {
1002+
MLC_THROW(TypeError) << "TypeError: `__replace__` doesn't work on type: " << source.GetTypeKey();
1003+
}
1004+
struct Copier {
1005+
MLC_INLINE void operator()(MLCTypeField *f, const Any *any) { AddField(f->name, AnyView(*any)); }
1006+
MLC_INLINE void operator()(MLCTypeField *f, ObjectRef *obj) { AddField(f->name, AnyView(*obj)); }
1007+
MLC_INLINE void operator()(MLCTypeField *f, Optional<ObjectRef> *opt) { AddField(f->name, AnyView(*opt)); }
1008+
MLC_INLINE void operator()(MLCTypeField *f, Optional<bool> *opt) { AddField(f->name, AnyView(*opt)); }
1009+
MLC_INLINE void operator()(MLCTypeField *f, Optional<int64_t> *opt) { AddField(f->name, AnyView(*opt)); }
1010+
MLC_INLINE void operator()(MLCTypeField *f, Optional<double> *opt) { AddField(f->name, AnyView(*opt)); }
1011+
MLC_INLINE void operator()(MLCTypeField *f, Optional<DLDevice> *opt) { AddField(f->name, AnyView(*opt)); }
1012+
MLC_INLINE void operator()(MLCTypeField *f, Optional<DLDataType> *opt) { AddField(f->name, AnyView(*opt)); }
1013+
MLC_INLINE void operator()(MLCTypeField *f, bool *v) { AddField(f->name, AnyView(*v)); }
1014+
MLC_INLINE void operator()(MLCTypeField *f, int8_t *v) { AddField(f->name, AnyView(*v)); }
1015+
MLC_INLINE void operator()(MLCTypeField *f, int16_t *v) { AddField(f->name, AnyView(*v)); }
1016+
MLC_INLINE void operator()(MLCTypeField *f, int32_t *v) { AddField(f->name, AnyView(*v)); }
1017+
MLC_INLINE void operator()(MLCTypeField *f, int64_t *v) { AddField(f->name, AnyView(*v)); }
1018+
MLC_INLINE void operator()(MLCTypeField *f, float *v) { AddField(f->name, AnyView(*v)); }
1019+
MLC_INLINE void operator()(MLCTypeField *f, double *v) { AddField(f->name, AnyView(*v)); }
1020+
MLC_INLINE void operator()(MLCTypeField *f, DLDataType *v) { AddField(f->name, AnyView(*v)); }
1021+
MLC_INLINE void operator()(MLCTypeField *f, DLDevice *v) { AddField(f->name, AnyView(*v)); }
1022+
MLC_INLINE void operator()(MLCTypeField *f, Optional<void *> *v) { AddField(f->name, AnyView(*v)); }
1023+
MLC_INLINE void operator()(MLCTypeField *f, void **v) { AddField(f->name, AnyView(*v)); }
1024+
MLC_INLINE void operator()(MLCTypeField *f, const char **v) { AddField(f->name, AnyView(*v)); }
1025+
1026+
void AddField(std::string_view name, AnyView v) {
1027+
if (auto it = replacements->find(name); it != replacements->end()) {
1028+
fields->push_back(it->second);
1029+
} else {
1030+
fields->push_back(v);
1031+
}
1032+
}
1033+
std::vector<AnyView> *fields;
1034+
std::unordered_map<std::string_view, AnyView> *replacements;
1035+
};
1036+
std::unordered_map<std::string_view, AnyView> replacements;
1037+
for (int32_t i = 1; i < num_args; i += 2) {
1038+
const char *name = args[i];
1039+
replacements[name] = args[i + 1];
1040+
}
1041+
FuncObj *init_func = Lib::_init(type_index);
1042+
MLCTypeInfo *type_info = Lib::GetTypeInfo(type_index);
1043+
std::vector<AnyView> fields;
1044+
VisitFields(source.operator Object *(), type_info, Copier{&fields, &replacements});
1045+
::mlc::base::FuncCall(init_func, static_cast<int32_t>(fields.size()), fields.data(), ret);
1046+
}
1047+
9901048
inline Any CopyDeepImpl(AnyView source) {
9911049
if (::mlc::base::IsTypeIndexPOD(source.type_index)) {
9921050
return source;
@@ -1508,6 +1566,7 @@ int64_t StructuralHash(AnyView root) {
15081566

15091567
Any CopyShallow(AnyView source) { return CopyShallowImpl(source); }
15101568
Any CopyDeep(AnyView source) { return CopyDeepImpl(source); }
1569+
void CopyReplace(int32_t num_args, const AnyView *args, Any *ret) { CopyReplaceImpl(num_args, args, ret); }
15111570

15121571
Any JSONLoads(AnyView json_str) {
15131572
if (json_str.type_index == kMLCRawStr) {

python/mlc/_cython/core.pyx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ cdef class PyAny:
389389
def _mlc_copy_deep(PyAny x) -> PyAny:
390390
return func_call(_COPY_DEEP, (x,))
391391

392+
@staticmethod
393+
def _mlc_copy_replace(*args) -> PyAny:
394+
return func_call(_COPY_REPLACE, args)
395+
392396
@classmethod
393397
def _C(cls, bytes name, *args):
394398
cdef int32_t type_index = cls._mlc_type_info.type_index
@@ -1672,6 +1676,7 @@ cdef PyAny _STRUCUTRAL_EQUAL = func_get_untyped("mlc.core.StructuralEqual")
16721676
cdef PyAny _STRUCUTRAL_HASH = func_get_untyped("mlc.core.StructuralHash")
16731677
cdef PyAny _COPY_SHALLOW = func_get_untyped("mlc.core.CopyShallow")
16741678
cdef PyAny _COPY_DEEP = func_get_untyped("mlc.core.CopyDeep")
1679+
cdef PyAny _COPY_REPLACE = func_get_untyped("mlc.core.CopyReplace")
16751680
cdef PyAny _TENSOR_TO_DLPACK = func_get_untyped("mlc.core.TensorToDLPack")
16761681

16771682
cdef MLCVTableHandle _VTABLE_STR = _vtable_get_global(b"__str__")

python/mlc/core/object.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ def __copy__(self: Object) -> Object:
4141
def __deepcopy__(self: Object, memo: dict[int, Object] | None) -> Object:
4242
return PyAny._mlc_copy_deep(self)
4343

44+
def __replace__(self: Object, /, **changes: typing.Any) -> Object:
45+
unpacked: list[typing.Any] = [self]
46+
for key, value in changes.items():
47+
unpacked.append(key)
48+
unpacked.append(value)
49+
return PyAny._mlc_copy_replace(*unpacked)
50+
4451
def __hash__(self) -> int:
4552
return hash((type(self), self._mlc_address))
4653

python/mlc/dataclasses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
add_vtable_method,
66
field,
77
prototype,
8+
replace,
89
vtable_method,
910
)

python/mlc/dataclasses/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,7 @@ def prototype(
445445
else:
446446
raise ValueError(f"Invalid `lang`: {lang}")
447447
return "\n\n".join(fn(i) for i in type_info_list)
448+
449+
450+
def replace(obj: Any, /, **changes: Any) -> Any:
451+
return obj.__replace__(**changes)

tests/python/test_dataclasses_copy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,15 @@ def test_copy_deep_dataclass(test_obj: CustomInit) -> None:
270270
assert src != dst
271271
assert src.a == dst.a
272272
assert src.b == dst.b
273+
274+
275+
def test_copy_replace_dataclass(test_obj: CustomInit) -> None:
276+
src = test_obj
277+
dst = mlc.dataclasses.replace(src, a=2)
278+
assert src != dst
279+
assert src.a != dst.a
280+
assert src.b == dst.b
281+
assert src.a == 1
282+
assert src.b == "hello"
283+
assert dst.a == 2
284+
assert dst.b == "hello"

0 commit comments

Comments
 (0)