From 0161e8c57db2417360b753508a9fa1d35eeb7185 Mon Sep 17 00:00:00 2001 From: Adel Memariani Date: Thu, 30 Nov 2023 11:29:30 +0100 Subject: [PATCH] Add penalties for min and max box sizes, and penalty for distances --- chebai/models/base.py | 2 +- chebai/models/electra.py | 40 ++++++++++++++++++++++++++++++---------- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 2dff9684..d4431b25 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -82,7 +82,7 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal loss_kwargs = dict() if self.pass_loss_kwargs: loss_kwargs = loss_kwargs_candidates - loss = self.criterion(loss_data, loss_labels, **loss_kwargs) + loss = self.criterion(loss_data, loss_labels, self, **loss_kwargs) d["loss"] = loss self.log( f"{prefix}loss", diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 48fcf6f3..3f6710c7 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -482,7 +482,9 @@ class ChebiBox(Electra): def __init__(self, dimensions=2, hidden_size=4000, **kwargs): super().__init__(**kwargs) + self.dimensions = dimensions + self.boxes = nn.Parameter( 3 - torch.rand((self.config.num_labels, self.dimensions, 2)) * 6 ) @@ -493,6 +495,8 @@ def __init__(self, dimensions=2, hidden_size=4000, **kwargs): nn.Linear(hidden_size, self.dimensions) ) + self.criterion = BoxLoss() + def forward(self, data, **kwargs): self.batch_size = data["features"].shape[0] inp = self.electra.embeddings.forward(data["features"]) @@ -515,23 +519,39 @@ def forward(self, data, **kwargs): return dict( boxes=b, embedded_points=points, - logits=logits, + logits=m, attentions=electra.attentions, target_mask=data.get("target_mask"), ) - - class BoxLoss(pl.LightningModule): def __init__(self, **kwargs): super().__init__(**kwargs) - def __call__(self, outputs, targets): - d = outputs - t = targets - theta = 0.4 - loss = ((torch.sqrt(d) * torch.log(1 + torch.exp(d - theta)) * t) + (torch.log(1 + torch.exp(-d)) * (1 - d) * (1 - t))) - scalar_loss = torch.mean(loss) - return scalar_loss + def __call__(self, outputs, targets, model, **kwargs): + boxes = model.boxes + corner_1 = boxes[:, :, 0] + corner_2 = boxes[:, :, 1] + box_sizes_per_dim = torch.abs(corner_1 - corner_2) + box_sizes = box_sizes_per_dim.prod(1) + + mask_min_box_size = (box_sizes < 0.2) + min_box_size_penalty = torch.sum(box_sizes[mask_min_box_size]) + + mask_max_box_size = (box_sizes > 30) + max_box_size_penalty = torch.sum(box_sizes[mask_max_box_size]) + + theta = 0.004 + distance_based_penalty = torch.sum(((torch.sqrt(outputs) * (outputs > theta) * targets) + ((outputs <= theta) * (1 - outputs) * (1 - targets)))) + + criterion = nn.BCEWithLogitsLoss() + bce_loss = criterion(outputs, targets) + + total_loss = bce_loss + min_box_size_penalty + #total_loss = bce_loss + min_box_size_penalty + max_box_size_penalty + #total_loss = bce_loss + min_box_size_penalty + max_box_size_penalty + distance_based_penalty + + + return total_loss def softabs(x, eps=0.01): return (x ** 2 + eps) ** 0.5 - eps ** 0.5