Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 71 additions & 28 deletions autowrap/CodeGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,27 @@ def create_wrapper_for_class(self, r_class: ResolvedClass, out_codes: CodeDict)
# Class documentation (multi-line)
docstring = "Cython implementation of %s\n" % cy_type
docstring += special_class_doc % locals()
if r_class.cpp_decl.annotations.get("wrap-inherits", "") != "":
docstring += " -- Inherits from %s\n" % r_class.cpp_decl.annotations.get(
"wrap-inherits", ""
)
inherit_annot = r_class.cpp_decl.annotations.get("wrap-inherits", [])
if inherit_annot:
# Normalize inherit_annot to a list
if isinstance(inherit_annot, str):
inherit_list = [inherit_annot]
elif isinstance(inherit_annot, list):
inherit_list = inherit_annot
else:
raise ValueError(
f"wrap-inherits annotation must be a string or list, got {type(inherit_annot).__name__}"
)
# Generate Sphinx RST links for inherited classes
inherit_links = []
for base_class in inherit_list:
if not isinstance(base_class, str) or not base_class:
continue # Skip empty or invalid entries
# Extract class name (handle template syntax like "Base[A]")
base_name = base_class.split('[')[0].strip()
inherit_links.append(":py:class:`%s`" % base_name)
if inherit_links:
docstring += " -- Inherits from %s\n" % ", ".join(inherit_links)

extra_doc = r_class.cpp_decl.annotations.get("wrap-doc", None)
if extra_doc is not None:
Expand Down Expand Up @@ -791,16 +808,42 @@ def create_wrapper_for_class(self, r_class: ResolvedClass, out_codes: CodeDict)

iterators, non_iter_methods = self.filterout_iterators(r_class.methods)

# Separate class-defined methods from inherited methods
inherited_method_bases = getattr(r_class.cpp_decl, 'inherited_method_bases', {})
class_methods = {}
inherited_methods = {}

for name, methods in non_iter_methods.items():
if name == r_class.name:
# Constructor always goes first
codes, stub_code = self.create_wrapper_for_constructor(r_class, methods)
cons_created = True
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)
elif name in inherited_method_bases:
inherited_methods[name] = methods
else:
codes, stub_code = self.create_wrapper_for_method(r_class, name, methods)
class_methods[name] = methods

# Generate class-defined methods first (sorted alphabetically)
for name, methods in sorted(class_methods.items()):
codes, stub_code = self.create_wrapper_for_method(r_class, name, methods)
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)

# Generate inherited methods grouped together (sorted alphabetically)
if inherited_methods:
for name, methods in sorted(inherited_methods.items()):
base_class_name = inherited_method_bases.get(name, "")
codes, stub_code = self.create_wrapper_for_method(
r_class, name, methods, inherited_from=base_class_name
)
typestub_code.add(stub_code)
for ci in codes:
class_code.add(ci)

has_ops = dict()
for ops in ["==", "!=", "<", "<=", ">", ">="]:
has_op = ("operator%s" % ops) in non_iter_methods
Expand Down Expand Up @@ -882,7 +925,7 @@ def _create_iter_methods(self, iterators, instance_mapping, local_mapping):
return codes, stub_codes

def _create_overloaded_method_decl(
self, py_name, dispatched_m_names, methods, use_return, use_kwargs=False
self, py_name, dispatched_m_names, methods, use_return, use_kwargs=False, inherited_from=None
):
L.info(" create wrapper decl for overloaded method %s" % py_name)

Expand Down Expand Up @@ -956,6 +999,8 @@ def _create_overloaded_method_decl(
extra_doc = method.cpp_decl.annotations.get("wrap-doc", None)
if extra_doc is not None:
docstring += "\n" + extra_doc.render(indent=8)
if inherited_from:
docstring += "\n Inherited from :py:class:`%s`." % inherited_from

typestub_code.add(
"""
Expand Down Expand Up @@ -1004,7 +1049,7 @@ def _create_overloaded_method_decl(
)
return method_code, typestub_code

def create_wrapper_for_method(self, cdcl, py_name, methods):
def create_wrapper_for_method(self, cdcl, py_name, methods, inherited_from=None):
if py_name.startswith("operator"):
__, __, op = py_name.partition("operator")
if op in ["!=", "==", "<", "<=", ">", ">="]:
Expand Down Expand Up @@ -1103,7 +1148,7 @@ def create_wrapper_for_method(self, cdcl, py_name, methods):

if len(methods) == 1:
code, typestubs = self.create_wrapper_for_nonoverloaded_method(
cdcl, py_name, methods[0]
cdcl, py_name, methods[0], inherited_from=inherited_from
)
return [code], typestubs
else:
Expand All @@ -1123,12 +1168,12 @@ def create_wrapper_for_method(self, cdcl, py_name, methods):
codes.append(code)

code, typestubs = self._create_overloaded_method_decl(
py_name, dispatched_m_names, methods, True
py_name, dispatched_m_names, methods, True, inherited_from=inherited_from
)
codes.append(code)
return codes, typestubs

def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_fun=False):
def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_fun=False, inherited_from=None):
"""Creates the function declarations and the input conversion to C++
and the output conversion back to Python.

Expand Down Expand Up @@ -1182,6 +1227,10 @@ def _create_fun_decl_and_input_conversion(self, code, py_name, method, is_free_f
if extra_doc is not None:
docstring += "\n" + extra_doc.render(indent=8)
stubdocstring += "\n" + extra_doc.render(indent=8)

# Add inherited from notation for typestub
if inherited_from:
stubdocstring += "\n Inherited from :py:class:`%s`." % inherited_from

if method.is_static:
code.add(
Expand Down Expand Up @@ -1356,7 +1405,7 @@ def _create_wrapper_for_attribute(self, attribute):
code.add(" return py_result")
return code, stubs

def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method):
def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method, inherited_from=None):
L.info(" create wrapper for %s ('%s')" % (py_name, method))
meth_code = Code()

Expand All @@ -1365,7 +1414,7 @@ def create_wrapper_for_nonoverloaded_method(self, cdcl, py_name, method):
cleanups,
in_types,
stubs,
) = self._create_fun_decl_and_input_conversion(meth_code, py_name, method)
) = self._create_fun_decl_and_input_conversion(meth_code, py_name, method, inherited_from=inherited_from)

