Skip to content

Commit

Permalink
Implement alternate box model
Browse files Browse the repository at this point in the history
  • Loading branch information
MGlauer committed Jan 24, 2024
1 parent c93c6bd commit 9f39602
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
36 changes: 36 additions & 0 deletions chebai/loss/boxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

class BoxLoss(torch.nn.Module):
def __init__(
self, base_loss: torch.nn.Module = None
):
super().__init__()
self.base_loss = base_loss

def forward(self, input, target, **kwargs):
b = input["boxes"]
points = input["embedded_points"]
target = target.float().unsqueeze(-1)
l, lind = torch.min(b, dim=-1)
r, rind = torch.max(b, dim=-1)

widths = r - l

l += 0.1*widths
r -= 0.1 * widths
inside = ((l < points) * (points < r)).float()
closer_to_l_than_to_r = (torch.abs(l - points) < torch.abs(r - points)).float()
fn_per_dim = ((1 - inside) * target)
fp_per_dim = (inside * (1 - target))
diff = torch.abs(fp_per_dim - fn_per_dim)
return self.base_loss(diff * closer_to_l_than_to_r * points, diff * closer_to_l_than_to_r * l) + self.base_loss(
diff * (1 - closer_to_l_than_to_r) * points, diff * (1 - closer_to_l_than_to_r) * r)

def _calculate_implication_loss(self, l, r):
capped_difference = torch.relu(l - r)
return torch.mean(
torch.sum(
(torch.softmax(capped_difference, dim=-1) * capped_difference), dim=-1
),
dim=0,
)
48 changes: 47 additions & 1 deletion chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def filter_dict(d, filter_key):


class ElectraBasedModel(ChebaiBaseNet):
NAME = "ElectraBase"

def _process_batch(self, batch, batch_idx):
model_kwargs = dict()
loss_kwargs = batch.additional_fields["loss_kwargs"]
Expand Down Expand Up @@ -127,7 +129,11 @@ def __init__(
# Remove this property in order to prevent it from being stored as a
# hyper parameter


super().__init__(**kwargs)



if config is None:
config = dict()
if not "num_labels" in config and self.out_dim is not None:
Expand All @@ -150,6 +156,9 @@ def __init__(
else:
self.electra = ElectraModel(config=self.config)

if self.out_dim is None:
self.out_dim = self.electra.config.hidden_size

def _process_for_loss(self, model_output, labels, loss_kwargs):
kwargs_copy = dict(loss_kwargs)
mask = kwargs_copy.pop("target_mask", None)
Expand Down Expand Up @@ -190,7 +199,6 @@ def forward(self, data, **kwargs):
target_mask=data.get("target_mask"),
)


class Electra(ElectraBasedModel):
NAME = "Electra"

Expand Down Expand Up @@ -389,6 +397,44 @@ def forward(self, outputs, targets, **kwargs):

return total_loss

class CrispBoxClassifier(ElectraBasedModel):
NAME = "CripsBox"

def __init__(self, box_dimensions=3, **kwargs):
super().__init__(**kwargs)

self.point_embedding = nn.Linear(self.config.hidden_size, box_dimensions)

self.num_boxes = kwargs["out_dim"]
b = torch.randn((self.num_boxes, box_dimensions, 2))
self.boxes = nn.Parameter(b, requires_grad=True)

def forward(self, x, **kwargs):
d = super().forward(x, **kwargs)
points = self.point_embedding(d["output"]).unsqueeze(1)
b = self.boxes.unsqueeze(0)
l, lind = torch.min(b, dim=-1)
r, rind = torch.max(b, dim=-1)
inside = torch.all((l < points) * (points < r), dim=-1).float()
return dict(
boxes=b,
embedded_points=points,
output=inside,
attentions=d["attentions"],
target_mask=d["target_mask"],
)

def _process_for_loss(self, model_output, labels, loss_kwargs):
kwargs_copy = dict(loss_kwargs)
mask = kwargs_copy.pop("target_mask", None)
if mask is not None:
d = model_output["output"] * mask - 100 * ~mask
else:
d = model_output["output"]
if labels is not None:
labels = labels.float()
model_output["output"] = d
return model_output, labels, kwargs_copy

class ConeElectra(ChebaiBaseNet):
NAME = "ConeElectra"
Expand Down

0 comments on commit 9f39602

Please sign in to comment.