diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 550c6c7..6ea89b9 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -571,13 +571,12 @@ def _invert_solve(self, A, y, rcond): # vectors if b.ndim was equal to a.ndim - 1. At = A.transpose([2, 1, 0]).conj() AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])] - Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])] + Aty = [np.dot(At[k], y[..., k])[..., None] for k in range(y.shape[-1])] # This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her') # But this sometimes errors if singular: - print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape) - return np.linalg.solve(AtA, Aty).T[0] + return np.linalg.solve(AtA, Aty)[..., 0].T def _invert_solve_sparse(self, xs_ys_vals, y, rcond): """Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs. @@ -588,7 +587,7 @@ def _invert_solve_sparse(self, xs_ys_vals, y, rcond): AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y) # AtA and Aty don't end up being that sparse, usually, so don't use this: # --> x = scipy.sparse.linalg.spsolve(AtA, Aty) - return np.linalg.solve(AtA, Aty).T + return np.linalg.solve(AtA, Aty[..., None])[..., 0].T def _invert_default(self, A, y, rcond): """The default inverter, currently 'pinv'.""" diff --git a/tests/test_linsolve.py b/tests/test_linsolve.py index df14d97..a41bcab 100644 --- a/tests/test_linsolve.py +++ b/tests/test_linsolve.py @@ -355,7 +355,7 @@ def test_degen_sol(self): class TestLinearSolverSparse(TestLinearSolver): - def setup(self): + def setup_method(self): self.sparse = True eqs = ["x+y", "x-y"] x, y = 1, 2 @@ -461,7 +461,7 @@ def test_dtype(self): class TestLogProductSolverSparse(TestLogProductSolver): - def setup(self): + def setup_method(self): self.sparse = True @@ -762,5 +762,5 @@ def test_degen_sol(self): class TestLinProductSolverSparse(TestLinProductSolver): - def setup(self): + def setup_method(self): self.sparse = True