Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adan optimizer integration #181

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ addlicense-install: go-install

pytest: test-install
cd tests && $(PYTHON) -c 'import $(PROJECT_PATH)' && \
$(PYTHON) -m pytest --verbose --color=yes --durations=0 \
$(PYTHON) -m pytest -k "test_adan" --verbose --color=yes --durations=0 \
--cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \
$(PYTESTOPTS) .

Expand Down
2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ torch >= 1.13

--requirement ../requirements.txt

git+https://github.com/benjamin-eecs/Adan.git

jax[cpu] >= 0.3; platform_system != 'Windows'
jaxopt; platform_system != 'Windows'
optax; platform_system != 'Windows'
Expand Down
74 changes: 74 additions & 0 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest
import torch
import torch.nn.functional as F
from adan import Adan

import helpers
import torchopt
Expand Down Expand Up @@ -201,6 +202,79 @@ def test_adadelta(
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999, 0.998), (0.95, 0.9995, 0.9985)],
eps=[1e-8],
inplace=[True, False],
weight_decay=[0.0, 1e-2],
max_grad_norm=[0.0, 1.0],
no_prox=[False, True],
maximize=[False, True],
use_accelerated_op=[False, True],
use_chain_flat=[True, False],
)
def test_adan(
dtype: torch.dtype,
lr: float,
betas: tuple[float, float, float],
eps: float,
inplace: bool,
weight_decay: float,
max_grad_norm: float,
no_prox: bool,
maximize: bool,
use_accelerated_op: bool,
use_chain_flat: bool,
) -> None:
_set_use_chain_flat(use_chain_flat)

model, model_ref, model_base, loader = helpers.get_models(device='cpu', dtype=dtype)

fmodel, params, buffers = functorch.make_functional_with_buffers(model)
optim = torchopt.adan(
lr,
betas=betas,
eps=eps,
eps_root=0.0,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
no_prox=no_prox,
maximize=maximize,
use_accelerated_op=use_accelerated_op,
)
optim_state = optim.init(params)
optim_ref = Adan(
model_ref.parameters(),
lr,
betas=betas,
eps=eps,
eps_root=0.0,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
no_prox=no_prox,
)

for xs, ys in loader:
xs = xs.to(dtype=dtype)
pred = fmodel(params, buffers, xs)
pred_ref = model_ref(xs)
loss = F.cross_entropy(pred, ys)
loss_ref = F.cross_entropy(pred_ref, ys)

grads = torch.autograd.grad(loss, params, allow_unused=True)
updates, optim_state = optim.update(grads, optim_state, params=params, inplace=inplace)
params = torchopt.apply_updates(params, updates, inplace=inplace)

optim_ref.zero_grad()
loss_ref.backward()
optim_ref.step()

helpers.assert_model_all_close((params, buffers), model_ref, model_base, dtype=dtype)
_set_use_chain_flat(True)


