From 4f932eff9facf153cf73daf2d87ea406c5180c60 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 8 Dec 2023 16:33:47 +0100 Subject: [PATCH] reformat using black, update tokens.txt and improve token file handling --- chebai/callbacks/epoch_metrics.py | 98 ++- chebai/callbacks/model_checkpoint.py | 15 +- chebai/cli.py | 13 +- chebai/loggers/custom.py | 44 +- chebai/loss/pretraining.py | 2 +- chebai/models/electra.py | 38 +- .../bin/deepsmiles_token/tokens.txt | 736 ++++++++++++++++++ .../preprocessing/bin/smiles_token/tokens.txt | 111 +++ chebai/preprocessing/collate.py | 4 +- chebai/preprocessing/datasets/base.py | 31 +- chebai/preprocessing/datasets/pubchem.py | 6 +- chebai/preprocessing/reader.py | 5 +- chebai/result/classification.py | 47 +- chebai/result/pretraining.py | 43 +- chebai/trainer/InnerCVTrainer.py | 48 +- demo_process_results.ipynb | 210 ++--- process_results_old_chebi.ipynb | 41 +- setup.py | 2 +- 18 files changed, 1265 insertions(+), 229 deletions(-) diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py index 3394cb06..09369da6 100644 --- a/chebai/callbacks/epoch_metrics.py +++ b/chebai/callbacks/epoch_metrics.py @@ -23,40 +23,90 @@ def metric_name(self): def apply_metric(self, target, pred): raise NotImplementedError - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: device = pl_module.device self.train_labels = torch.empty(size=(0,), dtype=torch.int, device=device) self.train_preds = torch.empty(size=(0,), dtype=torch.int, device=device) - def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, - batch: Any, batch_idx: int) -> None: - self.train_labels = torch.concatenate((self.train_labels, outputs['labels'],)) - self.train_preds = torch.concatenate((self.train_preds, outputs['preds'],)) - - def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - pl_module.log(f'train_{self.metric_name}', self.apply_metric(self.train_labels, self.train_preds)) - - def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self.val_labels = torch.empty(size=(0,), dtype=torch.int, device=pl_module.device) - self.val_preds = torch.empty(size=(0,), dtype=torch.int, device=pl_module.device) - - def on_validation_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, - batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: - self.val_labels = torch.concatenate((self.val_labels, outputs['labels'],)) - self.val_preds = torch.concatenate((self.val_preds, outputs['preds'],)) - - def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - pl_module.log(f'val_{self.metric_name}', self.apply_metric(self.val_labels, self.val_preds)) + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + self.train_labels = torch.concatenate( + ( + self.train_labels, + outputs["labels"], + ) + ) + self.train_preds = torch.concatenate( + ( + self.train_preds, + outputs["preds"], + ) + ) + + def on_train_epoch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + pl_module.log( + f"train_{self.metric_name}", + self.apply_metric(self.train_labels, self.train_preds), + ) + + def on_validation_epoch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + self.val_labels = torch.empty( + size=(0,), dtype=torch.int, device=pl_module.device + ) + self.val_preds = torch.empty( + size=(0,), dtype=torch.int, device=pl_module.device + ) + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + self.val_labels = torch.concatenate( + ( + self.val_labels, + outputs["labels"], + ) + ) + self.val_preds = torch.concatenate( + ( + self.val_preds, + outputs["preds"], + ) + ) + + def on_validation_epoch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + pl_module.log( + f"val_{self.metric_name}", + self.apply_metric(self.val_labels, self.val_preds), + ) class EpochLevelMacroF1(_EpochLevelMetric): - @property def metric_name(self): - return 'ep_macro-f1' + return "ep_macro-f1" def apply_metric(self, target, pred): - f1 = MultilabelF1Score(num_labels=self.num_labels, average='macro') - if target.get_device() != -1: # -1 == CPU + f1 = MultilabelF1Score(num_labels=self.num_labels, average="macro") + if target.get_device() != -1: # -1 == CPU f1 = f1.to(device=target.get_device()) return f1(pred, target) diff --git a/chebai/callbacks/model_checkpoint.py b/chebai/callbacks/model_checkpoint.py index b5740438..70119981 100644 --- a/chebai/callbacks/model_checkpoint.py +++ b/chebai/callbacks/model_checkpoint.py @@ -6,11 +6,14 @@ from lightning.fabric.utilities.cloud_io import _is_dir from lightning.pytorch.utilities.rank_zero import rank_zero_info + class CustomModelCheckpoint(ModelCheckpoint): """Checkpoint class that resolves checkpoint paths s.t. for the CustomLogger, checkpoints get saved to the same directory as the other logs""" - def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: + def setup( + self, trainer: "Trainer", pl_module: "LightningModule", stage: str + ) -> None: """Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir""" if self.dirpath is not None: self.dirpath = None @@ -22,7 +25,11 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: """Same as in parent class, duplicated because method in parent class is not accessible""" - if self.save_top_k != 0 and _is_dir(self._fs, dirpath, strict=True) and len(self._fs.ls(dirpath)) > 0: + if ( + self.save_top_k != 0 + and _is_dir(self._fs, dirpath, strict=True) + and len(self._fs.ls(dirpath)) > 0 + ): rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: @@ -36,7 +43,7 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: The path gets extended with subdirectory "checkpoints". """ - print(f'Resolving checkpoint dir (custom)') + print(f"Resolving checkpoint dir (custom)") if self.dirpath is not None: # short circuit if dirpath was passed to ModelCheckpoint return self.dirpath @@ -57,5 +64,5 @@ def __resolve_ckpt_dir(self, trainer: "Trainer") -> _PATH: # if no loggers, use default_root_dir ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") - print(f'Now using checkpoint path {ckpt_path}') + print(f"Now using checkpoint path {ckpt_path}") return ckpt_path diff --git a/chebai/cli.py b/chebai/cli.py index d1f24e32..cb0292df 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -5,7 +5,6 @@ class ChebaiCLI(LightningCLI): - def __init__(self, *args, **kwargs): super().__init__(trainer_class=InnerCVTrainer, *args, **kwargs) @@ -16,7 +15,9 @@ def add_arguments_to_parser(self, parser): "model.init_args.out_dim", f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels", ) - parser.link_arguments("model.init_args.out_dim", "trainer.callbacks.init_args.num_labels") + parser.link_arguments( + "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" + ) # parser.link_arguments('n_splits', 'data.init_args.inner_k_folds') # doesn't work but I don't know why @staticmethod @@ -28,10 +29,12 @@ def subcommands() -> Dict[str, Set[str]]: "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, "cv_fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, - "predict_from_file": {"model"} + "predict_from_file": {"model"}, } def cli(): - r = ChebaiCLI(save_config_kwargs={"config_filename": "lightning_config.yaml"}, - parser_kwargs={"parser_mode": "omegaconf"}) + r = ChebaiCLI( + save_config_kwargs={"config_filename": "lightning_config.yaml"}, + parser_kwargs={"parser_mode": "omegaconf"}, + ) diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py index e88b8d42..18da9b09 100644 --- a/chebai/loggers/custom.py +++ b/chebai/loggers/custom.py @@ -10,23 +10,41 @@ class CustomLogger(WandbLogger): """Adds support for custom naming of runs and cross-validation""" - def __init__(self, save_dir: _PATH, name: str = "logs", version: Optional[Union[int, str]] = None, prefix: str = "", - fold: Optional[int] = None, project: Optional[str] = None, entity: Optional[str] = None, - offline: bool = False, - log_model: Union[Literal["all"], bool] = False, **kwargs): + def __init__( + self, + save_dir: _PATH, + name: str = "logs", + version: Optional[Union[int, str]] = None, + prefix: str = "", + fold: Optional[int] = None, + project: Optional[str] = None, + entity: Optional[str] = None, + offline: bool = False, + log_model: Union[Literal["all"], bool] = False, + **kwargs, + ): if version is None: - version = f'{datetime.now():%y%m%d-%H%M}' + version = f"{datetime.now():%y%m%d-%H%M}" self._version = version self._name = name self._fold = fold - super().__init__(name=self.name, save_dir=save_dir, version=None, prefix=prefix, - log_model=log_model, entity=entity, project=project, offline=offline, **kwargs) + super().__init__( + name=self.name, + save_dir=save_dir, + version=None, + prefix=prefix, + log_model=log_model, + entity=entity, + project=project, + offline=offline, + **kwargs, + ) @property def name(self) -> Optional[str]: - name = f'{self._name}_{self.version}' + name = f"{self._name}_{self.version}" if self._fold is not None: - name += f'_fold{self._fold}' + name += f"_fold{self._fold}" return name @property @@ -39,17 +57,19 @@ def root_dir(self) -> Optional[str]: @property def log_dir(self) -> str: - version = self.version if isinstance(self.version, str) else f"version_{self.version}" + version = ( + self.version if isinstance(self.version, str) else f"version_{self.version}" + ) if self._fold is None: return os.path.join(self.root_dir, version) - return os.path.join(self.root_dir, version, f'fold_{self._fold}') + return os.path.join(self.root_dir, version, f"fold_{self._fold}") def set_fold(self, fold: int): if fold != self._fold: self._fold = fold # start new experiment wandb.finish() - self._wandb_init['name'] = self.name + self._wandb_init["name"] = self.name self._experiment = None _ = self.experiment diff --git a/chebai/loss/pretraining.py b/chebai/loss/pretraining.py index dee9aa00..d5af9bb1 100644 --- a/chebai/loss/pretraining.py +++ b/chebai/loss/pretraining.py @@ -1,5 +1,6 @@ import torch + class ElectraPreLoss(torch.nn.Module): def __init__(self): super().__init__() @@ -14,4 +15,3 @@ def forward(self, input, target, **loss_kwargs): target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred ) return gen_loss + disc_loss - diff --git a/chebai/models/electra.py b/chebai/models/electra.py index f61052c9..5d612604 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -53,14 +53,16 @@ def _process_labels_in_batch(self, batch): def forward(self, data, **kwargs): features = data["features"] - features = features.to(self.device).long() # this has been added for selfies, i neither know why it is needed now, nor why it wasnt needed before + features = features.to( + self.device + ).long() # this has been added for selfies, i neither know why it is needed now, nor why it wasnt needed before self.batch_size = batch_size = features.shape[0] max_seq_len = features.shape[1] mask = kwargs["mask"] with torch.no_grad(): dis_tar = ( - torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1) + torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1) ).int() disc_tar_one_hot = torch.eq( torch.arange(max_seq_len, device=self.device)[None, :], dis_tar[:, None] @@ -68,7 +70,7 @@ def forward(self, data, **kwargs): gen_tar = features[disc_tar_one_hot] gen_tar_one_hot = torch.eq( torch.arange(self.generator_config.vocab_size, device=self.device)[ - None, : + None, : ], gen_tar[:, None], ) @@ -101,7 +103,7 @@ def _get_prediction_and_labels(self, batch, labels, output): def filter_dict(d, filter_key): return { - str(k)[len(filter_key):]: v + str(k)[len(filter_key) :]: v for k, v in d.items() if str(k).startswith(filter_key) } @@ -122,10 +124,10 @@ def _process_batch(self, batch, batch_idx): batch_first=True, ) cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN + torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( + -1 + ) + * CLS_TOKEN ) return dict( features=torch.cat((cls_tokens, batch.x), dim=1), @@ -140,7 +142,7 @@ def as_pretrained(self): return self.electra.electra def __init__( - self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs + self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs ): # Remove this property in order to prevent it from being stored as a # hyper parameter @@ -202,7 +204,7 @@ def forward(self, data, **kwargs): try: inp = self.electra.embeddings.forward(data["features"].int()) except RuntimeError as e: - print(f'RuntimeError at forward: {e}') + print(f"RuntimeError at forward: {e}") print(f'data[features]: {data["features"]}') raise Exception inp = self.word_dropout(inp) @@ -258,10 +260,10 @@ def _process_batch(self, batch, batch_idx): batch_first=True, ) cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN + torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( + -1 + ) + * CLS_TOKEN ) return dict( features=torch.cat((cls_tokens, batch.x), dim=1), @@ -296,7 +298,7 @@ def __init__(self, cone_dimensions=20, **kwargs): model_dict = torch.load(fin, map_location=self.device) if model_prefix: state_dict = { - str(k)[len(model_prefix):]: v + str(k)[len(model_prefix) :]: v for k, v in model_dict["state_dict"].items() if str(k).startswith(model_prefix) } @@ -357,7 +359,7 @@ def forward(self, data, **kwargs): def softabs(x, eps=0.01): - return (x ** 2 + eps) ** 0.5 - eps ** 0.5 + return (x**2 + eps) ** 0.5 - eps**0.5 def anglify(x): @@ -384,8 +386,8 @@ def in_cone_parts(vectors, cone_axes, cone_arcs): dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang) return dis """ - a = cone_axes - cone_arcs ** 2 - b = cone_axes + cone_arcs ** 2 + a = cone_axes - cone_arcs**2 + b = cone_axes + cone_arcs**2 bigger_than_a = torch.sigmoid(vectors - a) smaller_than_b = torch.sigmoid(b - vectors) return bigger_than_a * smaller_than_b diff --git a/chebai/preprocessing/bin/deepsmiles_token/tokens.txt b/chebai/preprocessing/bin/deepsmiles_token/tokens.txt index e69de29b..9214efc9 100644 --- a/chebai/preprocessing/bin/deepsmiles_token/tokens.txt +++ b/chebai/preprocessing/bin/deepsmiles_token/tokens.txt @@ -0,0 +1,736 @@ +[*-] +[F-] +. +[K+] +[Cl-] +O +[Mg++] +[Ca++] +[Sr++] +Cl +[201Tl] +[Cs+] +[Mn] +[Pb] +) +C +[C@@H] +c +[nH] +n +5 +6 +[H+] +N +\ +S += +# +s +[N+] +[O-] +[H] +[C@@] +[C@] +10 +14 +9 +[C@H] +13 +11 +4 +F +7 +[NH2+] +3 +[NH+] +18 +Br +15 +/ +[o+] +- +o +[N] +[Ru++] +25 +[n+] +8 +[Br-] +[Na+] +[I-] +[B-] +[Cd++] +[Pb++] +[Cr] +[As] +[Y++] +P +[Ca+2] +[Co+2] +[Fe++] +[Mn++] +[Zn++] +[Li+] +[Au-] +17 +[N-] +12 +[NH3+] +16 +20 +B +[W] +[Mo] +I +[Pd-2] +[S-] +[OH-] +[Te] +[Fe+3] +[Fe] +[Co+3] +[Co+] +[W+4] +[Al+] +[B+] +[Ru+4] +[Sn+4] +[B+3] +[Tl+3] +[U+3] +[*+5] +[Co++] +* +30 +32 +[As+] +[n-] +[S@@] +[Al+3] +[Ca] +[V++] +[Rb+] +[Sn] +[C-] +[Mg--] +22 +[Fe--] +[*] +[Pt] +[2H] +[S+] +[Hg] +[Mg] +[Mg-2] +[NH4+] +[SeH] +[Se] +[Co--] +[Si] +[O] +[Br] +[V] +[Se-] +[Fe-] +[Pd--] +[Ru--] +[Bi-3] +[Ni--] +[Hg--] +[Al-3] +[Mo-4] +[Be-] +[P] +[P-] +[Te-] +[NH-] +19 +21 +24 +[Ni-2] +[nH+] +72 +[cH-] +[Fe+] +[CH-] +[OH+] +[Fe-3] +[N@+] +[N@@+] +[P@] +[Sb+] +[Xe+] +[Ir-3] +[F,Cl,Br,I-] +[O+] +[s+] +[Mo+4] +[Cu+] +[Se+4] +[I] +[n] +[C] +[B] +[N@@] +23 +28 +42 +[S@] +48 +33 +26 +29 +27 +36 +34 +31 +[18F] +[13CH2] +[P--] +[AsH] +[Sb] +[Bi+] +[Bi] +[Ge] +[Ge+] +[He++] +[Ba++] +[Be+2] +[1H] +[3H] +[Al] +[Al-] +[B++] +[P@@] +[203Hg] +44 +43 +35 +56 +52 +51 +40 +41 +80 +76 +75 +68 +67 +60 +59 +[13CH3] +38 +[13C@H] +46 +62 +[Tl] +[Ne] +[Br+] +[P-3] +[PH2] +[SiH] +[Co-3] +[Co] +37 +[Mg+2] +61 +[N@] +[224Ra] +[P+] +[Cu++] +[Ti] +[Ni] +[S+2] +[S+6] +[Fe+2] +[As-3] +[As+3] +[Mo+3] +[Hg+] +[Li] +[Rh] +[Zn] +[131I] +[CH2-] +[Tc] +[Pb--] +[Pb-] +[Zr] +[At] +[Kr] +[220Rn] +[S] +[C-4] +[Mo-3] +[W-4] +[63Cu] +[Ag++] +[Au] +[Rh+3] +[U] +[Gd+3] +[Ce--] +[Eu] +[Cd] +[F+] +[Ni-] +[Fr] +[25Mg] +[Be] +[Sr] +[205Tl] +[10C] +[11C] +[14C] +[33S] +[216Po] +[218Po] +[195Po] +[208Po] +[19F] +[111Cd] +[Os] +[Hs] +[Ds] +[Pr] +[Np] +[Am] +[Bk] +[Md] +[Ag+] +[Cu] +[Y+3] +[Yb+3] +[223Ra] +[Sn++] +[Pt+2] +[I+] +[LiH] +[At-] +[Cr-] +[K-] +[Li-] +[Na-] +[Si-] +[Rb-] +[Fr-] +[Cs-] +[Zn-] +[*-4] +[Ge-4] +[*++] +[4He++] +[4He] +[3He++] +[3He] +[Cr++] +[Eu++] +[Ge++] +[O--] +[Os++] +[Pt++] +[U++] +[Be++] +[Gd++] +[La++] +[Se++] +[Ti+4] +[O-2] +[Cu+2] +[S--] +[Mn+2] +[Hg++] +[Ni++] +[Zn+2] +[125I] +[Au-3] +[Pt--] +[Se--] +[Cl] +[NH2-] +[Fe-4] +[Ni-4] +[Os--] +[Fr+] +[Ca+] +[Ba+] +[Mg+] +[Sr+] +[Be+] +[As-] +[Cr+4] +[Cr+5] +[Cr+6] +[Cr+3] +[Cu+3] +[Ni+3] +[Ni+] +[Mn+4] +[Mn+3] +[V+] +[V+3] +[V+4] +[V+5] +[Zn+] +[Mo+6] +[Mo+5] +[Pt+4] +[Pt+3] +[W+6] +[W+5] +[Pb+4] +[Bi+3] +[Sm+3] +[*+] +[1H+] +[3H+] +[2H+] +[Au+] +[In+] +[Si+] +[Tl+] +[*+4] +[Ge+4] +[Os+4] +[Si+4] +[U+4] +[Ce+4] +[Gd+4] +[C+4] +[*+3] +[Sb+3] +[Eu+3] +[Au+3] +[In+3] +[Lu+3] +[Os+3] +[Ru+3] +[Tb+3] +[La+3] +[Ce+3] +[Ho+3] +[Ir+3] +[Ga+3] +[U+5] +[*+6] +[U+6] +[O+6] +[C+] +[Pr+3] +[Cu-2] +[Zn-2] +[Ba+2] +[V+2] +[Cr--] +[Hg-] +[Pt-] +[c-] +[Gd-] +[Fe-2] +[S@+] +[13C] +[Cr-3] +[Re] +[K] +[Na] +[W-] +[W--] +[Cu--] +[Cd--] +[Sn-] +[Rh-3] +[Ir--] +[Rh--] +[Ag--] +[Mo-] +[Ag-3] +[Cr-4] +[Ni-3] +[W-3] +[Ag-] +[Be--] +[Mg-] +[Mo--] +[Ru-] +[N--] +[Sb-] +[Bi-] +[C--] +[Co-4] +39 +[Se+] +[Ru] +[Mo++] +[Ti-2] +[c] +[N@@H+] +[N@H+] +[AsH3+] +[Cl+] +[Te+] +[He+] +[Kr+] +[Ar+] +[Ne+] +[Rn+] +[S@@+] +[CH+] +[OH2+] +[Te--] +[CH] +[NH] +[CH2] +[Ce] +[U+] +[18O] +[O++] +[H-] +[TeH] +[Po] +49 +55 +50 +85 +47 +63 +53 +[14C@] +[13C@@H] +[15NH2] +[13CH] +[SiH3] +[N++] +[PH3+] +[P++] +[As++] +[As--] +[AsH-] +[B--] +[Ge-] +[Sn+] +[Pb+] +[Ba] +[2H-] +[3H-] +[1H-] +[Al--] +[al] +[al--] +[B-3] +0 +2 +[N-3] +p +[As+5] +[F,Cl,Br,I] +[99Tc] +[F] +[197Hg] +64 +45 +[Ir] +[123I] +[cH+] +[se] +[te] +[pH] +[SiH2] +[In] +[Ga] +[He] +[Xe] +[129Xe] +[Ar] +[Rn] +: +[He-] +[Au++] +[75Se] +[Co-2] +[Pd] +[Gd] +[SH+] +[67Ga+3] +[59Fe+3] +[226Ra] +[228Ra] +[KH] +[NaH] +[Te+4] +[P+3] +[P+5] +[S-2] +[S+4] +[Mn+7] +[Ni+2] +[Se-2] +[Se+6] +[Sr+2] +[Cd+2] +[AlH3] +[Rh++] +[13C@@] +[Si--] +[Si-4] +[Sn--] +[207Pb] +[121Sb] +[123Sb] +[9Be] +[6He] +[8He] +[222Rn] +[219Rn] +[Og] +[Mo-5] +[Mn-] +[Re+] +[Os+] +[Ru-3] +[Ru-4] +[Ag] +[197Au] +[V-4] +[V--] +[V-] +[51V] +[Ni+4] +[Th] +[67Zn] +[Ta] +[Ta-] +[N-2] +[39K] +[23Na] +[7Li] +[6Li] +[Rb] +[87Rb] +[85Rb] +[Cs] +[135Cs] +[137Cs] +[43Ca] +[88Sr] +[Ra] +[11B] +[10B] +[27Al] +[26Al] +[28Al] +[203Tl] +[199Tl] +[12C] +[29Si] +[28Si] +[30Si] +[31Si] +[32Si] +[73Ge] +[119Sn] +[120Sn] +[118Sn] +[116Sn] +[117Sn] +[115Sn] +[Fl] +[31P] +[32P] +[33P] +[75As] +[Mc] +[16O] +[17O] +[15O] +[19O] +[32S] +[34S] +[36S] +[35S] +[37S] +[77Se] +[82Se] +[125Te] +[210Po] +[211Po] +[212Po] +[213Po] +[214Po] +[215Po] +[217Po] +[190Po] +[191Po] +[193Po] +[194Po] +[196Po] +[197Po] +[198Po] +[199Po] +[200Po] +[201Po] +[202Po] +[203Po] +[204Po] +[205Po] +[206Po] +[207Po] +[209Po] +[Lv] +[79Br] +[127I] +[129I] +[Ts] +[65Zn] +[66Zn] +[113Cd] +[Cn] +[Hf] +[Rf] +[Nb] +[93Nb] +[Db] +[51Cr] +[95Mo] +[98Mo] +[Sg] +[55Mn] +[Bh] +[57Fe] +[59Co] +[Mt] +[60Ni] +[Rg] +[Sc] +[45Sc] +[Y] +[89Y] +[Lu] +[Lr] +[151Eu] +[139La] +[Nd] +[Pm] +[Sm] +[Tb] +[Dy] +[Ho] +[Er] +[Tm] +[Yb] +[Ac] +[Pa] +[Pu] +[Cm] +[Cf] +[Es] +[Fm] +[No] +[Ti+3] +[Si+2] +[14CH3] +[Nh] +[38S] +[192Po] +[183W] +[La] diff --git a/chebai/preprocessing/bin/smiles_token/tokens.txt b/chebai/preprocessing/bin/smiles_token/tokens.txt index c9b495ab..f4d83ca0 100644 --- a/chebai/preprocessing/bin/smiles_token/tokens.txt +++ b/chebai/preprocessing/bin/smiles_token/tokens.txt @@ -656,3 +656,114 @@ p [Es] [Fm] [No] +[B+3] +[U+3] +[203Hg] +[Pb--] +[Kr] +[220Rn] +[10C] +[11C] +[33S] +[216Po] +[218Po] +[195Po] +[19F] +[Hs] +[Ds] +[Pr] +[Np] +[Am] +[Md] +[Li-] +[Zn-] +[Ni+] +[Pb+4] +[U+4] +[Ga+3] +[U+5] +[Ag-] +[C--] +[N@H+] +[AsH3+] +[3H-] +[B-3] +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +0 +[se] +[SH+] +[67Ga+3] +[P+3] +[Se+6] +[6He] +[Mn-] +[Ag] +[197Au] +[Ta-] +[6Li] +[19O] +[194Po] +[Nb] +[45Sc] +[Nd] +[Ti+3] +[14CH3] diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py index 43e045cc..56d6309b 100644 --- a/chebai/preprocessing/collate.py +++ b/chebai/preprocessing/collate.py @@ -30,7 +30,9 @@ def __call__(self, data): *((d["features"], d["labels"], d.get("ident")) for d in data) ) if any(x is not None for x in y): - target_mask_candidates = [[v is not None for v in row] for row in y if row is not None] + target_mask_candidates = [ + [v is not None for v in row] for row in y if row is not None + ] if any(map(any, target_mask_candidates)): loss_kwargs["target_mask"] = torch.tensor(target_mask_candidates) if any(x is None for x in y): diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 04c1c913..abf9a2d5 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -44,9 +44,11 @@ def __init__( self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.chebi_version = chebi_version - assert(type(inner_k_folds) is int) + assert type(inner_k_folds) is int self.inner_k_folds = inner_k_folds - self.use_inner_cross_validation = inner_k_folds > 1 # only use cv if there are at least 2 folds + self.use_inner_cross_validation = ( + inner_k_folds > 1 + ) # only use cv if there are at least 2 folds os.makedirs(self.raw_dir, exist_ok=True) os.makedirs(self.processed_dir, exist_ok=True) @@ -81,8 +83,8 @@ def dataloader(self, kind, **kwargs) -> DataLoader: except NotImplementedError: filename = f"{kind}.pt" dataset = torch.load(os.path.join(self.processed_dir, filename)) - if 'ids' in kwargs: - ids = kwargs.pop('ids') + if "ids" in kwargs: + ids = kwargs.pop("ids") _dataset = [] for i in range(len(dataset)): if i in ids: @@ -131,17 +133,24 @@ def _load_data_from_file(self, path): if d["features"] is not None ] # filter for missing features in resulting data - data = [val for val in data if val['features'] is not None] + data = [val for val in data if val["features"] is not None] return data def train_dataloader(self, *args, **kwargs) -> DataLoader: return self.dataloader( - "train" if not self.use_inner_cross_validation else "train_val", shuffle=True, num_workers=self.num_workers, **kwargs + "train" if not self.use_inner_cross_validation else "train_val", + shuffle=True, + num_workers=self.num_workers, + **kwargs, ) def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - return self.dataloader("validation" if not self.use_inner_cross_validation else "train_val", shuffle=False, **kwargs) + return self.dataloader( + "validation" if not self.use_inner_cross_validation else "train_val", + shuffle=False, + **kwargs, + ) def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: return self.dataloader("test", shuffle=False, **kwargs) @@ -153,7 +162,7 @@ def predict_dataloader( def setup(self, **kwargs): print("Check for processed data in ", self.processed_dir) - print(f'Cross-validation enabled: {self.use_inner_cross_validation}') + print(f"Cross-validation enabled: {self.use_inner_cross_validation}") if any( not os.path.isfile(os.path.join(self.processed_dir, f)) for f in self.processed_file_names @@ -161,7 +170,11 @@ def setup(self, **kwargs): self.setup_processed() if self.use_inner_cross_validation: - self.train_val_data = torch.load(os.path.join(self.processed_dir, self.processed_file_names_dict['train_val'])) + self.train_val_data = torch.load( + os.path.join( + self.processed_dir, self.processed_file_names_dict["train_val"] + ) + ) def teardown(self, stage: str) -> None: # cant save hyperparams at setup because logger is not initialised yet diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index ed30bc94..4ed73b6a 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -72,7 +72,7 @@ def download(self): tf.seek(0) with gzip.open(tf, "rb") as f_in: with open( - os.path.join(self.raw_dir, "smiles.txt"), "wb" + os.path.join(self.raw_dir, "smiles.txt"), "wb" ) as f_out: shutil.copyfileobj(f_in, f_out) else: @@ -115,8 +115,8 @@ def processed_file_names(self): def prepare_data(self, *args, **kwargs): print("Check for raw data in", self.raw_dir) if any( - not os.path.isfile(os.path.join(self.raw_dir, f)) - for f in self.raw_file_names + not os.path.isfile(os.path.join(self.raw_dir, f)) + for f in self.raw_file_names ): print("Downloading data. This may take some time...") self.download() diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 4dc4cba1..1fe374c2 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -45,11 +45,15 @@ def name(cls): @property def token_path(self): + """Get token path, create file if it does not exist yet""" if self._token_path is not None: return self._token_path dirname = os.path.dirname(__file__) token_path = os.path.join(dirname, "bin", self.name(), "tokens.txt") os.makedirs(os.path.join(dirname, "bin", self.name()), exist_ok=True) + if not os.path.exists(token_path): + with open(token_path, "x"): + pass return token_path def _read_id(self, raw_data): @@ -110,7 +114,6 @@ def _read_data(self, raw_data): def save_token_cache(self): """write contents of self.cache into tokens.txt""" - dirname = os.path.dirname(__file__) with open(self.token_path, "w") as pk: print(f"saving {len(self.cache)} tokens to {self.token_path}...") print(f"first 10 tokens: {self.cache[:10]}") diff --git a/chebai/result/classification.py b/chebai/result/classification.py index d0487aea..914d1c3e 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -11,20 +11,33 @@ def visualise_f1(logs_path): - df = pd.read_csv(os.path.join(logs_path, 'metrics.csv')) - df_loss = df.melt(id_vars='epoch', value_vars=['val_ep_macro-f1', 'val_micro-f1', 'train_micro-f1', - 'train_ep_macro-f1']) - lineplt = sns.lineplot(df_loss, x='epoch', y='value', hue='variable') - plt.savefig(os.path.join(logs_path, 'f1_plot.png')) + df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) + df_loss = df.melt( + id_vars="epoch", + value_vars=[ + "val_ep_macro-f1", + "val_micro-f1", + "train_micro-f1", + "train_ep_macro-f1", + ], + ) + lineplt = sns.lineplot(df_loss, x="epoch", y="value", hue="variable") + plt.savefig(os.path.join(logs_path, "f1_plot.png")) plt.show() + # get predictions from model def evaluate_model(logs_base_path, model_filename, data_module): model = electra.Electra.load_from_checkpoint( - os.path.join(logs_base_path, 'best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt', model_filename)) + os.path.join( + logs_base_path, + "best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt", + model_filename, + ) + ) assert isinstance(model, electra.Electra) collate = data_module.reader.COLLATER() - test_file = 'test.pt' + test_file = "test.pt" data_path = os.path.join(data_module.processed_dir, test_file) data_list = torch.load(data_path) preds_list = [] @@ -32,8 +45,10 @@ def evaluate_model(logs_base_path, model_filename, data_module): for row in tqdm.tqdm(data_list): processable_data = model._process_batch(collate([row]), 0) - model_output = model(processable_data, **processable_data['model_kwargs']) - preds, labels = model._get_prediction_and_labels(processable_data, processable_data["labels"], model_output) + model_output = model(processable_data, **processable_data["model_kwargs"]) + preds, labels = model._get_prediction_and_labels( + processable_data, processable_data["labels"], model_output + ) preds_list.append(preds) labels_list.append(labels) @@ -42,8 +57,12 @@ def evaluate_model(logs_base_path, model_filename, data_module): print(test_preds.shape) print(test_labels.shape) test_loss = ElectraPreLoss() - print(f'Loss on test set: {test_loss(test_preds, test_labels)}') - f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro') - f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro') - print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}') - print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}') + print(f"Loss on test set: {test_loss(test_preds, test_labels)}") + f1_macro = MultilabelF1Score(test_preds.shape[1], average="macro") + f1_micro = MultilabelF1Score(test_preds.shape[1], average="micro") + print( + f"Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}" + ) + print( + f"Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}" + ) diff --git a/chebai/result/pretraining.py b/chebai/result/pretraining.py index e8203cec..aeca2bf2 100644 --- a/chebai/result/pretraining.py +++ b/chebai/result/pretraining.py @@ -10,18 +10,27 @@ def visualise_loss(logs_path): - df = pd.read_csv(os.path.join(logs_path, 'metrics.csv')) - df_loss = df.melt(id_vars='epoch', value_vars=['val_loss_epoch', 'train_loss_epoch']) - lineplt = sns.lineplot(df_loss, x='epoch', y='value', hue='variable') - plt.savefig(os.path.join(logs_path, 'f1_plot.png')) + df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) + df_loss = df.melt( + id_vars="epoch", value_vars=["val_loss_epoch", "train_loss_epoch"] + ) + lineplt = sns.lineplot(df_loss, x="epoch", y="value", hue="variable") + plt.savefig(os.path.join(logs_path, "f1_plot.png")) plt.show() + # get predictions from model def evaluate_model(logs_base_path, model_filename, data_module): - model = electra.ElectraPre.load_from_checkpoint(os.path.join(logs_base_path, 'best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt', model_filename)) + model = electra.ElectraPre.load_from_checkpoint( + os.path.join( + logs_base_path, + "best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt", + model_filename, + ) + ) assert isinstance(model, electra.ElectraPre) collate = data_module.reader.COLLATER() - test_file = 'test.pt' + test_file = "test.pt" data_path = os.path.join(data_module.processed_dir, test_file) data_list = torch.load(data_path) preds_list = [] @@ -29,8 +38,10 @@ def evaluate_model(logs_base_path, model_filename, data_module): for row in tqdm.tqdm(data_list): processable_data = model._process_batch(collate([row]), 0) - model_output = model(processable_data, **processable_data['model_kwargs']) - preds, labels = model._get_prediction_and_labels(processable_data, processable_data["labels"], model_output) + model_output = model(processable_data, **processable_data["model_kwargs"]) + preds, labels = model._get_prediction_and_labels( + processable_data, processable_data["labels"], model_output + ) preds_list.append(preds) labels_list.append(labels) @@ -39,16 +50,14 @@ def evaluate_model(logs_base_path, model_filename, data_module): print(test_preds.shape) print(test_labels.shape) test_loss = ElectraPreLoss() - print(f'Loss on test set: {test_loss(test_preds, test_labels)}') - #f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro') - #f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro') - #print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}') - #print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}') + print(f"Loss on test set: {test_loss(test_preds, test_labels)}") + # f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro') + # f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro') + # print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}') + # print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}') -class PretrainingResultProcessor(ResultProcessor): +class PretrainingResultProcessor(ResultProcessor): @classmethod def _identifier(cls) -> str: - return 'PretrainingResultProcessor' - - + return "PretrainingResultProcessor" diff --git a/chebai/trainer/InnerCVTrainer.py b/chebai/trainer/InnerCVTrainer.py index 0f2ef2e5..2fe2ffb2 100644 --- a/chebai/trainer/InnerCVTrainer.py +++ b/chebai/trainer/InnerCVTrainer.py @@ -26,7 +26,6 @@ class InnerCVTrainer(Trainer): - def __init__(self, *args, **kwargs): self.init_args = args self.init_kwargs = kwargs @@ -44,7 +43,11 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar kfold = MultilabelStratifiedKFold(n_splits=n_splits) for fold, (train_ids, val_ids) in enumerate( - kfold.split(datamodule.train_val_data, [data['labels'] for data in datamodule.train_val_data])): + kfold.split( + datamodule.train_val_data, + [data["labels"] for data in datamodule.train_val_data], + ) + ): train_dataloader = datamodule.train_dataloader(ids=train_ids) val_dataloader = datamodule.val_dataloader(ids=val_ids) init_kwargs = self.init_kwargs @@ -52,34 +55,49 @@ def cv_fit(self, datamodule: XYBaseDataModule, n_splits: int = -1, *args, **kwar logger = new_trainer.logger if isinstance(logger, CustomLogger): logger.set_fold(fold) - print(f'Logging this fold at {logger.experiment.dir}') + print(f"Logging this fold at {logger.experiment.dir}") else: - rank_zero_warn(f"Using k-fold cross-validation without an adapted logger class") - new_trainer.fit(train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, *args, **kwargs) - - def predict_from_file(self, model: LightningModule, checkpoint_path: _PATH, input_path: _PATH, - save_to: _PATH='predictions.csv', classes_path: Optional[_PATH] = None): - loaded_model= model.__class__.load_from_checkpoint(checkpoint_path) - with open(input_path, 'r') as input: + rank_zero_warn( + f"Using k-fold cross-validation without an adapted logger class" + ) + new_trainer.fit( + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + *args, + **kwargs, + ) + + def predict_from_file( + self, + model: LightningModule, + checkpoint_path: _PATH, + input_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ): + loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) + with open(input_path, "r") as input: smiles_strings = [inp.strip() for inp in input.readlines()] loaded_model.eval() predictions = self._predict_smiles(loaded_model, smiles_strings) predictions_df = pd.DataFrame(predictions.detach().numpy()) if classes_path is not None: - with open(classes_path, 'r') as f: + with open(classes_path, "r") as f: predictions_df.columns = [cls.strip() for cls in f.readlines()] predictions_df.index = smiles_strings predictions_df.to_csv(save_to) - def _predict_smiles(self, model: LightningModule, smiles: List[str]): reader = ChemDataReader() parsed_smiles = [reader._read_data(s) for s in smiles] x = pad_sequence([torch.tensor(a) for a in parsed_smiles], batch_first=True) - cls_tokens = (torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) * CLS_TOKEN) + cls_tokens = ( + torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) + * CLS_TOKEN + ) features = torch.cat((cls_tokens, x), dim=1) - model_output = model({'features': features}) - preds = torch.sigmoid(model_output['logits']) + model_output = model({"features": features}) + preds = torch.sigmoid(model_output["logits"]) print(preds.shape) return preds diff --git a/demo_process_results.ipynb b/demo_process_results.ipynb index 8d1637a6..ee0c1ec9 100644 --- a/demo_process_results.ipynb +++ b/demo_process_results.ipynb @@ -54,41 +54,43 @@ "evalue": "[enforce fail at alloc_cpu.cpp:80] data. DefaultCPUAllocator: not enough memory: you tried to allocate 288800 bytes.", "output_type": "error", "traceback": [ - "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[1;32mIn[2], line 4\u001B[0m\n\u001B[0;32m 2\u001B[0m checkpoint_name \u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mbest_epoch=88_val_loss=0.7713_val_micro-f1=0.00.ckpt\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[0;32m 3\u001B[0m eval_pre\u001B[38;5;241m.\u001B[39mvisualise_loss(logs_path)\n\u001B[1;32m----> 4\u001B[0m \u001B[43meval_pre\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mevaluate_model\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlogs_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcheckpoint_name\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\Desktop\\chebai\\chebai\\result\\pretraining.py:36\u001B[0m, in \u001B[0;36mevaluate_model\u001B[1;34m(logs_base_path, model_filename)\u001B[0m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m row \u001B[38;5;129;01min\u001B[39;00m tqdm\u001B[38;5;241m.\u001B[39mtqdm(data_list):\n\u001B[0;32m 35\u001B[0m processable_data \u001B[38;5;241m=\u001B[39m model\u001B[38;5;241m.\u001B[39m_process_batch(collate([row]), \u001B[38;5;241m0\u001B[39m)\n\u001B[1;32m---> 36\u001B[0m model_output \u001B[38;5;241m=\u001B[39m \u001B[43mmodel\u001B[49m\u001B[43m(\u001B[49m\u001B[43mprocessable_data\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mprocessable_data\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mmodel_kwargs\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 37\u001B[0m preds, labels \u001B[38;5;241m=\u001B[39m model\u001B[38;5;241m.\u001B[39m_get_prediction_and_labels(processable_data, processable_data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mlabels\u001B[39m\u001B[38;5;124m\"\u001B[39m], model_output)\n\u001B[0;32m 38\u001B[0m preds_list\u001B[38;5;241m.\u001B[39mappend(preds)\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\Desktop\\chebai\\chebai\\models\\electra.py:91\u001B[0m, in \u001B[0;36mElectraPre.forward\u001B[1;34m(self, data, **kwargs)\u001B[0m\n\u001B[0;32m 86\u001B[0m random_tokens \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mrandint(\n\u001B[0;32m 87\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mgenerator_config\u001B[38;5;241m.\u001B[39mvocab_size, (batch_size,), device\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdevice\n\u001B[0;32m 88\u001B[0m )\n\u001B[0;32m 89\u001B[0m replacements \u001B[38;5;241m=\u001B[39m gen_best_guess \u001B[38;5;241m*\u001B[39m \u001B[38;5;241m~\u001B[39mcorrect_mask \u001B[38;5;241m+\u001B[39m random_tokens \u001B[38;5;241m*\u001B[39m correct_mask\n\u001B[1;32m---> 91\u001B[0m disc_out \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdiscriminator\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 92\u001B[0m \u001B[43m \u001B[49m\u001B[43mfeatures\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m \u001B[49m\u001B[38;5;241;43m~\u001B[39;49m\u001B[43mdisc_tar_one_hot\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m+\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mreplacements\u001B[49m\u001B[43m[\u001B[49m\u001B[43m:\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m]\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mdisc_tar_one_hot\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 93\u001B[0m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mmask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 94\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241m.\u001B[39mlogits\n\u001B[0;32m 95\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m (raw_gen_out, disc_out), (gen_tar_one_hot, disc_tar_one_hot)\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:1113\u001B[0m, in \u001B[0;36mElectraForPreTraining.forward\u001B[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001B[0m\n\u001B[0;32m 1078\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 1079\u001B[0m \u001B[38;5;124;03mlabels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\u001B[39;00m\n\u001B[0;32m 1080\u001B[0m \u001B[38;5;124;03m Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 1109\u001B[0m \u001B[38;5;124;03m[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]\u001B[39;00m\n\u001B[0;32m 1110\u001B[0m \u001B[38;5;124;03m```\"\"\"\u001B[39;00m\n\u001B[0;32m 1111\u001B[0m return_dict \u001B[38;5;241m=\u001B[39m return_dict \u001B[38;5;28;01mif\u001B[39;00m return_dict \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;28;01melse\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mconfig\u001B[38;5;241m.\u001B[39muse_return_dict\n\u001B[1;32m-> 1113\u001B[0m discriminator_hidden_states \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43melectra\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 1114\u001B[0m \u001B[43m \u001B[49m\u001B[43minput_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1115\u001B[0m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1116\u001B[0m \u001B[43m \u001B[49m\u001B[43mtoken_type_ids\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtoken_type_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1117\u001B[0m \u001B[43m \u001B[49m\u001B[43mposition_ids\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mposition_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1118\u001B[0m \u001B[43m \u001B[49m\u001B[43mhead_mask\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mhead_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1119\u001B[0m \u001B[43m \u001B[49m\u001B[43minputs_embeds\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs_embeds\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1120\u001B[0m \u001B[43m \u001B[49m\u001B[43moutput_attentions\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43moutput_attentions\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1121\u001B[0m \u001B[43m \u001B[49m\u001B[43moutput_hidden_states\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43moutput_hidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1122\u001B[0m \u001B[43m \u001B[49m\u001B[43mreturn_dict\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mreturn_dict\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 1123\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1124\u001B[0m discriminator_sequence_output \u001B[38;5;241m=\u001B[39m discriminator_hidden_states[\u001B[38;5;241m0\u001B[39m]\n\u001B[0;32m 1126\u001B[0m logits \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdiscriminator_predictions(discriminator_sequence_output)\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:911\u001B[0m, in \u001B[0;36mElectraModel.forward\u001B[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001B[0m\n\u001B[0;32m 908\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mhasattr\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124membeddings_project\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[0;32m 909\u001B[0m hidden_states \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39membeddings_project(hidden_states)\n\u001B[1;32m--> 911\u001B[0m hidden_states \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mencoder\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 912\u001B[0m \u001B[43m \u001B[49m\u001B[43mhidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 913\u001B[0m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mextended_attention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 914\u001B[0m \u001B[43m \u001B[49m\u001B[43mhead_mask\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mhead_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 915\u001B[0m \u001B[43m \u001B[49m\u001B[43mencoder_hidden_states\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mencoder_hidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 916\u001B[0m \u001B[43m \u001B[49m\u001B[43mencoder_attention_mask\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mencoder_extended_attention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 917\u001B[0m \u001B[43m \u001B[49m\u001B[43mpast_key_values\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mpast_key_values\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 918\u001B[0m \u001B[43m \u001B[49m\u001B[43muse_cache\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43muse_cache\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 919\u001B[0m \u001B[43m \u001B[49m\u001B[43moutput_attentions\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43moutput_attentions\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 920\u001B[0m \u001B[43m \u001B[49m\u001B[43moutput_hidden_states\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43moutput_hidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 921\u001B[0m \u001B[43m \u001B[49m\u001B[43mreturn_dict\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mreturn_dict\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 922\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 924\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m hidden_states\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:585\u001B[0m, in \u001B[0;36mElectraEncoder.forward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001B[0m\n\u001B[0;32m 574\u001B[0m layer_outputs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_gradient_checkpointing_func(\n\u001B[0;32m 575\u001B[0m layer_module\u001B[38;5;241m.\u001B[39m\u001B[38;5;21m__call__\u001B[39m,\n\u001B[0;32m 576\u001B[0m hidden_states,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 582\u001B[0m output_attentions,\n\u001B[0;32m 583\u001B[0m )\n\u001B[0;32m 584\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m--> 585\u001B[0m layer_outputs \u001B[38;5;241m=\u001B[39m \u001B[43mlayer_module\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 586\u001B[0m \u001B[43m \u001B[49m\u001B[43mhidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 587\u001B[0m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 588\u001B[0m \u001B[43m \u001B[49m\u001B[43mlayer_head_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 589\u001B[0m \u001B[43m \u001B[49m\u001B[43mencoder_hidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 590\u001B[0m \u001B[43m \u001B[49m\u001B[43mencoder_attention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 591\u001B[0m \u001B[43m \u001B[49m\u001B[43mpast_key_value\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 592\u001B[0m \u001B[43m \u001B[49m\u001B[43moutput_attentions\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 593\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 595\u001B[0m hidden_states \u001B[38;5;241m=\u001B[39m layer_outputs[\u001B[38;5;241m0\u001B[39m]\n\u001B[0;32m 596\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m use_cache:\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:474\u001B[0m, in \u001B[0;36mElectraLayer.forward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001B[0m\n\u001B[0;32m 462\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\n\u001B[0;32m 463\u001B[0m \u001B[38;5;28mself\u001B[39m,\n\u001B[0;32m 464\u001B[0m hidden_states: torch\u001B[38;5;241m.\u001B[39mTensor,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 471\u001B[0m ) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Tuple[torch\u001B[38;5;241m.\u001B[39mTensor]:\n\u001B[0;32m 472\u001B[0m \u001B[38;5;66;03m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001B[39;00m\n\u001B[0;32m 473\u001B[0m self_attn_past_key_value \u001B[38;5;241m=\u001B[39m past_key_value[:\u001B[38;5;241m2\u001B[39m] \u001B[38;5;28;01mif\u001B[39;00m past_key_value \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;28;01melse\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[1;32m--> 474\u001B[0m self_attention_outputs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mattention\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 475\u001B[0m \u001B[43m \u001B[49m\u001B[43mhidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 476\u001B[0m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 477\u001B[0m \u001B[43m \u001B[49m\u001B[43mhead_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 478\u001B[0m \u001B[43m \u001B[49m\u001B[43moutput_attentions\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43moutput_attentions\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 479\u001B[0m \u001B[43m \u001B[49m\u001B[43mpast_key_value\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mself_attn_past_key_value\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 480\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 481\u001B[0m attention_output \u001B[38;5;241m=\u001B[39m self_attention_outputs[\u001B[38;5;241m0\u001B[39m]\n\u001B[0;32m 483\u001B[0m \u001B[38;5;66;03m# if decoder, the last output is tuple of self-attn cache\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:401\u001B[0m, in \u001B[0;36mElectraAttention.forward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001B[0m\n\u001B[0;32m 391\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\n\u001B[0;32m 392\u001B[0m \u001B[38;5;28mself\u001B[39m,\n\u001B[0;32m 393\u001B[0m hidden_states: torch\u001B[38;5;241m.\u001B[39mTensor,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 399\u001B[0m output_attentions: Optional[\u001B[38;5;28mbool\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m,\n\u001B[0;32m 400\u001B[0m ) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Tuple[torch\u001B[38;5;241m.\u001B[39mTensor]:\n\u001B[1;32m--> 401\u001B[0m self_outputs \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mself\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 402\u001B[0m \u001B[43m \u001B[49m\u001B[43mhidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 403\u001B[0m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 404\u001B[0m \u001B[43m \u001B[49m\u001B[43mhead_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 405\u001B[0m \u001B[43m \u001B[49m\u001B[43mencoder_hidden_states\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 406\u001B[0m \u001B[43m \u001B[49m\u001B[43mencoder_attention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 407\u001B[0m \u001B[43m \u001B[49m\u001B[43mpast_key_value\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 408\u001B[0m \u001B[43m \u001B[49m\u001B[43moutput_attentions\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 409\u001B[0m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 410\u001B[0m attention_output \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39moutput(self_outputs[\u001B[38;5;241m0\u001B[39m], hidden_states)\n\u001B[0;32m 411\u001B[0m outputs \u001B[38;5;241m=\u001B[39m (attention_output,) \u001B[38;5;241m+\u001B[39m self_outputs[\u001B[38;5;241m1\u001B[39m:] \u001B[38;5;66;03m# add attentions if we output them\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1516\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1517\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1518\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_impl\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1522\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1523\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1524\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1525\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1526\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1527\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mforward_call\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 1529\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 1530\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n", - "File \u001B[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:321\u001B[0m, in \u001B[0;36mElectraSelfAttention.forward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001B[0m\n\u001B[0;32m 318\u001B[0m relative_position_scores_key \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39meinsum(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mbhrd,lrd->bhlr\u001B[39m\u001B[38;5;124m\"\u001B[39m, key_layer, positional_embedding)\n\u001B[0;32m 319\u001B[0m attention_scores \u001B[38;5;241m=\u001B[39m attention_scores \u001B[38;5;241m+\u001B[39m relative_position_scores_query \u001B[38;5;241m+\u001B[39m relative_position_scores_key\n\u001B[1;32m--> 321\u001B[0m attention_scores \u001B[38;5;241m=\u001B[39m \u001B[43mattention_scores\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m/\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mmath\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msqrt\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mattention_head_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 322\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m attention_mask \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 323\u001B[0m \u001B[38;5;66;03m# Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)\u001B[39;00m\n\u001B[0;32m 324\u001B[0m attention_scores \u001B[38;5;241m=\u001B[39m attention_scores \u001B[38;5;241m+\u001B[39m attention_mask\n", - "\u001B[1;31mRuntimeError\u001B[0m: [enforce fail at alloc_cpu.cpp:80] data. DefaultCPUAllocator: not enough memory: you tried to allocate 288800 bytes." + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[2], line 4\u001b[0m\n\u001b[0;32m 2\u001b[0m checkpoint_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbest_epoch=88_val_loss=0.7713_val_micro-f1=0.00.ckpt\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 3\u001b[0m eval_pre\u001b[38;5;241m.\u001b[39mvisualise_loss(logs_path)\n\u001b[1;32m----> 4\u001b[0m \u001b[43meval_pre\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlogs_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheckpoint_name\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\Desktop\\chebai\\chebai\\result\\pretraining.py:36\u001b[0m, in \u001b[0;36mevaluate_model\u001b[1;34m(logs_base_path, model_filename)\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m tqdm\u001b[38;5;241m.\u001b[39mtqdm(data_list):\n\u001b[0;32m 35\u001b[0m processable_data \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39m_process_batch(collate([row]), \u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m---> 36\u001b[0m model_output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocessable_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprocessable_data\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmodel_kwargs\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 37\u001b[0m preds, labels \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39m_get_prediction_and_labels(processable_data, processable_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m], model_output)\n\u001b[0;32m 38\u001b[0m preds_list\u001b[38;5;241m.\u001b[39mappend(preds)\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\Desktop\\chebai\\chebai\\models\\electra.py:91\u001b[0m, in \u001b[0;36mElectraPre.forward\u001b[1;34m(self, data, **kwargs)\u001b[0m\n\u001b[0;32m 86\u001b[0m random_tokens \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandint(\n\u001b[0;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgenerator_config\u001b[38;5;241m.\u001b[39mvocab_size, (batch_size,), device\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice\n\u001b[0;32m 88\u001b[0m )\n\u001b[0;32m 89\u001b[0m replacements \u001b[38;5;241m=\u001b[39m gen_best_guess \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m~\u001b[39mcorrect_mask \u001b[38;5;241m+\u001b[39m random_tokens \u001b[38;5;241m*\u001b[39m correct_mask\n\u001b[1;32m---> 91\u001b[0m disc_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdiscriminator\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 92\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m~\u001b[39;49m\u001b[43mdisc_tar_one_hot\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mreplacements\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdisc_tar_one_hot\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 93\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 94\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlogits\n\u001b[0;32m 95\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (raw_gen_out, disc_out), (gen_tar_one_hot, disc_tar_one_hot)\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:1113\u001b[0m, in \u001b[0;36mElectraForPreTraining.forward\u001b[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 1078\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 1079\u001b[0m \u001b[38;5;124;03mlabels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\u001b[39;00m\n\u001b[0;32m 1080\u001b[0m \u001b[38;5;124;03m Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1109\u001b[0m \u001b[38;5;124;03m[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]\u001b[39;00m\n\u001b[0;32m 1110\u001b[0m \u001b[38;5;124;03m```\"\"\"\u001b[39;00m\n\u001b[0;32m 1111\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m-> 1113\u001b[0m discriminator_hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43melectra\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1114\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1115\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1116\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken_type_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1117\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1118\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1119\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1120\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1121\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1122\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1123\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1124\u001b[0m discriminator_sequence_output \u001b[38;5;241m=\u001b[39m discriminator_hidden_states[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 1126\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdiscriminator_predictions(discriminator_sequence_output)\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:911\u001b[0m, in \u001b[0;36mElectraModel.forward\u001b[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 908\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124membeddings_project\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m 909\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings_project(hidden_states)\n\u001b[1;32m--> 911\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 912\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 913\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 914\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 915\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 916\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_extended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 917\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 918\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 919\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 920\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 921\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 922\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 924\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m hidden_states\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:585\u001b[0m, in \u001b[0;36mElectraEncoder.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[0;32m 574\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[0;32m 575\u001b[0m layer_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[0;32m 576\u001b[0m hidden_states,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 582\u001b[0m output_attentions,\n\u001b[0;32m 583\u001b[0m )\n\u001b[0;32m 584\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 585\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 586\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 587\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 588\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 589\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 590\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 591\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 592\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 593\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 595\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 596\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:474\u001b[0m, in \u001b[0;36mElectraLayer.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m 462\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[0;32m 463\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 464\u001b[0m hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 471\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[torch\u001b[38;5;241m.\u001b[39mTensor]:\n\u001b[0;32m 472\u001b[0m \u001b[38;5;66;03m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001b[39;00m\n\u001b[0;32m 473\u001b[0m self_attn_past_key_value \u001b[38;5;241m=\u001b[39m past_key_value[:\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m--> 474\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 475\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 476\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 477\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 478\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 479\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_past_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 480\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 481\u001b[0m attention_output \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 483\u001b[0m \u001b[38;5;66;03m# if decoder, the last output is tuple of self-attn cache\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:401\u001b[0m, in \u001b[0;36mElectraAttention.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m 391\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[0;32m 392\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 393\u001b[0m hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 399\u001b[0m output_attentions: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m 400\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[torch\u001b[38;5;241m.\u001b[39mTensor]:\n\u001b[1;32m--> 401\u001b[0m self_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 406\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 407\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 408\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 409\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 410\u001b[0m attention_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput(self_outputs[\u001b[38;5;241m0\u001b[39m], hidden_states)\n\u001b[0;32m 411\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (attention_output,) \u001b[38;5;241m+\u001b[39m self_outputs[\u001b[38;5;241m1\u001b[39m:] \u001b[38;5;66;03m# add attentions if we output them\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[1;32m~\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\transformers\\models\\electra\\modeling_electra.py:321\u001b[0m, in \u001b[0;36mElectraSelfAttention.forward\u001b[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[0;32m 318\u001b[0m relative_position_scores_key \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39meinsum(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbhrd,lrd->bhlr\u001b[39m\u001b[38;5;124m\"\u001b[39m, key_layer, positional_embedding)\n\u001b[0;32m 319\u001b[0m attention_scores \u001b[38;5;241m=\u001b[39m attention_scores \u001b[38;5;241m+\u001b[39m relative_position_scores_query \u001b[38;5;241m+\u001b[39m relative_position_scores_key\n\u001b[1;32m--> 321\u001b[0m attention_scores \u001b[38;5;241m=\u001b[39m \u001b[43mattention_scores\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msqrt\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_head_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 322\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 323\u001b[0m \u001b[38;5;66;03m# Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)\u001b[39;00m\n\u001b[0;32m 324\u001b[0m attention_scores \u001b[38;5;241m=\u001b[39m attention_scores \u001b[38;5;241m+\u001b[39m attention_mask\n", + "\u001b[1;31mRuntimeError\u001b[0m: [enforce fail at alloc_cpu.cpp:80] data. DefaultCPUAllocator: not enough memory: you tried to allocate 288800 bytes." ] } ], "source": [ - "logs_path = os.path.join('logs_server', 'pubchem_pretraining', 'version_6')\n", - "checkpoint_name = 'best_epoch=88_val_loss=0.7713_val_micro-f1=0.00.ckpt'\n", + "logs_path = os.path.join(\"logs_server\", \"pubchem_pretraining\", \"version_6\")\n", + "checkpoint_name = \"best_epoch=88_val_loss=0.7713_val_micro-f1=0.00.ckpt\"\n", "eval_pre.visualise_loss(logs_path)\n", - "eval_pre.evaluate_model(logs_path, checkpoint_name, PubChemDeepSMILES(chebi_version=227))\n", - "#todo: run on server" + "eval_pre.evaluate_model(\n", + " logs_path, checkpoint_name, PubChemDeepSMILES(chebi_version=227)\n", + ")\n", + "# todo: run on server" ], "metadata": { "collapsed": false, @@ -211,11 +213,13 @@ } ], "source": [ - "logs_path = os.path.join('logs_server', 'chebi100_bce_unweighted_deepsmiles', 'version_12')\n", - "#classification.visualise_f1(logs_path)\n", - "df = pd.read_csv(os.path.join(logs_path, 'metrics.csv'))\n", - "df2 = df[~df['val_micro-f1'].isna()]\n", - "df2 = df2[['val_micro-f1', 'epoch']]\n", + "logs_path = os.path.join(\n", + " \"logs_server\", \"chebi100_bce_unweighted_deepsmiles\", \"version_12\"\n", + ")\n", + "# classification.visualise_f1(logs_path)\n", + "df = pd.read_csv(os.path.join(logs_path, \"metrics.csv\"))\n", + "df2 = df[~df[\"val_micro-f1\"].isna()]\n", + "df2 = df2[[\"val_micro-f1\", \"epoch\"]]\n", "print(df2.to_string())" ], "metadata": { @@ -244,13 +248,13 @@ "# check if pretraining datasets overlap\n", "dm = PubChemDeepSMILES()\n", "processed_path = dm.processed_dir\n", - "test_set = torch.load(os.path.join(processed_path, 'test.pt'))\n", - "val_set = torch.load(os.path.join(processed_path, 'validation.pt'))\n", - "train_set = torch.load(os.path.join(processed_path, 'train.pt'))\n", + "test_set = torch.load(os.path.join(processed_path, \"test.pt\"))\n", + "val_set = torch.load(os.path.join(processed_path, \"validation.pt\"))\n", + "train_set = torch.load(os.path.join(processed_path, \"train.pt\"))\n", "print(processed_path)\n", - "test_smiles = [entry['features'] for entry in test_set]\n", - "val_smiles = [entry['features'] for entry in val_set]\n", - "train_smiles = [entry['features'] for entry in train_set]\n", + "test_smiles = [entry[\"features\"] for entry in test_set]\n", + "val_smiles = [entry[\"features\"] for entry in val_set]\n", + "train_smiles = [entry[\"features\"] for entry in train_set]\n", "train_smiles.append(val_smiles[0])\n", "val_smiles_in_test = [smiles for smiles in val_smiles if smiles in test_smiles]\n", "train_smiles_in_val = [smiles for smiles in train_smiles if smiles in val_smiles]\n", @@ -272,16 +276,15 @@ "source": [ "# not used\n", "class CustomResultsProcessor(ResultProcessor):\n", - "\n", " @classmethod\n", " def _identifier(cls) -> str:\n", - " return 'custom_results_processor'\n", + " return \"custom_results_processor\"\n", "\n", " def process_prediction(self, proc_id, features, labels, pred, ident):\n", - " print(f'id: {proc_id}')\n", - " print(f'features: {features}')\n", - " print(f'labels: {labels}')\n", - " print(f'pred: {pred}')" + " print(f\"id: {proc_id}\")\n", + " print(f\"features: {features}\")\n", + " print(f\"labels: {labels}\")\n", + " print(f\"pred: {pred}\")" ], "metadata": { "collapsed": false, @@ -295,22 +298,32 @@ "execution_count": 4, "outputs": [], "source": [ - "model_path_v148 = os.path.join('logs', 'chebi100_bce_unweighted', 'version_6', 'checkpoints',\n", - " 'per_epoch=99_val_loss=0.0252_val_micro-f1=0.89.ckpt')\n", - "model_path_v227 = os.path.join('logs', 'chebi100_bce_unweighted', 'version_8', 'checkpoints',\n", - " 'per_epoch=99_val_loss=0.0167_val_micro-f1=0.91.ckpt')\n", - "model_path_v200 = 'electra_c100_bce_unweighted.ckpt'\n", - "model_v148 = Electra.load_from_checkpoint(model_path_v148).to('cpu')\n", - "model_v200 = Electra.load_from_checkpoint(model_path_v200).to('cpu')\n", - "model_v227 = Electra.load_from_checkpoint(model_path_v227).to('cpu')\n", + "model_path_v148 = os.path.join(\n", + " \"logs\",\n", + " \"chebi100_bce_unweighted\",\n", + " \"version_6\",\n", + " \"checkpoints\",\n", + " \"per_epoch=99_val_loss=0.0252_val_micro-f1=0.89.ckpt\",\n", + ")\n", + "model_path_v227 = os.path.join(\n", + " \"logs\",\n", + " \"chebi100_bce_unweighted\",\n", + " \"version_8\",\n", + " \"checkpoints\",\n", + " \"per_epoch=99_val_loss=0.0167_val_micro-f1=0.91.ckpt\",\n", + ")\n", + "model_path_v200 = \"electra_c100_bce_unweighted.ckpt\"\n", + "model_v148 = Electra.load_from_checkpoint(model_path_v148).to(\"cpu\")\n", + "model_v200 = Electra.load_from_checkpoint(model_path_v200).to(\"cpu\")\n", + "model_v227 = Electra.load_from_checkpoint(model_path_v227).to(\"cpu\")\n", "\n", "data_module_v200 = ChEBIOver100()\n", "data_module_v148 = ChEBIOver100(chebi_version_train=148)\n", "data_module_v227 = ChEBIOver100(chebi_version_train=227)\n", - "#dataset = torch.load(data_path)\n", - "#processors = [CustomResultsProcessor()]\n", - "#factory = ResultFactory(model, data_module, processors)\n", - "#factory.execute(data_path)" + "# dataset = torch.load(data_path)\n", + "# processors = [CustomResultsProcessor()]\n", + "# factory = ResultFactory(model, data_module, processors)\n", + "# factory.execute(data_path)" ], "metadata": { "collapsed": false, @@ -324,9 +337,9 @@ "execution_count": 7, "outputs": [], "source": [ - "filename_200 = 'classes.txt'\n", - "filename_148 = f'classes_v148.txt'\n", - "filename_227 = f'classes_v227.txt'\n", + "filename_200 = \"classes.txt\"\n", + "filename_148 = f\"classes_v148.txt\"\n", + "filename_227 = f\"classes_v227.txt\"\n", "with open(os.path.join(data_module_v200.raw_dir, filename_200), \"r\") as file:\n", " v200_classes = file.readlines()\n", "with open(os.path.join(data_module_v148.raw_dir, filename_148), \"r\") as file:\n", @@ -395,8 +408,12 @@ " if v227_class in v200_classes:\n", " common_classes_v227.append(v227_class)\n", "# get filter if a class in v200/v148 is a common class\n", - "common_classes_with_v227_mask_for_v200 = torch.tensor([[c in common_classes_v227 for c in v200_classes]])\n", - "common_classes_with_v200_mask_for_v227 = torch.tensor([[c in common_classes_v227 for c in v227_classes]])" + "common_classes_with_v227_mask_for_v200 = torch.tensor(\n", + " [[c in common_classes_v227 for c in v200_classes]]\n", + ")\n", + "common_classes_with_v200_mask_for_v227 = torch.tensor(\n", + " [[c in common_classes_v227 for c in v227_classes]]\n", + ")" ], "metadata": { "collapsed": false, @@ -418,7 +435,7 @@ } ], "source": [ - "#print(len(common_classes))\n", + "# print(len(common_classes))\n", "print(len(common_classes_v227))" ], "metadata": { @@ -434,22 +451,24 @@ "outputs": [], "source": [ "# (not used)\n", - "#mapping = [-1 if new_class not in orig_classes else orig_classes.index(new_class) for new_class in\n", + "# mapping = [-1 if new_class not in orig_classes else orig_classes.index(new_class) for new_class in\n", "# new_classes]\n", - "#input = torch.tensor(np.random.random([1, 854]))\n", + "# input = torch.tensor(np.random.random([1, 854]))\n", "def _apply_mapping(input, index):\n", " orig_ind = mapping[index]\n", " if orig_ind is not None:\n", " return input[0, orig_ind].item()\n", " return None\n", + "\n", + "\n", "# mapping between model outputs / labels for chebi v200 (with 854 classes) and chebi v148 (with 709 classes)\n", "def apply_mapping(mapping: [], input: torch.Tensor):\n", - " input = input.detach().numpy()\n", - " output = np.array(np.zeros((1, len(mapping))))\n", - " for ind, value in enumerate(input[0]):\n", - " if mapping[ind] is not None:\n", - " output[0, mapping[ind]] = value\n", - " return torch.tensor(output)\n" + " input = input.detach().numpy()\n", + " output = np.array(np.zeros((1, len(mapping))))\n", + " for ind, value in enumerate(input[0]):\n", + " if mapping[ind] is not None:\n", + " output[0, mapping[ind]] = value\n", + " return torch.tensor(output)" ], "metadata": { "collapsed": false, @@ -624,21 +643,28 @@ "outputs": [], "source": [ "# get predictions from model\n", - "def evaluate_model(model: ChebaiBaseNet, data_module: XYBaseDataModule, common_classes_mask = None, test_file=None):\n", + "def evaluate_model(\n", + " model: ChebaiBaseNet,\n", + " data_module: XYBaseDataModule,\n", + " common_classes_mask=None,\n", + " test_file=None,\n", + "):\n", " collate = data_module.reader.COLLATER()\n", " if test_file is None:\n", - " test_file = data_module.processed_file_names_dict['test']\n", + " test_file = data_module.processed_file_names_dict[\"test\"]\n", " data_path = os.path.join(data_module.processed_dir, test_file)\n", " data_list = torch.load(data_path)\n", " preds_list = []\n", " labels_list = []\n", - " #if common_classes_mask is not N\n", + " # if common_classes_mask is not N\n", "\n", " for row in tqdm.tqdm(data_list):\n", " processable_data = model._process_batch(collate([row]), 0)\n", " model_output = model(processable_data)\n", " # TODO: collect both masked and unmasked data if possible to avoid running the model twice\n", - " preds, labels = model._get_prediction_and_labels(processable_data, processable_data[\"labels\"], model_output)\n", + " preds, labels = model._get_prediction_and_labels(\n", + " processable_data, processable_data[\"labels\"], model_output\n", + " )\n", " if common_classes_mask is not None:\n", " preds = preds[common_classes_mask]\n", " labels = labels[common_classes_mask]\n", @@ -652,10 +678,14 @@ " test_labels = torch.cat(labels_list)\n", " print(test_preds.shape)\n", " print(test_labels.shape)\n", - " f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro')\n", - " f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro')\n", - " print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}')\n", - " print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}')" + " f1_macro = MultilabelF1Score(test_preds.shape[1], average=\"macro\")\n", + " f1_micro = MultilabelF1Score(test_preds.shape[1], average=\"micro\")\n", + " print(\n", + " f\"Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}\"\n", + " )\n", + " print(\n", + " f\"Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}\"\n", + " )" ], "metadata": { "collapsed": false, @@ -687,7 +717,7 @@ } ], "source": [ - "evaluate_model(model_v200, data_module_v200, test_file='test_martin_server.pt')" + "evaluate_model(model_v200, data_module_v200, test_file=\"test_martin_server.pt\")" ], "metadata": { "collapsed": false, @@ -899,10 +929,12 @@ ], "source": [ "# visualize results from csv\n", - "df = pd.read_csv(os.path.join('logs_server', 'pubchem_pretraining', 'version_6', 'metrics.csv'))\n", - "df_loss = df.melt(id_vars='epoch', value_vars=['val_loss_epoch', 'train_loss_epoch'])\n", - "#df_macro = df.melt(id_vars='epoch', value_vars=['train_macro-f1', 'val_macro-f1', ])\n", - "#df_micro = df.melt(id_vars='epoch', value_vars=['train_micro-f1', 'val_micro-f1', ])\n", + "df = pd.read_csv(\n", + " os.path.join(\"logs_server\", \"pubchem_pretraining\", \"version_6\", \"metrics.csv\")\n", + ")\n", + "df_loss = df.melt(id_vars=\"epoch\", value_vars=[\"val_loss_epoch\", \"train_loss_epoch\"])\n", + "# df_macro = df.melt(id_vars='epoch', value_vars=['train_macro-f1', 'val_macro-f1', ])\n", + "# df_micro = df.melt(id_vars='epoch', value_vars=['train_micro-f1', 'val_micro-f1', ])\n", "print(df_loss)" ], "metadata": { @@ -930,7 +962,7 @@ } ], "source": [ - "lineplt = sns.lineplot(df_loss, x='epoch', y='value', hue='variable')\n", + "lineplt = sns.lineplot(df_loss, x=\"epoch\", y=\"value\", hue=\"variable\")\n", "plt.show()" ], "metadata": { @@ -958,7 +990,7 @@ } ], "source": [ - "sns.lineplot(df_micro, x='epoch', y='value', hue='variable')\n", + "sns.lineplot(df_micro, x=\"epoch\", y=\"value\", hue=\"variable\")\n", "plt.show()" ], "metadata": { @@ -983,7 +1015,7 @@ ], "source": [ "# values are not correct\n", - "sns.lineplot(df_macro, x='epoch', y='value', hue='variable')\n", + "sns.lineplot(df_macro, x=\"epoch\", y=\"value\", hue=\"variable\")\n", "plt.show()" ], "metadata": { diff --git a/process_results_old_chebi.ipynb b/process_results_old_chebi.ipynb index e3869723..c8af0860 100644 --- a/process_results_old_chebi.ipynb +++ b/process_results_old_chebi.ipynb @@ -40,7 +40,7 @@ "import torch\n", "import tqdm\n", "\n", - "DEVICE = 'cpu'" + "DEVICE = \"cpu\"" ] }, { @@ -58,8 +58,8 @@ }, "outputs": [], "source": [ - "model_path_v200 = os.path.join('models', 'electra_c100_bce_unweighted.ckpt')\n", - "model_path_v148 = os.path.join('models', 'electra_c100_bce_unweighted_v148.ckpt')\n", + "model_path_v200 = os.path.join(\"models\", \"electra_c100_bce_unweighted.ckpt\")\n", + "model_path_v148 = os.path.join(\"models\", \"electra_c100_bce_unweighted_v148.ckpt\")\n", "\n", "model_v200 = Electra.load_from_checkpoint(model_path_v200).to(DEVICE)\n", "model_v148 = Electra.load_from_checkpoint(model_path_v148).to(DEVICE)\n", @@ -83,8 +83,8 @@ }, "outputs": [], "source": [ - "classes_file_v200 = 'classes.txt'\n", - "classes_file_v148 = f'classes_v148.txt'\n", + "classes_file_v200 = \"classes.txt\"\n", + "classes_file_v148 = f\"classes_v148.txt\"\n", "with open(os.path.join(data_module_v200.raw_dir, classes_file_v200), \"r\") as file:\n", " v200_classes = file.readlines()\n", "with open(os.path.join(data_module_v148.raw_dir, classes_file_v148), \"r\") as file:\n", @@ -141,9 +141,9 @@ } ], "source": [ - "print(f'Number of classes in ChEBI_v148: {len(v148_classes)}')\n", - "print(f'Number of classes in ChEBI_v200: {len(v200_classes)}')\n", - "print(f'Number of classes in both versions: {len(common_classes)}')" + "print(f\"Number of classes in ChEBI_v148: {len(v148_classes)}\")\n", + "print(f\"Number of classes in ChEBI_v200: {len(v200_classes)}\")\n", + "print(f\"Number of classes in both versions: {len(common_classes)}\")" ] }, { @@ -157,10 +157,15 @@ }, "outputs": [], "source": [ - "def evaluate_model(model: ChebaiBaseNet, data_module: XYBaseDataModule, common_classes_mask = None, test_file=None):\n", + "def evaluate_model(\n", + " model: ChebaiBaseNet,\n", + " data_module: XYBaseDataModule,\n", + " common_classes_mask=None,\n", + " test_file=None,\n", + "):\n", " collate = data_module.reader.COLLATER()\n", " if test_file is None:\n", - " test_file = data_module.processed_file_names_dict['test']\n", + " test_file = data_module.processed_file_names_dict[\"test\"]\n", " data_path = os.path.join(data_module.processed_dir, test_file)\n", " data_list = torch.load(data_path)\n", " preds_list = []\n", @@ -169,7 +174,9 @@ " for row in tqdm.tqdm(data_list):\n", " processable_data = model._process_batch(collate([row]), 0)\n", " model_output = model(processable_data)\n", - " preds, labels = model._get_prediction_and_labels(processable_data, processable_data[\"labels\"], model_output)\n", + " preds, labels = model._get_prediction_and_labels(\n", + " processable_data, processable_data[\"labels\"], model_output\n", + " )\n", " if common_classes_mask is not None:\n", " preds = preds[common_classes_mask]\n", " labels = labels[common_classes_mask]\n", @@ -183,10 +190,14 @@ " test_labels = torch.cat(labels_list)\n", " print(test_preds.shape)\n", " print(test_labels.shape)\n", - " f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro')\n", - " f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro')\n", - " print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}')\n", - " print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}')" + " f1_macro = MultilabelF1Score(test_preds.shape[1], average=\"macro\")\n", + " f1_micro = MultilabelF1Score(test_preds.shape[1], average=\"micro\")\n", + " print(\n", + " f\"Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}\"\n", + " )\n", + " print(\n", + " f\"Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}\"\n", + " )" ] }, { diff --git a/setup.py b/setup.py index a8582038..44b4f88d 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ "selfies", "lightning", "jsonargparse[signatures]>=4.17.0", - "omegaconf" + "omegaconf", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, )