Skip to content

Commit

Permalink
exclude some parameters from weight decay
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Aug 12, 2024
1 parent bba0c07 commit 4bc35dc
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions mmlearn/tasks/contrastive_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contrastive pretraining task."""

import inspect
import itertools
from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -502,16 +503,54 @@ def on_test_epoch_end(self) -> None:
"""Compute and log epoch-level metrics at the end of the test epoch."""
self._on_eval_epoch_end("test")

def configure_optimizers(self) -> OptimizerLRScheduler:
def configure_optimizers(self) -> OptimizerLRScheduler: # noqa: PLR0912
"""Configure the optimizer and learning rate scheduler."""
if self.optimizer is None:
rank_zero_warn(
"Optimizer not provided. Training will continue without an optimizer. "
"LR scheduler will not be used.",
)
return None
# TODO: add mechanism to exclude certain parameters from weight decay
optimizer = self.optimizer(self.parameters())

weight_decay: Optional[float] = self.optimizer.keywords.get(
"weight_decay", None
)
if weight_decay is None: # try getting default value
kw_param = inspect.signature(self.optimizer.func).parameters.get(
"weight_decay"
)
if kw_param is not None and kw_param.default != inspect.Parameter.empty:
weight_decay = kw_param.default

parameters = [param for param in self.parameters() if param.requires_grad]

if weight_decay is not None:
decay_params = []
no_decay_params = []

for param in self.parameters():
if not param.requires_grad:
continue

if param.ndim < 2: # includes all bias and normalization parameters
no_decay_params.append(param)
else:
decay_params.append(param)

parameters = [
{
"params": decay_params,
"weight_decay": weight_decay,
"name": "weight_decay_params",
},
{
"params": no_decay_params,
"weight_decay": 0.0,
"name": "no_weight_decay_params",
},
]

optimizer = self.optimizer(parameters)
if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(
"Expected optimizer to be an instance of `torch.optim.Optimizer`, "
Expand Down

0 comments on commit 4bc35dc

Please sign in to comment.