Skip to content

Commit

Permalink
reformat using black, update tokens.txt and improve token file handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 8, 2023
1 parent b498f00 commit 4f932ef
Show file tree
Hide file tree
Showing 18 changed files with 1,265 additions and 229 deletions.
98 changes: 74 additions & 24 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,90 @@ def metric_name(self):
def apply_metric(self, target, pred):
raise NotImplementedError

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_train_epoch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
device = pl_module.device
self.train_labels = torch.empty(size=(0,), dtype=torch.int, device=device)
self.train_preds = torch.empty(size=(0,), dtype=torch.int, device=device)

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT,
batch: Any, batch_idx: int) -> None:
self.train_labels = torch.concatenate((self.train_labels, outputs['labels'],))
self.train_preds = torch.concatenate((self.train_preds, outputs['preds'],))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.log(f'train_{self.metric_name}', self.apply_metric(self.train_labels, self.train_preds))

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.val_labels = torch.empty(size=(0,), dtype=torch.int, device=pl_module.device)
self.val_preds = torch.empty(size=(0,), dtype=torch.int, device=pl_module.device)

def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT,
batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.val_labels = torch.concatenate((self.val_labels, outputs['labels'],))
self.val_preds = torch.concatenate((self.val_preds, outputs['preds'],))

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pl_module.log(f'val_{self.metric_name}', self.apply_metric(self.val_labels, self.val_preds))
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
self.train_labels = torch.concatenate(
(
self.train_labels,
outputs["labels"],
)
)
self.train_preds = torch.concatenate(
(
self.train_preds,
outputs["preds"],
)
)

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
pl_module.log(
f"train_{self.metric_name}",
self.apply_metric(self.train_labels, self.train_preds),
)

def on_validation_epoch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
self.val_labels = torch.empty(
size=(0,), dtype=torch.int, device=pl_module.device
)
self.val_preds = torch.empty(
size=(0,), dtype=torch.int, device=pl_module.device
)

def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self.val_labels = torch.concatenate(
(
self.val_labels,
outputs["labels"],
)
)
self.val_preds = torch.concatenate(
(
self.val_preds,
outputs["preds"],
)
)

def on_validation_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
pl_module.log(
f"val_{self.metric_name}",
self.apply_metric(self.val_labels, self.val_preds),
)


class EpochLevelMacroF1(_EpochLevelMetric):

@property
def metric_name(self):
return 'ep_macro-f1'
return "ep_macro-f1"

def apply_metric(self, target, pred):
f1 = MultilabelF1Score(num_labels=self.num_labels, average='macro')
if target.get_device() != -1: # -1 == CPU
f1 = MultilabelF1Score(num_labels=self.num_labels, average="macro")
if target.get_device() != -1: # -1 == CPU
f1 = f1.to(device=target.get_device())
return f1(pred, target)
15 changes: 11 additions & 4 deletions chebai/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from lightning.fabric.utilities.cloud_io import _is_dir
from lightning.pytorch.utilities.rank_zero import rank_zero_info


class CustomModelCheckpoint(ModelCheckpoint):
"""Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the
same directory as the other logs"""

def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None:
def setup(
self, trainer: "Trainer", pl_module: "LightningModule", stage: str
) -> None:
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir"""
if self.dirpath is not None:
self.dirpath = None
Expand All @@ -22,7 +25,11 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->

def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
"""Same as in parent class, duplicated because method in parent class is not accessible"""
if self.save_top_k != 0 and _is_dir(self._fs, dirpath, strict=True) and len(self._fs.ls(dirpath)) > 0:
if (
self.save_top_k != 0
and _is_dir(self._fs, dirpath, strict=True)
and len(self._fs.ls(dirpath)) > 0
):
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
Expand All @@ -36,7 +43,7 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
The path gets extended with subdirectory "checkpoints".
"""
print(f'Resolving checkpoint dir (custom)')
print(f"Resolving checkpoint dir (custom)")
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return self.dirpath
Expand All @@ -57,5 +64,5 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

print(f'Now using checkpoint path {ckpt_path}')
print(f"Now using checkpoint path {ckpt_path}")
return ckpt_path
13 changes: 8 additions & 5 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


class ChebaiCLI(LightningCLI):

def __init__(self, *args, **kwargs):
super().__init__(trainer_class=InnerCVTrainer, *args, **kwargs)

Expand All @@ -16,7 +15,9 @@ def add_arguments_to_parser(self, parser):
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)
parser.link_arguments("model.init_args.out_dim", "trainer.callbacks.init_args.num_labels")
parser.link_arguments(
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
)
# parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why

@staticmethod
Expand All @@ -28,10 +29,12 @@ def subcommands() -> Dict[str, Set[str]]:
"test": {"model", "dataloaders", "datamodule"},
"predict": {"model", "dataloaders", "datamodule"},
"cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
"predict_from_file": {"model"}
"predict_from_file": {"model"},
}


