Skip to content

Commit

Permalink
add gbmf with adjusted slope
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Jan 22, 2024
1 parent 4cacf8f commit c93c6bd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
2 changes: 1 addition & 1 deletion chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from lightning.pytorch.core.module import LightningModule
import torch

import pickle
from chebai.preprocessing.structures import XYData

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
Expand Down
34 changes: 27 additions & 7 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,16 @@ def gbmf(x, l, r, b=6):
c = l + (r - l) / 2
return 1 / (1 + (torch.abs((x - c) / a) ** (2 * b)))

def gbmf_adjusted_slope(x, l, r):
l = l.to(torch.float32)
r = r.to(torch.float32)
segment = torch.abs(r - l)
is_close = torch.allclose(l, r, rtol=1e-2, atol=1e-3)
a = torch.abs(r - l) - 1e-3 if is_close else 0.7 * torch.abs(r - l)
c = l + (r - l) / 2
b = 2 / (1 + torch.exp(-(segment - 5) / 5))
membership = 1 / (1 + (torch.abs((x - c) / a) ** (2 * b)))
return membership

def normal(sigma, mu, x):
v = (x - mu) / sigma
Expand All @@ -273,17 +283,17 @@ class ChebiBoxWithMemberships(ElectraBasedModel):
NAME = "ChebiBoxWithMemberships"

def __init__(
self, membership_method="normal", dimension_aggregation="lukaziewisz", **kwargs
self, **kwargs
):
super().__init__(**kwargs)

self.membership_method = self.config.membership_method
self.dimension_aggregation = self.config.dimension_aggregation

self.in_dim = self.config.hidden_size
self.hidden_dim = self.config.embeddings_to_points_hidden_size
self.out_dim = self.config.embeddings_dimensions
self.boxes = nn.Parameter(
3 - torch.rand((self.config.num_labels, self.out_dim, 2)) * 6
)
self.membership_method = membership_method
self.dimension_aggregation = dimension_aggregation
self.boxes = nn.Parameter(torch.rand((self.config.num_labels, self.out_dim, 2)))

self.embeddings_to_points = nn.Sequential(
nn.Linear(self.in_dim, self.hidden_dim),
Expand All @@ -297,6 +307,12 @@ def _prod_agg(self, memberships, dim=-1):
def _min_agg(self, memberships, dim=-1):
return torch.min(memberships, dim=dim)[0]

def _mean_agg(self, memberships, dim=-1):
return torch.mean(memberships, dim=dim)

def _sum_agg(self, memberships, dim=-1):
return torch.sum(memberships, dim=dim)

def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10):
"""
This is a version of the Łukaziewish-T-norm using a modified softplus instead of max
Expand All @@ -314,7 +330,7 @@ def _soft_lukaziewisz_agg(self, memberships, dim=-1, scale=10):
)

def _forward_gbmf_membership(self, points, left_corners, right_corners, **kwargs):
return gbmf(points, left_corners, right_corners)
return gbmf_adjusted_slope(points, left_corners, right_corners)

def _forward_normal_membership(self, points, left_corners, right_corners, **kwargs):
widths = 0.1 * (right_corners - left_corners)
Expand Down Expand Up @@ -344,6 +360,10 @@ def forward(self, data, **kwargs):
aggregated_memberships = self._soft_lukaziewisz_agg(memberships_per_dim)
elif self.dimension_aggregation == "min":
aggregated_memberships = self._min_agg(memberships_per_dim)
elif self.dimension_aggregation == "mean":
aggregated_memberships = self._mean_agg(memberships_per_dim)
elif self.dimension_aggregation == "sum":
aggregated_memberships = self._sum_agg(memberships_per_dim)
else:
raise Exception("Unknown aggregation function:", self.dimension_aggregation)

Expand Down
2 changes: 1 addition & 1 deletion chebai/preprocessing/collect_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def train(train_loader, validation_loader):
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(tb_logger.log_dir, "checkpoints"),
filename="{epoch}-{step}-{val_loss:.7f}",
save_top_k=5,
save_last=True,
verbose=True,
monitor="val_loss",
mode="min",
every_n_epochs=1
)
trainer = pl.Trainer(
logger=tb_logger,
Expand Down

0 comments on commit c93c6bd

Please sign in to comment.