Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add a parameter base_criterion to deep models #217

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions skada/deep/_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def DANN(
reg=1,
domain_classifier=None,
num_features=None,
base_criterion=None,
domain_criterion=None,
**kwargs,
):
Expand Down Expand Up @@ -109,6 +110,9 @@ def DANN(
the feature extractor.
If domain_classifier is None, num_features has to be
provided.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
domain_criterion : torch criterion (class)
The criterion (loss) used to compute the
DANN loss. If None, a BCELoss is used.
Expand All @@ -127,14 +131,17 @@ def DANN(
)
domain_classifier = DomainClassifier(num_features=num_features)

if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
module__domain_classifier=domain_classifier,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DANNLoss(domain_criterion=domain_criterion),
**kwargs,
Expand Down Expand Up @@ -319,6 +326,7 @@ def CDAN(
domain_classifier=None,
num_features=None,
n_classes=None,
base_criterion=None,
domain_criterion=None,
**kwargs,
):
Expand Down Expand Up @@ -351,6 +359,9 @@ def CDAN(
n_classes : int, default None
Number of output classes.
If domain_classifier is None, n_classes has to be provided.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
domain_criterion : torch criterion (class)
The criterion (loss) used to compute the
CDAN loss. If None, a BCELoss is used.
Expand All @@ -372,6 +383,9 @@ def CDAN(
num_features = np.min([num_features * n_classes, max_features])
domain_classifier = DomainClassifier(num_features=num_features)

if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=CDANModule,
module__base_module=module,
Expand All @@ -380,7 +394,7 @@ def CDAN(
module__max_features=max_features,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=CDANLoss(domain_criterion=domain_criterion),
**kwargs,
Expand Down
24 changes: 17 additions & 7 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(
return loss


def DeepCoral(module, layer_name, reg=1, **kwargs):
def DeepCoral(module, layer_name, reg=1, base_criterion=None, **kwargs):
"""DeepCORAL domain adaptation method.

From [12]_.
Expand All @@ -64,20 +64,26 @@ def DeepCoral(module, layer_name, reg=1, **kwargs):
collected during the training for the adaptation.
reg : float, optional (default=1)
The regularization parameter of the covariance estimator.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.

References
----------
.. [12] Baochen Sun and Kate Saenko. Deep coral:
Correlation alignment for deep domain
adaptation. In ECCV Workshops, 2016.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=torch.nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DeepCoralLoss(),
**kwargs,
Expand Down Expand Up @@ -123,7 +129,7 @@ def forward(
return loss


def DAN(module, layer_name, reg=1, sigmas=None, **kwargs):
def DAN(module, layer_name, reg=1, sigmas=None, base_criterion=None, **kwargs):
"""DAN domain adaptation method.

See [14]_.
Expand All @@ -139,22 +145,26 @@ def DAN(module, layer_name, reg=1, sigmas=None, **kwargs):
The regularization parameter of the covariance estimator.
sigmas : array-like, optional (default=None)
The sigmas for the Gaussian kernel.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.

References
----------
.. [14] Mingsheng Long et. al. Learning Transferable
Features with Deep Adaptation Networks.
In ICML, 2015.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion(
torch.nn.CrossEntropyLoss(), DANLoss(sigmas=sigmas), reg=reg
),
criterion__criterion=torch.nn.CrossEntropyLoss(),
criterion=DomainAwareCriterion,
criterion__base_criterion=base_criterion,
criterion__reg=reg,
criterion__adapt_criterion=DANLoss(sigmas=sigmas),
**kwargs,
Expand Down
20 changes: 17 additions & 3 deletions skada/deep/_optimal_transport.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Author: Theo Gnassounou <[email protected]>
#
# License: BSD 3-Clause
from torch import nn
import torch

from skada.deep.base import (
BaseDALoss,
Expand Down Expand Up @@ -67,7 +67,15 @@ def forward(
return loss


def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwargs):
def DeepJDOT(
module,
layer_name,
reg=1,
reg_cl=1,
base_criterion=None,
target_criterion=None,
**kwargs,
):
"""DeepJDOT.

See [13]_.
Expand All @@ -83,6 +91,9 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg
Regularization parameter.
reg_cl : float, default=1
Class distance term regularization parameter.
base_criterion : torch criterion (class)
The base criterion used to compute the loss with source
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
target_criterion : torch criterion (class)
The uninitialized criterion (loss) used to compute the
DeepJDOT loss. The criterion should support reduction='none'.
Expand All @@ -96,13 +107,16 @@ def DeepJDOT(module, layer_name, reg=1, reg_cl=1, target_criterion=None, **kwarg
15th European Conference on Computer Vision,
September 2018. Springer.
"""
if base_criterion is None:
base_criterion = torch.nn.CrossEntropyLoss()

net = DomainAwareNet(
module=DomainAwareModule,
module__base_module=module,
module__layer_name=layer_name,
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion,
criterion__criterion=nn.CrossEntropyLoss(),
criterion__base_criterion=base_criterion,
criterion__adapt_criterion=DeepJDOTLoss(reg_cl, target_criterion),
criterion__reg=reg,
**kwargs,
Expand Down
8 changes: 4 additions & 4 deletions skada/deep/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DomainAwareCriterion(torch.nn.Module):

Parameters
----------
criterion : torch criterion (class)
base_criterion : torch criterion (class)
The initialized criterion (loss) used to optimize the
module with prediction on source.
adapt_criterion : torch criterion (class)
Expand All @@ -32,9 +32,9 @@ class DomainAwareCriterion(torch.nn.Module):
Regularization parameter.
"""

def __init__(self, criterion, adapt_criterion, reg=1):
def __init__(self, base_criterion, adapt_criterion, reg=1):
super(DomainAwareCriterion, self).__init__()
self.criterion = criterion
self.base_criterion = base_criterion
self.adapt_criterion = adapt_criterion
self.reg = reg

Expand Down Expand Up @@ -73,7 +73,7 @@ def forward(
features_t = features[~source_idx]

# predict
return self.criterion(
return self.base_criterion(
y_pred_s, y_true[source_idx]
) + self.reg * self.adapt_criterion(
y_true[source_idx],
Expand Down
Loading