Skip to content

Commit 7bff7a8

Browse files
committed
Make the same change for GSObjectModel
1 parent 43b24bf commit 7bff7a8

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

piff/gsobject_model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ class GSObjectModel(Model):
5151
:param init: Initialization method. [default: None, which uses hsm unless a PSF
5252
class specifies a different default.]
5353
:param fit_flux: If True, the PSF model will include the flux value. This is useful when
54-
this model is an element of a Sum composite PSF. [default: False]
54+
this model is an element of a Sum composite PSF. [default: False,
55+
unless init=='zero', in which case it is automatically True.]
5556
:param scipy_kwargs: Optional kwargs to pass to scipy.optimize.least_squares [default: None]
5657
:param logger: A logger object for logging debug info. [default: None]
5758
"""
@@ -382,11 +383,20 @@ def initialize(self, star, logger=None, default_init=None):
382383
flux_scaling = None
383384
if init == 'zero':
384385
flux_scaling = 1.e-6
386+
# Also, make sure fit_flux=True. Otherwise, this won't work properly.
387+
if not self._fit_flux:
388+
logger.info('Setting fit_flux=True, since init=zero')
389+
self._fit_flux = True
390+
self.kwargs['fit_flux'] = True
385391
elif init == 'delta':
386392
size *= 1.e-6
387393
elif isinstance(init, tuple):
388394
flux_scaling, size_scaling = init
389395
size *= size_scaling
396+
if not self._fit_flux:
397+
logger.info(f'Setting fit_flux=True, since init={init}')
398+
self._fit_flux = True
399+
self.kwargs['fit_flux'] = True
390400
elif init.startswith('(') and init.endswith(')'):
391401
flux_scaling, size_scaling = eval(init)
392402
size *= size_scaling
@@ -405,8 +415,8 @@ def initialize(self, star, logger=None, default_init=None):
405415
params = [flux_scaling] + params
406416
params_var = [0.] + params_var
407417
else:
408-
if flux_scaling is not None:
409-
raise ValueError("{} initialization requires fit_flux=True".format(init))
418+
# This should have been guaranteed above.
419+
assert flux_scaling is None
410420
params = np.array(params)
411421
params_var = np.array(params_var)
412422

tests/test_gsobject_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,13 @@ def test_simple():
284284
np.testing.assert_allclose(fit.center[0], du, rtol=0, atol=1e-2)
285285
np.testing.assert_allclose(fit.center[1], dv, rtol=0, atol=1e-2)
286286

287-
# init=zero required fit_flux
287+
# init=zero forces fit_flux=True
288288
config['model']['fit_flux'] = False
289289
model = piff.Model.process(config['model'], logger)
290290
psf1 = piff.SimplePSF(model, None)
291-
with np.testing.assert_raises(ValueError):
292-
model.initialize(fiducial_star)
291+
with CaptureLog() as cl:
292+
model.initialize(fiducial_star, logger=cl.logger)
293+
assert "Setting fit_flux=True" in cl.output
293294

294295
print('Initializing with delta')
295296
config['model']['init'] = 'delta'
@@ -377,11 +378,13 @@ def test_simple():
377378
np.testing.assert_allclose(fit.center[0], du, rtol=0, atol=1e-2)
378379
np.testing.assert_allclose(fit.center[1], dv, rtol=0, atol=1e-2)
379380

381+
# tuple init also forces fit_flux=True.
380382
config['model']['fit_flux'] = False
381383
model = piff.Model.process(config['model'], logger)
382384
psf1 = piff.SimplePSF(model, None)
383-
with np.testing.assert_raises(ValueError):
384-
model.initialize(fiducial_star)
385+
with CaptureLog() as cl:
386+
model.initialize(fiducial_star, logger=cl.logger)
387+
assert "Setting fit_flux=True" in cl.output
385388

386389
# Invalid init method raises an error
387390
config['model']['init'] = 'invalid'

0 commit comments

Comments
 (0)