Skip to content

Commit d4bf98f

Browse files
committed
Allow python class as a base class
1 parent d068f9d commit d4bf98f

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

src/nb_func.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,8 @@ PyObject *nb_func_new(const void *in_) noexcept {
250250
instead, hide the parent's overloads in this case */
251251
if (fp->scope != f->scope)
252252
Py_CLEAR(func_prev);
253-
} else if (name_cstr[0] == '_') {
254-
Py_CLEAR(func_prev);
255253
} else {
256-
check(false,
257-
"nb::detail::nb_func_new(\"%s\"): cannot overload "
258-
"existing non-function object of the same name!", name_cstr);
254+
Py_CLEAR(func_prev);
259255
}
260256
} else {
261257
PyErr_Clear();

src/nb_type.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -725,10 +725,10 @@ void *type_get_slot(PyTypeObject *t, int slot_id) {
725725
#endif
726726

727727
static PyObject *nb_type_from_metaclass(PyTypeObject *meta, PyObject *mod,
728-
PyType_Spec *spec) {
728+
PyType_Spec *spec, PyObject *bases = nullptr) {
729729
#if NB_TYPE_FROM_METACLASS_IMPL == 0
730730
// Life is good, PyType_FromMetaclass() is available
731-
return PyType_FromMetaclass(meta, mod, spec, nullptr);
731+
return PyType_FromMetaclass(meta, mod, spec, bases);
732732
#else
733733
/* The fallback code below emulates PyType_FromMetaclass() on Python prior
734734
to version 3.12. It requires access to CPython-internal structures, which
@@ -1066,6 +1066,40 @@ static PyObject *nb_type_vectorcall(PyObject *self, PyObject *const *args_in,
10661066
}
10671067
}
10681068

1069+
/// Call __init__subclass__ of the parent class. This function assumes that the
1070+
/// passed type object has a parent class.
1071+
/// Copied and adapted from the following function in CPython
1072+
/// https://github.com/python/cpython/blob/89c220b93c06059f623e2d232bd54f49be1be22d/Objects/typeobject.c#L11848
1073+
static bool call_init_subclass(PyObject *type) {
1074+
// Call super(type, type) which returns a proxy for the parent class
1075+
PyObject *super_args[2] = {type, type};
1076+
PyObject *super = PyObject_Vectorcall((PyObject *)&PySuper_Type, super_args, 2, NULL);
1077+
1078+
// Try and get the `__init_subclass__` function from the base class
1079+
PyObject *func = PyObject_GetAttrString(super, "__init_subclass__");
1080+
Py_DECREF(super);
1081+
if (func == NULL)
1082+
return true;
1083+
1084+
// The base class might be written in C and not implement the vectorcall
1085+
// protocol. Use the PyObject_Call which is guaranteed to always work.
1086+
// This calls the base class __init_subclass__ with empty args and kwargs.
1087+
PyObject *args = PyTuple_New(0);
1088+
PyObject *kwargs = PyDict_New();
1089+
assert(args != NULL);
1090+
assert(kwargs != NULL);
1091+
1092+
PyObject *result = PyObject_Call(func, args, kwargs);
1093+
Py_DECREF(func);
1094+
Py_DECREF(args);
1095+
Py_DECREF(kwargs);
1096+
if (result == NULL)
1097+
return false;
1098+
1099+
Py_DECREF(result);
1100+
return true;
1101+
}
1102+
10691103
/// Called when a C++ type is bound via nb::class_<>
10701104
PyObject *nb_type_new(const type_init_data *t) noexcept {
10711105
bool has_doc = t->flags & (uint32_t) type_init_flags::has_doc,
@@ -1154,10 +1188,6 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
11541188
generic_base = true;
11551189
}
11561190
#endif
1157-
1158-
check(nb_type_check(base),
1159-
"nanobind::detail::nb_type_new(\"%s\"): base type is not a "
1160-
"nanobind type!", t_name);
11611191
} else if (has_base) {
11621192
lock_internals guard(internals_);
11631193
nb_type_map_slow::iterator it2 = internals_->type_c2p_slow.find(t->base);
@@ -1168,7 +1198,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
11681198
}
11691199

11701200
type_data *tb = nullptr;
1171-
if (base) {
1201+
if (base != nullptr && nb_type_check(base)) {
11721202
// Check if the base type already has dynamic attributes
11731203
tb = nb_type_data((PyTypeObject *) base);
11741204
if (tb->flags & (uint32_t) type_flags::has_dynamic_attr)
@@ -1335,7 +1365,7 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
13351365

13361366
PyTypeObject *metaclass = nb_type_tp(has_supplement ? t->supplement : 0);
13371367

1338-
PyObject *result = nb_type_from_metaclass(metaclass, mod, &spec);
1368+
PyObject *result = nb_type_from_metaclass(metaclass, mod, &spec, base);
13391369
if (!result) {
13401370
python_error err;
13411371
check(false,
@@ -1409,6 +1439,11 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
14091439
free((char *) t_name);
14101440
}
14111441

1442+
if (base != nullptr) {
1443+
bool init_sub_res = call_init_subclass(result);
1444+
assert(init_sub_res);
1445+
}
1446+
14121447
#if PY_VERSION_HEX >= 0x03090000
14131448
if (generic_base)
14141449
setattr(result, "__orig_bases__", make_tuple(handle(t->base_py)));

0 commit comments

Comments
 (0)