Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/linsolve/linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,12 @@ def _invert_solve(self, A, y, rcond):
# vectors if b.ndim was equal to a.ndim - 1.
At = A.transpose([2, 1, 0]).conj()
AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])]
Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])]
Aty = [np.dot(At[k], y[..., k])[..., None] for k in range(y.shape[-1])]

# This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her')

# But this sometimes errors if singular:
print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape)
return np.linalg.solve(AtA, Aty).T[0]
return np.linalg.solve(AtA, Aty)[..., 0].T

def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
"""Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs.
Expand All @@ -588,7 +587,7 @@ def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y)
# AtA and Aty don't end up being that sparse, usually, so don't use this:
# --> x = scipy.sparse.linalg.spsolve(AtA, Aty)
return np.linalg.solve(AtA, Aty).T
return np.linalg.solve(AtA, Aty[..., None])[..., 0].T

def _invert_default(self, A, y, rcond):
"""The default inverter, currently 'pinv'."""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def test_degen_sol(self):


class TestLinearSolverSparse(TestLinearSolver):
def setup(self):
def setup_method(self):
self.sparse = True
eqs = ["x+y", "x-y"]
x, y = 1, 2
Expand Down Expand Up @@ -461,7 +461,7 @@ def test_dtype(self):


class TestLogProductSolverSparse(TestLogProductSolver):
def setup(self):
def setup_method(self):
self.sparse = True


Expand Down Expand Up @@ -762,5 +762,5 @@ def test_degen_sol(self):


class TestLinProductSolverSparse(TestLinProductSolver):
def setup(self):
def setup_method(self):
self.sparse = True
Loading