Skip to content

Commit

Permalink
Add penalties for min and max box sizes, and penalty for distances
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Nov 30, 2023
1 parent 6868249 commit 0161e8c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
2 changes: 1 addition & 1 deletion chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
loss_kwargs = dict()
if self.pass_loss_kwargs:
loss_kwargs = loss_kwargs_candidates
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
loss = self.criterion(loss_data, loss_labels, self, **loss_kwargs)
d["loss"] = loss
self.log(
f"{prefix}loss",
Expand Down
40 changes: 30 additions & 10 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,9 @@ class ChebiBox(Electra):

def __init__(self, dimensions=2, hidden_size=4000, **kwargs):
super().__init__(**kwargs)

self.dimensions = dimensions

self.boxes = nn.Parameter(
3 - torch.rand((self.config.num_labels, self.dimensions, 2)) * 6
)
Expand All @@ -493,6 +495,8 @@ def __init__(self, dimensions=2, hidden_size=4000, **kwargs):
nn.Linear(hidden_size, self.dimensions)
)

self.criterion = BoxLoss()

def forward(self, data, **kwargs):
self.batch_size = data["features"].shape[0]
inp = self.electra.embeddings.forward(data["features"])
Expand All @@ -515,23 +519,39 @@ def forward(self, data, **kwargs):
return dict(
boxes=b,
embedded_points=points,
logits=logits,
logits=m,
attentions=electra.attentions,
target_mask=data.get("target_mask"),
)


class BoxLoss(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def __call__(self, outputs, targets):
d = outputs
t = targets
theta = 0.4
loss = ((torch.sqrt(d) * torch.log(1 + torch.exp(d - theta)) * t) + (torch.log(1 + torch.exp(-d)) * (1 - d) * (1 - t)))
scalar_loss = torch.mean(loss)
return scalar_loss
def __call__(self, outputs, targets, model, **kwargs):
boxes = model.boxes
corner_1 = boxes[:, :, 0]
corner_2 = boxes[:, :, 1]
box_sizes_per_dim = torch.abs(corner_1 - corner_2)
box_sizes = box_sizes_per_dim.prod(1)

mask_min_box_size = (box_sizes < 0.2)
min_box_size_penalty = torch.sum(box_sizes[mask_min_box_size])

mask_max_box_size = (box_sizes > 30)
max_box_size_penalty = torch.sum(box_sizes[mask_max_box_size])

theta = 0.004
distance_based_penalty = torch.sum(((torch.sqrt(outputs) * (outputs > theta) * targets) + ((outputs <= theta) * (1 - outputs) * (1 - targets))))

criterion = nn.BCEWithLogitsLoss()
bce_loss = criterion(outputs, targets)

total_loss = bce_loss + min_box_size_penalty
#total_loss = bce_loss + min_box_size_penalty + max_box_size_penalty
#total_loss = bce_loss + min_box_size_penalty + max_box_size_penalty + distance_based_penalty


return total_loss

def softabs(x, eps=0.01):
return (x ** 2 + eps) ** 0.5 - eps ** 0.5
Expand Down

0 comments on commit 0161e8c

Please sign in to comment.