Skip to content

Commit 3a9c489

Browse files
authored
Merge pull request #60 from HERA-Team/const_term
Added ability for linsolve equations to have purely constant terms, which are just subracted off of the equated value when solving.
2 parents 9f72a28 + a320f22 commit 3a9c489

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Overview
1010

1111
The solvers in `linsolve` include `LinearSolver`, `LogProductSolver`, and `LinProductSolver`.
12-
`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z'`.
12+
`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z + d'`.
1313
`LogProductSolver` uses logrithms to linearize equations of the form `'x*y*z'`.
1414
`LinProductSolver` uses symbolic Taylor expansion to linearize equations of the
1515
form `'x*y + y*z'`.

src/linsolve/linsolve.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
describing the equation (which is parsed according to python syntax) and each
1313
value is the corresponding "measured" value of that equation. Variable names
1414
in equations are checked against keyword arguments to the solver to determine
15-
if they are provided constants or parameters to be solved for. Parameter anmes
15+
if they are provided constants or parameters to be solved for. Parameter names
1616
and solutions are return are returned as key:value pairs in ls.solve().
1717
Parallel instances of equations can be evaluated by providing measured values
1818
as numpy arrays. Constants can also be arrays that comply with standard numpy
1919
broadcasting rules. Finally, weighting is implemented through an optional wgts
2020
dictionary that parallels the construction of data.
2121
22-
LinearSolver solves linear equations of the form 'a*x + b*y + c*z'.
22+
LinearSolver solves linear equations of the form 'a*x + b*y + c*z + d'.
2323
LogProductSolver uses logrithms to linearize equations of the form 'x*y*z'.
2424
LinProductSolver uses symbolic Taylor expansion to linearize equations of the
2525
form 'x*y + y*z'.
@@ -177,6 +177,7 @@ def __init__(self, val, **kwargs):
177177
self.wgts = kwargs.pop("wgts", np.float32(1.0))
178178
self.has_conj = False
179179
constants = kwargs.pop("constants", kwargs)
180+
self.additive_offset = np.float32(0.0)
180181
self.process_terms(val, constants)
181182

182183
def process_terms(self, terms, constants):
@@ -211,11 +212,17 @@ def order_terms(self, terms):
211212
for L in terms:
212213
L.sort(key=lambda x: get_name(x) in self.prms)
213214
# Validate that each term has exactly 1 unsolved parameter.
215+
final_terms = []
214216
for t in terms:
215-
assert get_name(t[-1]) in self.prms
217+
# Check if this term has no free parameters (i.e. additive constant)
218+
if get_name(t[-1]) not in self.prms:
219+
self.additive_offset += self.eval_consts(t)
220+
continue
221+
# Make sure there is no more than 1 free parameter per term
216222
for ti in t[:-1]:
217223
assert type(ti) is not str or get_name(ti) in self.consts
218-
return terms
224+
final_terms.append(t)
225+
return final_terms
219226

220227
def eval_consts(self, const_list, wgts=np.float32(1.0)):
221228
"""Multiply out constants (and wgts) for placing in matrix."""
@@ -251,6 +258,8 @@ def eval(self, sol):
251258
else:
252259
total *= sol[name]
253260
rv += total
261+
# add back in purely constant terms, which were filtered out of self.terms
262+
rv += self.additive_offset
254263
return rv
255264

256265

@@ -299,7 +308,7 @@ def infer_dtype(values):
299308

300309
class LinearSolver:
301310
def __init__(self, data, wgts={}, sparse=False, **kwargs):
302-
"""Set up a linear system of equations of the form 1*a + 2*b + 3*c = 4.
311+
"""Set up a linear system of equations of the form 1*a + 2*b + 3*c + 4 = 5.
303312
304313
Parameters
305314
----------
@@ -430,7 +439,7 @@ def get_weighted_data(self):
430439
dtype = np.complex64
431440
else:
432441
dtype = np.complex128
433-
d = np.array([self.data[k] for k in self.keys], dtype=dtype)
442+
d = np.array([self.data[k] - eq.additive_offset for k, eq in zip(self.keys, self.eqs)], dtype=dtype)
434443
if len(self.wgts) > 0:
435444
w = np.array([self.wgts[k] for k in self.keys])
436445
w.shape += (1,) * (d.ndim - w.ndim)

tests/test_linsolve.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def test_term_check(self):
124124
terms4 = [["c", "x", "a"], [1, "b", "y"]]
125125
with pytest.raises(AssertionError):
126126
le.order_terms(terms4)
127-
terms5 = [[1, "a", "b"], [1, "b", "y"]]
128-
with pytest.raises(AssertionError):
129-
le.order_terms(terms5)
127+
terms5 = [["a", "b"], [1, "b", "y"]]
128+
terms = le.order_terms(terms5)
129+
assert len(terms) == 1
130+
assert le.additive_offset == 8
130131

131132
def test_eval(self):
132133
le = linsolve.LinearEquation("a*x-b*y", a=2, b=4)
@@ -138,6 +139,9 @@ def test_eval(self):
138139
sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)}
139140
ans = np.conj(sol["x"]) - sol["y"]
140141
np.testing.assert_equal(ans, le.eval(sol))
142+
le = linsolve.LinearEquation("a*b+a*x-b*y", a=2, b=4)
143+
sol = {'x': 3, 'y': 7}
144+
assert 2 * 4 + 2 * 3 - 4 * 7 == le.eval(sol)
141145

142146

143147
class TestLinearSolver:
@@ -276,6 +280,23 @@ def test_eval(self):
276280
result = ls.eval(sol, "a*x+b*y")
277281
np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0])
278282

283+
def test_eval_const_term(self):
284+
x, y = 1.0, 2.0
285+
a, b = 3.0 * np.ones(4), 1.0
286+
eqs = ["a*b+a*x+y", "a+x+b*y"]
287+
d, w = {}, {}
288+
for eq in eqs:
289+
d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4)
290+
ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse)
291+
sol = ls.solve()
292+
np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64))
293+
np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64))
294+
result = ls.eval(sol)
295+
for eq in d:
296+
np.testing.assert_almost_equal(d[eq], result[eq])
297+
result = ls.eval(sol, "a*b+a*x+b*y")
298+
np.testing.assert_almost_equal(3 * 1 + 3 * 1 + 1 * 2, list(result.values())[0])
299+
279300
def test_chisq(self):
280301
x = 1.0
281302
d = {"x": 1, "a*x": 2}
@@ -297,6 +318,14 @@ def test_chisq(self):
297318
chisq = ls.chisq(sol)
298319
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
299320
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)
321+
x = 1.0
322+
d = {"1*x+1": 3.0, "x": 1.0}
323+
w = {"1*x+1": 1.0, "x": 0.5}
324+
ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse)
325+
sol = ls.solve()
326+
chisq = ls.chisq(sol)
327+
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
328+
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)
300329

301330
def test_dtypes(self):
302331
ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse)

0 commit comments

Comments
 (0)