diff --git a/chebai/models/electra.py b/chebai/models/electra.py index 3a2d1ec9..5800acc4 100644 --- a/chebai/models/electra.py +++ b/chebai/models/electra.py @@ -529,11 +529,12 @@ def forward(self, data, **kwargs): center = torch.mean(torch.stack([l, r]), dim=0) width = 0.6 * (r - l) slope = torch.sqrt(torch.abs(r - l)) + slope = torch.tensor(3) membership = 1 / (1 + ((torch.abs(p - center) / width) ** (2 * slope))) m = torch.prod(membership, dim=-1) m2 = torch.mean(membership, dim=-1) - + m3 = torch.min(membership, dim=-1) product_of_membership = 1.0 for i in range(self.batch_size): @@ -547,7 +548,7 @@ def forward(self, data, **kwargs): return dict( boxes=b, embedded_points=points, - logits=logits, + logits=m2, attentions=electra.attentions, target_mask=data.get("target_mask"), )