From 33c89da1df9f0026f211e70be99e6587500bd6db Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Sun, 23 Jul 2023 04:10:17 +0800
Subject: [PATCH 01/13] test: init

---
 tests/requirements.txt |  2 ++
 tests/test_alias.py    | 66 ++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 68 insertions(+)

diff --git a/tests/requirements.txt b/tests/requirements.txt
index 87c994e1..2e7acde6 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -3,6 +3,8 @@ torch >= 1.13
 
 --requirement ../requirements.txt
 
+git+https://github.com/sail-sg/Adan.git
+
 jax[cpu] >= 0.3; platform_system != 'Windows'
 jaxopt; platform_system != 'Windows'
 optax; platform_system != 'Windows'
diff --git a/tests/test_alias.py b/tests/test_alias.py
index a0a78129..ef38dd99 100644
--- a/tests/test_alias.py
+++ b/tests/test_alias.py
@@ -144,6 +144,72 @@ def test_sgd(
     _set_use_chain_flat(True)
 
 
+@helpers.parametrize(
+    dtype=[torch.float64],
+    lr=[1e-2, 1e-3, 1e-4],
+    betas=[(0.9, 0.999), (0.95, 0.9995)],
+    eps=[1e-8],
+    inplace=[True, False],
+    weight_decay=[0.0, 1e-2],
+    maximize=[False, True],
+    use_accelerated_op=[False, True],
+    use_chain_flat=[True, False],
+)
+def test_adan(
+    dtype: torch.dtype,
+    lr: float,
+    betas: tuple[float, float],
+    eps: float,
+    inplace: bool,
+    weight_decay: float,
+    maximize: bool,
+    use_accelerated_op: bool,
+    use_chain_flat: bool,
+) -> None:
+    _set_use_chain_flat(use_chain_flat)
+
+    model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)
+
+    fmodel, params, buffers = functorch.make_functional_with_buffers(model)
+    optim = torchopt.adan(
+        lr,
+        betas=betas,
+        eps=eps,
+        eps_root=0.0,
+        weight_decay=weight_decay,
+        maximize=maximize,
+        use_accelerated_op=use_accelerated_op,
+    )
+    optim_state = optim.init(params)
+    optim_ref = torch.optim.adan(
+        model_ref.parameters(),
+        lr,
+        betas=betas,
+        eps=eps,
+        amsgrad=False,
+        weight_decay=weight_decay,
+        maximize=maximize,
+    )
+
+    for xs, ys in loader:
+        xs = xs.to(dtype=dtype)
+        pred = fmodel(params, buffers, xs)
+        pred_ref = model_ref(xs)
+        loss = F.cross_entropy(pred, ys)
+        loss_ref = F.cross_entropy(pred_ref, ys)
+
+        grads = torch.autograd.grad(loss, params, allow_unused=True)
+        updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
+        params = torchopt.apply_updates(params, updates, inplace=inplace)
+
+        optim_ref.zero_grad()
+        loss_ref.backward()
+        optim_ref.step()
+
+    helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
+    _set_use_chain_flat(True)
+
+
 @helpers.parametrize(
     dtype=[torch.float64],
     lr=[1e-2, 1e-3, 1e-4],

From de1ab8a4074ef5790f4b9f99f94dbbc4dd0879fd Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Sun, 23 Jul 2023 04:10:56 +0800
Subject: [PATCH 02/13] feat: init adan alias

---
 torchopt/alias/adan.py | 14 ++++++++++++++
 1 file changed, 14 insertions(+)
 create mode 100644 torchopt/alias/adan.py

diff --git a/torchopt/alias/adan.py b/torchopt/alias/adan.py
new file mode 100644
index 00000000..6632c341
--- /dev/null
+++ b/torchopt/alias/adan.py
@@ -0,0 +1,14 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================

From 69c9578c4fb9635650d980bde18074d275bc3344 Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Sun, 23 Jul 2023 04:11:14 +0800
Subject: [PATCH 03/13] feat: init adan optim

---
 torchopt/optim/adan.py | 91 ++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 91 insertions(+)
 create mode 100644 torchopt/optim/adan.py

diff --git a/torchopt/optim/adan.py b/torchopt/optim/adan.py
new file mode 100644
index 00000000..95b57b66
--- /dev/null
+++ b/torchopt/optim/adan.py
@@ -0,0 +1,91 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Adan optimizer."""
+
+from __future__ import annotations
+
+from typing import Iterable
+
+import torch
+
+from torchopt import alias
+from torchopt.optim.base import Optimizer
+from torchopt.typing import ScalarOrSchedule
+
+
+__all__ = ['Adan']
+
+
+class Adan(Optimizer):
+    """The classic Adan optimizer.
+
+    See Also:
+        - The functional Adan optimizer: :func:`torchopt.adan`.
+        - The differentiable meta Adan optimizer: :class:`torchopt.MetaAdan`.
+    """
+
+    # pylint: disable-next=too-many-arguments
+    def __init__(
+        self,
+        params: Iterable[torch.Tensor],
+        lr: ScalarOrSchedule = 1e-3,
+        betas: tuple[float, float] = (0.9, 0.999),
+        eps: float = 1e-8,
+        weight_decay: float = 0.0,
+        max_grad_norm=0.0,
+        no_prox=False,
+        *,
+        eps_root: float = 0.0,
+        maximize: bool = False,
+        use_accelerated_op: bool = False,
+    ) -> None:
+        r"""Initialize the Adan optimizer.
+
+        Args:
+            params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what
+                tensors should be optimized.
+            lr (float or callable, optional): This is a fixed global scaling factor or a learning
+                rate scheduler. (default: :const:`1e-3`)
+            betas (tuple of float, optional): Coefficients used for computing running averages of
+                gradient and its square. (default: :const:`(0.9, 0.999)`)
+            eps (float, optional): A small constant applied to denominator outside of the square
+                root (as in the Adam paper) to avoid dividing by zero when rescaling.
+                (default: :const:`1e-8`)
+            weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
+                (default: :const:`0.0`)
+            eps_root (float, optional): A small constant applied to denominator inside the square
+                root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for
+                example when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+            moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
+                with flag ``requires_grad=True``, this flag is often used in Meta-Learning
+                algorithms. (default: :data:`False`)
+            maximize (bool, optional): Maximize the params based on the objective, instead of
+                minimizing. (default: :data:`False`)
+        """
+        super().__init__(
+            params,
+            alias.adan(
+                lr=lr,
+                betas=betas,
+                eps=eps,
+                weight_decay=weight_decay,
+                max_grad_norm=max_grad_norm,
+                no_prox=no_prox,
+                eps_root=eps_root,
+                moment_requires_grad=False,
+                maximize=maximize,
+                use_accelerated_op=use_accelerated_op,
+            ),
+        )

From a98001dbb98ef5828ecdec3c479ca564ec595045 Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Sun, 23 Jul 2023 04:11:44 +0800
Subject: [PATCH 04/13] feat: init adan transformation

---
 torchopt/transform/scale_by_adan.py | 283 ++++++++++++++++++++++++++++
 1 file changed, 283 insertions(+)
 create mode 100644 torchopt/transform/scale_by_adan.py

diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py
new file mode 100644
index 00000000..18827f4e
--- /dev/null
+++ b/torchopt/transform/scale_by_adan.py
@@ -0,0 +1,283 @@
+# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Preset transformations for scaling updates by Adan."""
+
+# pylint: disable=invalid-name
+
+from __future__ import annotations
+
+from typing import NamedTuple
+
+import torch
+
+from torchopt import pytree
+from torchopt.base import GradientTransformation
+from torchopt.transform.utils import update_moment
+from torchopt.typing import OptState, Updates
+
+
+class ScaleByAdanState(NamedTuple):
+    """State for the Adan algorithm."""
+
+    count: OptState
+    mu: Updates
+    nu: Updates
+    delta: Updates
+    grad_tm1: Updates
+
+
+def scale_by_adan(
+    b1: float = 0.98,
+    b2: float = 0.92,
+    b3: float = 0.99,
+    eps: float = 1e-8,
+    eps_root: float = 0.0,
+    moment_requires_grad: bool = False,
+) -> GradientTransformation:
+    """Rescale updates according to the Adan algorithm.
+
+    References:
+        - Xie et al., 2022: https://arxiv.org/pdf/2208.06677.pdf
+
+    Args:
+        b1 (float, optional): Decay rate for the exponentially weighted average of gradients.
+            (default: :const:`0.98`)
+        b2 (float, optional): Decay rate for the exponentially weighted average of difference of
+        gradients.
+        b3: Decay rate for the exponentially weighted average of the squared term.
+        eps (float, optional): Term added to the denominator to improve numerical stability.
+            (default: :const:`1e-8`)
+        eps_root (float, optional): Term added to the denominator inside the square-root to improve
+            numerical stability when backpropagating gradients through the rescaling.
+            (default: :const:`0.0`)
+        moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag
+            ``requires_grad = True``. (default: :data:`False`)
+
+    Returns:
+        An (init_fn, update_fn) tuple.
+    """
+    return _scale_by_adan(
+        b1=b1,
+        b2=b2,
+        b3=b3,
+        eps=eps,
+        eps_root=eps_root,
+        moment_requires_grad=moment_requires_grad,
+        already_flattened=False,
+    )
+
+
+def _scale_by_adan_flat(
+    b1: float = 0.98,
+    b2: float = 0.92,
+    b3: float = 0.99,
+    eps: float = 1e-8,
+    eps_root: float = 0.0,
+    moment_requires_grad: bool = False,
+) -> GradientTransformation:
+    return _scale_by_adan(
+        b1=b1,
+        b2=b2,
+        b3=b3,
+        eps=eps,
+        eps_root=eps_root,
+        moment_requires_grad=moment_requires_grad,
+        already_flattened=True,
+    )
+
+
+def _scale_by_adan(
+    b1: float = 0.98,
+    b2: float = 0.92,
+    b3: float = 0.99,
+    eps: float = 1e-8,
+    eps_root: float = 0.0,
+    moment_requires_grad: bool = False,
+    *,
+    already_flattened: bool = False,
+) -> GradientTransformation:
+    # pylint: disable=unneeded-not
+    if not eps >= 0.0:  # pragma: no cover
+        raise ValueError(f'Invalid epsilon value: {eps}')
+    if not 0.0 <= b1 < 1.0:  # pragma: no cover
+        raise ValueError(f'Invalid beta parameter at index 0: {b1}')
+    if not 0.0 <= b2 < 1.0:  # pragma: no cover
+        raise ValueError(f'Invalid beta parameter at index 1: {b2}')
+    if not 0.0 <= b3 < 1.0:
+        raise ValueError(f'Invalid beta parameter at index 2: {b3}')
+    # pylint: enable=unneeded-not
+
+    if already_flattened:  # noqa: SIM108
+        tree_map = tree_map_flat
+    else:
+        tree_map = pytree.tree_map  # type: ignore[assignment]
+
+    def init_fn(params: Params) -> OptState:
+        zero = tree_map(  # count init
+            lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(),
+            params,
+        )
+        mu = tree_map( # first moment
+            lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+            params,
+        )
+        nu = tree_map( # second moment
+            torch.zeros_like,
+            params,
+        )
+        delta = tree_map(  # EWA of Difference of gradients
+            lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
+            params,
+        )
+        grad_tm1 = tree_map(
+            torch.zeros_like,
+            params,
+        )  # Previous gradient
+        return ScaleByAdanState(
+            count=torch.zeros([], torch.int32),
+            mu=mu,
+            nu=nu,
+            delta=delta,
+            grad_tm1=grad_tm1,
+        )
+
+    def update_fn(updates, state, params=None):
+        del params
+        diff = pytree.lax.cond(
+            state.count != 0,
+            lambda X, Y: pytree.tree_map(lambda x, y: x - y, X, Y),
+            lambda X, _: pytree.tree_map(torch.zeros_like, X),
+            updates,
+            state.grad_tm1,
+        )
+
+        grad_prime = pytree.tree_map(lambda g, d: g + b2 * d, updates, diff)
+
+        mu = update_moment(updates, state.mu, b1, 1)
+        delta = update_moment(diff, state.delta, b2, 1)
+        nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)
+
+        count_inc = numerics.safe_int32_increment(state.count)
+        mu_hat = utils.cast_tree(bias_correction(mu, b1, count_inc), fo_dtype)
+        delta_hat = utils.cast_tree(bias_correction(delta, b2, count_inc), fo_dtype)
+        nu_hat = bias_correction(nu, b3, count_inc)
+        new_updates = pytree.tree_map(
+            lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps),
+            mu_hat,
+            delta_hat,
+            nu_hat,
+        )
+
+        return new_updates, ScaleByAdanState(
+            count=count_inc,
+            mu=mu,
+            nu=nu,
+            delta=delta,
+            grad_tm1=updates,
+        )
+
+    return base.GradientTransformation(init_fn, update_fn)
+
+
+# def scale_by_proximal_adan(
+#     learning_rate: ScalarOrSchedule,
+#     weight_decay: float,
+#     b1: float = 0.98,
+#     b2: float = 0.92,
+#     b3: float = 0.99,
+#     eps_root: float = 1e-8,
+#     fo_dtype: Optional[Any] = None,
+# ) -> base.GradientTransformation:
+#   """Rescale updates according to the proximal version of the Adan algorithm.
+#   References:
+#     [Xie et al, 2022](https://arxiv.org/abs/2208.06677)
+#   Args:
+#     b1: Decay rate for the exponentially weighted average of gradients.
+#     b2: Decay rate for the exponentially weighted average of difference of
+#       gradients.
+#     b3: Decay rate for the exponentially weighted average of the squared term.
+#     eps: term added to the denominator to improve numerical stability.
+#     eps_root: Term added to the denominator inside the square-root to improve
+#       numerical stability when backpropagating gradients through the rescaling.
+#     fo_dtype: optional `dtype` to be used for the first order accumulators
+#       mu and delta; if `None` then the `dtype is inferred from `params`
+#       and `updates`.
+#   Returns:
+#     An (init_fn, update_fn) tuple.
+#   """
+
+#   fo_dtype = utils.canonicalize_dtype(fo_dtype)
+
+#   def init_fn(params):
+#     mu = pytree.tree_map(  # First moment
+#         lambda t: torch.zeros_like(t, dtype=fo_dtype), params)
+#     nu = pytree.tree_map(torch.zeros_like, params)  # Second moment
+#     delta = pytree.tree_map(  # EWA of Difference of gradients
+#         lambda t: torch.zeros_like(t, dtype=fo_dtype), params)
+#     grad_tm1 = pytree.tree_map(torch.zeros_like, params)  # Previous gradient
+#     return ScaleByAdanState(count=torch.zeros([], torch.int32),
+#                             mu=mu, nu=nu, delta=delta, grad_tm1=grad_tm1)
+
+#   def update_fn(updates, state, params=None):
+#     diff = pytree.lax.cond(state.count != 0,
+#                         lambda X, Y: pytree.tree_map(lambda x, y: x - y, X, Y),
+#                         lambda X, _: pytree.tree_map(torch.zeros_like, X),
+#                         updates, state.grad_tm1)
+
+#     grad_prime = pytree.tree_map(lambda g, d: g + b2*d, updates, diff)
+
+#     mu = update_moment(updates, state.mu, b1, 1)
+#     delta = update_moment(diff, state.delta, b2, 1)
+#     nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)
+
+#     count_inc = numerics.safe_int32_increment(state.count)
+#     mu_hat = bias_correction(mu, b1, count_inc)
+#     delta_hat = bias_correction(delta, b2, count_inc)
+#     nu_hat = bias_correction(nu, b3, count_inc)
+
+#     if callable(learning_rate):
+#       lr = learning_rate(state.count)
+#     else:
+#       lr = learning_rate
+
+#     learning_rates = pytree.tree_util.tree_map(
+#       lambda n: lr / torch.sqrt(n + eps_root), nu_hat)
+
+#     # negative scale: gradient descent
+#     updates = pytree.tree_util.tree_map(lambda scale, m, v: -scale * (m + b2 * v),
+#                                      learning_rates, mu_hat,
+#                                      delta_hat)
+
+#     decay = 1. / (1. + weight_decay * lr)
+#     params_new = pytree.tree_util.tree_map(lambda p, u:
+#                                         decay * (p + u), params,
+#                                         updates)
+
+#     # params_new - params_old
+#     new_updates = pytree.tree_util.tree_map(lambda new, old: new - old, params_new,
+#                                      params)
+
+#     mu_hat = utils.cast_tree(mu_hat, fo_dtype)
+#     delta_hat = utils.cast_tree(delta_hat, fo_dtype)
+
+#     return new_updates, ScaleByAdanState(count=count_inc,
+#                                     mu=mu, nu=nu, delta=delta,
+#                                     grad_tm1=updates)
+
+#   return base.GradientTransformation(init_fn, update_fn)
+
+
+scale_by_adan.flat = _scale_by_adan_flat  # type: ignore[attr-defined]
+scale_by_adan.impl = _scale_by_adan  # type: ignore[attr-defined]

