Skip to content

Commit

Permalink
[TO-REVIEW] Add multi-domain Monge alignment and JCPOT Target shift m…
Browse files Browse the repository at this point in the history
…ethod (#180)

* add per domaij split

* better tets multi_domains

* first shit multimonge alignment

* add test

* add exmaple

* update maping to new API

* add proper references

* add JCPOT

* add stuff

* add jcpot tests

* upate doc and add JCPOT

* exmale in gallery of jcpot

* typo in exmaple
  • Loading branch information
rflamary authored Jul 19, 2024
1 parent fb26ff5 commit ce1d60c
Show file tree
Hide file tree
Showing 10 changed files with 849 additions and 6 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,10 @@ The library is distributed under the 3-Clause BSD license.
[27] S. Si, D. Tao and B. Geng. In IEEE Transactions on Knowledge and Data Engineering, (2010) [Bregman Divergence-Based Regularization for Transfer Subspace Learning](https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=4118b4fc7d61068b9b448fd499876d139baeec81)

[28] Solomon, J., Rustamov, R., Guibas, L., & Butscher, A. (2014, January). [Wasserstein propagation for semi-supervised learning](https://proceedings.mlr.press/v32/solomon14.pdf). In International Conference on Machine Learning (pp. 306-314). PMLR.

[29] Montesuma, Eduardo Fernandes, and Fred Maurice Ngole Mboula. ["Wasserstein barycenter for multi-source domain adaptation."](https://openaccess.thecvf.com/content/CVPR2021/papers/Montesuma_Wasserstein_Barycenter_for_Multi-Source_Domain_Adaptation_CVPR_2021_paper.pdf) In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 16785-16793. 2021.

[30] Gnassounou, Theo, Rémi Flamary, and Alexandre Gramfort. ["Convolution Monge Mapping Normalization for learning on sleep data."](https://proceedings.neurips.cc/paper_files/paper/2023/file/21718991f6acf19a42376b5c7a8668c5-Paper-Conference.pdf) Advances in Neural Information Processing Systems 36 (2024).

[31] Redko, Ievgen, Nicolas Courty, Rémi Flamary, and Devis Tuia.[ "Optimal transport for multi-source domain adaptation under target shift."](https://proceedings.mlr.press/v89/redko19a/redko19a.pdf) In The 22nd International Conference on artificial intelligence and statistics, pp. 849-858. PMLR, 2019.

101 changes: 100 additions & 1 deletion examples/methods/plot_label_prop_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import mean_squared_error
from sklearn.svm import SVC

from skada import OTLabelPropAdapter, make_da_pipeline, source_target_split
from skada import (
JCPOTLabelPropAdapter,
OTLabelPropAdapter,
make_da_pipeline,
source_target_split,
)
from skada.datasets import make_shifted_datasets

# %%
Expand Down Expand Up @@ -330,3 +336,96 @@

plt.title("Propagated labels data")
plt.axis(ax)


# %%
# Generate classification classification dataset and plot it
# -----------------------------------------------------
#
# We generate a simple 2D target shift dataset.

X, y, sample_domain = make_shifted_datasets(
n_samples_source=20,
n_samples_target=20,
shift="target_shift",
noise=0.2,
random_state=42,
)


Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)


# %%
# Train with LabelProp and JCPOT + classifier
# -------------------------
#
# On this target shift dataset, we can see that the label propagation method
# does not work well because it finds correspondences between the source and
# target samples with different classes. In this case JCPOT is more robust
# to this kind of shift because it estimates the class proportions in the target.


clf = make_da_pipeline(OTLabelPropAdapter(), SVC())
clf.fit(X, y, sample_domain=sample_domain)

clf_jcpot = make_da_pipeline(JCPOTLabelPropAdapter(reg=0.1), SVC())
clf_jcpot.fit(X, y, sample_domain=sample_domain)


yt_pred = clf.predict(Xt)
acc_t = (yt_pred == yt).mean()

print(f"LabelProp Accuracy on target: {acc_t:.2f}")


yt_pred = clf_jcpot.predict(Xt)
acc_s_jcpot = (yt_pred == yt).mean()

print(f"JCPOT Accuracy on target: {acc_s_jcpot:.2f}")

XX, YY = np.meshgrid(np.linspace(ax[0], ax[1], 100), np.linspace(ax[2], ax[3], 100))
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)
Z_jcpot = clf_jcpot.predict(np.c_[XX.ravel(), YY.ravel()]).reshape(XX.shape)


plt.figure(7, (10, 5))


plt.subplot(1, 2, 1)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
Z,
extent=(ax[0], ax[1], ax[2], ax[3]),
origin="lower",
alpha=0.5,
cmap="tab10",
vmax=9,
)
plt.title(f"LabelProp reglog on target (ACC={acc_t:.2f})")
plt.axis(ax)

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Prediction")
plt.imshow(
Z_jcpot,
extent=(ax[0], ax[1], ax[2], ax[3]),
origin="lower",
alpha=0.5,
cmap="tab10",
vmax=9,
)
plt.title(f"JCPOT reglog on target (ACC={acc_s_jcpot:.2f})")
plt.axis(ax)
259 changes: 259 additions & 0 deletions examples/methods/plot_monge_alignment_da.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
"""
Multi-domain Linear Monge Alignment
===================================
This example illustrates the use of the MultiLinearMongeAlignmentAdapter
"""

# Author: Remi Flamary
#
# License: BSD 3-Clause
# sphinx_gallery_thumbnail_number = 4

# %% Imports
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LogisticRegression

from skada import (
MultiLinearMongeAlignmentAdapter,
make_da_pipeline,
source_target_split,
)
from skada.datasets import make_shifted_datasets

# %%
# Generate concept drift classification dataset and plot it
# -----------------------------------------------------
#
# We generate a simple 2D concept drift dataset.

X, y, sample_domain = make_shifted_datasets(
n_samples_source=20,
n_samples_target=20,
shift="concept_drift",
noise=0.2,
label="multiclass",
random_state=42,
)


Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

# %%
# Train a classifier on source data
# --------------------------------
#
# We train a simple SVC classifier on the source domain and evaluate its
# performance on the source and target domain. Performance is much lower on
# the target domain due to the shift. We also plot the decision boundary


clf = MultiLinearMongeAlignmentAdapter()
clf.fit(X, sample_domain=sample_domain)

X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True)


plt.figure(5, (10, 3))
plt.subplot(1, 3, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 3, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

plt.subplot(1, 3, 3)
plt.scatter(
X_adapt[sample_domain >= 0, 0],
X_adapt[sample_domain >= 0, 1],
c=y[sample_domain >= 0],
marker="o",
cmap="tab10",
vmax=9,
label="Source",
alpha=0.5,
)
plt.scatter(
X_adapt[sample_domain < 0, 0],
X_adapt[sample_domain < 0, 1],
c=y[sample_domain < 0],
marker="x",
cmap="tab10",
vmax=9,
label="Target",
alpha=1,
)
plt.legend()
plt.title("Adapted data")


# %%
# Train a classifier on adapted data
# ----------------------------------

clf = make_da_pipeline(
MultiLinearMongeAlignmentAdapter(),
LogisticRegression(),
)

clf.fit(X, y, sample_domain=sample_domain)

print(
"Average accuracy on all domains:",
clf.score(X, y, sample_domain=sample_domain, allow_source=True),
)

# %% Multisource and taregt data


def get_multidomain_data(
n_samples_source=100,
n_samples_target=100,
noise=0.1,
random_state=None,
n_sources=3,
n_targets=2,
):
np.random.seed(random_state)
X, y, sample_domain = make_shifted_datasets(
n_samples_source=n_samples_source,
n_samples_target=n_samples_target,
noise=noise,
shift="concept_drift",
label="multiclass",
random_state=random_state,
)
for ns in range(n_sources - 1):
Xi, yi, sample_domaini = make_shifted_datasets(
n_samples_source=n_samples_source,
n_samples_target=n_samples_target,
noise=noise,
shift="concept_drift",
label="multiclass",
random_state=random_state + ns,
mean=np.random.randn(2),
sigma=np.random.rand(2) * 0.5 + 0.5,
)
Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini)
X = np.vstack([X, Xt])
y = np.hstack([y, yt])
sample_domain = np.hstack([sample_domain, np.ones(Xt.shape[0]) * (ns + 2)])

for nt in range(n_targets - 1):
Xi, yi, sample_domaini = make_shifted_datasets(
n_samples_source=n_samples_source,
n_samples_target=n_samples_target,
noise=noise,
shift="concept_drift",
label="multiclass",
random_state=random_state + nt + 42,
mean=np.random.randn(2),
sigma=np.random.rand(2) * 0.5 + 0.5,
)
Xs, Xt, ys, yt = source_target_split(Xi, yi, sample_domain=sample_domaini)
X = np.vstack([X, Xt])
y = np.hstack([y, yt])
sample_domain = np.hstack([sample_domain, -np.ones(Xt.shape[0]) * (nt + 1)])

return X, y, sample_domain


X, y, sample_domain = get_multidomain_data(
n_samples_source=50,
n_samples_target=50,
noise=0.1,
random_state=43,
n_sources=3,
n_targets=2,
)

Xs, Xt, ys, yt = source_target_split(X, y, sample_domain=sample_domain)


plt.figure(5, (10, 5))
plt.subplot(1, 2, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 2, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target domains")
plt.axis(ax)


# %%
clf = MultiLinearMongeAlignmentAdapter()
clf.fit(X, sample_domain=sample_domain)

X_adapt = clf.transform(X, sample_domain=sample_domain, allow_source=True)


plt.figure(5, (10, 3))
plt.subplot(1, 3, 1)
plt.scatter(Xs[:, 0], Xs[:, 1], c=ys, cmap="tab10", vmax=9, label="Source")
plt.title("Source data")
ax = plt.axis()

plt.subplot(1, 3, 2)
plt.scatter(Xt[:, 0], Xt[:, 1], c=yt, cmap="tab10", vmax=9, label="Target")
plt.title("Target data")
plt.axis(ax)

plt.subplot(1, 3, 3)
plt.scatter(
X_adapt[sample_domain >= 0, 0],
X_adapt[sample_domain >= 0, 1],
c=y[sample_domain >= 0],
marker="o",
cmap="tab10",
vmax=9,
label="Source",
alpha=0.5,
)
plt.scatter(
X_adapt[sample_domain < 0, 0],
X_adapt[sample_domain < 0, 1],
c=y[sample_domain < 0],
marker="x",
cmap="tab10",
vmax=9,
label="Target",
alpha=1,
)
plt.legend()
plt.axis(ax)
plt.title("Adapted data")

# %%
# Train a classifier on adapted data
# ----------------------------------

clf = make_da_pipeline(
MultiLinearMongeAlignmentAdapter(),
LogisticRegression(),
)

clf.fit(X, y, sample_domain=sample_domain)

print(
"Average accuracy on all domains:",
clf.score(X, y, sample_domain=sample_domain, allow_source=True),
)
Loading

0 comments on commit ce1d60c

Please sign in to comment.