Skip to content
62 changes: 50 additions & 12 deletions autowrap/CodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,16 @@ def create_wrapper_for_class(self, r_class: ResolvedClass, out_codes: CodeDict)
# Class documentation (multi-line)
docstring = "Cython implementation of %s\n" % cy_type
docstring += special_class_doc % locals()
if r_class.cpp_decl.annotations.get("wrap-inherits", "") != "":
docstring += " -- Inherits from %s\n" % r_class.cpp_decl.annotations.get(
"wrap-inherits", ""
)
inherit_annot = r_class.cpp_decl.annotations.get("wrap-inherits", [])
if inherit_annot:
# Generate Sphinx RST links for inherited classes
inherit_list = inherit_annot if isinstance(inherit_annot, list) else [inherit_annot]
inherit_links = []
for base_class in inherit_list:
# Extract class name (handle template syntax like "Base[A]")
base_name = base_class.split('[')[0].strip()
inherit_links.append(":py:class:`%s`" % base_name)
docstring += " -- Inherits from %s\n" % ", ".join(inherit_links)

extra_doc = r_class.cpp_decl.annotations.get("wrap-doc", None)
if extra_doc is not None:
Expand Down Expand Up @@ -790,16 +796,42 @@ def create_wrapper_for_class(self, r_class: ResolvedClass, out_codes: CodeDict)

iterators, non_iter_methods = self.filterout_iterators(r_class.methods)

# Separate class-defined methods from inherited methods
inherited_method_bases = getattr(r_class.cpp_decl, 'inherited_method_bases', {})
class_methods = {}
inherited_methods = {}

for name, methods in non_iter_methods.items():
if name == r_class.name:
# Constructor always goes first
codes, stub_code = self.create_wrapper_for_constructor(r_class, methods)
cons_created = True
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)
elif name in inherited_method_bases:
inherited_methods[name] = methods
else:
codes, stub_code = self.create_wrapper_for_method(r_class, name, methods)
class_methods[name] = methods

# Generate class-defined methods first (sorted alphabetically)
for name, methods in sorted(class_methods.items()):
codes, stub_code = self.create_wrapper_for_method(r_class, name, methods)
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)

# Generate inherited methods grouped together (sorted alphabetically)
if inherited_methods:
for name, methods in sorted(inherited_methods.items()):
base_class_name = inherited_method_bases.get(name, "")
codes, stub_code = self.create_wrapper_for_method(
r_class, name, methods, inherited_from=base_class_name
)
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)

has_ops = dict()
for ops in ["==", "!=", "<", "<=", ">", ">="]:
has_op = ("operator%s" % ops) in non_iter_methods
Expand Down Expand Up @@ -881,7 +913,7 @@ def _create_iter_methods(self, iterators, instance_mapping, local_mapping):
return codes, stub_codes

def _create_overloaded_method_decl(
self, py_name, dispatched_m_names, methods, use_return, use_kwargs=False
self, py_name, dispatched_m_names, methods, use_return, use_kwargs=False, inherited_from=None
):
L.info(" create wrapper decl for overloaded method %s" % py_name)

Expand Down Expand Up @@ -955,6 +987,8 @@ def _create_overloaded_method_decl(
extra_doc = method.cpp_decl.annotations.get("wrap-doc", None)
if extra_doc is not None:
docstring += "\n" + extra_doc.render(indent=8)
if inherited_from:
docstring += "\n Inherited from :py:class:`%s`." % inherited_from

typestub_code.add(
"""
Expand Down Expand Up @@ -1003,7 +1037,7 @@ def _create_overloaded_method_decl(
)
return method_code, typestub_code

def create_wrapper_for_method(self, cdcl, py_name, methods):
def create_wrapper_for_method(self, cdcl, py_name, methods, inherited_from=None):
if py_name.startswith("operator"):
__, __, op = py_name.partition("operator")
if op in ["!=", "==", "<", "<=", ">", ">="]:
Expand Down Expand Up @@ -1102,7 +1136,7 @@ def create_wrapper_for_method(self, cdcl, py_name, methods):

if len(methods) == 1:
code, typestubs = self.create_wrapper_for_nonoverloaded_method(
cdcl, py_name, methods[0]
cdcl, py_name, methods[0], inherited_from=inherited_from
)
return [code], typestubs
else:
Expand All @@ -1122,12 +1156,12 @@ def create_wrapper_for_method(self, cdcl, py_name, methods):
codes.append(code)

code, typestubs = self._create_overloaded_method_decl(
py_name, dispatched_m_names, methods, True
py_name, dispatched_m_names, methods, True, inherited_from=inherited_from
)
codes.append(code)
return codes, typestubs

def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_fun=False):
def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_fun=False, inherited_from=None):
"""Creates the function declarations and the input conversion to C++
and the output conversion back to Python.

Expand Down Expand Up @@ -1181,6 +1215,10 @@ def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_f
if extra_doc is not None:
docstring += "\n" + extra_doc.render(indent=8)
stubdocstring += "\n" + extra_doc.render(indent=8)

# Add inherited from notation for typestub
if inherited_from:
stubdocstring += "\n Inherited from :py:class:`%s`." % inherited_from

if method.is_static:
code.add(
Expand Down Expand Up @@ -1355,7 +1393,7 @@ def _create_wrapper_for_attribute(self, attribute):
code.add(" return py_result")
return code, stubs

def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method):
def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method, inherited_from=None):
L.info(" create wrapper for %s ('%s')" % (py_name, method))
meth_code = Code()

Expand All @@ -1364,7 +1402,7 @@ def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method):
cleanups,
in_types,
stubs,
) = self._create_fun_decl_and_input_conversion(meth_code, py_name, method)
) = self._create_fun_decl_and_input_conversion(meth_code, py_name, method, inherited_from=inherited_from)

# call wrapped method and convert result value back to python
cpp_name = method.cpp_decl.name
Expand Down
2 changes: 1 addition & 1 deletion autowrap/DeclResolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _add_inherited_methods(cdcl, super_cld, used_parameters):
) # remove constructors
for method in transformed_methods:
logger.info("attach to %s: %s" % (cdcl.name, method))
cdcl.attach_base_methods(transformed_methods)
cdcl.attach_base_methods(transformed_methods, base_class_name=super_cld.name)
# logger.info("")


Expand Down
8 changes: 7 additions & 1 deletion autowrap/PXDParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def __init__(self, name, template_parameters, methods, attributes, annotations,
self.methods = methods
self.attributes = attributes
self.template_parameters = template_parameters
# Track which methods are inherited and from which base class
self.inherited_method_bases = {} # method_name -> base_class_name

@classmethod
def parseTree(cls, node: Cython.Compiler.Nodes.CppClassNode, lines: Collection[str], pxd_path):
Expand Down Expand Up @@ -492,11 +494,15 @@ def has_method(self, other_decl):
return False
return any(decl.matches(other_decl) for decl in with_same_name)

def attach_base_methods(self, dd):
def attach_base_methods(self, dd, base_class_name=None):
for name, decls in dd.items():
for decl in decls:
if not self.has_method(decl):
self.methods.setdefault(decl.name, []).append(decl)
if base_class_name is not None:
# Track that this method is inherited (use first base if multiple)
if decl.name not in self.inherited_method_bases:
self.inherited_method_bases[decl.name] = base_class_name


class CppAttributeDecl(BaseDecl):
Expand Down
Loading