diff --git a/chebai/models/base.py b/chebai/models/base.py index 3a149832..cd6ab347 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -4,7 +4,7 @@ from lightning.pytorch.core.module import LightningModule import torch - +import pickle from chebai.preprocessing.structures import XYData logging.getLogger("pysmiles").setLevel(logging.CRITICAL) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index ed63581e..41e3dc82 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -263,6 +263,16 @@ def gbmf(x, l, r, b=6): c = l + (r - l) / 2 return 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) +def gbmf_adjusted_slope(x, l, r): + l = l.to(torch.float32) + r = r.to(torch.float32) + segment = torch.abs(r - l) + is_close = torch.allclose(l, r, rtol=1e-2, atol=1e-3) + a = torch.abs(r - l) - 1e-3 if is_close else 0.7 * torch.abs(r - l) + c = l + (r - l) / 2 + b = 2 / (1 + torch.exp(-(segment - 5) / 5)) + membership = 1 / (1 + (torch.abs((x - c) / a) ** (2 * b))) + return membership def normal(sigma, mu, x): v = (x - mu) / sigma @@ -273,17 +283,17 @@ class ChebiBoxWithMemberships(ElectraBasedModel): NAME = "ChebiBoxWithMemberships" def __init__( - self, membership_method="normal", dimension_aggregation="lukaziewisz", **kwargs + self, **kwargs ): super().__init__(**kwargs) + + self.membership_method = self.config.membership_method + self.dimension_aggregation = self.config.dimension_aggregation + self.in_dim = self.config.hidden_size self.hidden_dim = self.config.embeddings_to_points_hidden_size self.out_dim = self.config.embeddings_dimensions - self.boxes = nn.Parameter( - 3 - torch.rand((self.config.num_labels, self.out_dim, 2)) * 6 - ) - self.membership_method = membership_method - self.dimension_aggregation = dimension_aggregation + self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2))) self.embeddings_to_points = nn.Sequential( nn.Linear(self.in_dim, self.hidden_dim), @@ -297,6 +307,12 @@ def _prod_agg(self, memberships, dim=-1): def _min_agg(self, memberships, dim=-1): return torch.min(memberships, dim=dim)[0] + def _mean_agg(self, memberships, dim=-1): + return torch.mean(memberships, dim=dim) + + def _sum_agg(self, memberships, dim=-1): + return torch.sum(memberships, dim=dim) + def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): """ This is a version of the Ɓukaziewish-T-norm using a modified softplus instead of max @@ -314,7 +330,7 @@ def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10): ) def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs): - return gbmf(points, left_corners, right_corners) + return gbmf_adjusted_slope(points, left_corners, right_corners) def _forward_normal_membership(self, points, left_corners, right_corners, **kwargs): widths = 0.1 * (right_corners - left_corners) @@ -344,6 +360,10 @@ def forward(self, data, **kwargs): aggregated_memberships = self._soft_lukaziewisz_agg(memberships_per_dim) elif self.dimension_aggregation == "min": aggregated_memberships = self._min_agg(memberships_per_dim) + elif self.dimension_aggregation == "mean": + aggregated_memberships = self._mean_agg(memberships_per_dim) + elif self.dimension_aggregation == "sum": + aggregated_memberships = self._sum_agg(memberships_per_dim) else: raise Exception("Unknown aggregation function:", self.dimension_aggregation) diff --git a/chebai/preprocessing/collect_all.py b/chebai/preprocessing/collect_all.py index f82ce71c..53b01381 100644 --- a/chebai/preprocessing/collect_all.py +++ b/chebai/preprocessing/collect_all.py @@ -191,11 +191,11 @@ def train(train_loader, validation_loader): checkpoint_callback = ModelCheckpoint( dirpath=os.path.join(tb_logger.log_dir, "checkpoints"), filename="{epoch}-{step}-{val_loss:.7f}", - save_top_k=5, save_last=True, verbose=True, monitor="val_loss", mode="min", + every_n_epochs=1 ) trainer = pl.Trainer( logger=tb_logger,