-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsetup.py
119 lines (109 loc) · 3.88 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""KLUJAX Setup."""
import os
import site
import sys
from glob import glob
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
include_dirs = [
os.path.join("xla"),
os.path.join("pybind11", "include"),
os.path.join("suitesparse", "SuiteSparse_config"),
os.path.join("suitesparse", "AMD", "Include"),
os.path.join("suitesparse", "COLAMD", "Include"),
os.path.join("suitesparse", "BTF", "Include"),
os.path.join("suitesparse", "KLU", "Include"),
]
suitesparse_sources = [
os.path.join("suitesparse", "SuiteSparse_config", "SuiteSparse_config.c"),
*glob(os.path.join("suitesparse", "AMD", "Source", "*.c")),
*glob(os.path.join("suitesparse", "COLAMD", "Source", "*.c")),
*glob(os.path.join("suitesparse", "BTF", "Source", "*.c")),
*glob(os.path.join("suitesparse", "KLU", "Source", "*.c")),
]
if sys.platform == "linux": # gcc
extension = Extension(
name="klujax_cpp",
sources=["klujax.cpp", *suitesparse_sources],
include_dirs=include_dirs,
library_dirs=site.getsitepackages(),
extra_compile_args=["-std=c++17"],
extra_link_args=["-static-libgcc", "-static-libstdc++"],
language="c++",
)
elif sys.platform == "win32": # cl
extension = Extension(
name="klujax_cpp",
sources=["klujax.cpp", *suitesparse_sources],
include_dirs=include_dirs,
library_dirs=site.getsitepackages(),
extra_compile_args=["/std:c++17"],
extra_link_args=[],
language="c++",
)
elif sys.platform == "darwin": # MacOS: clang
extension = Extension(
name="klujax_cpp",
sources=["klujax.cpp", *suitesparse_sources],
include_dirs=include_dirs,
library_dirs=site.getsitepackages(),
extra_compile_args=["-std=c++17"],
extra_link_args=[],
language="c++",
)
else:
msg = f"Platform {sys.platform} not supported."
raise RuntimeError(msg)
# Custom BuildExt to enable combined build of C and C++ files on MacOs (clang)
# However, this class also removes some warnings when used on linux (gcc) and
# Windows (cl) so we use it everywhere.
class BuildExt(build_ext):
def build_extension(self, ext):
sources = ext.sources
c_sources = sorted([s for s in sources if s.endswith("c")])
cpp_sources = sorted([s for s in sources if s not in c_sources])
ext_path = self.get_ext_fullpath(ext.name)
macros = ext.define_macros[:]
for undef in ext.undef_macros:
macros.append((undef,))
c_objects = self.compiler.compile(
c_sources,
output_dir=self.build_temp,
macros=macros,
include_dirs=ext.include_dirs,
debug=self.debug,
extra_postargs=[
f
for f in ext.extra_compile_args
if f not in ["-std=c++17", "/std:c++17"] # THIS IS OUR HACK
],
depends=ext.depends,
)
cpp_objects = self.compiler.compile(
cpp_sources,
output_dir=self.build_temp,
macros=macros,
include_dirs=ext.include_dirs,
debug=self.debug,
extra_postargs=ext.extra_compile_args,
depends=ext.depends,
)
objects = c_objects + cpp_objects
extra_args = ext.extra_link_args or []
self.compiler.link_shared_object(
objects,
ext_path,
libraries=self.get_libraries(ext),
library_dirs=ext.library_dirs,
runtime_library_dirs=ext.runtime_library_dirs,
extra_postargs=extra_args,
export_symbols=self.get_export_symbols(ext),
debug=self.debug,
build_temp=self.build_temp,
target_lang=ext.language,
)
setup(
py_modules=["klujax"],
ext_modules=[extension],
cmdclass={"build_ext": BuildExt},
)