diff --git a/boplay/acq_funs/gamma_distribution.py b/boplay/acq_funs/gamma_distribution.py index 4daf46f..c190082 100644 --- a/boplay/acq_funs/gamma_distribution.py +++ b/boplay/acq_funs/gamma_distribution.py @@ -33,7 +33,7 @@ def estimate_gamma_params( s = np.log(x_mean) - np.log(x).mean(axis=1) for _ in range(max_iters): grad = (np.log(k) - digamma(k) - s) / (1.0 / k - polygamma(1, k) + 1e-8) - grad += k**2 * wd + grad += (k - 1) ** 2 * wd k -= grad * lr k = k.clip(min=k_min, max=k_max) theta = x_mean / k diff --git a/boplay/acq_funs/ves_base.py b/boplay/acq_funs/ves_base.py index 99ee3e7..b284fc0 100644 --- a/boplay/acq_funs/ves_base.py +++ b/boplay/acq_funs/ves_base.py @@ -50,12 +50,13 @@ def optimize_adam( max_iters: int, the maximum number of iterations tol: float, the tolerance for the optimization lr: float, the learning rate for the optimization + wd: float, weight decay centered at 1 Returns: theta: pt.Tensor, shape (n_x, 4) final_loss: float, the optimized loss """ - opt = pt.optim.Adam([theta], lr=lr, amsgrad=True, weight_decay=wd) + opt = pt.optim.Adam([theta], lr=lr, amsgrad=True, weight_decay=0) prev_loss = float("inf") L = loss_fn(theta) @@ -67,6 +68,11 @@ def optimize_adam( L = loss_fn(theta) L.backward() opt.step() + + # Apply custom weight decay centered at 1 + if wd > 0: + with pt.no_grad(): + theta.data -= wd * (theta.data - 1) # Early stopping if abs(prev_loss - L.item()) < tol: