From 33c89da1df9f0026f211e70be99e6587500bd6db Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Sun, 23 Jul 2023 04:10:17 +0800 Subject: [PATCH 01/13] test: init --- tests/requirements.txt | 2 ++ tests/test_alias.py | 66 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/tests/requirements.txt b/tests/requirements.txt index 87c994e1..2e7acde6 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,6 +3,8 @@ torch >= 1.13 --requirement ../requirements.txt +git+https://github.com/sail-sg/Adan.git + jax[cpu] >= 0.3; platform_system != 'Windows' jaxopt; platform_system != 'Windows' optax; platform_system != 'Windows' diff --git a/tests/test_alias.py b/tests/test_alias.py index a0a78129..ef38dd99 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -144,6 +144,72 @@ def test_sgd( _set_use_chain_flat(True) +@helpers.parametrize( + dtype=[torch.float64], + lr=[1e-2, 1e-3, 1e-4], + betas=[(0.9, 0.999), (0.95, 0.9995)], + eps=[1e-8], + inplace=[True, False], + weight_decay=[0.0, 1e-2], + maximize=[False, True], + use_accelerated_op=[False, True], + use_chain_flat=[True, False], +) +def test_adan( + dtype: torch.dtype, + lr: float, + betas: tuple[float, float], + eps: float, + inplace: bool, + weight_decay: float, + maximize: bool, + use_accelerated_op: bool, + use_chain_flat: bool, +) -> None: + _set_use_chain_flat(use_chain_flat) + + model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype) + + fmodel, params, buffers = functorch.make_functional_with_buffers(model) + optim = torchopt.adan( + lr, + betas=betas, + eps=eps, + eps_root=0.0, + weight_decay=weight_decay, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ) + optim_state = optim.init(params) + optim_ref = torch.optim.adan( + model_ref.parameters(), + lr, + betas=betas, + eps=eps, + amsgrad=False, + weight_decay=weight_decay, + maximize=maximize, + ) + + for xs, ys in loader: + xs = xs.to(dtype=dtype) + pred = fmodel(params, buffers, xs) + pred_ref = model_ref(xs) + loss = F.cross_entropy(pred, ys) + loss_ref = F.cross_entropy(pred_ref, ys) + + grads = torch.autograd.grad(loss, params, allow_unused=True) + updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace) + params = torchopt.apply_updates(params, updates, inplace=inplace) + + optim_ref.zero_grad() + loss_ref.backward() + optim_ref.step() + + helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype) + _set_use_chain_flat(True) + + @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], From de1ab8a4074ef5790f4b9f99f94dbbc4dd0879fd Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Sun, 23 Jul 2023 04:10:56 +0800 Subject: [PATCH 02/13] feat: init adan alias --- torchopt/alias/adan.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 torchopt/alias/adan.py diff --git a/torchopt/alias/adan.py b/torchopt/alias/adan.py new file mode 100644 index 00000000..6632c341 --- /dev/null +++ b/torchopt/alias/adan.py @@ -0,0 +1,14 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== From 69c9578c4fb9635650d980bde18074d275bc3344 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Sun, 23 Jul 2023 04:11:14 +0800 Subject: [PATCH 03/13] feat: init adan optim --- torchopt/optim/adan.py | 91 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 torchopt/optim/adan.py diff --git a/torchopt/optim/adan.py b/torchopt/optim/adan.py new file mode 100644 index 00000000..95b57b66 --- /dev/null +++ b/torchopt/optim/adan.py @@ -0,0 +1,91 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Adan optimizer.""" + +from __future__ import annotations + +from typing import Iterable + +import torch + +from torchopt import alias +from torchopt.optim.base import Optimizer +from torchopt.typing import ScalarOrSchedule + + +__all__ = ['Adan'] + + +class Adan(Optimizer): + """The classic Adan optimizer. + + See Also: + - The functional Adan optimizer: :func:`torchopt.adan`. + - The differentiable meta Adan optimizer: :class:`torchopt.MetaAdan`. + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + params: Iterable[torch.Tensor], + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + max_grad_norm=0.0, + no_prox=False, + *, + eps_root: float = 0.0, + maximize: bool = False, + use_accelerated_op: bool = False, + ) -> None: + r"""Initialize the Adan optimizer. + + Args: + params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what + tensors should be optimized. + lr (float or callable, optional): This is a fixed global scaling factor or a learning + rate scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for computing running averages of + gradient and its square. (default: :const:`(0.9, 0.999)`) + eps (float, optional): A small constant applied to denominator outside of the square + root (as in the Adam paper) to avoid dividing by zero when rescaling. + (default: :const:`1e-8`) + weight_decay (float, optional): Weight decay, add L2 penalty to parameters. + (default: :const:`0.0`) + eps_root (float, optional): A small constant applied to denominator inside the square + root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for + example when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True` the momentums will be created + with flag ``requires_grad=True``, this flag is often used in Meta-Learning + algorithms. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of + minimizing. (default: :data:`False`) + """ + super().__init__( + params, + alias.adan( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + eps_root=eps_root, + moment_requires_grad=False, + maximize=maximize, + use_accelerated_op=use_accelerated_op, + ), + ) From a98001dbb98ef5828ecdec3c479ca564ec595045 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Sun, 23 Jul 2023 04:11:44 +0800 Subject: [PATCH 04/13] feat: init adan transformation --- torchopt/transform/scale_by_adan.py | 283 ++++++++++++++++++++++++++++ 1 file changed, 283 insertions(+) create mode 100644 torchopt/transform/scale_by_adan.py diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py new file mode 100644 index 00000000..18827f4e --- /dev/null +++ b/torchopt/transform/scale_by_adan.py @@ -0,0 +1,283 @@ +# Copyright 2022-2023 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Preset transformations for scaling updates by Adan.""" + +# pylint: disable=invalid-name + +from __future__ import annotations + +from typing import NamedTuple + +import torch + +from torchopt import pytree +from torchopt.base import GradientTransformation +from torchopt.transform.utils import update_moment +from torchopt.typing import OptState, Updates + + +class ScaleByAdanState(NamedTuple): + """State for the Adan algorithm.""" + + count: OptState + mu: Updates + nu: Updates + delta: Updates + grad_tm1: Updates + + +def scale_by_adan( + b1: float = 0.98, + b2: float = 0.92, + b3: float = 0.99, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + """Rescale updates according to the Adan algorithm. + + References: + - Xie et al., 2022: https://arxiv.org/pdf/2208.06677.pdf + + Args: + b1 (float, optional): Decay rate for the exponentially weighted average of gradients. + (default: :const:`0.98`) + b2 (float, optional): Decay rate for the exponentially weighted average of difference of + gradients. + b3: Decay rate for the exponentially weighted average of the squared term. + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + + Returns: + An (init_fn, update_fn) tuple. + """ + return _scale_by_adan( + b1=b1, + b2=b2, + b3=b3, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=False, + ) + + +def _scale_by_adan_flat( + b1: float = 0.98, + b2: float = 0.92, + b3: float = 0.99, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, +) -> GradientTransformation: + return _scale_by_adan( + b1=b1, + b2=b2, + b3=b3, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + already_flattened=True, + ) + + +def _scale_by_adan( + b1: float = 0.98, + b2: float = 0.92, + b3: float = 0.99, + eps: float = 1e-8, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + *, + already_flattened: bool = False, +) -> GradientTransformation: + # pylint: disable=unneeded-not + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + if not 0.0 <= b3 < 1.0: + raise ValueError(f'Invalid beta parameter at index 2: {b3}') + # pylint: enable=unneeded-not + + if already_flattened: # noqa: SIM108 + tree_map = tree_map_flat + else: + tree_map = pytree.tree_map # type: ignore[assignment] + + def init_fn(params: Params) -> OptState: + zero = tree_map( # count init + lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), + params, + ) + mu = tree_map( # first moment + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + nu = tree_map( # second moment + torch.zeros_like, + params, + ) + delta = tree_map( # EWA of Difference of gradients + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), + params, + ) + grad_tm1 = tree_map( + torch.zeros_like, + params, + ) # Previous gradient + return ScaleByAdanState( + count=torch.zeros([], torch.int32), + mu=mu, + nu=nu, + delta=delta, + grad_tm1=grad_tm1, + ) + + def update_fn(updates, state, params=None): + del params + diff = pytree.lax.cond( + state.count != 0, + lambda X, Y: pytree.tree_map(lambda x, y: x - y, X, Y), + lambda X, _: pytree.tree_map(torch.zeros_like, X), + updates, + state.grad_tm1, + ) + + grad_prime = pytree.tree_map(lambda g, d: g + b2 * d, updates, diff) + + mu = update_moment(updates, state.mu, b1, 1) + delta = update_moment(diff, state.delta, b2, 1) + nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2) + + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = utils.cast_tree(bias_correction(mu, b1, count_inc), fo_dtype) + delta_hat = utils.cast_tree(bias_correction(delta, b2, count_inc), fo_dtype) + nu_hat = bias_correction(nu, b3, count_inc) + new_updates = pytree.tree_map( + lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps), + mu_hat, + delta_hat, + nu_hat, + ) + + return new_updates, ScaleByAdanState( + count=count_inc, + mu=mu, + nu=nu, + delta=delta, + grad_tm1=updates, + ) + + return base.GradientTransformation(init_fn, update_fn) + + +# def scale_by_proximal_adan( +# learning_rate: ScalarOrSchedule, +# weight_decay: float, +# b1: float = 0.98, +# b2: float = 0.92, +# b3: float = 0.99, +# eps_root: float = 1e-8, +# fo_dtype: Optional[Any] = None, +# ) -> base.GradientTransformation: +# """Rescale updates according to the proximal version of the Adan algorithm. +# References: +# [Xie et al, 2022](https://arxiv.org/abs/2208.06677) +# Args: +# b1: Decay rate for the exponentially weighted average of gradients. +# b2: Decay rate for the exponentially weighted average of difference of +# gradients. +# b3: Decay rate for the exponentially weighted average of the squared term. +# eps: term added to the denominator to improve numerical stability. +# eps_root: Term added to the denominator inside the square-root to improve +# numerical stability when backpropagating gradients through the rescaling. +# fo_dtype: optional `dtype` to be used for the first order accumulators +# mu and delta; if `None` then the `dtype is inferred from `params` +# and `updates`. +# Returns: +# An (init_fn, update_fn) tuple. +# """ + +# fo_dtype = utils.canonicalize_dtype(fo_dtype) + +# def init_fn(params): +# mu = pytree.tree_map( # First moment +# lambda t: torch.zeros_like(t, dtype=fo_dtype), params) +# nu = pytree.tree_map(torch.zeros_like, params) # Second moment +# delta = pytree.tree_map( # EWA of Difference of gradients +# lambda t: torch.zeros_like(t, dtype=fo_dtype), params) +# grad_tm1 = pytree.tree_map(torch.zeros_like, params) # Previous gradient +# return ScaleByAdanState(count=torch.zeros([], torch.int32), +# mu=mu, nu=nu, delta=delta, grad_tm1=grad_tm1) + +# def update_fn(updates, state, params=None): +# diff = pytree.lax.cond(state.count != 0, +# lambda X, Y: pytree.tree_map(lambda x, y: x - y, X, Y), +# lambda X, _: pytree.tree_map(torch.zeros_like, X), +# updates, state.grad_tm1) + +# grad_prime = pytree.tree_map(lambda g, d: g + b2*d, updates, diff) + +# mu = update_moment(updates, state.mu, b1, 1) +# delta = update_moment(diff, state.delta, b2, 1) +# nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2) + +# count_inc = numerics.safe_int32_increment(state.count) +# mu_hat = bias_correction(mu, b1, count_inc) +# delta_hat = bias_correction(delta, b2, count_inc) +# nu_hat = bias_correction(nu, b3, count_inc) + +# if callable(learning_rate): +# lr = learning_rate(state.count) +# else: +# lr = learning_rate + +# learning_rates = pytree.tree_util.tree_map( +# lambda n: lr / torch.sqrt(n + eps_root), nu_hat) + +# # negative scale: gradient descent +# updates = pytree.tree_util.tree_map(lambda scale, m, v: -scale * (m + b2 * v), +# learning_rates, mu_hat, +# delta_hat) + +# decay = 1. / (1. + weight_decay * lr) +# params_new = pytree.tree_util.tree_map(lambda p, u: +# decay * (p + u), params, +# updates) + +# # params_new - params_old +# new_updates = pytree.tree_util.tree_map(lambda new, old: new - old, params_new, +# params) + +# mu_hat = utils.cast_tree(mu_hat, fo_dtype) +# delta_hat = utils.cast_tree(delta_hat, fo_dtype) + +# return new_updates, ScaleByAdanState(count=count_inc, +# mu=mu, nu=nu, delta=delta, +# grad_tm1=updates) + +# return base.GradientTransformation(init_fn, update_fn) + + +scale_by_adan.flat = _scale_by_adan_flat # type: ignore[attr-defined] +scale_by_adan.impl = _scale_by_adan # type: ignore[attr-defined] From 15fdf139e57932b82438398c21f3e6eb2c537c7b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Jul 2023 20:13:44 +0000 Subject: [PATCH 05/13] fix: [pre-commit.ci] auto fixes [...] --- torchopt/transform/scale_by_adan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py index 18827f4e..9ad0d2df 100644 --- a/torchopt/transform/scale_by_adan.py +++ b/torchopt/transform/scale_by_adan.py @@ -125,15 +125,15 @@ def _scale_by_adan( tree_map = pytree.tree_map # type: ignore[assignment] def init_fn(params: Params) -> OptState: - zero = tree_map( # count init + tree_map( # count init lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params, ) - mu = tree_map( # first moment + mu = tree_map( # first moment lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params, ) - nu = tree_map( # second moment + nu = tree_map( # second moment torch.zeros_like, params, ) From 094582af0815d31818571e94114b5bb49c39edb2 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Sun, 23 Jul 2023 04:19:56 +0800 Subject: [PATCH 06/13] fix: update docstring --- torchopt/optim/adan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchopt/optim/adan.py b/torchopt/optim/adan.py index 95b57b66..7d6449ba 100644 --- a/torchopt/optim/adan.py +++ b/torchopt/optim/adan.py @@ -33,7 +33,7 @@ class Adan(Optimizer): See Also: - The functional Adan optimizer: :func:`torchopt.adan`. - - The differentiable meta Adan optimizer: :class:`torchopt.MetaAdan`. + - The differentiable meta-Adan optimizer: :class:`torchopt.MetaAdan`. """ # pylint: disable-next=too-many-arguments From 53d2bd02f4f7ca2190466e9abff7fde98d53f8f6 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Sun, 23 Jul 2023 13:51:56 +0800 Subject: [PATCH 07/13] fix: update adan transformation --- torchopt/transform/scale_by_adan.py | 134 +++++++++++++++++++--------- 1 file changed, 94 insertions(+), 40 deletions(-) diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py index 9ad0d2df..71255072 100644 --- a/torchopt/transform/scale_by_adan.py +++ b/torchopt/transform/scale_by_adan.py @@ -24,18 +24,40 @@ from torchopt import pytree from torchopt.base import GradientTransformation -from torchopt.transform.utils import update_moment -from torchopt.typing import OptState, Updates +from torchopt.transform.utils import inc_count, tree_map_flat, update_moment +from torchopt.typing import OptState, Params, Updates + + +__all__ = [ + 'scale_by_adan', +] class ScaleByAdanState(NamedTuple): """State for the Adan algorithm.""" - count: OptState mu: Updates nu: Updates delta: Updates grad_tm1: Updates + count: OptState + + +def _adan_bias_correction( + moment: Updates, + decay: float, + count: OptState, + *, + already_flattened: bool = False, +) -> Updates: + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + + def f(t: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + return t.div(1 - pow(decay, c)) + + if already_flattened: + return tree_map_flat(f, moment, count) + return pytree.tree_map(f, moment, count) def scale_by_adan( @@ -55,8 +77,10 @@ def scale_by_adan( b1 (float, optional): Decay rate for the exponentially weighted average of gradients. (default: :const:`0.98`) b2 (float, optional): Decay rate for the exponentially weighted average of difference of - gradients. - b3: Decay rate for the exponentially weighted average of the squared term. + gradients. + (default: :const:`0.92`) + b3 (float, optional): Decay rate for the exponentially weighted average of the squared term. + (default: :const:`0.99`) eps (float, optional): Term added to the denominator to improve numerical stability. (default: :const:`1e-8`) eps_root (float, optional): Term added to the denominator inside the square-root to improve @@ -134,7 +158,7 @@ def init_fn(params: Params) -> OptState: params, ) nu = tree_map( # second moment - torch.zeros_like, + lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params, ) delta = tree_map( # EWA of Difference of gradients @@ -142,53 +166,87 @@ def init_fn(params: Params) -> OptState: params, ) grad_tm1 = tree_map( - torch.zeros_like, + lambda t: torch.zeros_like( + t, + ), params, ) # Previous gradient return ScaleByAdanState( - count=torch.zeros([], torch.int32), mu=mu, nu=nu, delta=delta, grad_tm1=grad_tm1, + count=zero, ) - def update_fn(updates, state, params=None): - del params + def update_fn( + updates: Updates, + state: OptState, + *, + params: Params | None = None, # pylint: disable=unused-argument + inplace: bool = True, + ) -> tuple[Updates, OptState]: diff = pytree.lax.cond( state.count != 0, - lambda X, Y: pytree.tree_map(lambda x, y: x - y, X, Y), - lambda X, _: pytree.tree_map(torch.zeros_like, X), + lambda X, Y: tree_map(lambda x, y: x - y, X, Y), + lambda X, _: tree_map(torch.zeros_like, X), updates, state.grad_tm1, ) - grad_prime = pytree.tree_map(lambda g, d: g + b2 * d, updates, diff) + grad_prime = tree_map(lambda g, d: g + b2 * d, updates, diff) - mu = update_moment(updates, state.mu, b1, 1) - delta = update_moment(diff, state.delta, b2, 1) + mu = update_moment.impl( + updates, + state.mu, + b1, + order=1, + inplace=inplace, + already_flattened=already_flattened, + ) + delta = update_moment.impl( + diff, + state.delta, + b2, + 1, + ) nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2) - count_inc = numerics.safe_int32_increment(state.count) - mu_hat = utils.cast_tree(bias_correction(mu, b1, count_inc), fo_dtype) - delta_hat = utils.cast_tree(bias_correction(delta, b2, count_inc), fo_dtype) - nu_hat = bias_correction(nu, b3, count_inc) - new_updates = pytree.tree_map( - lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps), - mu_hat, - delta_hat, - nu_hat, - ) + count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened) # type: ignore[attr-defined] + mu_hat = _adan_bias_correction(mu, b1, count_inc, already_flattened=already_flattened) + delta_hat = _adan_bias_correction(delta, b2, count_inc, already_flattened=already_flattened) + nu_hat = _adan_bias_correction(nu, b3, count_inc, already_flattened=already_flattened) - return new_updates, ScaleByAdanState( - count=count_inc, - mu=mu, - nu=nu, - delta=delta, - grad_tm1=updates, + if inplace: + + def f( + m: torch.Tensor, + d: torch.Tensor, + n: torch.Tensor, + ) -> torch.Tensor: + return (m + b2 * d).div_(torch.sqrt(n + eps_root).add(eps)) + + else: + + def f( + m: torch.Tensor, + d: torch.Tensor, + n: torch.Tensor, + ) -> torch.Tensor: + return (m + b2 * d).div(torch.sqrt(n + eps_root).add(eps)) + + # lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps), + updates = pytree.tree_map(f, mu_hat, delta_hat, nu_hat) + + return updates, ScaleByAdanState( + count=count_inc, mu=mu, nu=nu, delta=delta, grad_tm1=updates, ) - return base.GradientTransformation(init_fn, update_fn) + return GradientTransformation(init_fn, update_fn) + + +scale_by_adan.flat = _scale_by_adan_flat # type: ignore[attr-defined] +scale_by_adan.impl = _scale_by_adan # type: ignore[attr-defined] # def scale_by_proximal_adan( @@ -243,9 +301,9 @@ def update_fn(updates, state, params=None): # nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2) # count_inc = numerics.safe_int32_increment(state.count) -# mu_hat = bias_correction(mu, b1, count_inc) -# delta_hat = bias_correction(delta, b2, count_inc) -# nu_hat = bias_correction(nu, b3, count_inc) +# mu_hat = _adan_bias_correction(mu, b1, count_inc) +# delta_hat = _adan_bias_correction(delta, b2, count_inc) +# nu_hat = _adan_bias_correction(nu, b3, count_inc) # if callable(learning_rate): # lr = learning_rate(state.count) @@ -276,8 +334,4 @@ def update_fn(updates, state, params=None): # mu=mu, nu=nu, delta=delta, # grad_tm1=updates) -# return base.GradientTransformation(init_fn, update_fn) - - -scale_by_adan.flat = _scale_by_adan_flat # type: ignore[attr-defined] -scale_by_adan.impl = _scale_by_adan # type: ignore[attr-defined] +# return GradientTransformation(init_fn, update_fn) From 5723cf6c7ca295934db5b71861996538fc6df994 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Jul 2023 05:52:13 +0000 Subject: [PATCH 08/13] fix: [pre-commit.ci] auto fixes [...] --- torchopt/transform/scale_by_adan.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py index 71255072..8c1d19bb 100644 --- a/torchopt/transform/scale_by_adan.py +++ b/torchopt/transform/scale_by_adan.py @@ -239,7 +239,11 @@ def f( updates = pytree.tree_map(f, mu_hat, delta_hat, nu_hat) return updates, ScaleByAdanState( - count=count_inc, mu=mu, nu=nu, delta=delta, grad_tm1=updates, + count=count_inc, + mu=mu, + nu=nu, + delta=delta, + grad_tm1=updates, ) return GradientTransformation(init_fn, update_fn) From cb2efb0328afb1a78cf466cb89a496f5284de13e Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Tue, 25 Jul 2023 02:41:40 +0800 Subject: [PATCH 09/13] fix: update requirements --- tests/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 2e7acde6..38b2b735 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,7 +3,7 @@ torch >= 1.13 --requirement ../requirements.txt -git+https://github.com/sail-sg/Adan.git +git+https://github.com/benjamin-eecs/Adan.git jax[cpu] >= 0.3; platform_system != 'Windows' jaxopt; platform_system != 'Windows' From 8e38169f16c0b951795455cb347cc5a58862b66e Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Tue, 25 Jul 2023 02:41:55 +0800 Subject: [PATCH 10/13] fix: update adan transformation --- torchopt/transform/scale_by_adan.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py index 71255072..eff98ca6 100644 --- a/torchopt/transform/scale_by_adan.py +++ b/torchopt/transform/scale_by_adan.py @@ -149,7 +149,7 @@ def _scale_by_adan( tree_map = pytree.tree_map # type: ignore[assignment] def init_fn(params: Params) -> OptState: - tree_map( # count init + zero = tree_map( # count init lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(), params, ) @@ -186,13 +186,10 @@ def update_fn( params: Params | None = None, # pylint: disable=unused-argument inplace: bool = True, ) -> tuple[Updates, OptState]: - diff = pytree.lax.cond( - state.count != 0, - lambda X, Y: tree_map(lambda x, y: x - y, X, Y), - lambda X, _: tree_map(torch.zeros_like, X), - updates, - state.grad_tm1, - ) + if state.count != 0: + diff = tree_map(lambda x, y: x - y, updates, state.grad_tm1) + else: + diff = tree_map(torch.zeros_like, updates) grad_prime = tree_map(lambda g, d: g + b2 * d, updates, diff) @@ -204,14 +201,13 @@ def update_fn( inplace=inplace, already_flattened=already_flattened, ) + nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2) delta = update_moment.impl( diff, state.delta, b2, 1, ) - nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2) - count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened) # type: ignore[attr-defined] mu_hat = _adan_bias_correction(mu, b1, count_inc, already_flattened=already_flattened) delta_hat = _adan_bias_correction(delta, b2, count_inc, already_flattened=already_flattened) @@ -239,7 +235,11 @@ def f( updates = pytree.tree_map(f, mu_hat, delta_hat, nu_hat) return updates, ScaleByAdanState( - count=count_inc, mu=mu, nu=nu, delta=delta, grad_tm1=updates, + count=count_inc, + mu=mu, + nu=nu, + delta=delta, + grad_tm1=updates, ) return GradientTransformation(init_fn, update_fn) From 4d941452cd168fd616b45e4d5755a821c42304ce Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Tue, 25 Jul 2023 03:30:31 +0800 Subject: [PATCH 11/13] test: update Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 0f7dd74e..163bba8d 100644 --- a/Makefile +++ b/Makefile @@ -113,7 +113,7 @@ addlicense-install: go-install pytest: test-install cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \ - $(PYTHON) -m pytest --verbose --color=yes --durations=0 \ + $(PYTHON) -m pytest -k "test_adan" --verbose --color=yes --durations=0 \ --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . From 501d1251b25324237e1c746f4fa1a2de76cc6449 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs <benjaminliu.eecs@gmail.com> Date: Tue, 25 Jul 2023 03:31:01 +0800 Subject: [PATCH 12/13] feat: init adan alias --- tests/test_alias.py | 18 +++++-- torchopt/alias/adan.py | 110 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 5 deletions(-) diff --git a/tests/test_alias.py b/tests/test_alias.py index c18e678f..1f5189c1 100644 --- a/tests/test_alias.py +++ b/tests/test_alias.py @@ -21,6 +21,7 @@ import pytest import torch import torch.nn.functional as F +from adan import Adan import helpers import torchopt @@ -204,10 +205,12 @@ def test_adadelta( @helpers.parametrize( dtype=[torch.float64], lr=[1e-2, 1e-3, 1e-4], - betas=[(0.9, 0.999), (0.95, 0.9995)], + betas=[(0.9, 0.999, 0.998), (0.95, 0.9995, 0.9985)], eps=[1e-8], inplace=[True, False], weight_decay=[0.0, 1e-2], + max_grad_norm=[0.0, 1.0], + no_prox=[False, True], maximize=[False, True], use_accelerated_op=[False, True], use_chain_flat=[True, False], @@ -215,10 +218,12 @@ def test_adadelta( def test_adan( dtype: torch.dtype, lr: float, - betas: tuple[float, float], + betas: tuple[float, float, float], eps: float, inplace: bool, weight_decay: float, + max_grad_norm: float, + no_prox: bool, maximize: bool, use_accelerated_op: bool, use_chain_flat: bool, @@ -234,18 +239,21 @@ def test_adan( eps=eps, eps_root=0.0, weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, maximize=maximize, use_accelerated_op=use_accelerated_op, ) optim_state = optim.init(params) - optim_ref = torch.optim.adan( + optim_ref = Adan( model_ref.parameters(), lr, betas=betas, eps=eps, - amsgrad=False, + eps_root=0.0, weight_decay=weight_decay, - maximize=maximize, + max_grad_norm=max_grad_norm, + no_prox=no_prox, ) for xs, ys in loader: diff --git a/torchopt/alias/adan.py b/torchopt/alias/adan.py index 6632c341..e194ae41 100644 --- a/torchopt/alias/adan.py +++ b/torchopt/alias/adan.py @@ -12,3 +12,113 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Preset :class:`GradientTransformation` for the Adan optimizer.""" + +from __future__ import annotations + +from torchopt.alias.utils import ( + _get_use_chain_flat, + flip_sign_and_add_weight_decay, + scale_by_neg_lr, +) +from torchopt.combine import chain +from torchopt.transform import scale_by_adan +from torchopt.typing import GradientTransformation, ScalarOrSchedule + + +__all__ = ['adan'] + + +# pylint: disable-next=too-many-arguments +def adan( + lr: ScalarOrSchedule = 1e-3, + betas: tuple[float, float, float] = (0.98, 0.92, 0.99), + eps: float = 1e-8, + weight_decay: float = 0.0, + max_grad_norm=0.0, + no_prox=False, + *, + eps_root: float = 0.0, + moment_requires_grad: bool = False, + maximize: bool = False, +) -> GradientTransformation: + """Create a functional version of the adan optimizer. + + adan is an SGD variant with learning rate adaptation. The *learning rate* used for each weight + is computed from estimates of first- and second-order moments of the gradients (using suitable + exponential moving averages). + + References: + - Kingma et al., 2014: https://arxiv.org/abs/1412.6980 + + Args: + lr (float or callable, optional): This is a fixed global scaling factor or a learning rate + scheduler. (default: :const:`1e-3`) + betas (tuple of float, optional): Coefficients used for + first- and second-order moments. (default: :const:`(0.98, 0.92, 0.99)`) + eps (float, optional): Term added to the denominator to improve numerical stability. + (default: :const:`1e-8`) + eps_root (float, optional): Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + (default: :const:`0.0`) + weight_decay (float, optional): Weight decay (L2 penalty). + (default: :const:`0.0`) + max_grad_norm (float, optional): Max norm of the gradients. + (default: :const:`0.0`) + no_prox (bool, optional): If :data:`True`, the proximal term is not applied. + (default: :data:`False`) + eps_root (float, optional): A small constant applied to denominator inside the square root + (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example + when computing (meta-)gradients through Adam. (default: :const:`0.0`) + moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag + ``requires_grad = True``. (default: :data:`False`) + maximize (bool, optional): Maximize the params based on the objective, instead of minimizing. + (default: :data:`False`) + + Returns: + The corresponding :class:`GradientTransformation` instance. + + See Also: + The functional optimizer wrapper :class:`torchopt.FuncOptimizer`. + """ + b1, b2, b3 = betas # pylint: disable=invalid-name + # pylint: disable=unneeded-not + if not 0.0 <= max_grad_norm: + raise ValueError(f'Invalid Max grad norm: {max_grad_norm}') + if not (callable(lr) or lr >= 0.0): # pragma: no cover + raise ValueError(f'Invalid learning rate: {lr}') + if not eps >= 0.0: # pragma: no cover + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= b1 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 0: {b1}') + if not 0.0 <= b2 < 1.0: # pragma: no cover + raise ValueError(f'Invalid beta parameter at index 1: {b2}') + if not 0.0 <= b3 < 1.0: + raise ValueError(f'Invalid beta parameter at index 2: {b3}') + if not weight_decay >= 0.0: # pragma: no cover + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + # pylint: enable=unneeded-not + + chain_fn = chain + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay + adan_scaler_fn = scale_by_adan if no_prox else scale_by_proximal_adan + scale_by_neg_lr_fn = scale_by_neg_lr + + if _get_use_chain_flat(): # default behavior + chain_fn = chain_fn.flat # type: ignore[attr-defined] + flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined] + adan_scaler_fn = adan_scaler_fn.flat # type: ignore[attr-defined] + scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined] + + return chain_fn( + flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize), + adan_scaler_fn( + b1=b1, + b2=b2, + b3=b3, + eps=eps, + eps_root=eps_root, + moment_requires_grad=moment_requires_grad, + ), + scale_by_neg_lr_fn(lr), + ) From e11cced6ae14d89a4156b3cb43bed22e8526649a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jul 2023 19:31:21 +0000 Subject: [PATCH 13/13] fix: [pre-commit.ci] auto fixes [...] --- torchopt/alias/adan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchopt/alias/adan.py b/torchopt/alias/adan.py index e194ae41..a30e2b04 100644 --- a/torchopt/alias/adan.py +++ b/torchopt/alias/adan.py @@ -83,7 +83,7 @@ def adan( """ b1, b2, b3 = betas # pylint: disable=invalid-name # pylint: disable=unneeded-not - if not 0.0 <= max_grad_norm: + if not max_grad_norm >= 0.0: raise ValueError(f'Invalid Max grad norm: {max_grad_norm}') if not (callable(lr) or lr >= 0.0): # pragma: no cover raise ValueError(f'Invalid learning rate: {lr}')