Skip to content

Commit 45fae3e

Browse files
authored
feat(core): Support serialization of opaque objects (#75)
Serialization/deserialization support for `mlc.Opaque`. Part of #73. For any MLC object, it now supports: - `.json(...)` method adds argument `fn_opaque_serialize`, which allows users to supply a customized serialization function; - `.from_json(...)` adds argument `fn_opaque_deserialize`, which allows users to supply a customized deserialization function. By default, the serialization and deserialization methods are: ```python @mlc.Func.register("mlc.Opaque.default.serialize") def _default_serialize(opaques: list[Any]) -> str: return jsonpickle.dumps(list(opaques)) @mlc.Func.register("mlc.Opaque.default.deserialize") def _default_deserialize(json_str: str) -> list[Any]: return jsonpickle.loads(json_str) ``` `jsonpickle` is not a perfect library - usually `pickle` or `cloudpickle` could be substantially better in terms of feature completeness, but it actually gives a balance between serializability and readability.
1 parent a8aa1d6 commit 45fae3e

File tree

10 files changed

+288
-91
lines changed

10 files changed

+288
-91
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: CI
33
on: [push, pull_request]
44
env:
55
CIBW_BUILD_VERBOSITY: 3
6-
CIBW_TEST_REQUIRES: "pytest torch"
6+
CIBW_TEST_REQUIRES: "pytest torch jsonpickle"
77
CIBW_TEST_COMMAND: "pytest -svv --durations=20 {project}/tests/python/"
88
CIBW_ENVIRONMENT: "MLC_SHOW_CPP_STACKTRACES=1"
99
CIBW_REPAIR_WHEEL_COMMAND_LINUX: >

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ repos:
2525
rev: "v1.14.1"
2626
hooks:
2727
- id: mypy
28-
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "torch"]
28+
additional_dependencies: ['numpy >= 1.22', "ml-dtypes >= 0.1", "pytest", "torch", "jsonpickle"]
2929
args: [--show-error-codes]
3030
- repo: https://github.com/pre-commit/mirrors-clang-format
3131
rev: "v19.1.6"

