Skip to content

Commit 3105ad6

Browse files
authored
[MAINT] Fix tests (#53)
1 parent 48f1534 commit 3105ad6

File tree

2 files changed

+33
-31
lines changed

2 files changed

+33
-31
lines changed

meegkit/utils/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def mldivide(A, B):
3939
try:
4040
# Note: we must use overwrite_a=False in order to be able to
4141
# use the fall-back solution below in case a LinAlgError is raised
42-
return linalg.solve(A, B, sym_pos=True, overwrite_a=False)
42+
return linalg.solve(A, B, assume_a='pos', overwrite_a=False)
4343
except linalg.LinAlgError:
4444
# Singular matrix in solving dual problem. Using least-squares
4545
# solution instead.

tests/test_cca.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
def test_cca():
1111
"""Test CCA."""
1212
# Compare results with Matlab
13-
# x = np.random.randn(1000, 11)
14-
# y = np.random.randn(1000, 9)
13+
# x = rng.randn(1000, 11)
14+
# y = rng.randn(1000, 9)
1515
# x = demean(x).squeeze()
1616
# y = demean(y).squeeze()
1717
mat = loadmat('./tests/data/ccadata.mat')
@@ -47,16 +47,16 @@ def test_cca():
4747
def test_cca2():
4848
"""Simulate correlations."""
4949
# import matplotlib.pyplot as plt
50-
51-
x = np.random.randn(10000, 20)
52-
y = np.random.randn(10000, 8)
50+
rng = np.random.RandomState(2022)
51+
x = rng.randn(10000, 20)
52+
y = rng.randn(10000, 8)
5353
y[:, :2] = x[:, :2]
5454
# perfectly correlated
55-
y[:, 2:4] = x[:, 2:4] + np.random.randn(10000, 2)
55+
y[:, 2:4] = x[:, 2:4] + rng.randn(10000, 2)
5656
# 1/2 correlated
57-
y[:, 4:6] = x[:, 4:6] + np.random.randn(10000, 2) * 3
57+
y[:, 4:6] = x[:, 4:6] + rng.randn(10000, 2) * 3
5858
# 1/4 correlated
59-
y[:, 6:8] = np.random.randn(10000, 2)
59+
y[:, 6:8] = rng.randn(10000, 2)
6060
# uncorrelated
6161
[A, B, R] = nt_cca(x, y)
6262

@@ -100,10 +100,11 @@ def test_canoncorr():
100100

101101
def test_correlated():
102102
"""Test x & y perfectly correlated."""
103-
x = np.random.randn(1000, 10)
104-
y = np.random.randn(1000, 10)
103+
rng = np.random.RandomState(2022)
104+
x = rng.randn(1000, 10)
105+
y = rng.randn(1000, 10)
105106

106-
y = x[:, np.random.permutation(10)] # +0.000001*y;
107+
y = x[:, rng.permutation(10)] # +0.000001*y;
107108

108109
[A1, B1, R1] = nt_cca(x, y)
109110

@@ -140,8 +141,9 @@ def test_cca_lags():
140141

141142
def test_cca_crossvalidate():
142143
"""Test CCA with crossvalidation."""
143-
# x = np.random.randn(1000, 11)
144-
# y = np.random.randn(1000, 9)
144+
rng = np.random.RandomState(2023)
145+
# x = rng.randn(1000, 11)
146+
# y = rng.randn(1000, 9)
145147
# xx = [x, x, x]
146148
# yy = [x[:, :9], y, y]
147149

@@ -157,8 +159,8 @@ def test_cca_crossvalidate():
157159

158160
# Create data where 1st comps should be uncorrelated, and 2nd and 3rd comps
159161
# are very correlated
160-
x = np.random.randn(1000, 10)
161-
y = np.random.randn(1000, 10)
162+
x = rng.randn(1000, 10)
163+
y = rng.randn(1000, 10)
162164
xx = [x, x, x]
163165
yy = [y, x, x]
164166
A, B, R = cca_crossvalidate(xx, yy)
@@ -168,17 +170,18 @@ def test_cca_crossvalidate():
168170

169171
def test_cca_crossvalidate_shifts():
170172
"""Test CCA crossvalidation with shifts."""
173+
rng = np.random.RandomState(2021)
171174
n_times, n_trials = 10000, 2
172-
x = np.random.randn(n_times, 20, n_trials)
173-
y = np.random.randn(n_times, 8, n_trials)
175+
x = rng.randn(n_times, 20, n_trials)
176+
y = rng.randn(n_times, 8, n_trials)
174177
# perfectly correlated
175178
y[:, :2, :] = x[:, :2, :]
176179
# 1/2 correlated
177-
y[:, 2:4, :] = x[:, 2:4, :] + np.random.randn(n_times, 2, n_trials)
180+
y[:, 2:4, :] = x[:, 2:4, :] + rng.randn(n_times, 2, n_trials)
178181
# 1/4 correlated
179-
y[:, 4:6, :] = x[:, 4:6, :] + np.random.randn(n_times, 2, n_trials) * 3
182+
y[:, 4:6, :] = x[:, 4:6, :] + rng.randn(n_times, 2, n_trials) * 3
180183
# uncorrelated
181-
y[:, 6:8, :] = np.random.randn(n_times, 2, n_trials)
184+
y[:, 6:8, :] = rng.randn(n_times, 2, n_trials)
182185

183186
xx = multishift(x, -np.arange(1, 4), reshape=True, solution='valid')
184187
yy = multishift(y, -np.arange(1, 4), reshape=True, solution='valid')
@@ -219,15 +222,14 @@ def test_cca_crossvalidate_shifts2():
219222

220223
def test_mcca(show=False):
221224
"""Test multiway CCA."""
222-
np.random.seed(9)
223-
225+
rng = np.random.RandomState(2021)
224226
# We create 3 uncorrelated data sets. There should be no common structure
225227
# between them.
226228

227229
# Build data
228-
x1 = np.random.randn(10000, 10)
229-
x2 = np.random.randn(10000, 10)
230-
x3 = np.random.randn(10000, 10)
230+
x1 = rng.randn(10000, 10)
231+
x2 = rng.randn(10000, 10)
232+
x3 = rng.randn(10000, 10)
231233
x = np.hstack((x1, x2, x3))
232234
C = np.dot(x.T, x)
233235

@@ -263,10 +265,10 @@ def test_mcca(show=False):
263265
# Now Create 3 data sets with some shared parts.
264266

265267
# Build data
266-
x1 = np.random.randn(10000, 5)
267-
x2 = np.random.randn(10000, 5)
268-
x3 = np.random.randn(10000, 5)
269-
x4 = np.random.randn(10000, 5)
268+
x1 = rng.randn(10000, 5)
269+
x2 = rng.randn(10000, 5)
270+
x3 = rng.randn(10000, 5)
271+
x4 = rng.randn(10000, 5)
270272
x = np.hstack((x2, x1, x3, x1, x4, x1))
271273
C = np.dot(x.T, x)
272274

@@ -299,7 +301,7 @@ def test_mcca(show=False):
299301
# cross-correlation plot).
300302

301303
# Build data
302-
x1 = np.random.randn(10000, 10)
304+
x1 = rng.randn(10000, 10)
303305
x = np.hstack((x1, x1, x1))
304306
C = np.dot(x.T, x)
305307

0 commit comments

Comments
 (0)