@@ -19,10 +19,14 @@ def build(
1919) -> Transform :
2020 """Builds BAOA transform.
2121
22- Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8),
23- where it is called ABAO and is conjugate to BAOAB but uses a single gradient
24- evaluation per iteration:
22+ Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8).
2523
24+ BAOA is conjugate to BAOAB (in Leimkuhler and Matthews' terminology) but requires
25+ only a single gradient evaluation per iteration.
26+ The two are equivalent when analyzing functions of the parameter trajectory.
27+ Unlike BAOAB, BAOA is not reversible, but since we don't apply Metropolis-Hastings
28+ or momenta reversal, the algorithm remains functionally identical to BAOAB.
29+
2630 \\ begin{align}
2731 m_{t+1/2} &= m_t + ε \\ nabla \\ log p(θ_t, \\ text{batch}), \\ \\
2832 θ_{t+1/2} &= θ_t + (ε / 2) σ^{-2} m_{t+1/2}, \\ \\
@@ -32,10 +36,6 @@ def build(
3236
3337 for learning rate $\\ epsilon$, temperature $T$, transformed friction $γ = α σ^{-2}$
3438 and transformed noise variance$ζ^2 = T(1 - e^{-2γε})$.
35-
36- The implementation of BAOA instead of BAOAB means that the update is not reversible,
37- but as we don't do any Metropolis-Hastings or momenta reversal the algorithm
38- is functionally equivalent to BAOAB.
3939
4040 Targets $p_T(θ, m) \\ propto \\ exp( (\\ log p(θ) - \\ frac{1}{2σ^2} m^Tm) / T)$
4141 with temperature $T$.
@@ -121,21 +121,11 @@ def update(
121121 inplace : bool = False ,
122122) -> BAOAState :
123123 """Updates parameters and momenta for BAOA.
124-
125- Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8),
126- where it is called ABAO and is conjugate to BAOAB but uses a single gradient
127- evaluation per iteration:
128-
129- \\ begin{align}
130- m_{t+1/2} &= m_t + ε \\ nabla \\ log p(θ_t, \\ text{batch}), \\ \\
131- θ_{t+1/2} &= θ_t + (ε / 2) σ^{-2} m_{t+1/2}, \\ \\
132- m_{t+1} &= e^{-h γ} m_{t+1/2} + N(0, ζ^2 σ^2), \\ \\
133- θ_{t+1} &= θ_{t+1/2} + (ε / 2) σ^{-2} m_{t+1} \\
134- \\ end{align}
135-
136- for learning rate $\\ epsilon$, temperature $T$, $γ = α σ^{-2}$
137- and $ζ^2 = T(1 - e^{-2γε})$.
138-
124+
125+ Algorithm from [Leimkuhler and Matthews, 2015 - p271](https://link.springer.com/ok/10.1007/978-3-319-16375-8).
126+
127+ See [build](baoa.md#posteriors.sgmcmc.baoa.build) for more details.
128+
139129 Args:
140130 state: SGHMCState containing params and momenta.
141131 batch: Data batch to be send to log_posterior.
0 commit comments