From fd44e85f49da18480eff046a828e915ffd971fa2 Mon Sep 17 00:00:00 2001 From: Aaron Parsons Date: Wed, 8 Jan 2025 15:37:11 -0800 Subject: [PATCH 1/4] Propagated numpy 2 fix a bit further. Renamed setup to setup_method in unit tests to avoid deprecation warning --- src/linsolve/linsolve.py | 7 +++---- tests/test_linsolve.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 550c6c7..6ea89b9 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -571,13 +571,12 @@ def _invert_solve(self, A, y, rcond): # vectors if b.ndim was equal to a.ndim - 1. At = A.transpose([2, 1, 0]).conj() AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])] - Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])] + Aty = [np.dot(At[k], y[..., k])[..., None] for k in range(y.shape[-1])] # This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her') # But this sometimes errors if singular: - print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape) - return np.linalg.solve(AtA, Aty).T[0] + return np.linalg.solve(AtA, Aty)[..., 0].T def _invert_solve_sparse(self, xs_ys_vals, y, rcond): """Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs. @@ -588,7 +587,7 @@ def _invert_solve_sparse(self, xs_ys_vals, y, rcond): AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y) # AtA and Aty don't end up being that sparse, usually, so don't use this: # --> x = scipy.sparse.linalg.spsolve(AtA, Aty) - return np.linalg.solve(AtA, Aty).T + return np.linalg.solve(AtA, Aty[..., None])[..., 0].T def _invert_default(self, A, y, rcond): """The default inverter, currently 'pinv'.""" diff --git a/tests/test_linsolve.py b/tests/test_linsolve.py index df14d97..a41bcab 100644 --- a/tests/test_linsolve.py +++ b/tests/test_linsolve.py @@ -355,7 +355,7 @@ def test_degen_sol(self): class TestLinearSolverSparse(TestLinearSolver): - def setup(self): + def setup_method(self): self.sparse = True eqs = ["x+y", "x-y"] x, y = 1, 2 @@ -461,7 +461,7 @@ def test_dtype(self): class TestLogProductSolverSparse(TestLogProductSolver): - def setup(self): + def setup_method(self): self.sparse = True @@ -762,5 +762,5 @@ def test_degen_sol(self): class TestLinProductSolverSparse(TestLinProductSolver): - def setup(self): + def setup_method(self): self.sparse = True From 816699ea6ad1a42305774f28d5a2faf9550be9b6 Mon Sep 17 00:00:00 2001 From: Aaron Parsons Date: Wed, 8 Jan 2025 16:40:22 -0800 Subject: [PATCH 2/4] Added constant-term functionality to linsolve. --- src/linsolve/linsolve.py | 15 ++++++++++++--- tests/test_linsolve.py | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 6ea89b9..5217c7a 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -177,6 +177,7 @@ def __init__(self, val, **kwargs): self.wgts = kwargs.pop("wgts", np.float32(1.0)) self.has_conj = False constants = kwargs.pop("constants", kwargs) + self.additive_offset = np.float32(0.0) self.process_terms(val, constants) def process_terms(self, terms, constants): @@ -211,11 +212,17 @@ def order_terms(self, terms): for L in terms: L.sort(key=lambda x: get_name(x) in self.prms) # Validate that each term has exactly 1 unsolved parameter. + final_terms = [] for t in terms: - assert get_name(t[-1]) in self.prms + # Check if this term has no free parameters (i.e. additive constant) + if get_name(t[-1]) not in self.prms: + self.additive_offset += self.eval_consts(t) + continue + # Make sure there is no more than 1 free parameter per term for ti in t[:-1]: assert type(ti) is not str or get_name(ti) in self.consts - return terms + final_terms.append(t) + return final_terms def eval_consts(self, const_list, wgts=np.float32(1.0)): """Multiply out constants (and wgts) for placing in matrix.""" @@ -251,6 +258,8 @@ def eval(self, sol): else: total *= sol[name] rv += total + # add back in purely constant terms, which were filtered out of self.terms + rv += self.additive_offset return rv @@ -430,7 +439,7 @@ def get_weighted_data(self): dtype = np.complex64 else: dtype = np.complex128 - d = np.array([self.data[k] for k in self.keys], dtype=dtype) + d = np.array([self.data[k] - eq.additive_offset for k, eq in zip(self.keys, self.eqs)], dtype=dtype) if len(self.wgts) > 0: w = np.array([self.wgts[k] for k in self.keys]) w.shape += (1,) * (d.ndim - w.ndim) diff --git a/tests/test_linsolve.py b/tests/test_linsolve.py index a41bcab..676f4bc 100644 --- a/tests/test_linsolve.py +++ b/tests/test_linsolve.py @@ -124,9 +124,10 @@ def test_term_check(self): terms4 = [["c", "x", "a"], [1, "b", "y"]] with pytest.raises(AssertionError): le.order_terms(terms4) - terms5 = [[1, "a", "b"], [1, "b", "y"]] - with pytest.raises(AssertionError): - le.order_terms(terms5) + terms5 = [["a", "b"], [1, "b", "y"]] + terms = le.order_terms(terms5) + assert len(terms) == 1 + assert le.additive_offset == 8 def test_eval(self): le = linsolve.LinearEquation("a*x-b*y", a=2, b=4) @@ -138,6 +139,9 @@ def test_eval(self): sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)} ans = np.conj(sol["x"]) - sol["y"] np.testing.assert_equal(ans, le.eval(sol)) + le = linsolve.LinearEquation("a*b+a*x-b*y", a=2, b=4) + sol = {'x': 3, 'y': 7} + assert 2 * 4 + 2 * 3 - 4 * 7 == le.eval(sol) class TestLinearSolver: @@ -276,6 +280,23 @@ def test_eval(self): result = ls.eval(sol, "a*x+b*y") np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0]) + def test_eval_const_term(self): + x, y = 1.0, 2.0 + a, b = 3.0 * np.ones(4), 1.0 + eqs = ["a*b+a*x+y", "a+x+b*y"] + d, w = {}, {} + for eq in eqs: + d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4) + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) + sol = ls.solve() + np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64)) + result = ls.eval(sol) + for eq in d: + np.testing.assert_almost_equal(d[eq], result[eq]) + result = ls.eval(sol, "a*b+a*x+b*y") + np.testing.assert_almost_equal(3 * 1 + 3 * 1 + 1 * 2, list(result.values())[0]) + def test_chisq(self): x = 1.0 d = {"x": 1, "a*x": 2} @@ -297,6 +318,14 @@ def test_chisq(self): chisq = ls.chisq(sol) np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6) np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0) + x = 1.0 + d = {"1*x+1": 3.0, "x": 1.0} + w = {"1*x+1": 1.0, "x": 0.5} + ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse) + sol = ls.solve() + chisq = ls.chisq(sol) + np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6) + np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0) def test_dtypes(self): ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse) From a07cf19c01e4f364a173666a0d2c8b225ac28992 Mon Sep 17 00:00:00 2001 From: Aaron Parsons Date: Wed, 8 Jan 2025 18:04:18 -0800 Subject: [PATCH 3/4] Updated docstrings --- src/linsolve/linsolve.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 5217c7a..2364d74 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -12,14 +12,14 @@ describing the equation (which is parsed according to python syntax) and each value is the corresponding "measured" value of that equation. Variable names in equations are checked against keyword arguments to the solver to determine -if they are provided constants or parameters to be solved for. Parameter anmes +if they are provided constants or parameters to be solved for. Parameter names and solutions are return are returned as key:value pairs in ls.solve(). Parallel instances of equations can be evaluated by providing measured values as numpy arrays. Constants can also be arrays that comply with standard numpy broadcasting rules. Finally, weighting is implemented through an optional wgts dictionary that parallels the construction of data. -LinearSolver solves linear equations of the form 'a*x + b*y + c*z'. +LinearSolver solves linear equations of the form 'a*x + b*y + c*z + d'. LogProductSolver uses logrithms to linearize equations of the form 'x*y*z'. LinProductSolver uses symbolic Taylor expansion to linearize equations of the form 'x*y + y*z'. @@ -308,7 +308,7 @@ def infer_dtype(values): class LinearSolver: def __init__(self, data, wgts={}, sparse=False, **kwargs): - """Set up a linear system of equations of the form 1*a + 2*b + 3*c = 4. + """Set up a linear system of equations of the form 1*a + 2*b + 3*c + 4 = 5. Parameters ---------- From a320f2204a71c726d2df594973d2da0a46f3db26 Mon Sep 17 00:00:00 2001 From: Aaron Parsons Date: Wed, 8 Jan 2025 18:05:30 -0800 Subject: [PATCH 4/4] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6f24f6f..70c380b 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ # Overview The solvers in `linsolve` include `LinearSolver`, `LogProductSolver`, and `LinProductSolver`. -`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z'`. +`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z + d'`. `LogProductSolver` uses logrithms to linearize equations of the form `'x*y*z'`. `LinProductSolver` uses symbolic Taylor expansion to linearize equations of the form `'x*y + y*z'`.