Skip to content

Commit fd44e85

Browse files
committed
Propagated numpy 2 fix a bit further. Renamed setup to setup_method in unit tests to avoid deprecation warning
1 parent b603702 commit fd44e85

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

src/linsolve/linsolve.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,13 +571,12 @@ def _invert_solve(self, A, y, rcond):
571571
# vectors if b.ndim was equal to a.ndim - 1.
572572
At = A.transpose([2, 1, 0]).conj()
573573
AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])]
574-
Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])]
574+
Aty = [np.dot(At[k], y[..., k])[..., None] for k in range(y.shape[-1])]
575575

576576
# This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her')
577577

578578
# But this sometimes errors if singular:
579-
print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape)
580-
return np.linalg.solve(AtA, Aty).T[0]
579+
return np.linalg.solve(AtA, Aty)[..., 0].T
581580

582581
def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
583582
"""Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs.
@@ -588,7 +587,7 @@ def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
588587
AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y)
589588
# AtA and Aty don't end up being that sparse, usually, so don't use this:
590589
# --> x = scipy.sparse.linalg.spsolve(AtA, Aty)
591-
return np.linalg.solve(AtA, Aty).T
590+
return np.linalg.solve(AtA, Aty[..., None])[..., 0].T
592591

593592
def _invert_default(self, A, y, rcond):
594593
"""The default inverter, currently 'pinv'."""

tests/test_linsolve.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def test_degen_sol(self):
355355

356356

357357
class TestLinearSolverSparse(TestLinearSolver):
358-
def setup(self):
358+
def setup_method(self):
359359
self.sparse = True
360360
eqs = ["x+y", "x-y"]
361361
x, y = 1, 2
@@ -461,7 +461,7 @@ def test_dtype(self):
461461

462462

463463
class TestLogProductSolverSparse(TestLogProductSolver):
464-
def setup(self):
464+
def setup_method(self):
465465
self.sparse = True
466466

467467

@@ -762,5 +762,5 @@ def test_degen_sol(self):
762762

763763

764764
class TestLinProductSolverSparse(TestLinProductSolver):
765-
def setup(self):
765+
def setup_method(self):
766766
self.sparse = True

0 commit comments

Comments
 (0)