Skip to content

Commit 90c435a

Browse files
committed
Add dict option for max_order to allow finer control of which cross terms to include
1 parent aa46f3a commit 90c435a

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

piff/basis_interp.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,13 @@ class BasisPolynomial(BasisInterp):
463463
[default: ('u','v')]
464464
:param max_order: The maximum total order to use for cross terms between keys.
465465
[default: None, which uses the maximum value of any individual key's order]
466+
If this is an integer, it applies to all pairs, but you may also specify
467+
a dict mapping pairs of keys to an integer. E.g. {('u','v'):3,
468+
('u','z'):0, ('v','z'):0}. This sets the maximum order for cross terms
469+
between these pairs. Furthermore, any pairs for which you want to skip
470+
cross terms (max=0) may be omitted from the dict.
466471
:param solver: Which solver to use. Solvers available are "scipy", "qr", "jax",
467-
"cpp". See above for details.
472+
"cpp". See above for details. [default: 'scipy']
468473
:param logger: A logger object for logging debug info. [default: None]
469474
"""
470475
_type_name = 'BasisPolynomial'
@@ -515,12 +520,13 @@ def __init__(
515520
logger.warning("JAX not installed. Reverting to numpy/scipy.")
516521
self.solver = "scipy"
517522

518-
if self._max_order<0 or np.any(np.array(self._orders) < 0):
523+
if np.any(np.array(self._orders) < 0):
519524
# Exception if we have any requests for negative orders
520525
raise ValueError('Negative polynomial order specified')
521526

522527
self.kwargs = {
523528
'order' : order,
529+
'max_order' : max_order,
524530
'keys' : keys,
525531
'solver': solver,
526532
}
@@ -529,9 +535,38 @@ def __init__(
529535
# Start with 1d arrays giving orders in all dimensions
530536
ord_ranges = [np.arange(order+1,dtype=int) for order in self._orders]
531537
# Nifty trick to produce n-dim array holding total order
532-
#sumorder = np.sum(np.ix_(*ord_ranges)) # This version doesn't work in numpy 1.19
533538
sumorder = np.sum(np.meshgrid(*ord_ranges, indexing='ij'), axis=0)
534-
self._mask = sumorder <= self._max_order
539+
540+
if isinstance(self._max_order, dict):
541+
# This code is not particularly efficient. Hopefully it doesn't matter.
542+
# Basically set a maxorder for each element in sumorder based on whether it is
543+
# a) a power of a single key. Use the order for that key.
544+
# b) a cross-product of multiple keys. Use it only if it is in the max_order dict.
545+
max_orders = np.zeros_like(sumorder)
546+
547+
def get_indices(arr, pre=()):
548+
# Get the index tuples of the given multi-dimensional array.
549+
if not isinstance(arr, np.ndarray):
550+
yield pre
551+
else:
552+
for i in range(len(arr)):
553+
yield from get_indices(arr[i], pre + (i,))
554+
for index in get_indices(sumorder):
555+
for k, order in enumerate(self._orders):
556+
if index[k] > 0 and all(index[j] == 0 for j in range(len(index)) if j != k):
557+
max_orders[index] = order
558+
for keys, order in self._max_order.items():
559+
kk = [keys.index(key) for key in keys]
560+
ok = True
561+
for k in range(len(self._orders)):
562+
if index[k] > 0 and k not in kk: ok = False
563+
if index[k] == 0 and k in kk: ok = False
564+
if ok:
565+
max_orders[index] = order
566+
else:
567+
max_orders = self._max_order
568+
569+
self._mask = sumorder <= max_orders
535570

536571
def getProperties(self, star):
537572
return np.array([star.data[k] for k in self._keys], dtype=float)
@@ -558,7 +593,6 @@ def basis(self, star):
558593
p[1:] = vals[i]
559594
pows1d.append(np.cumprod(p))
560595
# Use trick to produce outer product of all these powers
561-
#pows2d = np.prod(np.ix_(*pows1d))
562596
pows2d = np.prod(np.meshgrid(*pows1d, indexing='ij'), axis=0)
563597
# Return linear array of terms making total power constraint
564598
return pows2d[self._mask]
@@ -595,4 +629,3 @@ def _finish_read(self, reader):
595629
data = reader.read_table('solution')
596630
assert data is not None
597631
self.q = data['q'][0]
598-

tests/test_pixel.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ def test_basis_interp():
389389
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[-1,0])
390390
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[-4,-1])
391391
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=-2)
392-
np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[3,3], max_order=-1)
393392

394393

395394
@timer
@@ -1728,6 +1727,15 @@ def test_color():
17281727
piff.piffify(config)
17291728
psf = piff.read(psf_file)
17301729

1730+
# Show that the basis includes cross terms:
1731+
np.testing.assert_equal(psf.interp._orders, [2,2,1])
1732+
assert psf.interp._max_order == 2
1733+
s = psf.stars[0]
1734+
np.testing.assert_allclose(
1735+
psf.interp.basis(s),
1736+
[ 1, s['color'], s['v'], s['v']*s['color'], s['v']**2,
1737+
s['u'], s['u']*s['color'], s['u']*s['v'], s['u']**2 ])
1738+
17311739
for s in psf.stars:
17321740
orig_stamp = s.image
17331741
weight = s.weight
@@ -1742,6 +1750,35 @@ def test_color():
17421750
# Anyway, I think the fit is working, just this test doesn't
17431751
# seem quite the right thing.
17441752

1753+
# Repeat without the cross-terms between color and u/v.
1754+
config['psf']['interp']['max_order'] = { ('u','v') : 2 }
1755+
piff.piffify(config)
1756+
psf = piff.read(psf_file)
1757+
1758+
# Show that the basis now doesn't include cross terms:
1759+
np.testing.assert_equal(psf.interp._orders, [2,2,1])
1760+
assert psf.interp._max_order == { ('u','v') : 2 }
1761+
s = psf.stars[0]
1762+
print(s['u'], s['v'], s['color'])
1763+
print('basis = ',psf.interp.basis(s))
1764+
np.testing.assert_allclose(
1765+
psf.interp.basis(s),
1766+
[ 1, s['color'], s['v'], s['v']**2, s['u'], s['u']*s['v'], s['u']**2 ])
1767+
1768+
# Still works just as well without the cross terms.
1769+
for s in psf.stars:
1770+
orig_stamp = s.image
1771+
weight = s.weight
1772+
offset = s.center_to_offset(s.fit.center)
1773+
image = psf.draw(x=s['x'], y=s['y'], color=s['color'],
1774+
stamp_size=32, flux=s.fit.flux, offset=offset)
1775+
resid = image - orig_stamp
1776+
chisq = np.sum(resid.array**2 * weight.array)
1777+
dof = np.sum(weight.array != 0)
1778+
print('color = ',s['color'],'chisq = ',chisq,'dof = ',dof)
1779+
assert chisq < dof * 1.5
1780+
1781+
17451782
@timer
17461783
def test_convert_func():
17471784
"""Test PixelGrid fitting with a non-trivial convert_func

0 commit comments

Comments
 (0)