cpp/registry.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ namespace mlc {
1919
namespace registry {
2020

2121
Any JSONLoads(AnyView json_str);
22-
Any JSONDeserialize(AnyView json_str);
23-
Str JSONSerialize(AnyView source);
22+
Any JSONDeserialize(AnyView json_str, FuncObj *fn_opaque_deserialize);
23+
Str JSONSerialize(AnyView source, FuncObj *fn_opaque_serialize);
2424
bool StructuralEqual(AnyView lhs, AnyView rhs, bool bind_free_vars, bool assert_mode);
2525
int64_t StructuralHash(AnyView root);
2626
Optional<Str> StructuralEqualFailReason(AnyView lhs, AnyView rhs, bool bind_free_vars);

cpp/structure.cc

Lines changed: 141 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ostream>
1111
#include <sstream>
1212
#include <stdexcept>
13+
#include <string>
1314
#include <unordered_map>
1415

1516
namespace mlc {
@@ -1300,8 +1301,11 @@ Tensor TensorFromBytes(const uint8_t *data_ptr, int64_t max_size) {
13001301

13011302
/****************** Serialize / Deserialize ******************/
13021303

1303-
inline mlc::Str Serialize(Any any) {
1304+
inline mlc::Str Serialize(Any any, FuncObj *fn_opaque_serialize) {
13041305
using mlc::base::TypeTraits;
1306+
// Section 1. Define two lookups
1307+
// 1) `type_keys` and `get_json_type_index`, which maps a `type_key` to type index/key to that in JSON
1308+
// 2) `opaques` and `get_opaque_index`, which maps an `OpaqueObj` to its index in the list of opaques
13051309
std::vector<const char *> type_keys;
13061310
auto get_json_type_index = [type_key2index = std::unordered_map<const char *, int32_t>(),
13071311
&type_keys](const char *type_key) mutable -> int32_t {
@@ -1313,8 +1317,27 @@ inline mlc::Str Serialize(Any any) {
13131317
type_keys.push_back(type_key);
13141318
return type_index;
13151319
};
1316-
using TObj2Idx = std::unordered_map<Object *, int32_t>;
1317-
using TJsonTypeIndex = decltype(get_json_type_index);
1320+
using TGetJSONTypeIndex = decltype(get_json_type_index);
1321+
struct OpaqueHash {
1322+
size_t operator()(const OpaqueObj *opaque) const { return std::hash<void *>{}(opaque->handle); }
1323+
};
1324+
struct OpaqueEq {
1325+
bool operator()(const OpaqueObj *lhs, const OpaqueObj *rhs) const { return lhs->handle == rhs->handle; }
1326+
};
1327+
UList opaques;
1328+
auto get_opaque_index = [opaque2index = std::unordered_map<const OpaqueObj *, int32_t, OpaqueHash, OpaqueEq>(),
1329+
&opaques](const OpaqueObj *opaque) mutable -> int32_t {
1330+
if (auto it = opaque2index.find(opaque); it != opaque2index.end()) {
1331+
return it->second;
1332+
}
1333+
int32_t type_index = static_cast<int32_t>(opaque2index.size());
1334+
opaque2index[opaque] = type_index;
1335+
opaques.push_back(opaque);
1336+
return type_index;
1337+
};
1338+
// Section 2. Define `Emitter`, which emits a singleton type of:
1339+
// - POD: bool, int, float, string, DLDataType, DLDevice, void*
1340+
// - Object that has been known previously
13181341
struct Emitter {
13191342
MLC_INLINE void operator()(MLCTypeField *, const Any *any) { EmitAny(any); }
13201343
// clang-format off
@@ -1381,57 +1404,77 @@ inline mlc::Str Serialize(Any any) {
13811404
if (!obj) {
13821405
MLC_THROW(InternalError) << "This should never happen: null object pointer during EmitObject";
13831406
}
1384-
int32_t obj_idx = obj2index->at(obj);
1407+
int32_t obj_idx = topo_index->at(obj);
13851408
if (obj_idx == -1) {
13861409
MLC_THROW(InternalError) << "This should never happen: topological ordering violated";
13871410
}
13881411
(*os) << ", " << obj_idx;
13891412
}
13901413
std::ostringstream *os;
1391-
TJsonTypeIndex *get_json_type_index;
1392-
const TObj2Idx *obj2index;
1414+
TGetJSONTypeIndex *get_json_type_index;
1415+
const std::unordered_map<Object *, int32_t> *topo_index;
13931416
};
1394-
1395-
std::unordered_map<Object *, int32_t> topo_indices;
1417+
// Section 3. Define `on_visit` method for topological traversal of the graph
1418+
// Inside `values` section, every `on_visit` generates one of the following:
1419+
// 1) string `s`: for string literal `s`
1420+
// 2) list -> normal case, its layout is:
1421+
// - [0] = json_type_index
1422+
// - [1...] = each field of the type
1423+
// * int: refer to `values[i]`
1424+
// * list: TODO: explain
1425+
// * str / bool / float / None: literals
13961426
std::vector<TensorObj *> tensors;
13971427
std::ostringstream os;
1398-
auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os, &tensors,
1399-
is_first_object = true](Object *object, MLCTypeInfo *type_info) mutable -> void {
1428+
auto on_visit =
1429+
[get_json_type_index = &get_json_type_index, os = &os, &tensors, &get_opaque_index, is_first_object = true,
1430+
topo_indices = std::unordered_map<Object *, int32_t>()](Object *object, MLCTypeInfo *type_info) mutable -> void {
1431+
// Step 1. Allocate `topo_index` assigned to the current object
14001432
int32_t &topo_index = topo_indices[object];
14011433
if (topo_index == 0) {
14021434
topo_index = static_cast<int32_t>(topo_indices.size()) - 1;
14031435
} else {
14041436
MLC_THROW(InternalError) << "This should never happen: object already visited";
14051437
}
1406-
Emitter emitter{os, get_json_type_index, &topo_indices};
14071438
if (is_first_object) {
14081439
is_first_object = false;
14091440
} else {
14101441
os->put(',');
14111442
}
1443+
// Step 2. Print the current object
1444+
// Special case: string
14121445
if (StrObj *str = object->as<StrObj>()) {
14131446
str->PrintEscape(*os);
14141447
return;
14151448
}
1449+
// [0] = json_type_index
14161450
(*os) << '[' << (*get_json_type_index)(type_info->type_key);
1451+
// [1...] = each field of the type. A few possible cases:
1452+
// 1) list
1453+
// 2) dict
1454+
// 3) tensor
1455+
// 4) opaque
1456+
// 5) a normal dataclass
14171457
if (UListObj *list = object->as<UListObj>()) {
1458+
Emitter emitter{os, get_json_type_index, &topo_indices};
14181459
for (Any &any : *list) {
1419-
emitter(nullptr, &any);
1460+
emitter.EmitAny(&any);
14201461
}
14211462
} else if (UDictObj *dict = object->as<UDictObj>()) {
1463+
Emitter emitter{os, get_json_type_index, &topo_indices};
14221464
for (auto &kv : *dict) {
1423-
emitter(nullptr, &kv.first);
1424-
emitter(nullptr, &kv.second);
1465+
emitter.EmitAny(&kv.first);
1466+
emitter.EmitAny(&kv.second);
14251467
}
14261468
} else if (TensorObj *tensor = object->as<TensorObj>()) {
14271469
(*os) << ", " << tensors.size();
14281470
tensors.push_back(tensor);
1471+
} else if (OpaqueObj *opaque = object->as<OpaqueObj>()) {
1472+
int32_t opaque_index = get_opaque_index(opaque);
1473+
(*os) << ", " << opaque_index;
14291474
} else if (object->IsInstance<FuncObj>() || object->IsInstance<ErrorObj>()) {
14301475
MLC_THROW(TypeError) << "Unserializable type: " << object->GetTypeKey();
1431-
} else if (object->IsInstance<OpaqueObj>()) {
1432-
MLC_THROW(TypeError) << "Cannot serialize `mlc.Opaque` of type: "
1433-
<< object->DynCast<OpaqueObj>()->opaque_type_name;
14341476
} else {
1477+
Emitter emitter{os, get_json_type_index, &topo_indices};
14351478
VisitFields(object, type_info, emitter);
14361479
}
14371480
os->put(']');
@@ -1481,12 +1524,25 @@ inline mlc::Str Serialize(Any any) {
14811524
}
14821525
os << "]";
14831526
}
1527+
if (!opaques.empty()) {
1528+
os << ", \"opaques\":";
1529+
if (!fn_opaque_serialize) {
1530+
fn_opaque_serialize = Func::GetGlobal("mlc.Opaque.default.serialize", true);
1531+
}
1532+
if (!fn_opaque_serialize) {
1533+
MLC_THROW(ValueError) << "Cannot find serialization function `mlc.Opaque.default.serialize`. Register it with "
1534+
"`mlc.Func.register(\"mlc.Opaque.default.serialize\")(serialize_func)`";
1535+
}
1536+
Str opaque_repr = (*fn_opaque_serialize)(opaques);
1537+
opaque_repr->PrintEscape(os);
1538+
}
14841539
os << "}";
14851540
return os.str();
14861541
}
14871542

1488-
inline Any Deserialize(const char *json_str, int64_t json_str_len) {
1543+
inline Any Deserialize(const char *json_str, int64_t json_str_len, FuncObj *fn_opaque_deserialize) {
14891544
int32_t json_type_index_tensor = -1;
1545+
int32_t json_type_index_opaque = -1;
14901546
// Step 0. Parse JSON string
14911547
UDict json_obj = JSONLoads(json_str, json_str_len);
14921548
// Step 1. type_key => constructors
@@ -1496,10 +1552,12 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
14961552
for (Str type_key : type_keys) {
14971553
int32_t type_index = Lib::GetTypeIndex(type_key->data());
14981554
FuncObj *func = nullptr;
1499-
if (type_index != kMLCTensor) {
1500-
func = Lib::_init(type_index);
1501-
} else {
1555+
if (type_index == kMLCTensor) {
15021556
json_type_index_tensor = static_cast<int32_t>(constructors.size());
1557+
} else if (type_index == kMLCOpaque) {
1558+
json_type_index_opaque = static_cast<int32_t>(constructors.size());
1559+
} else {
1560+
func = Lib::_init(type_index);
15031561
}
15041562
constructors.push_back(func);
15051563
}
@@ -1522,45 +1580,71 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
15221580
json_obj->erase("tensors");
15231581
std::reverse(tensors.begin(), tensors.end());
15241582
}
1525-
// Step 3. Translate JSON object to objects
1583+
// Step 3. Handle opaque objects
1584+
UList opaques;
1585+
if (json_obj.count("opaques")) {
1586+
if (!fn_opaque_deserialize) {
1587+
fn_opaque_deserialize = Func::GetGlobal("mlc.Opaque.default.deserialize", true);
1588+
}
1589+
if (!fn_opaque_deserialize) {
1590+
MLC_THROW(ValueError)
1591+
<< "Cannot find deserialization function `mlc.Opaque.default.deserialize`. Register it with "
1592+
"`mlc.Func.register(\"mlc.Opaque.default.deserialize\")(deserialize_func)`";
1593+
}
1594+
opaques = (*fn_opaque_deserialize)(json_obj->at("opaques")).operator UList();
1595+
}
1596+
// Step 4. Translate JSON object to objects
15261597
UList values = json_obj->at("values");
15271598
for (int64_t i = 0; i < values->size(); ++i) {
1528-
Any obj = values[i];
1529-
if (obj.type_index == kMLCList) {
1530-
UList list = obj.operator UList();
1531-
int32_t json_type_index = list[0];
1599+
Any &value = values[i];
1600+
// every `value` is
1601+
// 1) integer `i` -> refer to `values[i]`
1602+
// 2) string `s` -> string literal `s`
1603+
// 3) list -> normal case, its layout is:
1604+
// - [0] = json_type_index
1605+
// - [1...] = each field of the type
1606+
// * int: refer to `values[i]`
1607+
// * list: TODO: explain
1608+
// * str / bool / float / None: literals
1609+
// TODO: how about kMLCBool, kMLCFloat, kMLCNone?
1610+
if (UListObj *list = value.as<UListObj>()) {
1611+
// Layout of the list:
1612+
int32_t json_type_index = (*list)[0];
15321613
if (json_type_index == json_type_index_tensor) {
1533-
values[i] = tensors[list[1].operator int32_t()];
1534-
continue;
1535-
}
1536-
for (int64_t j = 1; j < list.size(); ++j) {
1537-
Any arg = list[j];
1538-
if (arg.type_index == kMLCInt) {
1539-
int64_t k = arg;
1540-
if (k < i) {
1541-
list[j] = values[k];
1614+
int32_t idx = (*list)[1];
1615+
value = tensors[idx];
1616+
} else if (json_type_index == json_type_index_opaque) {
1617+
int32_t idx = (*list)[1];
1618+
value = opaques[idx];
1619+
} else {
1620+
for (int64_t j = 1; j < list->size(); ++j) {
1621+
Any arg = (*list)[j];
1622+
if (arg.type_index == kMLCInt) {
1623+
int64_t k = arg;
1624+
if (k < i) {
1625+
(*list)[j] = values[k];
1626+
} else {
1627+
MLC_THROW(ValueError) << "Invalid reference when parsing type `" << type_keys[json_type_index]
1628+
<< "`: referring #" << k << " at #" << i << ". v = " << value;
1629+
}
1630+
} else if (arg.type_index == kMLCList) {
1631+
(*list)[j] = invoke_init(arg.operator UList());
1632+
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
1633+
arg.type_index == kMLCNone) {
1634+
// Do nothing
15421635
} else {
1543-
MLC_THROW(ValueError) << "Invalid reference when parsing type `" << type_keys[json_type_index]
1544-
<< "`: referring #" << k << " at #" << i << ". v = " << obj;
1636+
MLC_THROW(ValueError) << "Unexpected value: " << arg;
15451637
}
1546-
} else if (arg.type_index == kMLCList) {
1547-
list[j] = invoke_init(arg.operator UList());
1548-
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
1549-
arg.type_index == kMLCNone) {
1550-
// Do nothing
1551-
} else {
1552-
MLC_THROW(ValueError) << "Unexpected value: " << arg;
15531638
}
1639+
value = invoke_init(UList(list));
15541640
}
1555-
values[i] = invoke_init(list);
1556-
} else if (obj.type_index == kMLCInt) {
1557-
int32_t k = obj;
1558-
values[i] = values[k];
1559-
} else if (obj.type_index == kMLCStr) {
1641+
} else if (value.type_index == kMLCInt) {
1642+
int32_t k = value;
1643+
value = values[k];
1644+
} else if (value.type_index == kMLCStr) {
15601645
// Do nothing
1561-
// TODO: how about kMLCBool, kMLCFloat, kMLCNone?
15621646
} else {
1563-
MLC_THROW(ValueError) << "Unexpected value: " << obj;
1647+
MLC_THROW(ValueError) << "Unexpected value: " << value;
15641648
}
15651649
}
15661650
return values->back();
@@ -1617,16 +1701,18 @@ Any JSONLoads(AnyView json_str) {
16171701
}
16181702
}
16191703

1620-
Any JSONDeserialize(AnyView json_str) {
1704+
Any JSONDeserialize(AnyView json_str, FuncObj *fn_opaque_deserialize) {
16211705
if (json_str.type_index == kMLCRawStr) {
1622-
return ::mlc::Deserialize(json_str.operator const char *(), -1);
1706+
return ::mlc::Deserialize(json_str.operator const char *(), -1, fn_opaque_deserialize);
16231707
} else {
16241708
StrObj *js = json_str.operator StrObj *();
1625-
return ::mlc::Deserialize(js->data(), js->size());
1709+
return ::mlc::Deserialize(js->data(), js->size(), fn_opaque_deserialize);
16261710
}
16271711
}
16281712

1629-
Str JSONSerialize(AnyView source) { return ::mlc::Serialize(source); }
1713+
Str JSONSerialize(AnyView source, FuncObj *fn_opaque_serialize) {
1714+
return ::mlc::Serialize(source, fn_opaque_serialize);
1715+
}
16301716

16311717
Str TensorToBytes(const TensorObj *src) {
16321718
return ::mlc::TensorToBytes(&src->tensor); //

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ authors = [{ name = "MLC Authors", email = "[email protected]" }]
2525
"mlc.config" = "mlc.config:main"
2626

2727
[project.optional-dependencies]
28-
tests = ['pytest', 'torch']
28+
tests = ['pytest', 'torch', 'jsonpickle']
2929
dev = [
3030
"cython>=3.1",
3131
"pre-commit",
@@ -35,6 +35,7 @@ dev = [
3535
"ruff",
3636
"mypy",
3737
"torch",
38+
"jsonpickle",
3839
]
3940

4041
[build-system]

python/mlc/_cython/core.pyx

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -360,25 +360,25 @@ cdef class PyAny:
360360
return (base.new_object, (type(self),), self.__getstate__())
361361

362362
def __getstate__(self):
363-
return {"mlc_json": func_call(_SERIALIZE, (self,))}
363+
return {"mlc_json": func_call(_SERIALIZE, (self, None))}
364364

365365
def __setstate__(self, state):
366-
cdef PyAny ret = func_call(_DESERIALIZE, (state["mlc_json"], ))
366+
cdef PyAny ret = func_call(_DESERIALIZE, (state["mlc_json"], None))
367367
cdef MLCAny tmp = self._mlc_any
368368
self._mlc_any = ret._mlc_any
369369
ret._mlc_any = tmp
370370

371-
def _mlc_json(self):
372-
return func_call(_SERIALIZE, (self,))
371+
def _mlc_json(self, fn_opaque_serialize):
372+
return func_call(_SERIALIZE, (self, fn_opaque_serialize))
373373

374374
def _mlc_swap(self, PyAny other):
375375
cdef MLCAny tmp = self._mlc_any
376376
self._mlc_any = other._mlc_any
377377
other._mlc_any = tmp
378378

379379
@staticmethod
380-
def _mlc_from_json(mlc_json):
381-
return func_call(_DESERIALIZE, (mlc_json,))
380+
def _mlc_from_json(mlc_json, fn_opaque_deserialize):
381+
return func_call(_DESERIALIZE, (mlc_json, fn_opaque_deserialize))
382382

383383
@staticmethod
384384
def _mlc_eq_s(PyAny lhs, PyAny rhs, bint bind_free_vars, bint assert_mode) -> bool:
@@ -1442,8 +1442,8 @@ cpdef void func_init(PyAny self, object callable):
14421442
self._mlc_any = ret._mlc_any
14431443
ret._mlc_any = _MLCAnyNone()
14441444

1445-
cpdef void opaque_init(PyAny self, object callable):
1446-
cdef PyAny ret = _pyany_from_opaque(callable)
1445+
cpdef void opaque_init(PyAny self, object opaque):
1446+
cdef PyAny ret = _pyany_from_opaque(opaque)
14471447
self._mlc_any = ret._mlc_any
14481448
ret._mlc_any = _MLCAnyNone()
14491449

0 commit comments

Comments
 (0)