Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 112 additions & 3 deletions include/tvm/ffi/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class AnyView {
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const {
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
<< TypeTraits<T>::TypeStr() << "`";
Expand Down Expand Up @@ -361,6 +361,29 @@ class Any {
}
}

/**
* \brief Strictly reinterpret the Any as a type T or throw.
*
* \tparam T The type to cast to.
* \return The casted value.
* \note This function will not run fallback conversions.
*/
template <typename T,
typename = std::enable_if_t<TypeTraits<T>::storage_enabled || std::is_same_v<T, Any>>>
TVM_FFI_INLINE T as_or_throw() && {
if constexpr (std::is_same_v<T, Any>) {
return std::move(*this);
} else {
std::optional<T> result = std::move(*this).template as<T>();
if (TVM_FFI_PREDICT_FALSE(!result.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot treat type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` as type `"
<< TypeTraits<T>::TypeStr() << "`";
}
return *std::move(result);
}
}

/**
* \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible.
*
Expand All @@ -382,6 +405,29 @@ class Any {
}
}

/**
* \brief Strictly reinterpret the Any as a type T or throw.
*
* \tparam T The type to cast to.
* \return The casted value.
* \note This function will not run fallback conversions.
*/
template <typename T,
typename = std::enable_if_t<TypeTraits<T>::convert_enabled || std::is_same_v<T, Any>>>
TVM_FFI_INLINE T as_or_throw() const& {
if constexpr (std::is_same_v<T, Any>) {
return *this;
} else {
std::optional<T> result = this->as<T>();
if (TVM_FFI_PREDICT_FALSE(!result.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot treat type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` as type `"
<< TypeTraits<T>::TypeStr() << "`";
}
return *std::move(result);
}
}

/*!
* \brief Shortcut of as Object to cast to a const pointer when T is an Object.
*
Expand All @@ -401,7 +447,7 @@ class Any {
template <typename T, typename = std::enable_if_t<TypeTraits<T>::convert_enabled>>
TVM_FFI_INLINE T cast() const& {
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
<< TypeTraits<T>::TypeStr() << "`";
Expand All @@ -421,7 +467,7 @@ class Any {
}
// slow path, try to do fallback convert
std::optional<T> opt = TypeTraits<T>::TryCastFromAnyView(&data_);
if (!opt.has_value()) {
if (TVM_FFI_PREDICT_FALSE(!opt.has_value())) {
TVM_FFI_THROW(TypeError) << "Cannot convert from type `"
<< TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `"
<< TypeTraits<T>::TypeStr() << "`";
Expand Down Expand Up @@ -824,6 +870,69 @@ struct AnyEqual {
}
};

// Defer this definition until any.h so the throwing path can depend on
// TVM_FFI_THROW(TypeError), while object.h stays below the error layer.
//! \cond Doxygen_Suppress
template <typename ObjectRefType, typename>
TVM_FFI_INLINE ObjectRefType ObjectRef::as_or_throw() const& {
if (data_ != nullptr) {
// Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data will optimize away.
TVMFFIAny any_data;
any_data.type_index = data_->type_index();
TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin);
any_data.zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
any_data.v_obj = reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
if (TVM_FFI_PREDICT_TRUE(TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data))) {
ObjectRefType result(UnsafeInit{});
result.data_ = data_;
return result;
} else {
TVM_FFI_THROW(TypeError) << "Cannot treat type `"
<< TypeTraits<ObjectRefType>::GetMismatchTypeInfo(&any_data)
<< "` as type `" << TypeTraits<ObjectRefType>::TypeStr() << "`";
}
} else {
if constexpr (ObjectRefType::_type_is_nullable) {
return ObjectRefType(UnsafeInit{});
} else {
TVM_FFI_THROW(TypeError) << "Cannot treat type `" << StaticTypeKey::kTVMFFINone
<< "` as type `" << TypeTraits<ObjectRefType>::TypeStr() << "`";
}
}
}

template <typename ObjectRefType, typename>
TVM_FFI_INLINE ObjectRefType ObjectRef::as_or_throw() && {
if (data_ != nullptr) {
// Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data will optimize away.
TVMFFIAny any_data;
any_data.type_index = data_->type_index();
TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin);
any_data.zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
any_data.v_obj = reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
if (TVM_FFI_PREDICT_TRUE(TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data))) {
ObjectRefType result(UnsafeInit{});
result.data_ = std::move(data_);
data_ = nullptr;
return result;
} else {
TVM_FFI_THROW(TypeError) << "Cannot treat type `"
<< TypeTraits<ObjectRefType>::GetMismatchTypeInfo(&any_data)
<< "` as type `" << TypeTraits<ObjectRefType>::TypeStr() << "`";
}
} else {
if constexpr (ObjectRefType::_type_is_nullable) {
return ObjectRefType(UnsafeInit{});
} else {
TVM_FFI_THROW(TypeError) << "Cannot treat type `" << StaticTypeKey::kTVMFFINone
<< "` as type `" << TypeTraits<ObjectRefType>::TypeStr() << "`";
}
}
}
Comment thread
tqchen marked this conversation as resolved.
//! \endcond

// Placed near the end because this specialization depends on error handling.
template <>
struct TypeTraits<uint64_t> : public TypeTraitsIntBase<uint64_t> {
Expand Down
14 changes: 7 additions & 7 deletions include/tvm/ffi/base_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@
#if defined(__clang__)
#define TVM_FFI_UNSAFE_ASSUME(cond) __builtin_assume(cond)
#elif defined(__GNUC__)
// GCC 13+ supports __attribute__((assume(...))); fall back to the void-cast
// no-op for older GCC where __builtin_assume is absent.
#if __GNUC__ >= 13
#define TVM_FFI_UNSAFE_ASSUME(cond) __attribute__((assume(cond)))
#else
#define TVM_FFI_UNSAFE_ASSUME(cond) static_cast<void>(0)
#endif
// GCC does not reliably propagate __attribute__((assume(...))) through the
// returned-aggregate/helper flows used in TVM_FFI hot paths. Lower to an
// unreachable edge instead so GCC 11/14 recover the intended codegen.
#define TVM_FFI_UNSAFE_ASSUME(cond) \
do { \
if (!(cond)) __builtin_unreachable(); \
} while (0)
#elif defined(_MSC_VER)
#define TVM_FFI_UNSAFE_ASSUME(cond) __assume(cond)
#else
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/ffi/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ namespace ffi {
template <typename RefType, typename ObjectType>
inline RefType GetRef(const ObjectType* ptr) {
using ContainerType = typename RefType::ContainerType;
static_assert(RefType::_type_container_is_exact,
"GetRef requires RefType::ContainerType to exactly describe all objects the ref "
"can hold; use ObjectRef::as<RefType>() for richer TypeTraits-based refs");
static_assert(std::is_base_of_v<ContainerType, ObjectType>,
"Can only cast to the ref of same container type");

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/ffi/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,10 @@ class Array : public ObjectRef {
return static_cast<ArrayObj*>(data_.get());
}

/*! \brief specify container node */
/// \cond Doxygen_Suppress
using ContainerType = ArrayObj;
static constexpr bool _type_container_is_exact = false;
/// \endcond

/*!
* \brief Agregate arguments into a single Array<T>
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/ffi/container/dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,10 @@ class Dict : public ObjectRef {
}
}

/*! \brief specify container node */
/// \cond Doxygen_Suppress
using ContainerType = DictObj;
static constexpr bool _type_container_is_exact = false;
/// \endcond

/// \cond Doxygen_Suppress
/*! \brief Iterator of the hash map */
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/ffi/container/list.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,10 @@ class List : public ObjectRef {
}
}

/*! \brief specify container node */
/// \cond Doxygen_Suppress
using ContainerType = ListObj;
static constexpr bool _type_container_is_exact = false;
/// \endcond

private:
/*!
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/ffi/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ class Map : public ObjectRef {
}
return GetMapObj();
}
/*! \brief specify container node */
/// \cond Doxygen_Suppress
using ContainerType = MapObj;
static constexpr bool _type_container_is_exact = false;
/// \endcond

/// \cond Doxygen_Suppress
/*! \brief Iterator of the hash map */
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/ffi/container/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ class Tuple : public ObjectRef {
*ptr = T(std::forward<U>(item));
}

/*! \brief specify container node */
/// \cond Doxygen_Suppress
using ContainerType = ArrayObj;
static constexpr bool _type_container_is_exact = false;
/// \endcond

private:
static ObjectPtr<ArrayObj> MakeDefaultTupleNode() {
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/ffi/container/variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ class Variant : public details::VariantBase<details::all_object_ref_v<V...>> {
using TParent = details::VariantBase<details::all_object_ref_v<V...>>;
static_assert(details::all_storage_enabled_v<V...>,
"All types used in Variant<...> must be compatible with Any");
/// \cond Doxygen_Suppress
static constexpr bool _type_container_is_exact = false;
/// \endcond
/*
* \brief Helper utility to check if the type can be contained in the variant
*/
Expand Down
93 changes: 82 additions & 11 deletions include/tvm/ffi/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ constexpr uint64_t kCombinedRefCountMaskUInt32 = (static_cast<uint64_t>(1) << 32
*/
template <typename TargetType>
TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index);

} // namespace details

/*!
Expand Down Expand Up @@ -747,7 +748,7 @@ class ObjectRef {
* \return The pointer to the requested type.
*/
template <typename ObjectType, typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
const ObjectType* as() const {
const ObjectType* as() const& {
if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
Expand All @@ -758,27 +759,95 @@ class ObjectRef {
/*!
* \brief Try to downcast the ObjectRef to Optional<T> of the requested type.
*
* The function will return a std::nullopt if the cast or if the pointer is nullptr.
* If the cast fails, returns std::nullopt. If this ObjectRef is null, returns
* a null ref for nullable target refs and std::nullopt for non-nullable targets.
*
* \tparam ObjectRefType the target type, must be a subtype of ObjectRef'
* \tparam ObjectRefType the target type, must be a subtype of ObjectRef.
* \return The optional value of the requested type.
*/
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
TVM_FFI_INLINE std::optional<ObjectRefType> as() const {
TVM_FFI_INLINE std::optional<ObjectRefType> as() const& {
if (data_ != nullptr) {
if (data_->IsInstance<typename ObjectRefType::ContainerType>()) {
ObjectRefType ref(UnsafeInit{});
ref.data_ = data_;
return ref;
} else {
return std::nullopt;
// Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data will optimize away.
TVMFFIAny any_data;
any_data.type_index = data_->type_index();
TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin);
any_data.zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
any_data.v_obj = reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
if (TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data)) {
ObjectRefType result(UnsafeInit{});
result.data_ = data_;
return result;
}
return std::nullopt;
}
if constexpr (ObjectRefType::_type_is_nullable) {
return ObjectRefType(UnsafeInit{});
}
return std::nullopt;
}

/*!
* \brief Try to move-downcast the ObjectRef to Optional<T> of the requested type.
*
* If the cast succeeds, moves the internal object pointer to the returned
* ObjectRefType. If the cast fails, returns std::nullopt. If this ObjectRef
* is null, returns a null ref for nullable target refs and std::nullopt for
* non-nullable targets.
*
* \tparam ObjectRefType the target type, must be a subtype of ObjectRef.
* \return The optional value of the requested type.
*/
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
TVM_FFI_INLINE std::optional<ObjectRefType> as() && {
if (data_ != nullptr) {
// Piggy back to Any TypeTraits for rich ObjectRef check, temp any_data will optimize away.
TVMFFIAny any_data;
any_data.type_index = data_->type_index();
TVM_FFI_UNSAFE_ASSUME(any_data.type_index >= TypeIndex::kTVMFFIStaticObjectBegin);
any_data.zero_padding = 0;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data);
any_data.v_obj = reinterpret_cast<TVMFFIObject*>(const_cast<Object*>(data_.get()));
if (TypeTraits<ObjectRefType>::CheckAnyStrict(&any_data)) {
ObjectRefType result(UnsafeInit{});
result.data_ = std::move(data_);
data_ = nullptr;
return result;
}
} else {
return std::nullopt;
}
if constexpr (ObjectRefType::_type_is_nullable) {
return ObjectRefType(UnsafeInit{});
}
return std::nullopt;
}

/*!
* \brief Downcast the ObjectRef to the requested type or throw.
*
* \tparam ObjectRefType the target type, must be a subtype of ObjectRef
* \return The requested value.
*/
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
TVM_FFI_INLINE ObjectRefType as_or_throw() const&;

/*!
* \brief Move-downcast the ObjectRef to the requested type or throw.
*
* If the cast succeeds, moves the internal object pointer to the returned
* ObjectRefType.
*
* \tparam ObjectRefType the target type, must be a subtype of ObjectRef
* \return The requested value.
*/
template <typename ObjectRefType,
typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>>
TVM_FFI_INLINE ObjectRefType as_or_throw() &&;

/*!
* \brief Get the type index of the ObjectRef
* \return The type index of the ObjectRef
Expand All @@ -797,6 +866,8 @@ class ObjectRef {

/*! \brief type indicate the container type. */
using ContainerType = Object;
/*! \brief Whether ContainerType exactly describes the objects this ref can hold. */
static constexpr bool _type_container_is_exact = true;
/*! \brief Whether the reference can point to nullptr */
static constexpr bool _type_is_nullable = true;

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ffi/optional.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ template <typename T>
class Optional<T, std::enable_if_t<use_ptr_based_optional_v<T>>> : public ObjectRef {
public:
using ContainerType = typename T::ContainerType;
static constexpr bool _type_container_is_exact = T::_type_container_is_exact;
Optional() = default;
// NOLINTBEGIN(google-explicit-constructor)
Optional(const Optional<T>& other) : ObjectRef(other) {}
Expand Down
Loading
Loading