Skip to content

Commit d83a888

Browse files
committed
accelerate directlingam with cuda implementation
1 parent 340d60a commit d83a888

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

causallearn/search/FCMBased/lingam/direct_lingam.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,20 @@
88
from sklearn.utils import check_array
99

1010
from .base import _BaseLiNGAM
11-
11+
try:
12+
from lingam_cuda import causal_order as causal_order_gpu
13+
except ImportError:
14+
pass
1215

1316
class DirectLiNGAM(_BaseLiNGAM):
1417
"""Implementation of DirectLiNGAM Algorithm [1]_ [2]_
1518
1619
References
1720
----------
18-
.. [1] S. Shimizu, T. Inazumi, Y. Sogawa, A. Hyvärinen, Y. Kawahara, T. Washio, P. O. Hoyer and K. Bollen.
21+
.. [1] S. Shimizu, T. Inazumi, Y. Sogawa, A. Hyvärinen, Y. Kawahara, T. Washio, P. O. Hoyer and K. Bollen.
1922
DirectLiNGAM: A direct method for learning a linear non-Gaussian structural equation model. Journal of Machine Learning Research, 12(Apr): 1225--1248, 2011.
20-
.. [2] A. Hyvärinen and S. M. Smith. Pairwise likelihood ratios for estimation of non-Gaussian structural eauation models.
21-
Journal of Machine Learning Research 14:111-152, 2013.
23+
.. [2] A. Hyvärinen and S. M. Smith. Pairwise likelihood ratios for estimation of non-Gaussian structural eauation models.
24+
Journal of Machine Learning Research 14:111-152, 2013.
2225
"""
2326

2427
def __init__(self, random_state=None, prior_knowledge=None, apply_prior_knowledge_softly=False, measure='pwling'):
@@ -38,7 +41,7 @@ def __init__(self, random_state=None, prior_knowledge=None, apply_prior_knowledg
3841
* ``-1`` : No prior background_knowledge is available to know if either of the two cases above (0 or 1) is true.
3942
apply_prior_knowledge_softly : boolean, optional (default=False)
4043
If True, apply prior background_knowledge softly.
41-
measure : {'pwling', 'kernel'}, optional (default='pwling')
44+
measure : {'pwling', 'kernel', 'pwling_fast'}, optional (default='pwling')
4245
Measure to evaluate independence: 'pwling' [2]_ or 'kernel' [1]_.
4346
"""
4447
super().__init__(random_state)
@@ -86,6 +89,8 @@ def fit(self, X):
8689
for _ in range(n_features):
8790
if self._measure == 'kernel':
8891
m = self._search_causal_order_kernel(X_, U)
92+
elif self._measure == "pwling_fast":
93+
m = self._search_causal_order_gpu(X_.astype(np.float64), U.astype(np.int32))
8994
else:
9095
m = self._search_causal_order(X_, U)
9196
for i in U:
@@ -257,3 +262,24 @@ def _search_causal_order_kernel(self, X, U):
257262
Tkernels.append(Tkernel)
258263

259264
return Uc[np.argmin(Tkernels)]
265+
266+
def _search_causal_order_gpu(self, X, U):
267+
"""Accelerated Causal ordering.
268+
Parameters
269+
----------
270+
X : array-like, shape (n_samples, n_features)
271+
Training data, where ``n_samples`` is the number of samples
272+
and ``n_features`` is the number of features.
273+
U: indices of cols in X
274+
Returns
275+
-------
276+
self : object
277+
Returns the instance itself.
278+
mlist: causal ordering
279+
"""
280+
cols = len(U)
281+
rows = len(X)
282+
283+
arr = X[:, np.array(U)]
284+
mlist = causal_order_gpu(arr, rows, cols)
285+
return U[np.argmax(mlist)]

setup.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
'pydot',
2525
'tqdm'
2626
],
27+
extras_require={
28+
'gpu': ['culingam'] # optional dependency for accelerated lingam. cuda required.
29+
},
2730
url='https://github.com/py-why/causal-learn',
2831
packages=setuptools.find_packages(),
2932
classifiers=[

tests/TestDirectLiNGAMfast.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import sys
2+
3+
sys.path.append("")
4+
import unittest
5+
from pickle import load
6+
7+
import numpy as np
8+
import pandas as pd
9+
import subprocess
10+
11+
from causallearn.search.FCMBased import lingam
12+
13+
def get_cuda_version():
14+
try:
15+
nvcc_version = subprocess.check_output(["nvcc", "--version"]).decode('utf-8')
16+
print("CUDA Version found:\n", nvcc_version)
17+
return True
18+
except Exception as e:
19+
print("CUDA not found or nvcc not in PATH:", e)
20+
return False
21+
22+
class TestDirectLiNGAMFast(unittest.TestCase):
23+
24+
def test_DirectLiNGAM(self):
25+
np.set_printoptions(precision=3, suppress=True)
26+
np.random.seed(100)
27+
x3 = np.random.uniform(size=1000)
28+
x0 = 3.0 * x3 + np.random.uniform(size=1000)
29+
x2 = 6.0 * x3 + np.random.uniform(size=1000)
30+
x1 = 3.0 * x0 + 2.0 * x2 + np.random.uniform(size=1000)
31+
x5 = 4.0 * x0 + np.random.uniform(size=1000)
32+
x4 = 8.0 * x0 - 1.0 * x2 + np.random.uniform(size=1000)
33+
X = pd.DataFrame(np.array([x0, x1, x2, x3, x4, x5]).T, columns=['x0', 'x1', 'x2', 'x3', 'x4', 'x5'])
34+
35+
cuda = get_cuda_version()
36+
if cuda:
37+
model = lingam.DirectLiNGAM()
38+
model.fit(X)
39+
40+
print(model.causal_order_)
41+
print(model.adjacency_matrix_)

0 commit comments

Comments
 (0)