-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from ChEB-AI/features-sfluegel
Various features
- Loading branch information
Showing
50 changed files
with
4,897 additions
and
1,993 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import os | ||
|
||
import torch | ||
|
||
MODULE_PATH = os.path.abspath(os.path.dirname(__file__)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
import torchmetrics | ||
|
||
|
||
def custom_reduce_fx(input): | ||
print(f"called reduce (device: {input.device})") | ||
return torch.sum(input, dim=0) | ||
|
||
|
||
class MacroF1(torchmetrics.Metric): | ||
def __init__(self, num_labels, dist_sync_on_step=False, threshold=0.5): | ||
super().__init__(dist_sync_on_step=dist_sync_on_step) | ||
|
||
self.add_state( | ||
"true_positives", | ||
default=torch.zeros(num_labels, dtype=torch.int), | ||
dist_reduce_fx="sum", | ||
) | ||
self.add_state( | ||
"positive_predictions", | ||
default=torch.zeros(num_labels, dtype=torch.int), | ||
dist_reduce_fx="sum", | ||
) | ||
self.add_state( | ||
"positive_labels", | ||
default=torch.zeros(num_labels, dtype=torch.int), | ||
dist_reduce_fx="sum", | ||
) | ||
self.threshold = threshold | ||
|
||
def update(self, preds: torch.Tensor, labels: torch.Tensor): | ||
tps = torch.sum( | ||
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0 | ||
) | ||
self.true_positives += tps | ||
self.positive_predictions += torch.sum(preds > self.threshold, dim=0) | ||
self.positive_labels += torch.sum(labels, dim=0) | ||
|
||
def compute(self): | ||
# ignore classes without positive labels | ||
# classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0), | ||
# which is propagated to the classwise_f1 and then turned into 0 | ||
mask = self.positive_labels != 0 | ||
precision = self.true_positives[mask] / self.positive_predictions[mask] | ||
recall = self.true_positives[mask] / self.positive_labels[mask] | ||
classwise_f1 = 2 * precision * recall / (precision + recall) | ||
# if (precision and recall are 0) or (precision is nan), set f1 to 0 | ||
classwise_f1 = classwise_f1.nan_to_num() | ||
return torch.mean(classwise_f1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
|
||
from lightning.fabric.utilities.cloud_io import _is_dir | ||
from lightning.fabric.utilities.types import _PATH | ||
from lightning.pytorch import LightningModule, Trainer | ||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint | ||
from lightning.pytorch.loggers import WandbLogger | ||
from lightning.pytorch.utilities.rank_zero import rank_zero_info | ||
from lightning_utilities.core.rank_zero import rank_zero_warn | ||
|
||
|
||
class CustomModelCheckpoint(ModelCheckpoint): | ||
"""Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the | ||
same directory as the other logs""" | ||
|
||
def setup( | ||
self, trainer: "Trainer", pl_module: "LightningModule", stage: str | ||
) -> None: | ||
"""Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir""" | ||
if self.dirpath is not None: | ||
self.dirpath = None | ||
dirpath = self.__resolve_ckpt_dir(trainer) | ||
dirpath = trainer.strategy.broadcast(dirpath) | ||
self.dirpath = dirpath | ||
if trainer.is_global_zero and stage == "fit": | ||
self.__warn_if_dir_not_empty(self.dirpath) | ||
|
||
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: | ||
"""Same as in parent class, duplicated because method in parent class is not accessible""" | ||
if ( | ||
self.save_top_k != 0 | ||
and _is_dir(self._fs, dirpath, strict=True) | ||
and len(self._fs.ls(dirpath)) > 0 | ||
): | ||
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") | ||
|
||
def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: | ||
"""Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs""" | ||
rank_zero_info(f"Resolving checkpoint dir (custom)") | ||
if self.dirpath is not None: | ||
# short circuit if dirpath was passed to ModelCheckpoint | ||
return self.dirpath | ||
if len(trainer.loggers) > 0: | ||
if trainer.loggers[0].save_dir is not None: | ||
save_dir = trainer.loggers[0].save_dir | ||
else: | ||
save_dir = trainer.default_root_dir | ||
name = trainer.loggers[0].name | ||
version = trainer.loggers[0].version | ||
version = version if isinstance(version, str) else f"version_{version}" | ||
logger = trainer.loggers[0] | ||
if isinstance(logger, WandbLogger) and isinstance( | ||
logger.experiment.dir, str | ||
): | ||
ckpt_path = os.path.join(logger.experiment.dir, "checkpoints") | ||
else: | ||
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") | ||
else: | ||
# if no loggers, use default_root_dir | ||
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") | ||
|
||
rank_zero_info(f"Now using checkpoint path {ckpt_path}") | ||
return ckpt_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,39 @@ | ||
from lightning.pytorch.cli import LightningCLI | ||
from typing import Dict, Set | ||
|
||
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI | ||
|
||
from chebai.trainer.CustomTrainer import CustomTrainer | ||
|
||
|
||
class ChebaiCLI(LightningCLI): | ||
def add_arguments_to_parser(self, parser): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(trainer_class=CustomTrainer, *args, **kwargs) | ||
|
||
def add_arguments_to_parser(self, parser: LightningArgumentParser): | ||
for kind in ("train", "val", "test"): | ||
for average in ("micro", "macro"): | ||
parser.link_arguments( | ||
"model.init_args.out_dim", | ||
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels", | ||
) | ||
parser.link_arguments( | ||
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" | ||
) | ||
|
||
@staticmethod | ||
def subcommands() -> Dict[str, Set[str]]: | ||
"""Defines the list of available subcommands and the arguments to skip.""" | ||
return { | ||
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, | ||
"validate": {"model", "dataloaders", "datamodule"}, | ||
"test": {"model", "dataloaders", "datamodule"}, | ||
"predict": {"model", "dataloaders", "datamodule"}, | ||
"predict_from_file": {"model"}, | ||
} | ||
|
||
|
||
def cli(): | ||
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"}) | ||
r = ChebaiCLI( | ||
save_config_kwargs={"config_filename": "lightning_config.yaml"}, | ||
parser_kwargs={"parser_mode": "omegaconf"}, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from datetime import datetime | ||
from typing import Literal, Optional, Union | ||
import os | ||
|
||
from lightning.fabric.utilities.types import _PATH | ||
from lightning.pytorch.callbacks import ModelCheckpoint | ||
from lightning.pytorch.loggers import WandbLogger | ||
import wandb | ||
|
||
|
||
class CustomLogger(WandbLogger): | ||
"""Adds support for custom naming of runs and cross-validation""" | ||
|
||
def __init__( | ||
self, | ||
save_dir: _PATH, | ||
name: str = "logs", | ||
version: Optional[Union[int, str]] = None, | ||
prefix: str = "", | ||
fold: Optional[int] = None, | ||
project: Optional[str] = None, | ||
entity: Optional[str] = None, | ||
offline: bool = False, | ||
log_model: Union[Literal["all"], bool] = False, | ||
**kwargs, | ||
): | ||
if version is None: | ||
version = f"{datetime.now():%y%m%d-%H%M}" | ||
self._version = version | ||
self._name = name | ||
self._fold = fold | ||
super().__init__( | ||
name=self.name, | ||
save_dir=save_dir, | ||
version=None, | ||
prefix=prefix, | ||
log_model=log_model, | ||
entity=entity, | ||
project=project, | ||
offline=offline, | ||
**kwargs, | ||
) | ||
|
||
@property | ||
def name(self) -> Optional[str]: | ||
name = f"{self._name}_{self.version}" | ||
if self._fold is not None: | ||
name += f"_fold{self._fold}" | ||
return name | ||
|
||
@property | ||
def version(self) -> Optional[str]: | ||
return self._version | ||
|
||
@property | ||
def root_dir(self) -> Optional[str]: | ||
return os.path.join(self.save_dir, self.name) | ||
|
||
@property | ||
def log_dir(self) -> str: | ||
version = ( | ||
self.version if isinstance(self.version, str) else f"version_{self.version}" | ||
) | ||
if self._fold is None: | ||
return os.path.join(self.root_dir, version) | ||
return os.path.join(self.root_dir, version, f"fold_{self._fold}") | ||
|
||
def set_fold(self, fold: int): | ||
if fold != self._fold: | ||
self._fold = fold | ||
# start new experiment | ||
wandb.finish() | ||
self._wandb_init["name"] = self.name | ||
self._experiment = None | ||
_ = self.experiment | ||
|
||
@property | ||
def fold(self): | ||
return self._fold | ||
|
||
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: | ||
# don't save checkpoint as wandb artifact | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.