@@ -952,7 +952,9 @@ inline Any CopyShallowImpl(AnyView source) {
952
952
return UList (list->begin (), list->end ());
953
953
} else if (UDictObj *dict = source.TryCast <UDictObj>()) {
954
954
return UDict (dict->begin (), dict->end ());
955
- } else if (source.IsInstance <StrObj>() || source.IsInstance <ErrorObj>() || source.IsInstance <FuncObj>()) {
955
+ } else if (source.IsInstance <StrObj>() || source.IsInstance <ErrorObj>() || source.IsInstance <FuncObj>() ||
956
+ source.IsInstance <TensorObj>()) {
957
+ // TODO: do we want to shallow copy these types at all?
956
958
return source;
957
959
}
958
960
struct Copier {
@@ -987,6 +989,62 @@ inline Any CopyShallowImpl(AnyView source) {
987
989
return ret;
988
990
}
989
991
992
+ inline void CopyReplaceImpl (int32_t num_args, const AnyView *args, Any *ret) {
993
+ if (num_args <= 0 ) {
994
+ MLC_THROW (InternalError) << " InternalError: `CopyReplace` requires at least one argument" ;
995
+ }
996
+ AnyView source = args[0 ];
997
+ int32_t type_index = source.type_index ;
998
+ if (::mlc::base::IsTypeIndexPOD (type_index)) {
999
+ MLC_THROW (TypeError) << " TypeError: `__replace__` doesn't work on a POD type: " << source;
1000
+ } else if (source.IsInstance <StrObj>() || source.IsInstance <ErrorObj>() || source.IsInstance <FuncObj>() ||
1001
+ source.IsInstance <UListObj>() || source.IsInstance <UDictObj>() || source.IsInstance <TensorObj>()) {
1002
+ MLC_THROW (TypeError) << " TypeError: `__replace__` doesn't work on type: " << source.GetTypeKey ();
1003
+ }
1004
+ struct Copier {
1005
+ MLC_INLINE void operator ()(MLCTypeField *f, const Any *any) { AddField (f->name , AnyView (*any)); }
1006
+ MLC_INLINE void operator ()(MLCTypeField *f, ObjectRef *obj) { AddField (f->name , AnyView (*obj)); }
1007
+ MLC_INLINE void operator ()(MLCTypeField *f, Optional<ObjectRef> *opt) { AddField (f->name , AnyView (*opt)); }
1008
+ MLC_INLINE void operator ()(MLCTypeField *f, Optional<bool > *opt) { AddField (f->name , AnyView (*opt)); }
1009
+ MLC_INLINE void operator ()(MLCTypeField *f, Optional<int64_t > *opt) { AddField (f->name , AnyView (*opt)); }
1010
+ MLC_INLINE void operator ()(MLCTypeField *f, Optional<double > *opt) { AddField (f->name , AnyView (*opt)); }
1011
+ MLC_INLINE void operator ()(MLCTypeField *f, Optional<DLDevice> *opt) { AddField (f->name , AnyView (*opt)); }
1012
+ MLC_INLINE void operator ()(MLCTypeField *f, Optional<DLDataType> *opt) { AddField (f->name , AnyView (*opt)); }
1013
+ MLC_INLINE void operator ()(MLCTypeField *f, bool *v) { AddField (f->name , AnyView (*v)); }
1014
+ MLC_INLINE void operator ()(MLCTypeField *f, int8_t *v) { AddField (f->name , AnyView (*v)); }
1015
+ MLC_INLINE void operator ()(MLCTypeField *f, int16_t *v) { AddField (f->name , AnyView (*v)); }
1016
+ MLC_INLINE void operator ()(MLCTypeField *f, int32_t *v) { AddField (f->name , AnyView (*v)); }
1017
+ MLC_INLINE void operator ()(MLCTypeField *f, int64_t *v) { AddField (f->name , AnyView (*v)); }
1018
+ MLC_INLINE void operator ()(MLCTypeField *f, float *v) { AddField (f->name , AnyView (*v)); }
1019
+ MLC_INLINE void operator ()(MLCTypeField *f, double *v) { AddField (f->name , AnyView (*v)); }
1020
+ MLC_INLINE void operator ()(MLCTypeField *f, DLDataType *v) { AddField (f->name , AnyView (*v)); }
1021
+ MLC_INLINE void operator ()(MLCTypeField *f, DLDevice *v) { AddField (f->name , AnyView (*v)); }
1022
+ MLC_INLINE void operator ()(MLCTypeField *f, Optional<void *> *v) { AddField (f->name , AnyView (*v)); }
1023
+ MLC_INLINE void operator ()(MLCTypeField *f, void **v) { AddField (f->name , AnyView (*v)); }
1024
+ MLC_INLINE void operator ()(MLCTypeField *f, const char **v) { AddField (f->name , AnyView (*v)); }
1025
+
1026
+ void AddField (std::string_view name, AnyView v) {
1027
+ if (auto it = replacements->find (name); it != replacements->end ()) {
1028
+ fields->push_back (it->second );
1029
+ } else {
1030
+ fields->push_back (v);
1031
+ }
1032
+ }
1033
+ std::vector<AnyView> *fields;
1034
+ std::unordered_map<std::string_view, AnyView> *replacements;
1035
+ };
1036
+ std::unordered_map<std::string_view, AnyView> replacements;
1037
+ for (int32_t i = 1 ; i < num_args; i += 2 ) {
1038
+ const char *name = args[i];
1039
+ replacements[name] = args[i + 1 ];
1040
+ }
1041
+ FuncObj *init_func = Lib::_init (type_index);
1042
+ MLCTypeInfo *type_info = Lib::GetTypeInfo (type_index);
1043
+ std::vector<AnyView> fields;
1044
+ VisitFields (source.operator Object *(), type_info, Copier{&fields, &replacements});
1045
+ ::mlc::base::FuncCall (init_func, static_cast <int32_t >(fields.size()), fields.data(), ret);
1046
+ }
1047
+
990
1048
inline Any CopyDeepImpl (AnyView source) {
991
1049
if (::mlc::base::IsTypeIndexPOD (source.type_index )) {
992
1050
return source;
@@ -1508,6 +1566,7 @@ int64_t StructuralHash(AnyView root) {
1508
1566
1509
1567
Any CopyShallow (AnyView source) { return CopyShallowImpl (source); }
1510
1568
Any CopyDeep (AnyView source) { return CopyDeepImpl (source); }
1569
+ void CopyReplace (int32_t num_args, const AnyView *args, Any *ret) { CopyReplaceImpl (num_args, args, ret); }
1511
1570
1512
1571
Any JSONLoads (AnyView json_str) {
1513
1572
if (json_str.type_index == kMLCRawStr ) {
0 commit comments