Skip to content

Commit cd145d9

Browse files
authored
fix(core): always check error code from libmlc (#21)
This PR introduces a macro `MLC_CHECK_ERR`, which is always used when accessing MLC's C APIs to make sure errors are always caught in time.
1 parent 8d20121 commit cd145d9

File tree

9 files changed

+133
-46
lines changed

9 files changed

+133
-46
lines changed

cpp/c_api.cc

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ thread_local Any last_error;
5454

5555
MLC_API MLCAny MLCGetLastError() {
5656
MLCAny ret;
57-
*static_cast<Any *>(&ret) = std::move(last_error);
57+
static_cast<Any &>(ret) = std::move(last_error);
5858
return ret;
5959
}
6060

@@ -99,23 +99,44 @@ MLC_API int32_t MLCTypeAddMethod(MLCTypeTableHandle _self, int32_t type_index, M
9999
MLC_SAFE_CALL_END(&last_error);
100100
}
101101

102+
MLC_API int32_t MLCVTableCreate(MLCTypeTableHandle _self, const char *key, MLCVTableHandle *ret) {
103+
MLC_SAFE_CALL_BEGIN();
104+
*ret = new mlc::registry::MLCVTable(TypeTable::Get(_self), key);
105+
MLC_SAFE_CALL_END(&last_error);
106+
}
107+
108+
MLC_API int32_t MLCVTableDelete(MLCVTableHandle self) {
109+
MLC_SAFE_CALL_BEGIN();
110+
if (self) {
111+
delete static_cast<mlc::registry::MLCVTable *>(self);
112+
}
113+
MLC_SAFE_CALL_END(&last_error);
114+
}
115+
102116
MLC_API int32_t MLCVTableGetGlobal(MLCTypeTableHandle _self, const char *key, MLCVTableHandle *ret) {
103117
MLC_SAFE_CALL_BEGIN();
104118
*ret = TypeTable::Get(_self)->GetGlobalVTable(key);
105119
MLC_SAFE_CALL_END(&last_error);
106120
}
107121

108122
MLC_API int32_t MLCVTableGetFunc(MLCVTableHandle vtable, int32_t type_index, int32_t allow_ancestor, MLCAny *ret) {
109-
using ::mlc::registry::VTable;
123+
using ::mlc::registry::MLCVTable;
110124
MLC_SAFE_CALL_BEGIN();
111-
*static_cast<Any *>(ret) = static_cast<VTable *>(vtable)->GetFunc(type_index, allow_ancestor);
125+
*static_cast<Any *>(ret) = static_cast<MLCVTable *>(vtable)->GetFunc(type_index, allow_ancestor);
112126
MLC_SAFE_CALL_END(&last_error);
113127
}
114128

115129
MLC_API int32_t MLCVTableSetFunc(MLCVTableHandle vtable, int32_t type_index, MLCFunc *func, int32_t override_mode) {
116-
using ::mlc::registry::VTable;
130+
using ::mlc::registry::MLCVTable;
131+
MLC_SAFE_CALL_BEGIN();
132+
static_cast<MLCVTable *>(vtable)->Set(type_index, static_cast<FuncObj *>(func), override_mode);
133+
MLC_SAFE_CALL_END(&last_error);
134+
}
135+
136+
MLC_API int32_t MLCVTableCall(MLCVTableHandle vtable, int32_t num_args, MLCAny *args, MLCAny *ret) {
137+
using ::mlc::registry::MLCVTable;
117138
MLC_SAFE_CALL_BEGIN();
118-
static_cast<VTable *>(vtable)->Set(type_index, static_cast<FuncObj *>(func), override_mode);
139+
static_cast<MLCVTable *>(vtable)->Call(num_args, args, ret);
119140
MLC_SAFE_CALL_END(&last_error);
120141
}
121142

cpp/registry.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,11 @@ struct ResourcePool {
136136

137137
struct TypeTable;
138138

139-
struct VTable {
140-
explicit VTable(TypeTable *type_table, std::string name) : type_table(type_table), name(std::move(name)), data() {}
139+
struct MLCVTable {
140+
explicit MLCVTable(TypeTable *type_table, std::string name) : type_table(type_table), name(std::move(name)), data() {}
141141
void Set(int32_t type_index, FuncObj *func, int32_t override_mode);
142142
FuncObj *GetFunc(int32_t type_index, bool allow_ancestor) const;
143+
void Call(int32_t num_args, MLCAny *args, MLCAny *ret) const;
143144

144145
private:
145146
TypeTable *type_table;
@@ -269,7 +270,7 @@ struct TypeTable {
269270
std::vector<std::unique_ptr<TypeInfoWrapper>> type_table;
270271
std::unordered_map<std::string, MLCTypeInfo *> type_key_to_info;
271272
std::unordered_map<std::string, FuncObj *> global_funcs;
272-
std::unordered_map<std::string, std::unique_ptr<VTable>> global_vtables;
273+
std::unordered_map<std::string, std::unique_ptr<MLCVTable>> global_vtables;
273274
std::unordered_map<std::string, std::unique_ptr<DSOLibrary>> dso_libraries;
274275
std::unordered_map<std::string, DLDataType> dtype_presets{
275276
{"void", {kDLOpaqueHandle, 0, 0}},
@@ -414,18 +415,18 @@ struct TypeTable {
414415

415416
FuncObj *GetVTable(int32_t type_index, const char *attr_key, bool allow_ancestor) {
416417
if (auto it = this->global_vtables.find(attr_key); it != this->global_vtables.end()) {
417-
VTable *vtable = it->second.get();
418+
MLCVTable *vtable = it->second.get();
418419
return vtable->GetFunc(type_index, allow_ancestor);
419420
} else {
420421
return nullptr;
421422
}
422423
}
423424

424-
VTable *GetGlobalVTable(const char *name) {
425+
MLCVTable *GetGlobalVTable(const char *name) {
425426
if (auto it = this->global_vtables.find(name); it != this->global_vtables.end()) {
426427
return it->second.get();
427428
} else {
428-
std::unique_ptr<VTable> &vtable = this->global_vtables[name] = std::make_unique<VTable>(this, name);
429+
std::unique_ptr<MLCVTable> &vtable = this->global_vtables[name] = std::make_unique<MLCVTable>(this, name);
429430
return vtable.get();
430431
}
431432
}
@@ -713,7 +714,7 @@ inline TypeTable *TypeTable::New() {
713714
#undef MLC_TYPE_TABLE_INIT_TYPE_BEGIN
714715
#undef MLC_TYPE_TABLE_INIT_TYPE_END
715716

716-
inline void VTable::Set(int32_t type_index, FuncObj *func, int32_t override_mode) {
717+
inline void MLCVTable::Set(int32_t type_index, FuncObj *func, int32_t override_mode) {
717718
auto [it, success] = this->data.try_emplace(type_index, nullptr);
718719
if (!success) {
719720
if (override_mode == 0) {
@@ -740,7 +741,7 @@ inline void VTable::Set(int32_t type_index, FuncObj *func, int32_t override_mode
740741
this->type_table->pool.AddObj(func);
741742
}
742743

743-
inline FuncObj *VTable::GetFunc(int32_t type_index, bool allow_ancestor) const {
744+
inline FuncObj *MLCVTable::GetFunc(int32_t type_index, bool allow_ancestor) const {
744745
if (auto it = this->data.find(type_index); it != this->data.end()) {
745746
return it->second;
746747
}
@@ -757,6 +758,20 @@ inline FuncObj *VTable::GetFunc(int32_t type_index, bool allow_ancestor) const {
757758
return nullptr;
758759
}
759760

761+
inline void MLCVTable::Call(int32_t num_args, MLCAny *args, MLCAny *ret) const {
762+
constexpr bool allow_ancestor = false;
763+
if (num_args == 0) {
764+
MLC_THROW(ValueError) << "Calling a vtable requires at least one argument";
765+
}
766+
int32_t type_index = args[0].type_index;
767+
FuncObj *func = this->GetFunc(type_index, allow_ancestor);
768+
if (func == nullptr) {
769+
MLC_THROW(KeyError) << "VTable `" << name
770+
<< "` doesn't have type registered: " << this->type_table->GetTypeInfo(type_index)->type_key;
771+
}
772+
::mlc::base::FuncCall(func, num_args, args, ret);
773+
}
774+
760775
} // namespace registry
761776
} // namespace mlc
762777

include/mlc/base/lib.h

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@
55

66
namespace mlc {
77

8+
struct VTable {
9+
VTable(const VTable &) = delete;
10+
VTable &operator=(const VTable &) = delete;
11+
VTable(VTable &&other) noexcept : self(other.self) { other.self = nullptr; }
12+
VTable &operator=(VTable &&other) noexcept {
13+
this->Swap(other);
14+
return *this;
15+
}
16+
~VTable() { MLC_CHECK_ERR(::MLCVTableDelete(self), nullptr); }
17+
18+
template <typename R, typename... Args> R operator()(Args... args) const;
19+
template <typename Obj> VTable &Set(Func func);
20+
21+
private:
22+
friend struct Lib;
23+
VTable(MLCVTableHandle self) : self(self) {}
24+
void Swap(VTable &other) { std::swap(self, other.self); }
25+
MLCVTableHandle self;
26+
};
27+
828
struct Lib {
929
static int32_t FuncSetGlobal(const char *name, FuncObj *func, bool allow_override = false);
1030
static FuncObj *FuncGetGlobal(const char *name, bool allow_missing = false);
@@ -19,14 +39,19 @@ struct Lib {
1939
static void DataTypeRegister(const char *name, int32_t dtype_bits);
2040

2141
static FuncObj *_init(int32_t type_index) { return VTableGetFunc(init, type_index, "__init__"); }
42+
static VTable MakeVTable(const char *name) {
43+
MLCVTableHandle vtable = nullptr;
44+
MLC_CHECK_ERR(::MLCVTableCreate(_lib, name, &vtable), nullptr);
45+
return VTable(vtable);
46+
}
2247
MLC_INLINE static MLCTypeInfo *GetTypeInfo(int32_t type_index) {
23-
MLCTypeInfo *type_info;
24-
MLCTypeIndex2Info(_lib, type_index, &type_info);
48+
MLCTypeInfo *type_info = nullptr;
49+
MLC_CHECK_ERR(::MLCTypeIndex2Info(_lib, type_index, &type_info), nullptr);
2550
return type_info;
2651
}
2752
MLC_INLINE static MLCTypeInfo *GetTypeInfo(const char *type_key) {
28-
MLCTypeInfo *type_info;
29-
MLCTypeKey2Info(_lib, type_key, &type_info);
53+
MLCTypeInfo *type_info = nullptr;
54+
MLC_CHECK_ERR(::MLCTypeKey2Info(_lib, type_key, &type_info), nullptr);
3055
return type_info;
3156
}
3257
MLC_INLINE static const char *GetTypeKey(int32_t type_index) {
@@ -52,14 +77,14 @@ struct Lib {
5277
}
5378
MLC_INLINE static MLCTypeInfo *TypeRegister(int32_t parent_type_index, int32_t type_index, const char *type_key) {
5479
MLCTypeInfo *info = nullptr;
55-
MLCTypeRegister(_lib, parent_type_index, type_key, type_index, &info);
80+
MLC_CHECK_ERR(::MLCTypeRegister(_lib, parent_type_index, type_key, type_index, &info), nullptr);
5681
return info;
5782
}
5883

5984
private:
6085
static FuncObj *VTableGetFunc(MLCVTableHandle vtable, int32_t type_index, const char *vtable_name) {
6186
MLCAny func{};
62-
MLCVTableGetFunc(vtable, type_index, true, &func);
87+
MLC_CHECK_ERR(::MLCVTableGetFunc(vtable, type_index, true, &func), &func);
6388
if (!::mlc::base::IsTypeIndexPOD(func.type_index)) {
6489
::mlc::base::DecRef(func.v.v_obj);
6590
}
@@ -74,13 +99,13 @@ struct Lib {
7499
return ret;
75100
}
76101
static MLCVTableHandle VTableGetGlobal(const char *name) {
77-
MLCVTableHandle ret;
78-
MLCVTableGetGlobal(_lib, name, &ret);
102+
MLCVTableHandle ret = nullptr;
103+
MLC_CHECK_ERR(::MLCVTableGetGlobal(_lib, name, &ret), nullptr);
79104
return ret;
80105
}
81106
static MLC_SYMBOL_HIDE inline MLCTypeTableHandle _lib = []() {
82107
MLCTypeTableHandle ret = nullptr;
83-
::MLCHandleGetGlobal(&ret);
108+
MLC_CHECK_ERR(::MLCHandleGetGlobal(&ret), nullptr);
84109
return ret;
85110
}();
86111
static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__");

include/mlc/base/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@
7777
} \
7878
MLC_UNREACHABLE()
7979

80+
#define MLC_CHECK_ERR(Call, Ret) \
81+
if (int32_t err_code = (Call)) { \
82+
::mlc::base::FuncCallCheckError(err_code, (Ret)); \
83+
}
84+
8085
namespace mlc {
8186
namespace base {
8287

@@ -108,6 +113,7 @@ struct ErrorBuilder {
108113

109114
StrObj *StrCopyFromCharArray(const char *source, size_t length);
110115
void FuncCall(const void *func, int32_t num_args, const MLCAny *args, MLCAny *ret);
116+
void FuncCallCheckError(int32_t err_code, MLCAny *ret) noexcept(false);
111117
template <typename Callable> Any CallableToAny(Callable &&callable);
112118
template <typename DerivedType, typename SelfType = Object> bool IsInstanceOf(const MLCAny *self);
113119

include/mlc/c_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ MLC_API int32_t MLCTypeRegisterStructure(MLCTypeTableHandle self, int32_t type_i
284284
int64_t num_sub_structures, int32_t *sub_structure_indices,
285285
int32_t *sub_structure_kinds);
286286
MLC_API int32_t MLCTypeAddMethod(MLCTypeTableHandle self, int32_t type_index, MLCTypeMethod method);
287+
MLC_API int32_t MLCVTableCreate(MLCTypeTableHandle self, const char *key, MLCVTableHandle *ret);
288+
MLC_API int32_t MLCVTableDelete(MLCVTableHandle self);
289+
MLC_API int32_t MLCVTableCall(MLCVTableHandle vtable, int32_t num_args, MLCAny *args, MLCAny *ret);
287290
MLC_API int32_t MLCVTableGetGlobal(MLCTypeTableHandle self, const char *key, MLCVTableHandle *ret);
288291
MLC_API int32_t MLCVTableGetFunc(MLCVTableHandle vtable, int32_t type_index, int32_t allow_ancestor, MLCAny *ret);
289292
MLC_API int32_t MLCVTableSetFunc(MLCVTableHandle vtable, int32_t type_index, MLCFunc *func, int32_t override_mode);

include/mlc/core/all.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,12 @@ inline Any Lib::IRPrint(AnyView obj, AnyView printer, AnyView path) {
147147
return ret;
148148
}
149149
inline int32_t Lib::FuncSetGlobal(const char *name, FuncObj *func, bool allow_override) {
150-
::MLCFuncSetGlobal(_lib, name, Any(func), allow_override);
150+
MLC_CHECK_ERR(::MLCFuncSetGlobal(_lib, name, Any(func), allow_override), nullptr);
151151
return 0;
152152
}
153153
inline FuncObj *Lib::FuncGetGlobal(const char *name, bool allow_missing) {
154154
Any ret;
155-
::MLCFuncGetGlobal(_lib, name, &ret);
155+
MLC_CHECK_ERR(::MLCFuncGetGlobal(_lib, name, &ret), &ret);
156156
if (!ret.defined() && !allow_missing) {
157157
MLC_THROW(KeyError) << "Missing global function: " << name;
158158
}
@@ -200,6 +200,19 @@ inline void Lib::DataTypeRegister(const char *name, int32_t dtype_bits) {
200200
Any ret;
201201
::mlc::base::FuncCall(func_dtype_register, 2, arg, &ret);
202202
}
203+
template <typename R, typename... Args> inline R VTable::operator()(Args... args) const {
204+
constexpr size_t N = sizeof...(Args);
205+
AnyViewArray<N> stack_args;
206+
Any ret;
207+
stack_args.Fill(std::forward<Args>(args)...);
208+
MLC_CHECK_ERR(::MLCVTableCall(self, N, stack_args.v, &ret), &ret);
209+
}
210+
template <typename Obj> inline VTable &VTable::Set(Func func) {
211+
constexpr bool override_mode = false;
212+
int32_t type_index = Obj::_type_index;
213+
MLC_CHECK_ERR(::MLCVTableSetFunc(this->self, type_index, func.get(), override_mode), nullptr);
214+
return *this;
215+
}
203216

204217
} // namespace mlc
205218

include/mlc/core/func.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,33 @@ struct FuncRegistryHelper {
7979
const char *name;
8080
};
8181

82-
MLC_INLINE void HandleSafeCallError(int32_t err_code, MLCAny *ret) noexcept(false) {
82+
} // namespace core
83+
} // namespace mlc
84+
85+
namespace mlc {
86+
namespace base {
87+
inline void FuncCallCheckError(int32_t err_code, MLCAny *ret) noexcept(false) {
88+
Any err;
89+
if (ret != nullptr) {
90+
err = static_cast<Any &&>(*ret);
91+
} else {
92+
static_cast<MLCAny &>(err) = ::MLCGetLastError();
93+
}
8394
if (err_code == -1) { // string errors
84-
MLC_THROW(InternalError) << "Error: " << *static_cast<Any *>(ret);
95+
MLC_THROW(InternalError) << "Error: " << err;
8596
} else if (err_code == -2) { // error objects
86-
throw Exception(static_cast<Any *>(ret)->operator Ref<ErrorObj>()->AppendWith(MLC_TRACEBACK_HERE()));
97+
throw Exception(err.operator Ref<ErrorObj>()->AppendWith(MLC_TRACEBACK_HERE()));
8798
} else { // error code
8899
MLC_THROW(InternalError) << "Error code: " << err_code;
89100
}
90101
MLC_UNREACHABLE();
91102
}
92-
} // namespace core
93-
} // namespace mlc
94-
95-
namespace mlc {
96-
namespace base {
97-
MLC_INLINE void FuncCall(const void *self, int32_t num_args, const MLCAny *args, MLCAny *ret) {
103+
inline void FuncCall(const void *self, int32_t num_args, const MLCAny *args, MLCAny *ret) {
98104
const MLCFunc *func = static_cast<const MLCFunc *>(self);
99105
if (func->call && reinterpret_cast<void *>(func->safe_call) == reinterpret_cast<void *>(FuncObj::SafeCallImpl)) {
100106
func->call(func, num_args, args, ret);
101-
} else if (int32_t err_code = func->safe_call(func, num_args, args, ret)) {
102-
::mlc::core::HandleSafeCallError(err_code, ret);
107+
} else {
108+
MLC_CHECK_ERR(func->safe_call(func, num_args, args, ret), ret);
103109
}
104110
}
105111
template <int32_t num_args> inline auto GetGlobalFuncCall(const char *name) {

include/mlc/core/func_details.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,12 @@ template <typename FuncType, typename> MLC_INLINE FuncObj *FuncObj::Allocator::N
190190
inline Ref<FuncObj> FuncObj::FromForeign(void *self, MLCDeleterType deleter, MLCFuncSafeCallType safe_call) {
191191
if (deleter == nullptr) {
192192
return Ref<FuncObj>::New([self, safe_call](int32_t num_args, const MLCAny *args, MLCAny *ret) {
193-
if (int32_t err_code = safe_call(self, num_args, args, ret); err_code != 0) {
194-
::mlc::core::HandleSafeCallError(err_code, ret);
195-
}
193+
MLC_CHECK_ERR(safe_call(self, num_args, args, ret), ret);
196194
});
197195
} else {
198196
return Ref<FuncObj>::New(
199197
[self = std::shared_ptr<void>(self, deleter), safe_call](int32_t num_args, const MLCAny *args, MLCAny *ret) {
200-
if (int32_t err_code = safe_call(self.get(), num_args, args, ret); err_code != 0) {
201-
::mlc::core::HandleSafeCallError(err_code, ret);
202-
}
198+
MLC_CHECK_ERR(safe_call(self.get(), num_args, args, ret), ret);
203199
});
204200
}
205201
}

include/mlc/core/reflection.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,14 @@ struct _Reflect {
100100
reinterpret_cast<MLCFunc *>(func_any_to_ref.v.v_obj), //
101101
kStaticFn});
102102
}
103-
MLCTypeRegisterFields(nullptr, this->type_index, this->fields.size(), this->fields.data());
104-
MLCTypeRegisterStructure(nullptr, this->type_index, static_cast<int32_t>(this->structure_kind),
105-
this->sub_structure_indices.size(), this->sub_structure_indices.data(),
106-
this->sub_structure_kinds.data());
103+
MLC_CHECK_ERR(::MLCTypeRegisterFields(nullptr, this->type_index, this->fields.size(), this->fields.data()),
104+
nullptr);
105+
MLC_CHECK_ERR(::MLCTypeRegisterStructure(nullptr, this->type_index, static_cast<int32_t>(this->structure_kind),
106+
this->sub_structure_indices.size(), this->sub_structure_indices.data(),
107+
this->sub_structure_kinds.data()),
108+
nullptr);
107109
for (const MLCTypeMethod &method : this->methods) {
108-
MLCTypeAddMethod(nullptr, this->type_index, method);
110+
MLC_CHECK_ERR(::MLCTypeAddMethod(nullptr, this->type_index, method), nullptr);
109111
}
110112
}
111113
return 0;

0 commit comments

Comments
 (0)