# call wrapped method and convert result value back to python
cpp_name = method.cpp_decl.name
Expand Down Expand Up @@ -1812,31 +1861,25 @@ def create_special_setitem_method(self, mdcl):
is_integral = self._is_integral_type(ctype_in)
size_guard = mdcl.cpp_decl.annotations.get("wrap-upper-limit")

if is_integral:
meth_code.add(
"""
# Generate method signature (same for all key types)
meth_code.add(
"""
|def __setitem__(self, $in_t_cy key, $res_t_base value):
| \"\"\"$docstring\"\"\"
""",
locals(),
)
if size_guard:
meth_code.add(
"""
locals(),
)

# Apply bounds checking only for integral types with size guard
if is_integral and size_guard:
meth_code.add(
"""
| cdef int _idx = $call_arg
| if _idx < 0:
| raise IndexError("invalid index %d" % _idx)
| if _idx >= self.inst.get().$size_guard:
| raise IndexError("invalid index %d" % _idx)
""",
locals(),
)
else:
meth_code.add(
"""
|def __setitem__(self, $in_t_cy key, $res_t_base value):
| \"\"\"$docstring\"\"\"
""",
locals(),
)

Expand Down
2 changes: 1 addition & 1 deletion autowrap/DeclResolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _add_inherited_methods(cdcl, super_cld, used_parameters):
) # remove constructors
for method in transformed_methods:
logger.info("attach to %s: %s" % (cdcl.name, method))
cdcl.attach_base_methods(transformed_methods)
cdcl.attach_base_methods(transformed_methods, base_class_name=super_cld.name)
# logger.info("")


Expand Down
8 changes: 7 additions & 1 deletion autowrap/PXDParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def __init__(self, name, template_parameters, methods, attributes, annotations,
self.methods = methods
self.attributes = attributes
self.template_parameters = template_parameters
# Track which methods are inherited and from which base class
self.inherited_method_bases = {} # method_name -> base_class_name

@classmethod
def parseTree(cls, node: Cython.Compiler.Nodes.CppClassNode, lines: Collection[str], pxd_path):
Expand Down Expand Up @@ -492,11 +494,15 @@ def has_method(self, other_decl):
return False
return any(decl.matches(other_decl) for decl in with_same_name)

def attach_base_methods(self, dd):
def attach_base_methods(self, dd, base_class_name=None):
for name, decls in dd.items():
for decl in decls:
if not self.has_method(decl):
self.methods.setdefault(decl.name, []).append(decl)
if base_class_name is not None:
# Track that this method is inherited (use first base if multiple)
if decl.name not in self.inherited_method_bases:
self.inherited_method_bases[decl.name] = base_class_name


class CppAttributeDecl(BaseDecl):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_code_generator_libcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def test_libcpp():
assert len(libcpp.LibCppTest.__doc__) == 214
assert len(libcpp.LibCppTest.twist.__doc__) == 111
assert len(libcpp.LibCppTest.gett.__doc__) == 72
assert len(libcpp.ABS_Impl1.__doc__) == 90
# Length changed due to Sphinx RST syntax (:py:class:`AbstractBaseClass` instead of AbstractBaseClass)
assert len(libcpp.ABS_Impl1.__doc__) == 98

sub_libcpp_copy_constructors(libcpp)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_full_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def test_full_lib(tmpdir):
# Check doc string
assert "Inherits from" in moduleB.Bklass.__doc__
assert "some doc!" in moduleB.Bklass.__doc__
assert len(moduleB.Bklass.__doc__) == 93, len(moduleB.Bklass.__doc__)
# Length changed due to Sphinx RST syntax (:py:class:`A_second` instead of A_second)
assert len(moduleB.Bklass.__doc__) == 101, len(moduleB.Bklass.__doc__)

Bsecond = moduleB.B_second(8)
Dsecond = moduleCD.D_second(11)
Expand Down