10
10
#include < ostream>
11
11
#include < sstream>
12
12
#include < stdexcept>
13
+ #include < string>
13
14
#include < unordered_map>
14
15
15
16
namespace mlc {
@@ -1300,8 +1301,11 @@ Tensor TensorFromBytes(const uint8_t *data_ptr, int64_t max_size) {
1300
1301
1301
1302
/* ***************** Serialize / Deserialize ******************/
1302
1303
1303
- inline mlc::Str Serialize (Any any) {
1304
+ inline mlc::Str Serialize (Any any, FuncObj *fn_opaque_serialize ) {
1304
1305
using mlc::base::TypeTraits;
1306
+ // Section 1. Define two lookups
1307
+ // 1) `type_keys` and `get_json_type_index`, which maps a `type_key` to type index/key to that in JSON
1308
+ // 2) `opaques` and `get_opaque_index`, which maps an `OpaqueObj` to its index in the list of opaques
1305
1309
std::vector<const char *> type_keys;
1306
1310
auto get_json_type_index = [type_key2index = std::unordered_map<const char *, int32_t >(),
1307
1311
&type_keys](const char *type_key) mutable -> int32_t {
@@ -1313,8 +1317,27 @@ inline mlc::Str Serialize(Any any) {
1313
1317
type_keys.push_back (type_key);
1314
1318
return type_index;
1315
1319
};
1316
- using TObj2Idx = std::unordered_map<Object *, int32_t >;
1317
- using TJsonTypeIndex = decltype (get_json_type_index);
1320
+ using TGetJSONTypeIndex = decltype (get_json_type_index);
1321
+ struct OpaqueHash {
1322
+ size_t operator ()(const OpaqueObj *opaque) const { return std::hash<void *>{}(opaque->handle ); }
1323
+ };
1324
+ struct OpaqueEq {
1325
+ bool operator ()(const OpaqueObj *lhs, const OpaqueObj *rhs) const { return lhs->handle == rhs->handle ; }
1326
+ };
1327
+ UList opaques;
1328
+ auto get_opaque_index = [opaque2index = std::unordered_map<const OpaqueObj *, int32_t , OpaqueHash, OpaqueEq>(),
1329
+ &opaques](const OpaqueObj *opaque) mutable -> int32_t {
1330
+ if (auto it = opaque2index.find (opaque); it != opaque2index.end ()) {
1331
+ return it->second ;
1332
+ }
1333
+ int32_t type_index = static_cast <int32_t >(opaque2index.size ());
1334
+ opaque2index[opaque] = type_index;
1335
+ opaques.push_back (opaque);
1336
+ return type_index;
1337
+ };
1338
+ // Section 2. Define `Emitter`, which emits a singleton type of:
1339
+ // - POD: bool, int, float, string, DLDataType, DLDevice, void*
1340
+ // - Object that has been known previously
1318
1341
struct Emitter {
1319
1342
MLC_INLINE void operator ()(MLCTypeField *, const Any *any) { EmitAny (any); }
1320
1343
// clang-format off
@@ -1381,57 +1404,77 @@ inline mlc::Str Serialize(Any any) {
1381
1404
if (!obj) {
1382
1405
MLC_THROW (InternalError) << " This should never happen: null object pointer during EmitObject" ;
1383
1406
}
1384
- int32_t obj_idx = obj2index ->at (obj);
1407
+ int32_t obj_idx = topo_index ->at (obj);
1385
1408
if (obj_idx == -1 ) {
1386
1409
MLC_THROW (InternalError) << " This should never happen: topological ordering violated" ;
1387
1410
}
1388
1411
(*os) << " , " << obj_idx;
1389
1412
}
1390
1413
std::ostringstream *os;
1391
- TJsonTypeIndex *get_json_type_index;
1392
- const TObj2Idx *obj2index ;
1414
+ TGetJSONTypeIndex *get_json_type_index;
1415
+ const std::unordered_map<Object *, int32_t > *topo_index ;
1393
1416
};
1394
-
1395
- std::unordered_map<Object *, int32_t > topo_indices;
1417
+ // Section 3. Define `on_visit` method for topological traversal of the graph
1418
+ // Inside `values` section, every `on_visit` generates one of the following:
1419
+ // 1) string `s`: for string literal `s`
1420
+ // 2) list -> normal case, its layout is:
1421
+ // - [0] = json_type_index
1422
+ // - [1...] = each field of the type
1423
+ // * int: refer to `values[i]`
1424
+ // * list: TODO: explain
1425
+ // * str / bool / float / None: literals
1396
1426
std::vector<TensorObj *> tensors;
1397
1427
std::ostringstream os;
1398
- auto on_visit = [&topo_indices, get_json_type_index = &get_json_type_index, os = &os, &tensors,
1399
- is_first_object = true ](Object *object, MLCTypeInfo *type_info) mutable -> void {
1428
+ auto on_visit =
1429
+ [get_json_type_index = &get_json_type_index, os = &os, &tensors, &get_opaque_index, is_first_object = true ,
1430
+ topo_indices = std::unordered_map<Object *, int32_t >()](Object *object, MLCTypeInfo *type_info) mutable -> void {
1431
+ // Step 1. Allocate `topo_index` assigned to the current object
1400
1432
int32_t &topo_index = topo_indices[object];
1401
1433
if (topo_index == 0 ) {
1402
1434
topo_index = static_cast <int32_t >(topo_indices.size ()) - 1 ;
1403
1435
} else {
1404
1436
MLC_THROW (InternalError) << " This should never happen: object already visited" ;
1405
1437
}
1406
- Emitter emitter{os, get_json_type_index, &topo_indices};
1407
1438
if (is_first_object) {
1408
1439
is_first_object = false ;
1409
1440
} else {
1410
1441
os->put (' ,' );
1411
1442
}
1443
+ // Step 2. Print the current object
1444
+ // Special case: string
1412
1445
if (StrObj *str = object->as <StrObj>()) {
1413
1446
str->PrintEscape (*os);
1414
1447
return ;
1415
1448
}
1449
+ // [0] = json_type_index
1416
1450
(*os) << ' [' << (*get_json_type_index)(type_info->type_key );
1451
+ // [1...] = each field of the type. A few possible cases:
1452
+ // 1) list
1453
+ // 2) dict
1454
+ // 3) tensor
1455
+ // 4) opaque
1456
+ // 5) a normal dataclass
1417
1457
if (UListObj *list = object->as <UListObj>()) {
1458
+ Emitter emitter{os, get_json_type_index, &topo_indices};
1418
1459
for (Any &any : *list) {
1419
- emitter ( nullptr , &any);
1460
+ emitter. EmitAny ( &any);
1420
1461
}
1421
1462
} else if (UDictObj *dict = object->as <UDictObj>()) {
1463
+ Emitter emitter{os, get_json_type_index, &topo_indices};
1422
1464
for (auto &kv : *dict) {
1423
- emitter ( nullptr , &kv.first );
1424
- emitter ( nullptr , &kv.second );
1465
+ emitter. EmitAny ( &kv.first );
1466
+ emitter. EmitAny ( &kv.second );
1425
1467
}
1426
1468
} else if (TensorObj *tensor = object->as <TensorObj>()) {
1427
1469
(*os) << " , " << tensors.size ();
1428
1470
tensors.push_back (tensor);
1471
+ } else if (OpaqueObj *opaque = object->as <OpaqueObj>()) {
1472
+ int32_t opaque_index = get_opaque_index (opaque);
1473
+ (*os) << " , " << opaque_index;
1429
1474
} else if (object->IsInstance <FuncObj>() || object->IsInstance <ErrorObj>()) {
1430
1475
MLC_THROW (TypeError) << " Unserializable type: " << object->GetTypeKey ();
1431
- } else if (object->IsInstance <OpaqueObj>()) {
1432
- MLC_THROW (TypeError) << " Cannot serialize `mlc.Opaque` of type: "
1433
- << object->DynCast <OpaqueObj>()->opaque_type_name ;
1434
1476
} else {
1477
+ Emitter emitter{os, get_json_type_index, &topo_indices};
1435
1478
VisitFields (object, type_info, emitter);
1436
1479
}
1437
1480
os->put (' ]' );
@@ -1481,12 +1524,25 @@ inline mlc::Str Serialize(Any any) {
1481
1524
}
1482
1525
os << " ]" ;
1483
1526
}
1527
+ if (!opaques.empty ()) {
1528
+ os << " , \" opaques\" :" ;
1529
+ if (!fn_opaque_serialize) {
1530
+ fn_opaque_serialize = Func::GetGlobal (" mlc.Opaque.default.serialize" , true );
1531
+ }
1532
+ if (!fn_opaque_serialize) {
1533
+ MLC_THROW (ValueError) << " Cannot find serialization function `mlc.Opaque.default.serialize`. Register it with "
1534
+ " `mlc.Func.register(\" mlc.Opaque.default.serialize\" )(serialize_func)`" ;
1535
+ }
1536
+ Str opaque_repr = (*fn_opaque_serialize)(opaques);
1537
+ opaque_repr->PrintEscape (os);
1538
+ }
1484
1539
os << " }" ;
1485
1540
return os.str ();
1486
1541
}
1487
1542
1488
- inline Any Deserialize (const char *json_str, int64_t json_str_len) {
1543
+ inline Any Deserialize (const char *json_str, int64_t json_str_len, FuncObj *fn_opaque_deserialize ) {
1489
1544
int32_t json_type_index_tensor = -1 ;
1545
+ int32_t json_type_index_opaque = -1 ;
1490
1546
// Step 0. Parse JSON string
1491
1547
UDict json_obj = JSONLoads (json_str, json_str_len);
1492
1548
// Step 1. type_key => constructors
@@ -1496,10 +1552,12 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
1496
1552
for (Str type_key : type_keys) {
1497
1553
int32_t type_index = Lib::GetTypeIndex (type_key->data ());
1498
1554
FuncObj *func = nullptr ;
1499
- if (type_index != kMLCTensor ) {
1500
- func = Lib::_init (type_index);
1501
- } else {
1555
+ if (type_index == kMLCTensor ) {
1502
1556
json_type_index_tensor = static_cast <int32_t >(constructors.size ());
1557
+ } else if (type_index == kMLCOpaque ) {
1558
+ json_type_index_opaque = static_cast <int32_t >(constructors.size ());
1559
+ } else {
1560
+ func = Lib::_init (type_index);
1503
1561
}
1504
1562
constructors.push_back (func);
1505
1563
}
@@ -1522,45 +1580,71 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len) {
1522
1580
json_obj->erase (" tensors" );
1523
1581
std::reverse (tensors.begin (), tensors.end ());
1524
1582
}
1525
- // Step 3. Translate JSON object to objects
1583
+ // Step 3. Handle opaque objects
1584
+ UList opaques;
1585
+ if (json_obj.count (" opaques" )) {
1586
+ if (!fn_opaque_deserialize) {
1587
+ fn_opaque_deserialize = Func::GetGlobal (" mlc.Opaque.default.deserialize" , true );
1588
+ }
1589
+ if (!fn_opaque_deserialize) {
1590
+ MLC_THROW (ValueError)
1591
+ << " Cannot find deserialization function `mlc.Opaque.default.deserialize`. Register it with "
1592
+ " `mlc.Func.register(\" mlc.Opaque.default.deserialize\" )(deserialize_func)`" ;
1593
+ }
1594
+ opaques = (*fn_opaque_deserialize)(json_obj->at (" opaques" )).operator UList ();
1595
+ }
1596
+ // Step 4. Translate JSON object to objects
1526
1597
UList values = json_obj->at (" values" );
1527
1598
for (int64_t i = 0 ; i < values->size (); ++i) {
1528
- Any obj = values[i];
1529
- if (obj.type_index == kMLCList ) {
1530
- UList list = obj.operator UList ();
1531
- int32_t json_type_index = list[0 ];
1599
+ Any &value = values[i];
1600
+ // every `value` is
1601
+ // 1) integer `i` -> refer to `values[i]`
1602
+ // 2) string `s` -> string literal `s`
1603
+ // 3) list -> normal case, its layout is:
1604
+ // - [0] = json_type_index
1605
+ // - [1...] = each field of the type
1606
+ // * int: refer to `values[i]`
1607
+ // * list: TODO: explain
1608
+ // * str / bool / float / None: literals
1609
+ // TODO: how about kMLCBool, kMLCFloat, kMLCNone?
1610
+ if (UListObj *list = value.as <UListObj>()) {
1611
+ // Layout of the list:
1612
+ int32_t json_type_index = (*list)[0 ];
1532
1613
if (json_type_index == json_type_index_tensor) {
1533
- values[i] = tensors[list[1 ].operator int32_t ()];
1534
- continue ;
1535
- }
1536
- for (int64_t j = 1 ; j < list.size (); ++j) {
1537
- Any arg = list[j];
1538
- if (arg.type_index == kMLCInt ) {
1539
- int64_t k = arg;
1540
- if (k < i) {
1541
- list[j] = values[k];
1614
+ int32_t idx = (*list)[1 ];
1615
+ value = tensors[idx];
1616
+ } else if (json_type_index == json_type_index_opaque) {
1617
+ int32_t idx = (*list)[1 ];
1618
+ value = opaques[idx];
1619
+ } else {
1620
+ for (int64_t j = 1 ; j < list->size (); ++j) {
1621
+ Any arg = (*list)[j];
1622
+ if (arg.type_index == kMLCInt ) {
1623
+ int64_t k = arg;
1624
+ if (k < i) {
1625
+ (*list)[j] = values[k];
1626
+ } else {
1627
+ MLC_THROW (ValueError) << " Invalid reference when parsing type `" << type_keys[json_type_index]
1628
+ << " `: referring #" << k << " at #" << i << " . v = " << value;
1629
+ }
1630
+ } else if (arg.type_index == kMLCList ) {
1631
+ (*list)[j] = invoke_init (arg.operator UList ());
1632
+ } else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
1633
+ arg.type_index == kMLCNone ) {
1634
+ // Do nothing
1542
1635
} else {
1543
- MLC_THROW (ValueError) << " Invalid reference when parsing type `" << type_keys[json_type_index]
1544
- << " `: referring #" << k << " at #" << i << " . v = " << obj;
1636
+ MLC_THROW (ValueError) << " Unexpected value: " << arg;
1545
1637
}
1546
- } else if (arg.type_index == kMLCList ) {
1547
- list[j] = invoke_init (arg.operator UList ());
1548
- } else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
1549
- arg.type_index == kMLCNone ) {
1550
- // Do nothing
1551
- } else {
1552
- MLC_THROW (ValueError) << " Unexpected value: " << arg;
1553
1638
}
1639
+ value = invoke_init (UList (list));
1554
1640
}
1555
- values[i] = invoke_init (list);
1556
- } else if (obj.type_index == kMLCInt ) {
1557
- int32_t k = obj;
1558
- values[i] = values[k];
1559
- } else if (obj.type_index == kMLCStr ) {
1641
+ } else if (value.type_index == kMLCInt ) {
1642
+ int32_t k = value;
1643
+ value = values[k];
1644
+ } else if (value.type_index == kMLCStr ) {
1560
1645
// Do nothing
1561
- // TODO: how about kMLCBool, kMLCFloat, kMLCNone?
1562
1646
} else {
1563
- MLC_THROW (ValueError) << " Unexpected value: " << obj ;
1647
+ MLC_THROW (ValueError) << " Unexpected value: " << value ;
1564
1648
}
1565
1649
}
1566
1650
return values->back ();
@@ -1617,16 +1701,18 @@ Any JSONLoads(AnyView json_str) {
1617
1701
}
1618
1702
}
1619
1703
1620
- Any JSONDeserialize (AnyView json_str) {
1704
+ Any JSONDeserialize (AnyView json_str, FuncObj *fn_opaque_deserialize ) {
1621
1705
if (json_str.type_index == kMLCRawStr ) {
1622
- return ::mlc::Deserialize (json_str.operator const char *(), -1 );
1706
+ return ::mlc::Deserialize (json_str.operator const char *(), -1 , fn_opaque_deserialize );
1623
1707
} else {
1624
1708
StrObj *js = json_str.operator StrObj *();
1625
- return ::mlc::Deserialize (js->data (), js->size ());
1709
+ return ::mlc::Deserialize (js->data (), js->size (), fn_opaque_deserialize );
1626
1710
}
1627
1711
}
1628
1712
1629
- Str JSONSerialize (AnyView source) { return ::mlc::Serialize (source); }
1713
+ Str JSONSerialize (AnyView source, FuncObj *fn_opaque_serialize) {
1714
+ return ::mlc::Serialize (source, fn_opaque_serialize);
1715
+ }
1630
1716
1631
1717
Str TensorToBytes (const TensorObj *src) {
1632
1718
return ::mlc::TensorToBytes (&src->tensor ); //
0 commit comments