From 15fdf139e57932b82438398c21f3e6eb2c537c7b Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sat, 22 Jul 2023 20:13:44 +0000
Subject: [PATCH 05/13] fix: [pre-commit.ci] auto fixes [...]

---
 torchopt/transform/scale_by_adan.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py
index 18827f4e..9ad0d2df 100644
--- a/torchopt/transform/scale_by_adan.py
+++ b/torchopt/transform/scale_by_adan.py
@@ -125,15 +125,15 @@ def _scale_by_adan(
         tree_map = pytree.tree_map  # type: ignore[assignment]
 
     def init_fn(params: Params) -> OptState:
-        zero = tree_map(  # count init
+        tree_map(  # count init
             lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(),
             params,
         )
-        mu = tree_map( # first moment
+        mu = tree_map(  # first moment
             lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
             params,
         )
-        nu = tree_map( # second moment
+        nu = tree_map(  # second moment
             torch.zeros_like,
             params,
         )

From 094582af0815d31818571e94114b5bb49c39edb2 Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Sun, 23 Jul 2023 04:19:56 +0800
Subject: [PATCH 06/13] fix: update docstring

---
 torchopt/optim/adan.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torchopt/optim/adan.py b/torchopt/optim/adan.py
index 95b57b66..7d6449ba 100644
--- a/torchopt/optim/adan.py
+++ b/torchopt/optim/adan.py
@@ -33,7 +33,7 @@ class Adan(Optimizer):
 
     See Also:
         - The functional Adan optimizer: :func:`torchopt.adan`.
-        - The differentiable meta Adan optimizer: :class:`torchopt.MetaAdan`.
+        - The differentiable meta-Adan optimizer: :class:`torchopt.MetaAdan`.
     """
 
     # pylint: disable-next=too-many-arguments

From 53d2bd02f4f7ca2190466e9abff7fde98d53f8f6 Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Sun, 23 Jul 2023 13:51:56 +0800
Subject: [PATCH 07/13] fix: update adan transformation

---
 torchopt/transform/scale_by_adan.py | 134 +++++++++++++++++++---------
 1 file changed, 94 insertions(+), 40 deletions(-)

diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py
index 9ad0d2df..71255072 100644
--- a/torchopt/transform/scale_by_adan.py
+++ b/torchopt/transform/scale_by_adan.py
@@ -24,18 +24,40 @@
 
 from torchopt import pytree
 from torchopt.base import GradientTransformation
-from torchopt.transform.utils import update_moment
-from torchopt.typing import OptState, Updates
+from torchopt.transform.utils import inc_count, tree_map_flat, update_moment
+from torchopt.typing import OptState, Params, Updates
+
+
+__all__ = [
+    'scale_by_adan',
+]
 
 
 class ScaleByAdanState(NamedTuple):
     """State for the Adan algorithm."""
 
-    count: OptState
     mu: Updates
     nu: Updates
     delta: Updates
     grad_tm1: Updates
+    count: OptState
+
+
+def _adan_bias_correction(
+    moment: Updates,
+    decay: float,
+    count: OptState,
+    *,
+    already_flattened: bool = False,
+) -> Updates:
+    """Perform bias correction. This becomes a no-op as count goes to infinity."""
+
+    def f(t: torch.Tensor, c: torch.Tensor) -> torch.Tensor:  # pylint: disable=invalid-name
+        return t.div(1 - pow(decay, c))
+
+    if already_flattened:
+        return tree_map_flat(f, moment, count)
+    return pytree.tree_map(f, moment, count)
 
 
 def scale_by_adan(
@@ -55,8 +77,10 @@ def scale_by_adan(
         b1 (float, optional): Decay rate for the exponentially weighted average of gradients.
             (default: :const:`0.98`)
         b2 (float, optional): Decay rate for the exponentially weighted average of difference of
-        gradients.
-        b3: Decay rate for the exponentially weighted average of the squared term.
+            gradients.
+            (default: :const:`0.92`)
+        b3 (float, optional): Decay rate for the exponentially weighted average of the squared term.
+            (default: :const:`0.99`)
         eps (float, optional): Term added to the denominator to improve numerical stability.
             (default: :const:`1e-8`)
         eps_root (float, optional): Term added to the denominator inside the square-root to improve
@@ -134,7 +158,7 @@ def init_fn(params: Params) -> OptState:
             params,
         )
         nu = tree_map(  # second moment
-            torch.zeros_like,
+            lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad),
             params,
         )
         delta = tree_map(  # EWA of Difference of gradients
@@ -142,53 +166,87 @@ def init_fn(params: Params) -> OptState:
             params,
         )
         grad_tm1 = tree_map(
-            torch.zeros_like,
+            lambda t: torch.zeros_like(
+                t,
+            ),
             params,
         )  # Previous gradient
         return ScaleByAdanState(
-            count=torch.zeros([], torch.int32),
             mu=mu,
             nu=nu,
             delta=delta,
             grad_tm1=grad_tm1,
+            count=zero,
         )
 
-    def update_fn(updates, state, params=None):
-        del params
+    def update_fn(
+        updates: Updates,
+        state: OptState,
+        *,
+        params: Params | None = None,  # pylint: disable=unused-argument
+        inplace: bool = True,
+    ) -> tuple[Updates, OptState]:
         diff = pytree.lax.cond(
             state.count != 0,
-            lambda X, Y: pytree.tree_map(lambda x, y: x - y, X, Y),
-            lambda X, _: pytree.tree_map(torch.zeros_like, X),
+            lambda X, Y: tree_map(lambda x, y: x - y, X, Y),
+            lambda X, _: tree_map(torch.zeros_like, X),
             updates,
             state.grad_tm1,
         )
 
-        grad_prime = pytree.tree_map(lambda g, d: g + b2 * d, updates, diff)
+        grad_prime = tree_map(lambda g, d: g + b2 * d, updates, diff)
 
-        mu = update_moment(updates, state.mu, b1, 1)
-        delta = update_moment(diff, state.delta, b2, 1)
+        mu = update_moment.impl(
+            updates,
+            state.mu,
+            b1,
+            order=1,
+            inplace=inplace,
+            already_flattened=already_flattened,
+        )
+        delta = update_moment.impl(
+            diff,
+            state.delta,
+            b2,
+            1,
+        )
         nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)
 
-        count_inc = numerics.safe_int32_increment(state.count)
-        mu_hat = utils.cast_tree(bias_correction(mu, b1, count_inc), fo_dtype)
-        delta_hat = utils.cast_tree(bias_correction(delta, b2, count_inc), fo_dtype)
-        nu_hat = bias_correction(nu, b3, count_inc)
-        new_updates = pytree.tree_map(
-            lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps),
-            mu_hat,
-            delta_hat,
-            nu_hat,
-        )
+        count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened)  # type: ignore[attr-defined]
+        mu_hat = _adan_bias_correction(mu, b1, count_inc, already_flattened=already_flattened)
+        delta_hat = _adan_bias_correction(delta, b2, count_inc, already_flattened=already_flattened)
+        nu_hat = _adan_bias_correction(nu, b3, count_inc, already_flattened=already_flattened)
 
-        return new_updates, ScaleByAdanState(
-            count=count_inc,
-            mu=mu,
-            nu=nu,
-            delta=delta,
-            grad_tm1=updates,
+        if inplace:
+
+            def f(
+                m: torch.Tensor,
+                d: torch.Tensor,
+                n: torch.Tensor,
+            ) -> torch.Tensor:
+                return (m + b2 * d).div_(torch.sqrt(n + eps_root).add(eps))
+
+        else:
+
+            def f(
+                m: torch.Tensor,
+                d: torch.Tensor,
+                n: torch.Tensor,
+            ) -> torch.Tensor:
+                return (m + b2 * d).div(torch.sqrt(n + eps_root).add(eps))
+
+        # lambda m, d, n: (m + b2 * d) / (torch.sqrt(n + eps_root) + eps),
+        updates = pytree.tree_map(f, mu_hat, delta_hat, nu_hat)
+
+        return updates, ScaleByAdanState(
+            count=count_inc, mu=mu, nu=nu, delta=delta, grad_tm1=updates,
         )
 
-    return base.GradientTransformation(init_fn, update_fn)
+    return GradientTransformation(init_fn, update_fn)
+
+
+scale_by_adan.flat = _scale_by_adan_flat  # type: ignore[attr-defined]
+scale_by_adan.impl = _scale_by_adan  # type: ignore[attr-defined]
 
 
 # def scale_by_proximal_adan(
@@ -243,9 +301,9 @@ def update_fn(updates, state, params=None):
 #     nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)
 
 #     count_inc = numerics.safe_int32_increment(state.count)
-#     mu_hat = bias_correction(mu, b1, count_inc)
-#     delta_hat = bias_correction(delta, b2, count_inc)
-#     nu_hat = bias_correction(nu, b3, count_inc)
+#     mu_hat = _adan_bias_correction(mu, b1, count_inc)
+#     delta_hat = _adan_bias_correction(delta, b2, count_inc)
+#     nu_hat = _adan_bias_correction(nu, b3, count_inc)
 
 #     if callable(learning_rate):
 #       lr = learning_rate(state.count)
@@ -276,8 +334,4 @@ def update_fn(updates, state, params=None):
 #                                     mu=mu, nu=nu, delta=delta,
 #                                     grad_tm1=updates)
 
-#   return base.GradientTransformation(init_fn, update_fn)
-
-
-scale_by_adan.flat = _scale_by_adan_flat  # type: ignore[attr-defined]
-scale_by_adan.impl = _scale_by_adan  # type: ignore[attr-defined]
+#   return GradientTransformation(init_fn, update_fn)

From 5723cf6c7ca295934db5b71861996538fc6df994 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sun, 23 Jul 2023 05:52:13 +0000
Subject: [PATCH 08/13] fix: [pre-commit.ci] auto fixes [...]

---
 torchopt/transform/scale_by_adan.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py
index 71255072..8c1d19bb 100644
--- a/torchopt/transform/scale_by_adan.py
+++ b/torchopt/transform/scale_by_adan.py
@@ -239,7 +239,11 @@ def f(
         updates = pytree.tree_map(f, mu_hat, delta_hat, nu_hat)
 
         return updates, ScaleByAdanState(
-            count=count_inc, mu=mu, nu=nu, delta=delta, grad_tm1=updates,
+            count=count_inc,
+            mu=mu,
+            nu=nu,
+            delta=delta,
+            grad_tm1=updates,
         )
 
     return GradientTransformation(init_fn, update_fn)

From cb2efb0328afb1a78cf466cb89a496f5284de13e Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Tue, 25 Jul 2023 02:41:40 +0800
Subject: [PATCH 09/13] fix: update requirements

---
 tests/requirements.txt | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/requirements.txt b/tests/requirements.txt
index 2e7acde6..38b2b735 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -3,7 +3,7 @@ torch >= 1.13
 
 --requirement ../requirements.txt
 
-git+https://github.com/sail-sg/Adan.git
+git+https://github.com/benjamin-eecs/Adan.git
 
 jax[cpu] >= 0.3; platform_system != 'Windows'
 jaxopt; platform_system != 'Windows'

From 8e38169f16c0b951795455cb347cc5a58862b66e Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Tue, 25 Jul 2023 02:41:55 +0800
Subject: [PATCH 10/13] fix: update adan transformation

---
 torchopt/transform/scale_by_adan.py | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/torchopt/transform/scale_by_adan.py b/torchopt/transform/scale_by_adan.py
index 71255072..eff98ca6 100644
--- a/torchopt/transform/scale_by_adan.py
+++ b/torchopt/transform/scale_by_adan.py
@@ -149,7 +149,7 @@ def _scale_by_adan(
         tree_map = pytree.tree_map  # type: ignore[assignment]
 
     def init_fn(params: Params) -> OptState:
-        tree_map(  # count init
+        zero = tree_map(  # count init
             lambda t: torch.zeros(1, dtype=torch.int64, device=t.device).squeeze_(),
             params,
         )
@@ -186,13 +186,10 @@ def update_fn(
         params: Params | None = None,  # pylint: disable=unused-argument
         inplace: bool = True,
     ) -> tuple[Updates, OptState]:
-        diff = pytree.lax.cond(
-            state.count != 0,
-            lambda X, Y: tree_map(lambda x, y: x - y, X, Y),
-            lambda X, _: tree_map(torch.zeros_like, X),
-            updates,
-            state.grad_tm1,
-        )
+        if state.count != 0:
+            diff = tree_map(lambda x, y: x - y, updates, state.grad_tm1)
+        else:
+            diff = tree_map(torch.zeros_like, updates)
 
         grad_prime = tree_map(lambda g, d: g + b2 * d, updates, diff)
 
@@ -204,14 +201,13 @@ def update_fn(
             inplace=inplace,
             already_flattened=already_flattened,
         )
+        nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)
         delta = update_moment.impl(
             diff,
             state.delta,
             b2,
             1,
         )
-        nu = update_moment_per_elem_norm(grad_prime, state.nu, b3, 2)
-
         count_inc = inc_count.impl(updates, state.count, already_flattened=already_flattened)  # type: ignore[attr-defined]
         mu_hat = _adan_bias_correction(mu, b1, count_inc, already_flattened=already_flattened)
         delta_hat = _adan_bias_correction(delta, b2, count_inc, already_flattened=already_flattened)
@@ -239,7 +235,11 @@ def f(
         updates = pytree.tree_map(f, mu_hat, delta_hat, nu_hat)
 
         return updates, ScaleByAdanState(
-            count=count_inc, mu=mu, nu=nu, delta=delta, grad_tm1=updates,
+            count=count_inc,
+            mu=mu,
+            nu=nu,
+            delta=delta,
+            grad_tm1=updates,
         )
 
     return GradientTransformation(init_fn, update_fn)

From 4d941452cd168fd616b45e4d5755a821c42304ce Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Tue, 25 Jul 2023 03:30:31 +0800
Subject: [PATCH 11/13] test: update Makefile

---
 Makefile | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Makefile b/Makefile
index 0f7dd74e..163bba8d 100644
--- a/Makefile
+++ b/Makefile
@@ -113,7 +113,7 @@ addlicense-install: go-install
 
 pytest: test-install
 	cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \
-	$(PYTHON) -m pytest --verbose --color=yes --durations=0 \
+	$(PYTHON) -m pytest -k "test_adan" --verbose --color=yes --durations=0 \
 		--cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \
 		$(PYTESTOPTS) .
 

From 501d1251b25324237e1c746f4fa1a2de76cc6449 Mon Sep 17 00:00:00 2001
From: Benjamin-eecs <benjaminliu.eecs@gmail.com>
Date: Tue, 25 Jul 2023 03:31:01 +0800
Subject: [PATCH 12/13] feat: init adan alias

---
 tests/test_alias.py    |  18 +++++--
 torchopt/alias/adan.py | 110 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 123 insertions(+), 5 deletions(-)

diff --git a/tests/test_alias.py b/tests/test_alias.py
index c18e678f..1f5189c1 100644
--- a/tests/test_alias.py
+++ b/tests/test_alias.py
@@ -21,6 +21,7 @@
 import pytest
 import torch
 import torch.nn.functional as F
+from adan import Adan
 
 import helpers
 import torchopt
@@ -204,10 +205,12 @@ def test_adadelta(
 @helpers.parametrize(
     dtype=[torch.float64],
     lr=[1e-2, 1e-3, 1e-4],
-    betas=[(0.9, 0.999), (0.95, 0.9995)],
+    betas=[(0.9, 0.999, 0.998), (0.95, 0.9995, 0.9985)],
     eps=[1e-8],
     inplace=[True, False],
     weight_decay=[0.0, 1e-2],
+    max_grad_norm=[0.0, 1.0],
+    no_prox=[False, True],
     maximize=[False, True],
     use_accelerated_op=[False, True],
     use_chain_flat=[True, False],
@@ -215,10 +218,12 @@ def test_adadelta(
 def test_adan(
     dtype: torch.dtype,
     lr: float,
-    betas: tuple[float, float],
+    betas: tuple[float, float, float],
     eps: float,
     inplace: bool,
     weight_decay: float,
+    max_grad_norm: float,
+    no_prox: bool,
     maximize: bool,
     use_accelerated_op: bool,
     use_chain_flat: bool,
@@ -234,18 +239,21 @@ def test_adan(
         eps=eps,
         eps_root=0.0,
         weight_decay=weight_decay,
+        max_grad_norm=max_grad_norm,
+        no_prox=no_prox,
         maximize=maximize,
         use_accelerated_op=use_accelerated_op,
     )
     optim_state = optim.init(params)
-    optim_ref = torch.optim.adan(
+    optim_ref = Adan(
         model_ref.parameters(),
         lr,
         betas=betas,
         eps=eps,
-        amsgrad=False,
+        eps_root=0.0,
         weight_decay=weight_decay,
-        maximize=maximize,
+        max_grad_norm=max_grad_norm,
+        no_prox=no_prox,
     )
 
     for xs, ys in loader:
diff --git a/torchopt/alias/adan.py b/torchopt/alias/adan.py
index 6632c341..e194ae41 100644
--- a/torchopt/alias/adan.py
+++ b/torchopt/alias/adan.py
@@ -12,3 +12,113 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
+"""Preset :class:`GradientTransformation` for the Adan optimizer."""
+
+from __future__ import annotations
+
+from torchopt.alias.utils import (
+    _get_use_chain_flat,
+    flip_sign_and_add_weight_decay,
+    scale_by_neg_lr,
+)
+from torchopt.combine import chain
+from torchopt.transform import scale_by_adan
+from torchopt.typing import GradientTransformation, ScalarOrSchedule
+
+
+__all__ = ['adan']
+
+
+# pylint: disable-next=too-many-arguments
+def adan(
+    lr: ScalarOrSchedule = 1e-3,
+    betas: tuple[float, float, float] = (0.98, 0.92, 0.99),
+    eps: float = 1e-8,
+    weight_decay: float = 0.0,
+    max_grad_norm=0.0,
+    no_prox=False,
+    *,
+    eps_root: float = 0.0,
+    moment_requires_grad: bool = False,
+    maximize: bool = False,
+) -> GradientTransformation:
+    """Create a functional version of the adan optimizer.
+
+    adan is an SGD variant with learning rate adaptation. The *learning rate* used for each weight
+    is computed from estimates of first- and second-order moments of the gradients (using suitable
+    exponential moving averages).
+
+    References:
+        - Kingma et al., 2014: https://arxiv.org/abs/1412.6980
+
+    Args:
+        lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
+            scheduler. (default: :const:`1e-3`)
+        betas (tuple of float, optional): Coefficients used for
+            first- and second-order moments. (default: :const:`(0.98, 0.92, 0.99)`)
+        eps (float, optional): Term added to the denominator to improve numerical stability.
+            (default: :const:`1e-8`)
+        eps_root (float, optional): Term added to the denominator inside the square-root to improve
+            numerical stability when backpropagating gradients through the rescaling.
+            (default: :const:`0.0`)
+        weight_decay (float, optional): Weight decay (L2 penalty).
+            (default: :const:`0.0`)
+        max_grad_norm (float, optional): Max norm of the gradients.
+            (default: :const:`0.0`)
+        no_prox (bool, optional): If :data:`True`, the proximal term is not applied.
+            (default: :data:`False`)
+        eps_root (float, optional): A small constant applied to denominator inside the square root
+            (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example
+            when computing (meta-)gradients through Adam. (default: :const:`0.0`)
+        moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag
+            ``requires_grad = True``. (default: :data:`False`)
+        maximize (bool, optional): Maximize the params based on the objective, instead of minimizing.
+            (default: :data:`False`)
+
+    Returns:
+        The corresponding :class:`GradientTransformation` instance.
+
+    See Also:
+        The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
+    """
+    b1, b2, b3 = betas  # pylint: disable=invalid-name
+    # pylint: disable=unneeded-not
+    if not 0.0 <= max_grad_norm:
+        raise ValueError(f'Invalid Max grad norm: {max_grad_norm}')
+    if not (callable(lr) or lr >= 0.0):  # pragma: no cover
+        raise ValueError(f'Invalid learning rate: {lr}')
+    if not eps >= 0.0:  # pragma: no cover
+        raise ValueError(f'Invalid epsilon value: {eps}')
+    if not 0.0 <= b1 < 1.0:  # pragma: no cover
+        raise ValueError(f'Invalid beta parameter at index 0: {b1}')
+    if not 0.0 <= b2 < 1.0:  # pragma: no cover
+        raise ValueError(f'Invalid beta parameter at index 1: {b2}')
+    if not 0.0 <= b3 < 1.0:
+        raise ValueError(f'Invalid beta parameter at index 2: {b3}')
+    if not weight_decay >= 0.0:  # pragma: no cover
+        raise ValueError(f'Invalid weight_decay value: {weight_decay}')
+    # pylint: enable=unneeded-not
+
+    chain_fn = chain
+    flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay
+    adan_scaler_fn = scale_by_adan if no_prox else scale_by_proximal_adan
+    scale_by_neg_lr_fn = scale_by_neg_lr
+
+    if _get_use_chain_flat():  # default behavior
+        chain_fn = chain_fn.flat  # type: ignore[attr-defined]
+        flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat  # type: ignore[attr-defined]
+        adan_scaler_fn = adan_scaler_fn.flat  # type: ignore[attr-defined]
+        scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat  # type: ignore[attr-defined]
+
+    return chain_fn(
+        flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize),
+        adan_scaler_fn(
+            b1=b1,
+            b2=b2,
+            b3=b3,
+            eps=eps,
+            eps_root=eps_root,
+            moment_requires_grad=moment_requires_grad,
+        ),
+        scale_by_neg_lr_fn(lr),
+    )

From e11cced6ae14d89a4156b3cb43bed22e8526649a Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
 <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Mon, 24 Jul 2023 19:31:21 +0000
Subject: [PATCH 13/13] fix: [pre-commit.ci] auto fixes [...]

---
 torchopt/alias/adan.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torchopt/alias/adan.py b/torchopt/alias/adan.py
index e194ae41..a30e2b04 100644
--- a/torchopt/alias/adan.py
+++ b/torchopt/alias/adan.py
@@ -83,7 +83,7 @@ def adan(
     """
     b1, b2, b3 = betas  # pylint: disable=invalid-name
     # pylint: disable=unneeded-not
-    if not 0.0 <= max_grad_norm:
+    if not max_grad_norm >= 0.0:
         raise ValueError(f'Invalid Max grad norm: {max_grad_norm}')
     if not (callable(lr) or lr >= 0.0):  # pragma: no cover
         raise ValueError(f'Invalid learning rate: {lr}')