@helpers.parametrize(
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
Expand Down
124 changes: 124 additions & 0 deletions torchopt/alias/adan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preset :class:`GradientTransformation` for the Adan optimizer."""

from __future__ import annotations

from torchopt.alias.utils import (
_get_use_chain_flat,
flip_sign_and_add_weight_decay,
scale_by_neg_lr,
)
from torchopt.combine import chain
from torchopt.transform import scale_by_adan
from torchopt.typing import GradientTransformation, ScalarOrSchedule


__all__ = ['adan']


# pylint: disable-next=too-many-arguments
def adan(
lr: ScalarOrSchedule = 1e-3,
betas: tuple[float, float, float] = (0.98, 0.92, 0.99),
eps: float = 1e-8,
weight_decay: float = 0.0,
max_grad_norm=0.0,
no_prox=False,
*,
eps_root: float = 0.0,
moment_requires_grad: bool = False,
maximize: bool = False,
) -> GradientTransformation:
"""Create a functional version of the adan optimizer.

adan is an SGD variant with learning rate adaptation. The *learning rate* used for each weight
is computed from estimates of first- and second-order moments of the gradients (using suitable
exponential moving averages).

References:
- Kingma et al., 2014: https://arxiv.org/abs/1412.6980

Args:
lr (float or callable, optional): This is a fixed global scaling factor or a learning rate
scheduler. (default: :const:`1e-3`)
betas (tuple of float, optional): Coefficients used for
first- and second-order moments. (default: :const:`(0.98, 0.92, 0.99)`)
eps (float, optional): Term added to the denominator to improve numerical stability.
(default: :const:`1e-8`)
eps_root (float, optional): Term added to the denominator inside the square-root to improve
numerical stability when backpropagating gradients through the rescaling.
(default: :const:`0.0`)
weight_decay (float, optional): Weight decay (L2 penalty).
(default: :const:`0.0`)
max_grad_norm (float, optional): Max norm of the gradients.
(default: :const:`0.0`)
no_prox (bool, optional): If :data:`True`, the proximal term is not applied.
(default: :data:`False`)
eps_root (float, optional): A small constant applied to denominator inside the square root
(as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example
when computing (meta-)gradients through Adam. (default: :const:`0.0`)
moment_requires_grad (bool, optional): If :data:`True`, states will be created with flag
``requires_grad = True``. (default: :data:`False`)
maximize (bool, optional): Maximize the params based on the objective, instead of minimizing.
(default: :data:`False`)

Returns:
The corresponding :class:`GradientTransformation` instance.

See Also:
The functional optimizer wrapper :class:`torchopt.FuncOptimizer`.
"""
b1, b2, b3 = betas # pylint: disable=invalid-name
# pylint: disable=unneeded-not
if not max_grad_norm >= 0.0:
raise ValueError(f'Invalid Max grad norm: {max_grad_norm}')
if not (callable(lr) or lr >= 0.0): # pragma: no cover
raise ValueError(f'Invalid learning rate: {lr}')
if not eps >= 0.0: # pragma: no cover
raise ValueError(f'Invalid epsilon value: {eps}')
if not 0.0 <= b1 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 0: {b1}')
if not 0.0 <= b2 < 1.0: # pragma: no cover
raise ValueError(f'Invalid beta parameter at index 1: {b2}')
if not 0.0 <= b3 < 1.0:
raise ValueError(f'Invalid beta parameter at index 2: {b3}')
if not weight_decay >= 0.0: # pragma: no cover
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
# pylint: enable=unneeded-not

chain_fn = chain
flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay
adan_scaler_fn = scale_by_adan if no_prox else scale_by_proximal_adan
scale_by_neg_lr_fn = scale_by_neg_lr

if _get_use_chain_flat(): # default behavior
chain_fn = chain_fn.flat # type: ignore[attr-defined]
flip_sign_and_add_weight_decay_fn = flip_sign_and_add_weight_decay_fn.flat # type: ignore[attr-defined]
adan_scaler_fn = adan_scaler_fn.flat # type: ignore[attr-defined]
scale_by_neg_lr_fn = scale_by_neg_lr_fn.flat # type: ignore[attr-defined]

return chain_fn(
flip_sign_and_add_weight_decay_fn(weight_decay=weight_decay, maximize=maximize),
adan_scaler_fn(
b1=b1,
b2=b2,
b3=b3,
eps=eps,
eps_root=eps_root,
moment_requires_grad=moment_requires_grad,
),
scale_by_neg_lr_fn(lr),
)
91 changes: 91 additions & 0 deletions torchopt/optim/adan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adan optimizer."""

from __future__ import annotations

from typing import Iterable

import torch

from torchopt import alias
from torchopt.optim.base import Optimizer
from torchopt.typing import ScalarOrSchedule


__all__ = ['Adan']


class Adan(Optimizer):
"""The classic Adan optimizer.

See Also:
- The functional Adan optimizer: :func:`torchopt.adan`.
- The differentiable meta-Adan optimizer: :class:`torchopt.MetaAdan`.
"""

# pylint: disable-next=too-many-arguments
def __init__(
self,
params: Iterable[torch.Tensor],
lr: ScalarOrSchedule = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
max_grad_norm=0.0,
no_prox=False,
*,
eps_root: float = 0.0,
maximize: bool = False,
use_accelerated_op: bool = False,
) -> None:
r"""Initialize the Adan optimizer.

Args:
params (iterable of Tensor): An iterable of :class:`torch.Tensor`\s. Specifies what
tensors should be optimized.
lr (float or callable, optional): This is a fixed global scaling factor or a learning
rate scheduler. (default: :const:`1e-3`)
betas (tuple of float, optional): Coefficients used for computing running averages of
gradient and its square. (default: :const:`(0.9, 0.999)`)
eps (float, optional): A small constant applied to denominator outside of the square
root (as in the Adam paper) to avoid dividing by zero when rescaling.
(default: :const:`1e-8`)
weight_decay (float, optional): Weight decay, add L2 penalty to parameters.
(default: :const:`0.0`)
eps_root (float, optional): A small constant applied to denominator inside the square
root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for
example when computing (meta-)gradients through Adam. (default: :const:`0.0`)
moment_requires_grad (bool, optional): If :data:`True` the momentums will be created
with flag ``requires_grad=True``, this flag is often used in Meta-Learning
algorithms. (default: :data:`False`)
maximize (bool, optional): Maximize the params based on the objective, instead of
minimizing. (default: :data:`False`)
"""
super().__init__(
params,
alias.adan(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
max_grad_norm=max_grad_norm,
no_prox=no_prox,
eps_root=eps_root,
moment_requires_grad=False,
maximize=maximize,
use_accelerated_op=use_accelerated_op,
),
)
Loading