Skip to content

Commit

Permalink
Automatically discover all .pyx files for cython compilation, as well…
Browse files Browse the repository at this point in the history
… as detect if pyx files that uses numpy to add extra compilation args for those files.
  • Loading branch information
jfranmatheu committed Feb 12, 2025
1 parent 76a5dd4 commit da0e9f4
Showing 1 changed file with 50 additions and 26 deletions.
76 changes: 50 additions & 26 deletions cy_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,27 @@
import subprocess


numpy_extra_compile_args = {
"include_dirs": [np.get_include()],
"define_macros": [('NPY_NO_DEPRECATED_API', 'NPY_1_7_API_VERSION')]
}


def build_for_architecture(arch):
"""Build extensions for a specific architecture"""
print(f"Building for architecture: {arch}")

# Base compiler flags for all platforms
compiler_flags = {
'Windows': ['/O2'],
'Darwin': ['-O3'], # macOS
'Linux': ['-O3']
}

# Get base optimization flag for current platform
extra_compile_args = compiler_flags.get(platform.system(), ['-O3'])
extra_link_args = []

# Add architecture flags only for macOS
if platform.system() == 'Darwin' and arch:
# Let ARCHFLAGS environment variable handle the architecture
Expand All @@ -36,27 +42,45 @@ def build_for_architecture(arch):
'extra_link_args': extra_link_args,
'language': 'c++'
}

np_ext_kwargs = {
"include_dirs": [np.get_include()],
"define_macros": [('NPY_NO_DEPRECATED_API', 'NPY_1_7_API_VERSION')],
}

# Define extension modules
ext_modules = [
Extension(
f"retopoflow.cy.rfmesh_visibility",
sources=["retopoflow/cy/rfmesh_visibility.pyx"],
**shared_ext_kwargs,
**np_ext_kwargs
),
Extension(
f"retopoflow.cy.bmesh_utils",
sources=["retopoflow/cy/bmesh_utils.pyx"],
**shared_ext_kwargs
)
]


# Function to check if a .pyx file uses numpy.
def uses_numpy(file_path):
with open(file_path, 'r') as f:
# Only check first 50 lines where imports typically are
for i, line in enumerate(f):
if i > 50: # Stop after checking first 50 lines
break
if any(numpy_import in line for numpy_import in [
'cimport numpy',
'import numpy',
'from numpy'
]):
return True
return False

# Automatically discover all .pyx files.
cy_dir = "retopoflow/cy"
ext_modules = []
for file in os.listdir(cy_dir):
if file.endswith('.pyx'):
module_name = f"retopoflow.cy.{file[:-4]}" # Remove .pyx extension.
file_path = os.path.join(cy_dir, file)

# Build extension kwargs.
ext_kwargs = {**shared_ext_kwargs}

# Add numpy kwargs only if the file uses numpy.
if uses_numpy(file_path):
ext_kwargs.update(numpy_extra_compile_args)

ext_modules.append(
Extension(
module_name,
sources=[file_path],
**ext_kwargs
)
)

# Build extensions
setup(
name="retopoflow",
Expand All @@ -80,11 +104,11 @@ def build_for_architecture(arch):

def main():
system = platform.system()

if system == 'Darwin': # macOS
# Check if specific architecture is requested
target_arch = os.environ.get('TARGET_ARCH')

if target_arch:
# Build for specific architecture
build_for_architecture(target_arch)
Expand Down

0 comments on commit da0e9f4

Please sign in to comment.