-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
This code causes a massive number of errors:
def test_deriv_params_moffat_with_trunc():
val = jnp.array([2.0, 3.0])
trunc = 20.0
eps = 1e-5
def _run(val_, trunc):
return jnp.max(
jgs.Gaussian(
half_light_radius=val_,
# trunc=trunc,
gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64),
)
.drawImage(nx=48, ny=48, scale=0.2)
.array[24, 24]
** 2
)
gfunc = jax.jit(jax.vmap(jax.grad(_run), in_axes=(0, None)))
gval = gfunc(val, trunc)
gfdiff = (_run(val + eps, trunc) - _run(val - eps, trunc)) / 2.0 / eps
np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels