From 9f396026271b81bde13ae0608102e7ea16f596a5 Mon Sep 17 00:00:00 2001 From: MGlauer Date: Wed, 24 Jan 2024 14:36:59 +0100 Subject: [PATCH] Implement alternate box model --- chebai/loss/boxes.py | 36 ++++++++++++++++++++++++++++++ chebai/models/electra.py | 48 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 chebai/loss/boxes.py diff --git a/chebai/loss/boxes.py b/chebai/loss/boxes.py new file mode 100644 index 00000000..671267ca --- /dev/null +++ b/chebai/loss/boxes.py @@ -0,0 +1,36 @@ +import torch + +class BoxLoss(torch.nn.Module): + def __init__( + self, base_loss: torch.nn.Module = None + ): + super().__init__() + self.base_loss = base_loss + + def forward(self, input, target, **kwargs): + b = input["boxes"] + points = input["embedded_points"] + target = target.float().unsqueeze(-1) + l, lind = torch.min(b, dim=-1) + r, rind = torch.max(b, dim=-1) + + widths = r - l + + l += 0.1*widths + r -= 0.1 * widths + inside = ((l < points) * (points < r)).float() + closer_to_l_than_to_r = (torch.abs(l - points) < torch.abs(r - points)).float() + fn_per_dim = ((1 - inside) * target) + fp_per_dim = (inside * (1 - target)) + diff = torch.abs(fp_per_dim - fn_per_dim) + return self.base_loss(diff * closer_to_l_than_to_r * points, diff * closer_to_l_than_to_r * l) + self.base_loss( + diff * (1 - closer_to_l_than_to_r) * points, diff * (1 - closer_to_l_than_to_r) * r) + + def _calculate_implication_loss(self, l, r): + capped_difference = torch.relu(l - r) + return torch.mean( + torch.sum( + (torch.softmax(capped_difference, dim=-1) * capped_difference), dim=-1 + ), + dim=0, + ) diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 41e3dc82..767639dd 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -96,6 +96,8 @@ def filter_dict(d, filter_key): class ElectraBasedModel(ChebaiBaseNet): + NAME = "ElectraBase" + def _process_batch(self, batch, batch_idx): model_kwargs = dict() loss_kwargs = batch.additional_fields["loss_kwargs"] @@ -127,7 +129,11 @@ def __init__( # Remove this property in order to prevent it from being stored as a # hyper parameter + super().__init__(**kwargs) + + + if config is None: config = dict() if not "num_labels" in config and self.out_dim is not None: @@ -150,6 +156,9 @@ def __init__( else: self.electra = ElectraModel(config=self.config) + if self.out_dim is None: + self.out_dim = self.electra.config.hidden_size + def _process_for_loss(self, model_output, labels, loss_kwargs): kwargs_copy = dict(loss_kwargs) mask = kwargs_copy.pop("target_mask", None) @@ -190,7 +199,6 @@ def forward(self, data, **kwargs): target_mask=data.get("target_mask"), ) - class Electra(ElectraBasedModel): NAME = "Electra" @@ -389,6 +397,44 @@ def forward(self, outputs, targets, **kwargs): return total_loss +class CrispBoxClassifier(ElectraBasedModel): + NAME = "CripsBox" + + def __init__(self, box_dimensions=3, **kwargs): + super().__init__(**kwargs) + + self.point_embedding = nn.Linear(self.config.hidden_size, box_dimensions) + + self.num_boxes = kwargs["out_dim"] + b = torch.randn((self.num_boxes, box_dimensions, 2)) + self.boxes = nn.Parameter(b, requires_grad=True) + + def forward(self, x, **kwargs): + d = super().forward(x, **kwargs) + points = self.point_embedding(d["output"]).unsqueeze(1) + b = self.boxes.unsqueeze(0) + l, lind = torch.min(b, dim=-1) + r, rind = torch.max(b, dim=-1) + inside = torch.all((l < points) * (points < r), dim=-1).float() + return dict( + boxes=b, + embedded_points=points, + output=inside, + attentions=d["attentions"], + target_mask=d["target_mask"], + ) + + def _process_for_loss(self, model_output, labels, loss_kwargs): + kwargs_copy = dict(loss_kwargs) + mask = kwargs_copy.pop("target_mask", None) + if mask is not None: + d = model_output["output"] * mask - 100 * ~mask + else: + d = model_output["output"] + if labels is not None: + labels = labels.float() + model_output["output"] = d + return model_output, labels, kwargs_copy class ConeElectra(ChebaiBaseNet): NAME = "ConeElectra"