Skip to content

Commit

Permalink
Fix type hints for functools.partial parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Aug 12, 2024
1 parent 76e8c1b commit bba0c07
Showing 1 changed file with 68 additions and 37 deletions.
105 changes: 68 additions & 37 deletions mmlearn/tasks/contrastive_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import itertools
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Tuple, Union
from functools import partial
from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union

import lightning as L # noqa: N812
import torch
Expand Down Expand Up @@ -48,17 +49,15 @@ class AuxiliaryTaskSpec:
"""Specification for an auxiliary task to run alongside the main task."""

modality: str
task: Callable[[nn.Module], L.LightningModule]
task: Any # `functools.partial[L.LightningModule]` expected
loss_weight: float = 1.0


@dataclass
class EvaluationSpec:
"""Specification for an evaluation task."""

task: (
Any # NOTE: hydra/omegaconf does not support custom types in structured configs
)
task: Any # `EvaluationHooks` expected
run_on_validation: bool = True
run_on_test: bool = True

Expand Down Expand Up @@ -99,21 +98,30 @@ class ContrastivePretraining(L.LightningModule):
modality_module_mapping : Dict[str, ModuleKeySpec], optional, default=None
A dictionary mapping modalities to encoders, heads, and postprocessors.
Useful for reusing the same instance of a module across multiple modalities.
optimizer : torch.optim.Optimizer, optional, default=None
The optimizer to use.
lr_scheduler : Union[Dict[str, torch.optim.lr_scheduler.LRScheduler], torch.optim.lr_scheduler.LRScheduler], optional, default=None
The learning rate scheduler to use.
optimizer : partial[torch.optim.Optimizer], optional, default=None
The optimizer to use for training. This is expected to be a partial function,
created using `functools.partial`, that takes the model parameters as the
only required argument. If not provided, training will continue without an
optimizer.
lr_scheduler : Union[Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]], partial[torch.optim.lr_scheduler.LRScheduler]], optional, default=None
The learning rate scheduler to use for training. This can be a partial function
that takes the optimizer as the only required argument or a dictionary with
a `scheduler` key that specifies the scheduler and an optional `extras` key
that specifies additional arguments to pass to the scheduler. If not provided,
the learning rate will not be adjusted during training.
loss : CLIPLoss, optional, default=None
The loss function to use.
modality_loss_pairs : List[LossPairSpec], optional, default=None
A list of pairs of modalities to compute the contrastive loss between and
the weight to apply to each pair.
auxiliary_tasks : Dict[str, AuxiliaryTaskSpec], optional, default=None
Auxiliary tasks to run alongside the main contrastive pretraining task.
The auxiliary task module is expected to be a partially-initialized
instance of a `LightningModule`, such that an initialized encoder can be
passed as the first argument to the module. The `modality` parameter
specifies the modality of the encoder to use for the auxiliary task.
The auxiliary task module is expected to be a partially-initialized instance
of a `LightningModule` created using `functools.partial`, such that an
initialized encoder can be passed as the only argument. The `modality`
parameter specifies the modality of the encoder to use for the auxiliary task.
The `loss_weight` parameter specifies the weight to apply to the auxiliary
task loss.
log_auxiliary_tasks_loss : bool, optional, default=False
Whether to log the loss of auxiliary tasks to the main logger.
compute_validation_loss : bool, optional, default=True
Expand All @@ -135,11 +143,11 @@ def __init__( # noqa: PLR0912, PLR0915
Dict[str, Union[nn.Module, Dict[str, nn.Module]]]
] = None,
modality_module_mapping: Optional[Dict[str, ModuleKeySpec]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
optimizer: Optional[partial[torch.optim.Optimizer]] = None,
lr_scheduler: Optional[
Union[
Dict[str, torch.optim.lr_scheduler.LRScheduler],
torch.optim.lr_scheduler.LRScheduler,
Dict[str, Union[partial[torch.optim.lr_scheduler.LRScheduler], Any]],
partial[torch.optim.lr_scheduler.LRScheduler],
]
] = None,
loss: Optional[CLIPLoss] = None,
Expand Down Expand Up @@ -260,16 +268,19 @@ def __init__( # noqa: PLR0912, PLR0915
self.aux_task_specs = auxiliary_tasks or {}
self.auxiliary_tasks: Dict[str, L.LightningModule] = {}
for task_name, task_spec in self.aux_task_specs.items():
try:
aux_task_modality = Modalities.get_modality(task_spec.modality)
self.auxiliary_tasks[task_name] = task_spec.task(
self.encoders[aux_task_modality]
)
except KeyError as exc:
if not Modalities.has_modality(task_spec.modality):
raise ValueError(
f"Found unsupported modality `{task_spec.modality}` in the auxiliary tasks. "
f"Available modalities are {self._available_modalities}."
) from exc
)
if not isinstance(task_spec.task, partial):
raise TypeError(
f"Expected auxiliary task to be a partial function, but got {type(task_spec.task)}."
)

self.auxiliary_tasks[task_name] = task_spec.task(
self.encoders[Modalities.get_modality(task_spec.modality)]
)

if loss is None and (compute_validation_loss or compute_test_loss):
raise ValueError(
Expand All @@ -282,6 +293,14 @@ def __init__( # noqa: PLR0912, PLR0915
self.log_auxiliary_tasks_loss = log_auxiliary_tasks_loss
self.compute_validation_loss = compute_validation_loss
self.compute_test_loss = compute_test_loss

if evaluation_tasks is not None:
for eval_task_spec in evaluation_tasks.values():
if not isinstance(eval_task_spec.task, EvaluationHooks):
raise TypeError(
f"Expected {eval_task_spec.task} to be an instance of `EvaluationHooks` "
f"but got {type(eval_task_spec.task)}."
)
self.evaluation_tasks = evaluation_tasks

def encode(
Expand Down Expand Up @@ -491,26 +510,43 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
"LR scheduler will not be used.",
)
return None

optimizer: torch.optim.Optimizer = self.optimizer(params=self.parameters())
# TODO: add mechanism to exclude certain parameters from weight decay
optimizer = self.optimizer(self.parameters())
if not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError(
"Expected optimizer to be an instance of `torch.optim.Optimizer`, "
f"but got {type(optimizer)}.",
)

if self.lr_scheduler is not None:
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
return [optimizer], [self.lr_scheduler(optimizer)]
if isinstance(self.lr_scheduler, dict):
if "scheduler" not in self.lr_scheduler:
raise ValueError(
"Expected 'scheduler' key in the learning rate scheduler dictionary.",
)

if isinstance(self.lr_scheduler, dict) and "scheduler" in self.lr_scheduler:
lr_scheduler = self.lr_scheduler["scheduler"](optimizer=optimizer)
lr_scheduler = self.lr_scheduler["scheduler"](optimizer)
if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
raise TypeError(
"Expected scheduler to be an instance of "
f"`torch.optim.lr_scheduler.LRScheduler`, but got {type(lr_scheduler)}.",
"Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
f"but got {type(lr_scheduler)}.",
)
lr_scheduler_dict = {"scheduler": lr_scheduler}
lr_scheduler_dict: Dict[
str, Union[torch.optim.lr_scheduler.LRScheduler, Any]
] = {"scheduler": lr_scheduler}

if self.lr_scheduler.get("extras"):
lr_scheduler_dict.update(self.lr_scheduler["extras"])
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}

lr_scheduler = self.lr_scheduler(optimizer)
if not isinstance(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):
raise TypeError(
"Expected scheduler to be an instance of `torch.optim.lr_scheduler.LRScheduler`, "
f"but got {type(lr_scheduler)}.",
)
return [optimizer], [lr_scheduler]

return optimizer

def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
Expand All @@ -520,11 +556,6 @@ def _on_eval_epoch_start(self, eval_type: Literal["val", "test"]) -> None:
if (eval_type == "val" and task_spec.run_on_validation) or (
eval_type == "test" and task_spec.run_on_test
):
if not isinstance(task_spec.task, EvaluationHooks):
raise TypeError(
f"Expected {task_spec.task} to be an instance of "
f"`EvaluationHooks` but got {task_spec.task.__bases__}."
)
task_spec.task.on_evaluation_epoch_start(self)

def _shared_eval_step(
Expand Down

0 comments on commit bba0c07

Please sign in to comment.