Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle BatchNorm and base_module call #252

Open
antoinedemathelin opened this issue Oct 15, 2024 · 1 comment
Open

Handle BatchNorm and base_module call #252

antoinedemathelin opened this issue Oct 15, 2024 · 1 comment

Comments

@antoinedemathelin
Copy link
Contributor

Hi everyone,
I open this issue because I noticed that the base_module is called separately on source and target. This can cause undesired behavior when using BatchNorm with the estimation of the batch variance and mean. I propose to concatenate source and target and separate the two after forward pass in base_module

skada/skada/deep/base.py

Lines 421 to 442 in e80e205

source_idx = sample_domain >= 0
X_s = X[source_idx]
X_t = X[~source_idx]
# Pass sample_weight to base_module_
if sample_weight is not None:
sample_weight_s = sample_weight[source_idx]
y_pred_s = self.base_module_(X_s, sample_weight=sample_weight_s)
else:
y_pred_s = self.base_module_(X_s)
if self.layer_name is not None:
features_s = self.intermediate_layers[self.layer_name]
else:
features_s = None
if sample_weight is not None:
sample_weight_t = sample_weight[~source_idx]
y_pred_t = self.base_module_(X_t, sample_weight=sample_weight_t)
else:
y_pred_t = self.base_module_(X_t)

@YanisLalou
Copy link
Collaborator

You're right about the potential undesired behaviour when using BatchNorm.
However we can't concatenate source and target since afterward we need to fetch the source and target features from the intermediate_layers (I agree we should add comments in the code to explain this).

skada/skada/deep/base.py

Lines 427 to 436 in e80e205

if sample_weight is not None:
sample_weight_s = sample_weight[source_idx]
y_pred_s = self.base_module_(X_s, sample_weight=sample_weight_s)
else:
y_pred_s = self.base_module_(X_s)
if self.layer_name is not None:
features_s = self.intermediate_layers[self.layer_name]
else:
features_s = None

skada/skada/deep/base.py

Lines 438 to 447 in e80e205

if sample_weight is not None:
sample_weight_t = sample_weight[~source_idx]
y_pred_t = self.base_module_(X_t, sample_weight=sample_weight_t)
else:
y_pred_t = self.base_module_(X_t)
if self.layer_name is not None:
features_t = self.intermediate_layers[self.layer_name]
else:
features_t = None

So if you have other ideas on how to fix this issue, don't hesitate to write them down here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants