Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1358,15 +1358,17 @@ TVM_FFI_DLL const TVMFFIByteArray* TVMFFIBacktrace(const char* filename, int lin
* If the static_tindex is non-negative, the function will
* allocate a runtime type index.
* Otherwise, we will populate the type table and return the static index.
* If parent_type_index is -2, the function queries the existing type index:
* it returns the registered index for type_key, or -2 if type_key is not registered.
*
* \param type_key The type key.
* \param type_depth The type depth.
* \param static_type_index Static type index if any, can be -1, which means this is a dynamic index
* \param num_child_slots Number of slots reserved for its children.
* \param child_slots_can_overflow Whether to allow child to overflow the slots.
* \param parent_type_index Parent type index, pass in -1 if it is root.
* \param parent_type_index Parent type index, pass in -1 if it is root, or -2 to query only.
*
* \return The allocated type index.
* \return The existing or allocated type index; -2 when query-only mode misses.
*/
TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key,
int32_t static_type_index, int32_t type_depth,
Expand Down
3 changes: 3 additions & 0 deletions src/ffi/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class TypeTable {
if (it != type_key2index_.end()) {
return type_table_[(*it).second]->type_index;
}
if (parent_type_index == -2) {
return -2;
}

// get parent's entry
Entry* parent = [&]() -> Entry* {
Expand Down
15 changes: 15 additions & 0 deletions tests/cpp/test_object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ TEST(Object, CRTPObjectInfo) {
EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin);
}

TEST(Object, TypeGetOrAllocIndexQueryRegistered) {
TVMFFIByteArray type_key{TIntObj::_type_key, std::char_traits<char>::length(TIntObj::_type_key)};
EXPECT_EQ(TVMFFITypeGetOrAllocIndex(&type_key, -1, 0, 0, 0, -2), TIntObj::RuntimeTypeIndex());
}

TEST(Object, TypeGetOrAllocIndexQueryMissDoesNotRegister) {
const char* type_key_data = "test.TypeGetOrAllocIndexQueryMiss";
TVMFFIByteArray type_key{type_key_data, std::char_traits<char>::length(type_key_data)};
EXPECT_EQ(TVMFFITypeGetOrAllocIndex(&type_key, -1, 0, 0, 0, -2), -2);

int32_t type_index = -1;
EXPECT_NE(TVMFFITypeKeyToIndex(&type_key, &type_index), 0);
EXPECT_EQ(type_index, -1);
}

TEST(Object, InstanceCheck) {
ObjectPtr<Object> a = make_object<TIntObj>(11);
ObjectPtr<Object> b = make_object<TFloatObj>(11);
Expand Down
Loading