-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathssl_online.py
198 lines (165 loc) · 6.72 KB
/
ssl_online.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from contextlib import contextmanager
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities import rank_zero_warn
from torch import Tensor, nn
from torch.nn import functional as F
from torch.optim import Optimizer
from torchmetrics.functional import accuracy
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pl_bolts.utils.stability import under_review
import ray
from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import ASHAScheduler
@under_review()
class SSLOnlineEvaluator(Callback): # pragma: no cover
"""Attaches a MLP for fine-tuning using the standard self-supervised protocol.
Example::
# your datamodule must have 2 attributes
dm = DataModule()
dm.num_classes = ... # the num of classes in the datamodule
dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10)
# your model must have 1 attribute
model = Model()
model.z_dim = ... # the representation dim
online_eval = SSLOnlineEvaluator(
z_dim=model.z_dim
)
"""
def __init__(
self,
z_dim: int,
drop_p: float = 0.2,
hidden_dim: Optional[int] = None,
num_classes: Optional[int] = None,
dataset: Optional[str] = None,
isTune=False
):
"""
Args:
z_dim: Representation dimension
drop_p: Dropout probability
hidden_dim: Hidden dimension for the fine-tune MLP
"""
super().__init__()
self.z_dim = z_dim
self.hidden_dim = hidden_dim
self.drop_p = drop_p
self.optimizer: Optional[Optimizer] = None
self.online_evaluator: Optional[SSLEvaluator] = None
self.num_classes: Optional[int] = None
self.dataset: Optional[str] = None
self.num_classes: Optional[int] = num_classes
self.dataset: Optional[str] = dataset
self.isTune = isTune
self._recovered_callback_state: Optional[Dict[str, Any]] = None
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
if self.num_classes is None:
self.num_classes = trainer.datamodule.num_classes
if self.dataset is None:
self.dataset = trainer.datamodule.name
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
# must move to device after setup, as during setup, pl_module is still on cpu
self.online_evaluator = SSLEvaluator(
n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim,
).to(pl_module.device)
# switch fo PL compatibility reasons
accel = (
trainer.accelerator_connector
if hasattr(trainer, "accelerator_connector")
else trainer._accelerator_connector
)
if accel.is_distributed:
if accel.use_ddp:
from torch.nn.parallel import DistributedDataParallel as DDP
self.online_evaluator = DDP(self.online_evaluator, device_ids=[pl_module.device])
elif accel.use_dp:
from torch.nn.parallel import DataParallel as DP
self.online_evaluator = DP(self.online_evaluator, device_ids=[pl_module.device])
else:
rank_zero_warn(
"Does not support this type of distributed accelerator. The online evaluator will not sync."
)
self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(), lr=1e-4)
if self._recovered_callback_state is not None:
self.online_evaluator.load_state_dict(self._recovered_callback_state["state_dict"])
self.optimizer.load_state_dict(self._recovered_callback_state["optimizer_state"])
def to_device(self, batch: Sequence, device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]:
# get the labeled batch
if self.dataset == "stl10":
labeled_batch = batch[1]
batch = labeled_batch
inputs, y = batch
# last input is for online eval
x = inputs[-1]
x = x.to(device)
y = y.to(device)
return x, y
def shared_step(
self,
pl_module: LightningModule,
batch: Sequence,
):
with torch.no_grad():
with set_training(pl_module, False):
x, y = self.to_device(batch, pl_module.device)
representations = pl_module(x).flatten(start_dim=1)
# forward pass
mlp_logits = self.online_evaluator(representations) # type: ignore[operator]
mlp_loss = F.cross_entropy(mlp_logits, y)
acc = accuracy(mlp_logits.softmax(-1), y)
return acc, mlp_loss
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Sequence,
batch: Sequence,
batch_idx: int,
) -> None:
train_acc, mlp_loss = self.shared_step(pl_module, batch)
# update finetune weights
mlp_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
pl_module.log("online_train_acc", train_acc, on_step=True, on_epoch=False)
pl_module.log("online_train_loss", mlp_loss, on_step=True, on_epoch=False)
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Sequence,
batch: Sequence,
batch_idx: int,
dataloader_idx: int,
) -> None:
val_acc, mlp_loss = self.shared_step(pl_module, batch)
if self.isTune:
session.report({"online_val_acc": val_acc.item(), "online_val_loss": mlp_loss.item()})
pl_module.log("online_val_acc", val_acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log("online_val_loss", mlp_loss, on_step=False, on_epoch=True, sync_dist=True)
def state_dict(self) -> dict:
return {"state_dict": self.online_evaluator.state_dict(), "optimizer_state": self.optimizer.state_dict()}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._recovered_callback_state = state_dict
@under_review()
@contextmanager
def set_training(module: nn.Module, mode: bool):
"""Context manager to set training mode.
When exit, recover the original training mode.
Args:
module: module to set training mode
mode: whether to set training mode (True) or evaluation mode (False).
"""
original_mode = module.training
try:
module.train(mode)
yield module
finally:
module.train(original_mode)