diff --git a/piff/basis_interp.py b/piff/basis_interp.py index 29cfdc98..f81c36b8 100644 --- a/piff/basis_interp.py +++ b/piff/basis_interp.py @@ -467,8 +467,13 @@ class BasisPolynomial(BasisInterp): [default: ('u','v')] :param max_order: The maximum total order to use for cross terms between keys. [default: None, which uses the maximum value of any individual key's order] + If this is an integer, it applies to all pairs, but you may also specify + a dict mapping pairs of keys to an integer. E.g. {('u','v'):3, + ('u','z'):0, ('v','z'):0}. This sets the maximum order for cross terms + between these pairs. Furthermore, any pairs for which you want to skip + cross terms (max=0) may be omitted from the dict. :param solver: Which solver to use. Solvers available are "scipy", "qr", "jax", - "cpp". See above for details. + "cpp". See above for details. [default: 'scipy'] :param logger: A logger object for logging debug info. [default: None] """ _type_name = 'BasisPolynomial' @@ -519,12 +524,13 @@ def __init__( logger.info("JAX not installed. Reverting to numpy/scipy.") self.solver = "scipy" - if self._max_order<0 or np.any(np.array(self._orders) < 0): + if np.any(np.array(self._orders) < 0): # Exception if we have any requests for negative orders raise ValueError('Negative polynomial order specified') self.kwargs = { 'order' : order, + 'max_order' : max_order, 'keys' : keys, 'solver': solver, } @@ -533,9 +539,38 @@ def __init__( # Start with 1d arrays giving orders in all dimensions ord_ranges = [np.arange(order+1,dtype=int) for order in self._orders] # Nifty trick to produce n-dim array holding total order - #sumorder = np.sum(np.ix_(*ord_ranges)) # This version doesn't work in numpy 1.19 sumorder = np.sum(np.meshgrid(*ord_ranges, indexing='ij'), axis=0) - self._mask = sumorder <= self._max_order + + if isinstance(self._max_order, dict): + # This code is not particularly efficient. Hopefully it doesn't matter. + # Basically set a maxorder for each element in sumorder based on whether it is + # a) a power of a single key. Use the order for that key. + # b) a cross-product of multiple keys. Use it only if it is in the max_order dict. + max_orders = np.zeros_like(sumorder) + + def get_indices(arr, pre=()): + # Get the index tuples of the given multi-dimensional array. + if not isinstance(arr, np.ndarray): + yield pre + else: + for i in range(len(arr)): + yield from get_indices(arr[i], pre + (i,)) + for index in get_indices(sumorder): + for k, order in enumerate(self._orders): + if index[k] > 0 and all(index[j] == 0 for j in range(len(index)) if j != k): + max_orders[index] = order + for keys, order in self._max_order.items(): + kk = [keys.index(key) for key in keys] + ok = True + for k in range(len(self._orders)): + if index[k] > 0 and k not in kk: ok = False + if index[k] == 0 and k in kk: ok = False + if ok: + max_orders[index] = order + else: + max_orders = self._max_order + + self._mask = sumorder <= max_orders def getProperties(self, star): return np.array([star.data[k] for k in self._keys], dtype=float) @@ -562,7 +597,6 @@ def basis(self, star): p[1:] = vals[i] pows1d.append(np.cumprod(p)) # Use trick to produce outer product of all these powers - #pows2d = np.prod(np.ix_(*pows1d)) pows2d = np.prod(np.meshgrid(*pows1d, indexing='ij'), axis=0) # Return linear array of terms making total power constraint return pows2d[self._mask] @@ -599,4 +633,3 @@ def _finish_read(self, reader): data = reader.read_table('solution') assert data is not None self.q = data['q'][0] - diff --git a/piff/convolvepsf.py b/piff/convolvepsf.py index f1a38895..31dea650 100644 --- a/piff/convolvepsf.py +++ b/piff/convolvepsf.py @@ -45,14 +45,16 @@ class ConvolvePSF(PSF): [default: None] :param chisq_thresh: Change in reduced chisq at which iteration will terminate. [default: 0.1] + :param min_iter: Minimum number of iterations to try. [default: 2] :param max_iter: Maximum number of iterations to try. [default: 30] """ _type_name = 'Convolve' - def __init__(self, components, outliers=None, chisq_thresh=0.1, max_iter=30): + def __init__(self, components, outliers=None, chisq_thresh=0.1, min_iter=2, max_iter=30): self.components = components self.outliers = outliers self.chisq_thresh = chisq_thresh + self.min_iter = min_iter self.max_iter = max_iter self.kwargs = { # If components is a list, mark the number of components here for I/O purposes. @@ -60,6 +62,7 @@ def __init__(self, components, outliers=None, chisq_thresh=0.1, max_iter=30): 'components': len(components) if isinstance(components, list) else components, 'outliers': 0, 'chisq_thresh': self.chisq_thresh, + 'min_iter': self.min_iter, 'max_iter': self.max_iter, } self.chisq = 0. diff --git a/piff/psf.py b/piff/psf.py index c668ecd2..1e8e4dd7 100644 --- a/piff/psf.py +++ b/piff/psf.py @@ -436,7 +436,11 @@ def fit(self, stars, wcs, pointing, logger=None, convert_funcs=None, draw_method # Very simple convergence test here: # Note, the lack of abs here means if chisq increases, we also stop. # Also, don't quit if we removed any outliers. - if (iter_nremoved == 0) and (oldchisq > 0) and (oldchisq-chisq < self.chisq_thresh*dof): + if (iter_nremoved == 0 and + oldchisq > 0 and + oldchisq - chisq < self.chisq_thresh * dof and + iteration+1 >= self.min_iter + ): return oldchisq = chisq diff --git a/piff/readers.py b/piff/readers.py index 581c06a8..8d9a2b27 100644 --- a/piff/readers.py +++ b/piff/readers.py @@ -63,7 +63,12 @@ def read_struct(self, name): cols = self._fits[extname].get_colnames() data = self._fits[extname].read() assert len(data) == 1 - return dict([ (col, data[col][0]) for col in cols ]) + struct = dict([ (col, data[col][0]) for col in cols ]) + # If any dicts were converted to str, convert back. + for k,v in struct.items(): + if isinstance(v, str) and v[0] == '{': + struct[k] = eval(v) + return struct def read_table(self, name, metadata=None): """Load a table as a numpy array. diff --git a/piff/simplepsf.py b/piff/simplepsf.py index 31c083dd..18e3e12e 100644 --- a/piff/simplepsf.py +++ b/piff/simplepsf.py @@ -40,15 +40,17 @@ class SimplePSF(PSF): [default: None] :param chisq_thresh: Change in reduced chisq at which iteration will terminate. [default: 0.1] + :param min_iter: Minimum number of iterations to try. [default: 2] :param max_iter: Maximum number of iterations to try. [default: 30] """ _type_name = 'Simple' - def __init__(self, model, interp, outliers=None, chisq_thresh=0.1, max_iter=30): + def __init__(self, model, interp, outliers=None, chisq_thresh=0.1, min_iter=2, max_iter=30): self.model = model self.interp = interp self.outliers = outliers self.chisq_thresh = chisq_thresh + self.min_iter = min_iter self.max_iter = max_iter self.kwargs = { # Use 0 here for things that will get overwritten in _finish_read. @@ -56,6 +58,7 @@ def __init__(self, model, interp, outliers=None, chisq_thresh=0.1, max_iter=30): 'interp': 0, 'outliers': 0, 'chisq_thresh': self.chisq_thresh, + 'min_iter': self.min_iter, 'max_iter': self.max_iter, } self.chisq = 0. diff --git a/piff/sumpsf.py b/piff/sumpsf.py index b64b0d00..6533d738 100644 --- a/piff/sumpsf.py +++ b/piff/sumpsf.py @@ -45,14 +45,16 @@ class SumPSF(PSF): [default: None] :param chisq_thresh: Change in reduced chisq at which iteration will terminate. [default: 0.1] + :param min_iter: Minimum number of iterations to try. [default: 2] :param max_iter: Maximum number of iterations to try. [default: 30] """ _type_name = 'Sum' - def __init__(self, components, outliers=None, chisq_thresh=0.1, max_iter=30): + def __init__(self, components, outliers=None, chisq_thresh=0.1, min_iter=2, max_iter=30): self.components = components self.outliers = outliers self.chisq_thresh = chisq_thresh + self.min_iter = min_iter self.max_iter = max_iter self.kwargs = { # If components is a list, mark the number of components here for I/O purposes. @@ -60,6 +62,7 @@ def __init__(self, components, outliers=None, chisq_thresh=0.1, max_iter=30): 'components': len(components) if isinstance(components, list) else components, 'outliers': 0, 'chisq_thresh': self.chisq_thresh, + 'min_iter': self.min_iter, 'max_iter': self.max_iter, } self.chisq = 0. diff --git a/piff/writers.py b/piff/writers.py index 1c6c2597..f80b418e 100644 --- a/piff/writers.py +++ b/piff/writers.py @@ -72,6 +72,8 @@ def write_struct(self, name, struct): # Don't add values that are None to the table. if value is None: continue + if isinstance(value, dict): + value = repr(value) dt = make_dtype(key, value) value = adjust_value(value, dt) cols.append([value]) diff --git a/tests/test_convolvepsf.py b/tests/test_convolvepsf.py index f90e718a..d9e7394b 100644 --- a/tests/test_convolvepsf.py +++ b/tests/test_convolvepsf.py @@ -135,6 +135,7 @@ def test_trivial_convolve1(): assert psf.chisq_thresh == 0.2 assert psf.max_iter == 10 + assert psf.min_iter == 2 for i, star in enumerate(psf.stars): target = targets[i] diff --git a/tests/test_pixel.py b/tests/test_pixel.py index 3d969046..dce54b38 100644 --- a/tests/test_pixel.py +++ b/tests/test_pixel.py @@ -389,7 +389,6 @@ def test_basis_interp(): np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[-1,0]) np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[-4,-1]) np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=-2) - np.testing.assert_raises(ValueError, piff.BasisPolynomial, order=[3,3], max_order=-1) @timer @@ -1728,6 +1727,15 @@ def test_color(): piff.piffify(config) psf = piff.read(psf_file) + # Show that the basis includes cross terms: + np.testing.assert_equal(psf.interp._orders, [2,2,1]) + assert psf.interp._max_order == 2 + s = psf.stars[0] + np.testing.assert_allclose( + psf.interp.basis(s), + [ 1, s['color'], s['v'], s['v']*s['color'], s['v']**2, + s['u'], s['u']*s['color'], s['u']*s['v'], s['u']**2 ]) + for s in psf.stars: orig_stamp = s.image weight = s.weight @@ -1742,6 +1750,35 @@ def test_color(): # Anyway, I think the fit is working, just this test doesn't # seem quite the right thing. + # Repeat without the cross-terms between color and u/v. + config['psf']['interp']['max_order'] = { ('u','v') : 2 } + piff.piffify(config) + psf = piff.read(psf_file) + + # Show that the basis now doesn't include cross terms: + np.testing.assert_equal(psf.interp._orders, [2,2,1]) + assert psf.interp._max_order == { ('u','v') : 2 } + s = psf.stars[0] + print(s['u'], s['v'], s['color']) + print('basis = ',psf.interp.basis(s)) + np.testing.assert_allclose( + psf.interp.basis(s), + [ 1, s['color'], s['v'], s['v']**2, s['u'], s['u']*s['v'], s['u']**2 ]) + + # Still works just as well without the cross terms. + for s in psf.stars: + orig_stamp = s.image + weight = s.weight + offset = s.center_to_offset(s.fit.center) + image = psf.draw(x=s['x'], y=s['y'], color=s['color'], + stamp_size=32, flux=s.fit.flux, offset=offset) + resid = image - orig_stamp + chisq = np.sum(resid.array**2 * weight.array) + dof = np.sum(weight.array != 0) + print('color = ',s['color'],'chisq = ',chisq,'dof = ',dof) + assert chisq < dof * 1.5 + + @timer def test_convert_func(): """Test PixelGrid fitting with a non-trivial convert_func diff --git a/tests/test_simple.py b/tests/test_simple.py index 171939f6..73bd4ea0 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -240,6 +240,7 @@ def test_single_image(): # Check default values of options psf = piff.SimplePSF(model, interp) assert psf.chisq_thresh == 0.1 + assert psf.min_iter == 2 assert psf.max_iter == 30 assert psf.outliers == None assert psf.interp_property_names == ('u','v') @@ -260,6 +261,7 @@ def test_single_image(): 'fastfit': True, 'include_pixel': False}, 'interp' : { 'type' : 'Mean' }, + 'min_iter' : 3, 'max_iter' : 10, 'chisq_thresh' : 0.2, }, @@ -269,8 +271,9 @@ def test_single_image(): # Use a SimplePSF to process the stars data this time. interp = piff.Mean() - psf = piff.SimplePSF(model, interp, max_iter=10, chisq_thresh=0.2) + psf = piff.SimplePSF(model, interp, min_iter=3, max_iter=10, chisq_thresh=0.2) assert psf.chisq_thresh == 0.2 + assert psf.min_iter == 3 assert psf.max_iter == 10 # Error if input has no stars @@ -332,6 +335,7 @@ def test_single_image(): assert psf2.chisq == psf.chisq assert psf2.last_delta_chisq == psf.last_delta_chisq assert psf2.chisq_thresh == psf.chisq_thresh + assert psf2.min_iter == psf.min_iter assert psf2.max_iter == psf.max_iter assert psf2.dof == psf.dof assert psf2.nremoved == psf.nremoved @@ -360,6 +364,7 @@ def test_single_image(): assert psf3.chisq == psf.chisq assert psf3.last_delta_chisq == psf.last_delta_chisq assert psf3.chisq_thresh == psf.chisq_thresh + assert psf3.min_iter == psf.min_iter assert psf3.max_iter == psf.max_iter assert psf3.dof == psf.dof assert psf3.nremoved == psf.nremoved @@ -385,6 +390,7 @@ def test_single_image(): assert psf4.chisq == psf.chisq assert psf4.last_delta_chisq == psf.last_delta_chisq assert psf4.chisq_thresh == psf.chisq_thresh + assert psf4.min_iter == psf.min_iter assert psf4.max_iter == psf.max_iter assert psf4.dof == psf.dof assert psf4.nremoved == psf.nremoved