Skip to content

Commit

Permalink
use same slope for all boxes
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Jan 12, 2024
1 parent 31bf586 commit 6ab31eb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,12 @@ def forward(self, data, **kwargs):
center = torch.mean(torch.stack([l, r]), dim=0)
width = 0.6 * (r - l)
slope = torch.sqrt(torch.abs(r - l))
slope = torch.tensor(3)

membership = 1 / (1 + ((torch.abs(p - center) / width) ** (2 * slope)))
m = torch.prod(membership, dim=-1)
m2 = torch.mean(membership, dim=-1)

m3 = torch.min(membership, dim=-1)

product_of_membership = 1.0
for i in range(self.batch_size):
Expand All @@ -547,7 +548,7 @@ def forward(self, data, **kwargs):
return dict(
boxes=b,
embedded_points=points,
logits=logits,
logits=m2,
attentions=electra.attentions,
target_mask=data.get("target_mask"),
)
Expand Down

0 comments on commit 6ab31eb

Please sign in to comment.