def cli():
r = ChebaiCLI(save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"})
r = ChebaiCLI(
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)
44 changes: 32 additions & 12 deletions chebai/loggers/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,41 @@
class CustomLogger(WandbLogger):
"""Adds support for custom naming of runs and cross-validation"""

def __init__(self, save_dir: _PATH, name: str = "logs", version: Optional[Union[int, str]] = None, prefix: str = "",
fold: Optional[int] = None, project: Optional[str] = None, entity: Optional[str] = None,
offline: bool = False,
log_model: Union[Literal["all"], bool] = False, **kwargs):
def __init__(
self,
save_dir: _PATH,
name: str = "logs",
version: Optional[Union[int, str]] = None,
prefix: str = "",
fold: Optional[int] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
offline: bool = False,
log_model: Union[Literal["all"], bool] = False,
**kwargs,
):
if version is None:
version = f'{datetime.now():%y%m%d-%H%M}'
version = f"{datetime.now():%y%m%d-%H%M}"
self._version = version
self._name = name
self._fold = fold
super().__init__(name=self.name, save_dir=save_dir, version=None, prefix=prefix,
log_model=log_model, entity=entity, project=project, offline=offline, **kwargs)
super().__init__(
name=self.name,
save_dir=save_dir,
version=None,
prefix=prefix,
log_model=log_model,
entity=entity,
project=project,
offline=offline,
**kwargs,
)

@property
def name(self) -> Optional[str]:
name = f'{self._name}_{self.version}'
name = f"{self._name}_{self.version}"
if self._fold is not None:
name += f'_fold{self._fold}'
name += f"_fold{self._fold}"
return name

@property
Expand All @@ -39,17 +57,19 @@ def root_dir(self) -> Optional[str]:

@property
def log_dir(self) -> str:
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
version = (
self.version if isinstance(self.version, str) else f"version_{self.version}"
)
if self._fold is None:
return os.path.join(self.root_dir, version)
return os.path.join(self.root_dir, version, f'fold_{self._fold}')
return os.path.join(self.root_dir, version, f"fold_{self._fold}")

def set_fold(self, fold: int):
if fold != self._fold:
self._fold = fold
# start new experiment
wandb.finish()
self._wandb_init['name'] = self.name
self._wandb_init["name"] = self.name
self._experiment = None
_ = self.experiment

Expand Down
2 changes: 1 addition & 1 deletion chebai/loss/pretraining.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch


class ElectraPreLoss(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -14,4 +15,3 @@ def forward(self, input, target, **loss_kwargs):
target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred
)
return gen_loss + disc_loss

38 changes: 20 additions & 18 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,24 @@ def _process_labels_in_batch(self, batch):

def forward(self, data, **kwargs):
features = data["features"]
features = features.to(self.device).long() # this has been added for selfies, i neither know why it is needed now, nor why it wasnt needed before
features = features.to(
self.device
).long() # this has been added for selfies, i neither know why it is needed now, nor why it wasnt needed before
self.batch_size = batch_size = features.shape[0]
max_seq_len = features.shape[1]

mask = kwargs["mask"]
with torch.no_grad():
dis_tar = (
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
).int()
disc_tar_one_hot = torch.eq(
torch.arange(max_seq_len, device=self.device)[None, :], dis_tar[:, None]
)
gen_tar = features[disc_tar_one_hot]
gen_tar_one_hot = torch.eq(
torch.arange(self.generator_config.vocab_size, device=self.device)[
None, :
None, :
],
gen_tar[:, None],
)
Expand Down Expand Up @@ -101,7 +103,7 @@ def _get_prediction_and_labels(self, batch, labels, output):

def filter_dict(d, filter_key):
return {
str(k)[len(filter_key):]: v
str(k)[len(filter_key) :]: v
for k, v in d.items()
if str(k).startswith(filter_key)
}
Expand All @@ -122,10 +124,10 @@ def _process_batch(self, batch, batch_idx):
batch_first=True,
)
cls_tokens = (
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
)
return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
Expand All @@ -140,7 +142,7 @@ def as_pretrained(self):
return self.electra.electra

def __init__(
self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs
self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs
):
# Remove this property in order to prevent it from being stored as a
# hyper parameter
Expand Down Expand Up @@ -202,7 +204,7 @@ def forward(self, data, **kwargs):
try:
inp = self.electra.embeddings.forward(data["features"].int())
except RuntimeError as e:
print(f'RuntimeError at forward: {e}')
print(f"RuntimeError at forward: {e}")
print(f'data[features]: {data["features"]}')
raise Exception
inp = self.word_dropout(inp)
Expand Down Expand Up @@ -258,10 +260,10 @@ def _process_batch(self, batch, batch_idx):
batch_first=True,
)
cls_tokens = (
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
)
return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
Expand Down Expand Up @@ -296,7 +298,7 @@ def __init__(self, cone_dimensions=20, **kwargs):
model_dict = torch.load(fin, map_location=self.device)
if model_prefix:
state_dict = {
str(k)[len(model_prefix):]: v
str(k)[len(model_prefix) :]: v
for k, v in model_dict["state_dict"].items()
if str(k).startswith(model_prefix)
}
Expand Down Expand Up @@ -357,7 +359,7 @@ def forward(self, data, **kwargs):


def softabs(x, eps=0.01):
return (x ** 2 + eps) ** 0.5 - eps ** 0.5
return (x**2 + eps) ** 0.5 - eps**0.5


def anglify(x):
Expand All @@ -384,8 +386,8 @@ def in_cone_parts(vectors, cone_axes, cone_arcs):
dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang)
return dis
"""
a = cone_axes - cone_arcs ** 2
b = cone_axes + cone_arcs ** 2
a = cone_axes - cone_arcs**2
b = cone_axes + cone_arcs**2
bigger_than_a = torch.sigmoid(vectors - a)
smaller_than_b = torch.sigmoid(b - vectors)
return bigger_than_a * smaller_than_b
Expand Down
Loading

0 comments on commit 4f932ef

Please sign in to comment.