Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions autowrap/CodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,16 @@ def create_wrapper_for_class(self, r_class: ResolvedClass, out_codes: CodeDict)
# Class documentation (multi-line)
docstring = "Cython implementation of %s\n" % cy_type
docstring += special_class_doc % locals()
if r_class.cpp_decl.annotations.get("wrap-inherits", "") != "":
docstring += " -- Inherits from %s\n" % r_class.cpp_decl.annotations.get(
"wrap-inherits", ""
)
inherit_annot = r_class.cpp_decl.annotations.get("wrap-inherits", [])
if inherit_annot:
# Generate Sphinx RST links for inherited classes
inherit_list = inherit_annot if isinstance(inherit_annot, list) else [inherit_annot]
inherit_links = []
for base_class in inherit_list:
# Extract class name (handle template syntax like "Base[A]")
base_name = base_class.split('[')[0].strip()
inherit_links.append(":py:class:`%s`" % base_name)
docstring += " -- Inherits from %s\n" % ", ".join(inherit_links)

extra_doc = r_class.cpp_decl.annotations.get("wrap-doc", None)
if extra_doc is not None:
Expand Down Expand Up @@ -790,16 +796,42 @@ def create_wrapper_for_class(self, r_class: ResolvedClass, out_codes: CodeDict)

iterators, non_iter_methods = self.filterout_iterators(r_class.methods)

# Separate class-defined methods from inherited methods
inherited_method_bases = getattr(r_class.cpp_decl, 'inherited_method_bases', {})
class_methods = {}
inherited_methods = {}

for name, methods in non_iter_methods.items():
if name == r_class.name:
# Constructor always goes first
codes, stub_code = self.create_wrapper_for_constructor(r_class, methods)
cons_created = True
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)
elif name in inherited_method_bases:
inherited_methods[name] = methods
else:
codes, stub_code = self.create_wrapper_for_method(r_class, name, methods)
class_methods[name] = methods

# Generate class-defined methods first (sorted alphabetically)
for name, methods in sorted(class_methods.items()):
codes, stub_code = self.create_wrapper_for_method(r_class, name, methods)
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)

# Generate inherited methods grouped together (sorted alphabetically)
if inherited_methods:
for name, methods in sorted(inherited_methods.items()):
base_class_name = inherited_method_bases.get(name, "")
codes, stub_code = self.create_wrapper_for_method(
r_class, name, methods, inherited_from=base_class_name
)
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)

has_ops = dict()
for ops in ["==", "!=", "<", "<=", ">", ">="]:
has_op = ("operator%s" % ops) in non_iter_methods
Expand Down Expand Up @@ -881,7 +913,7 @@ def _create_iter_methods(self, iterators, instance_mapping, local_mapping):
return codes, stub_codes

def _create_overloaded_method_decl(
self, py_name, dispatched_m_names, methods, use_return, use_kwargs=False
self, py_name, dispatched_m_names, methods, use_return, use_kwargs=False, inherited_from=None
):
L.info(" create wrapper decl for overloaded method %s" % py_name)

Expand Down Expand Up @@ -955,6 +987,8 @@ def _create_overloaded_method_decl(
extra_doc = method.cpp_decl.annotations.get("wrap-doc", None)
if extra_doc is not None:
docstring += "\n" + extra_doc.render(indent=8)
if inherited_from:
docstring += "\n Inherited from :py:class:`%s`." % inherited_from

typestub_code.add(
"""
Expand Down Expand Up @@ -1003,7 +1037,7 @@ def _create_overloaded_method_decl(
)
return method_code, typestub_code

def create_wrapper_for_method(self, cdcl, py_name, methods):
def create_wrapper_for_method(self, cdcl, py_name, methods, inherited_from=None):
if py_name.startswith("operator"):
__, __, op = py_name.partition("operator")
if op in ["!=", "==", "<", "<=", ">", ">="]:
Expand Down Expand Up @@ -1102,7 +1136,7 @@ def create_wrapper_for_method(self, cdcl, py_name, methods):

if len(methods) == 1:
code, typestubs = self.create_wrapper_for_nonoverloaded_method(
cdcl, py_name, methods[0]
cdcl, py_name, methods[0], inherited_from=inherited_from
)
return [code], typestubs
else:
Expand All @@ -1122,12 +1156,12 @@ def create_wrapper_for_method(self, cdcl, py_name, methods):
codes.append(code)

code, typestubs = self._create_overloaded_method_decl(
py_name, dispatched_m_names, methods, True
py_name, dispatched_m_names, methods, True, inherited_from=inherited_from
)
codes.append(code)
return codes, typestubs

def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_fun=False):
def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_fun=False, inherited_from=None):
"""Creates the function declarations and the input conversion to C++
and the output conversion back to Python.

Expand Down Expand Up @@ -1181,6 +1215,10 @@ def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_f
if extra_doc is not None:
docstring += "\n" + extra_doc.render(indent=8)
stubdocstring += "\n" + extra_doc.render(indent=8)

# Add inherited from notation for typestub
if inherited_from:
stubdocstring += "\n Inherited from :py:class:`%s`." % inherited_from

if method.is_static:
code.add(
Expand Down Expand Up @@ -1355,7 +1393,7 @@ def _create_wrapper_for_attribute(self, attribute):
code.add(" return py_result")
return code, stubs

def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method):
def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method, inherited_from=None):
L.info(" create wrapper for %s ('%s')" % (py_name, method))
meth_code = Code()

Expand All @@ -1364,7 +1402,7 @@ def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method):
cleanups,
in_types,
stubs,
) = self._create_fun_decl_and_input_conversion(meth_code, py_name, method)
) = self._create_fun_decl_and_input_conversion(meth_code, py_name, method, inherited_from=inherited_from)

# call wrapped method and convert result value back to python
cpp_name = method.cpp_decl.name
Expand Down
2 changes: 1 addition & 1 deletion autowrap/DeclResolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _add_inherited_methods(cdcl, super_cld, used_parameters):
) # remove constructors
for method in transformed_methods:
logger.info("attach to %s: %s" % (cdcl.name, method))
cdcl.attach_base_methods(transformed_methods)
cdcl.attach_base_methods(transformed_methods, base_class_name=super_cld.name)
# logger.info("")


Expand Down
8 changes: 7 additions & 1 deletion autowrap/PXDParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def __init__(self, name, template_parameters, methods, attributes, annotations,
self.methods = methods
self.attributes = attributes
self.template_parameters = template_parameters
# Track which methods are inherited and from which base class
self.inherited_method_bases = {} # method_name -> base_class_name

@classmethod
def parseTree(cls, node: Cython.Compiler.Nodes.CppClassNode, lines: Collection[str], pxd_path):
Expand Down Expand Up @@ -492,11 +494,15 @@ def has_method(self, other_decl):
return False
return any(decl.matches(other_decl) for decl in with_same_name)

def attach_base_methods(self, dd):
def attach_base_methods(self, dd, base_class_name=None):
for name, decls in dd.items():
for decl in decls:
if not self.has_method(decl):
self.methods.setdefault(decl.name, []).append(decl)
if base_class_name is not None:
# Track that this method is inherited (use first base if multiple)
if decl.name not in self.inherited_method_bases:
self.inherited_method_bases[decl.name] = base_class_name


class CppAttributeDecl(BaseDecl):
Expand Down
Loading