1010#include < ostream>
1111#include < sstream>
1212#include < stdexcept>
13+ #include < string>
1314#include < unordered_map>
1415
1516namespace mlc {
@@ -1300,8 +1301,11 @@ Tensor TensorFromBytes(const uint8_t *data_ptr, int64_t max_size) {
13001301
13011302/* ***************** Serialize / Deserialize ******************/
13021303
1303- inline mlc::Str Serialize (Any any) {
1304+ inline mlc::Str Serialize (Any any, FuncObj *fn_opaque_serialize ) {
13041305 using mlc::base::TypeTraits;
1306+ // Section 1. Define two lookups
1307+ // 1) `type_keys` and `get_json_type_index`, which maps a `type_key` to type index/key to that in JSON
1308+ // 2) `opaques` and `get_opaque_index`, which maps an `OpaqueObj` to its index in the list of opaques
13051309 std::vector<const char *> type_keys;
13061310 auto get_json_type_index = [type_key2index = std::unordered_map<const char *, int32_t >(),
13071311 &type_keys](const char *type_key) mutable -> int32_t {
@@ -1313,8 +1317,27 @@ inline mlc::Str Serialize(Any any) {
13131317 type_keys.push_back (type_key);
13141318 return type_index;
13151319 };
1316- using TObj2Idx = std::unordered_map<Object *, int32_t >;
1317- using TJsonTypeIndex = decltype (get_json_type_index);
1320+ using TGetJSONTypeIndex = decltype (get_json_type_index);
1321+ struct OpaqueHash {
1322+ size_t operator ()(const OpaqueObj *opaque) const { return std::hash<void *>{}(opaque->handle ); }
1323+ };
1324+ struct OpaqueEq {
1325+ bool operator ()(const OpaqueObj *lhs, const OpaqueObj *rhs) const { return lhs->handle == rhs->handle ; }
1326+ };
1327+ UList opaques;
1328+ auto get_opaque_index = [opaque2index = std::unordered_map<const OpaqueObj *, int32_t , OpaqueHash, OpaqueEq>(),
1329+ &opaques](const OpaqueObj *opaque) mutable -> int32_t {
1330+ if (auto it = opaque2index.find (opaque); it != opaque2index.end ()) {
1331+ return it->second ;
1332+ }
1333+ int32_t type_index = static_cast <int32_t >(opaque2index.size ());
1334+ opaque2index[opaque] = type_index;
1335+ opaques.push_back (opaque);
1336+ return type_index;
1337+ };
1338+ // Section 2. Define `Emitter`, which emits a singleton type of:
1339+ // - POD: bool, int, float, string, DLDataType, DLDevice, void*
1340+ // - Object that has been known previously
13181341 struct Emitter {
13191342 MLC_INLINE void operator ()(MLCTypeField *, const Any *any) { EmitAny (any); }
13201343 // clang-format off
@@ -1381,57 +1404,77 @@ inline mlc::Str Serialize(Any any) {
13811404 if (!obj) {
13821405 MLC_THROW (InternalError) << " This should never happen: null object pointer during EmitObject" ;
13831406 }
1384- int32_t obj_idx = obj2index ->at (obj);
1407+ int32_t obj_idx = topo_index ->at (obj);
13851408 if (obj_idx == -1 ) {
13861409 MLC_THROW (InternalError) << " This should never happen: topological ordering violated" ;
13871410 }
13881411 (*os) << " , " << obj_idx;
13891412 }
13901413 std::ostringstream *os;
1391- TJsonTypeIndex *get_json_type_index;
1392- const TObj2Idx *obj2index ;
1414+ TGetJSONTypeIndex *get_json_type_index;
1415+ const std::unordered_map<Object *, int32_t > *topo_index ;
13931416 };
1394-
1395- std::unordered_map<Object *, int32_t > topo_indices;
1417+ // Section 3. Define `on_visit` method for topological traversal of the graph
1418+ // Inside `values` section, every `on_visit` generates one of the following:
1419+ // 1) string `s`: for string literal `s`
1420+ // 2) list -> normal case, its layout is:
1421+ // - [0] = json_type_index
1422+ // - [1...] = each field of the type
1423+ // * int: refer to `values[i]`
1424+ // * list: TODO: explain
1425+ // * str / bool / float / None: literals
13961426 std::vector<TensorObj *> tensors;
13971427 std::ostringstream os;
1398- auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os, &tensors,
1399- is_first_object = true ](Object *object, MLCTypeInfo *type_info) mutable -> void {
1428+ auto on_visit =
1429+ [get_json_type_index = &get_json_type_index, os = &os, &tensors, &get_opaque_index, is_first_object = true ,
1430+ topo_indices = std::unordered_map<Object *, int32_t >()](Object *object, MLCTypeInfo *type_info) mutable -> void {
1431+ // Step 1. Allocate `topo_index` assigned to the current object
14001432 int32_t &topo_index = topo_indices[object];
14011433 if (topo_index == 0 ) {
14021434 topo_index = static_cast <int32_t >(topo_indices.size ()) - 1 ;
14031435 } else {
14041436 MLC_THROW (InternalError) << " This should never happen: object already visited" ;
14051437 }
1406- Emitter emitter{os, get_json_type_index, &topo_indices};
14071438 if (is_first_object) {
14081439 is_first_object = false ;
14091440 } else {
14101441 os->put (' ,' );
14111442 }
1443+ // Step 2. Print the current object
1444+ // Special case: string
14121445 if (StrObj *str = object->as <StrObj>()) {
14131446 str->PrintEscape (*os);
14141447 return ;
14151448 }
1449+ // [0] = json_type_index
14161450 (*os) << ' [' << (*get_json_type_index)(type_info->type_key );
1451+ // [1...] = each field of the type. A few possible cases:
1452+ // 1) list
1453+ // 2) dict
1454+ // 3) tensor
1455+ // 4) opaque
1456+ // 5) a normal dataclass
14171457 if (UListObj *list = object->as <UListObj>()) {
1458+ Emitter emitter{os, get_json_type_index, &topo_indices};
14181459 for (Any &any : *list) {
1419- emitter ( nullptr , &any);
1460+ emitter. EmitAny ( &any);
14201461 }
14211462 } else if (UDictObj *dict = object->as <UDictObj>()) {
1463+ Emitter emitter{os, get_json_type_index, &topo_indices};
14221464 for (auto &kv : *dict) {
1423- emitter ( nullptr , &kv.first );
1424- emitter ( nullptr , &kv.second );
1465+ emitter. EmitAny ( &kv.first );
1466+ emitter. EmitAny ( &kv.second );
14251467 }
14261468 } else if (TensorObj *tensor = object->as <TensorObj>()) {
14271469 (*os) << " , " << tensors.size ();
14281470 tensors.push_back (tensor);
1471+ } else if (OpaqueObj *opaque = object->as <OpaqueObj>()) {
1472+ int32_t opaque_index = get_opaque_index (opaque);
1473+ (*os) << " , " << opaque_index;
14291474 } else if (object->IsInstance <FuncObj>() || object->IsInstance <ErrorObj>()) {
14301475 MLC_THROW (TypeError) << " Unserializable type: " << object->GetTypeKey ();
1431- } else if (object->IsInstance <OpaqueObj>()) {
1432- MLC_THROW (TypeError) << " Cannot serialize `mlc.Opaque` of type: "
1433- << object->DynCast <OpaqueObj>()->opaque_type_name ;
14341476 } else {
1477+ Emitter emitter{os, get_json_type_index, &topo_indices};
14351478 VisitFields (object, type_info, emitter);
14361479 }
14371480 os->put (' ]' );
@@ -1481,12 +1524,25 @@ inline mlc::Str Serialize(Any any) {
14811524 }
14821525 os << " ]" ;
14831526 }
1527+ if (!opaques.empty ()) {
1528+ os << " , \" opaques\" :" ;
1529+ if (!fn_opaque_serialize) {
1530+ fn_opaque_serialize = Func::GetGlobal (" mlc.Opaque.default.serialize" , true );
1531+ }
1532+ if (!fn_opaque_serialize) {
1533+ MLC_THROW (ValueError) << " Cannot find serialization function `mlc.Opaque.default.serialize`. Register it with "
1534+ " `mlc.Func.register(\" mlc.Opaque.default.serialize\" )(serialize_func)`" ;
1535+ }
1536+ Str opaque_repr = (*fn_opaque_serialize)(opaques);
1537+ opaque_repr->PrintEscape (os);
1538+ }
14841539 os << " }" ;
14851540 return os.str ();
14861541}
14871542
1488- inline Any Deserialize (const char *json_str, int64_t json_str_len) {
1543+ inline Any Deserialize (const char *json_str, int64_t json_str_len, FuncObj *fn_opaque_deserialize ) {
14891544 int32_t json_type_index_tensor = -1 ;
1545+ int32_t json_type_index_opaque = -1 ;
14901546 // Step 0. Parse JSON string
14911547 UDict json_obj = JSONLoads (json_str, json_str_len);
14921548 // Step 1. type_key => constructors
@@ -1496,10 +1552,12 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
14961552 for (Str type_key : type_keys) {
14971553 int32_t type_index = Lib::GetTypeIndex (type_key->data ());
14981554 FuncObj *func = nullptr ;
1499- if (type_index != kMLCTensor ) {
1500- func = Lib::_init (type_index);
1501- } else {
1555+ if (type_index == kMLCTensor ) {
15021556 json_type_index_tensor = static_cast <int32_t >(constructors.size ());
1557+ } else if (type_index == kMLCOpaque ) {
1558+ json_type_index_opaque = static_cast <int32_t >(constructors.size ());
1559+ } else {
1560+ func = Lib::_init (type_index);
15031561 }
15041562 constructors.push_back (func);
15051563 }
@@ -1522,45 +1580,71 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
15221580 json_obj->erase (" tensors" );
15231581 std::reverse (tensors.begin (), tensors.end ());
15241582 }
1525- // Step 3. Translate JSON object to objects
1583+ // Step 3. Handle opaque objects
1584+ UList opaques;
1585+ if (json_obj.count (" opaques" )) {
1586+ if (!fn_opaque_deserialize) {
1587+ fn_opaque_deserialize = Func::GetGlobal (" mlc.Opaque.default.deserialize" , true );
1588+ }
1589+ if (!fn_opaque_deserialize) {
1590+ MLC_THROW (ValueError)
1591+ << " Cannot find deserialization function `mlc.Opaque.default.deserialize`. Register it with "
1592+ " `mlc.Func.register(\" mlc.Opaque.default.deserialize\" )(deserialize_func)`" ;
1593+ }
1594+ opaques = (*fn_opaque_deserialize)(json_obj->at (" opaques" )).operator UList ();
1595+ }
1596+ // Step 4. Translate JSON object to objects
15261597 UList values = json_obj->at (" values" );
15271598 for (int64_t i = 0 ; i < values->size (); ++i) {
1528- Any obj = values[i];
1529- if (obj.type_index == kMLCList ) {
1530- UList list = obj.operator UList ();
1531- int32_t json_type_index = list[0 ];
1599+ Any &value = values[i];
1600+ // every `value` is
1601+ // 1) integer `i` -> refer to `values[i]`
1602+ // 2) string `s` -> string literal `s`
1603+ // 3) list -> normal case, its layout is:
1604+ // - [0] = json_type_index
1605+ // - [1...] = each field of the type
1606+ // * int: refer to `values[i]`
1607+ // * list: TODO: explain
1608+ // * str / bool / float / None: literals
1609+ // TODO: how about kMLCBool, kMLCFloat, kMLCNone?
1610+ if (UListObj *list = value.as <UListObj>()) {
1611+ // Layout of the list:
1612+ int32_t json_type_index = (*list)[0 ];
15321613 if (json_type_index == json_type_index_tensor) {
1533- values[i] = tensors[list[1 ].operator int32_t ()];
1534- continue ;
1535- }
1536- for (int64_t j = 1 ; j < list.size (); ++j) {
1537- Any arg = list[j];
1538- if (arg.type_index == kMLCInt ) {
1539- int64_t k = arg;
1540- if (k < i) {
1541- list[j] = values[k];
1614+ int32_t idx = (*list)[1 ];
1615+ value = tensors[idx];
1616+ } else if (json_type_index == json_type_index_opaque) {
1617+ int32_t idx = (*list)[1 ];
1618+ value = opaques[idx];
1619+ } else {
1620+ for (int64_t j = 1 ; j < list->size (); ++j) {
1621+ Any arg = (*list)[j];
1622+ if (arg.type_index == kMLCInt ) {
1623+ int64_t k = arg;
1624+ if (k < i) {
1625+ (*list)[j] = values[k];
1626+ } else {
1627+ MLC_THROW (ValueError) << " Invalid reference when parsing type `" << type_keys[json_type_index]
1628+ << " `: referring #" << k << " at #" << i << " . v = " << value;
1629+ }
1630+ } else if (arg.type_index == kMLCList ) {
1631+ (*list)[j] = invoke_init (arg.operator UList ());
1632+ } else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
1633+ arg.type_index == kMLCNone ) {
1634+ // Do nothing
15421635 } else {
1543- MLC_THROW (ValueError) << " Invalid reference when parsing type `" << type_keys[json_type_index]
1544- << " `: referring #" << k << " at #" << i << " . v = " << obj;
1636+ MLC_THROW (ValueError) << " Unexpected value: " << arg;
15451637 }
1546- } else if (arg.type_index == kMLCList ) {
1547- list[j] = invoke_init (arg.operator UList ());
1548- } else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
1549- arg.type_index == kMLCNone ) {
1550- // Do nothing
1551- } else {
1552- MLC_THROW (ValueError) << " Unexpected value: " << arg;
15531638 }
1639+ value = invoke_init (UList (list));
15541640 }
1555- values[i] = invoke_init (list);
1556- } else if (obj.type_index == kMLCInt ) {
1557- int32_t k = obj;
1558- values[i] = values[k];
1559- } else if (obj.type_index == kMLCStr ) {
1641+ } else if (value.type_index == kMLCInt ) {
1642+ int32_t k = value;
1643+ value = values[k];
1644+ } else if (value.type_index == kMLCStr ) {
15601645 // Do nothing
1561- // TODO: how about kMLCBool, kMLCFloat, kMLCNone?
15621646 } else {
1563- MLC_THROW (ValueError) << " Unexpected value: " << obj ;
1647+ MLC_THROW (ValueError) << " Unexpected value: " << value ;
15641648 }
15651649 }
15661650 return values->back ();
@@ -1617,16 +1701,18 @@ Any JSONLoads(AnyView json_str) {
16171701 }
16181702}
16191703
1620- Any JSONDeserialize (AnyView json_str) {
1704+ Any JSONDeserialize (AnyView json_str, FuncObj *fn_opaque_deserialize ) {
16211705 if (json_str.type_index == kMLCRawStr ) {
1622- return ::mlc::Deserialize (json_str.operator const char *(), -1 );
1706+ return ::mlc::Deserialize (json_str.operator const char *(), -1 , fn_opaque_deserialize );
16231707 } else {
16241708 StrObj *js = json_str.operator StrObj *();
1625- return ::mlc::Deserialize (js->data (), js->size ());
1709+ return ::mlc::Deserialize (js->data (), js->size (), fn_opaque_deserialize );
16261710 }
16271711}
16281712
1629- Str JSONSerialize (AnyView source) { return ::mlc::Serialize (source); }
1713+ Str JSONSerialize (AnyView source, FuncObj *fn_opaque_serialize) {
1714+ return ::mlc::Serialize (source, fn_opaque_serialize);
1715+ }
16301716
16311717Str TensorToBytes (const TensorObj *src) {
16321718 return ::mlc::TensorToBytes (&src->tensor ); //
0 commit comments