Skip to content

Commit

Permalink
add mean pooling to laser transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
NISHIMWE Lydia committed Nov 24, 2023
1 parent 6db9f27 commit 50676b5
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions examples/laser/laser_src/laser_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

logger = logging.getLogger(__name__)

LASER_EMBED_DIM = 1024

@register_model("laser_transformer")
class LaserTransformerModel(FairseqEncoderDecoderModel):
"""Train Transformer for LASER task
Expand Down Expand Up @@ -66,7 +68,7 @@ def add_args(parser):
)
parser.add_argument(
"--sentemb-criterion",
choices=["maxpool", "cls"],
choices=["maxpool", "meanpool", "cls"],
help="How to build sentence embeddings?",
)
parser.add_argument(
Expand Down Expand Up @@ -126,7 +128,6 @@ def __init__(self, sentemb_criterion, *args, **kwargs):
tasks = [
task.split(":")[0] for task in namespace.student_teacher_config.split(",")
]
laser_embed_dim = 1024
# if we have a masking task, then add a linear layer projecting from embed_dim to vocab_size
if "mask" in tasks:
self.project_vocabulary = nn.Linear(
Expand All @@ -143,9 +144,9 @@ def __init__(self, sentemb_criterion, *args, **kwargs):
self.activation_fn = utils.get_activation_fn("tanh")
self.layer_norm = LayerNorm(namespace.encoder_embed_dim)
# if the embed_dim is different, then add a linear layer projecting from embed_dim to the Laser one (1024)
elif namespace.encoder_embed_dim != laser_embed_dim:
elif namespace.encoder_embed_dim != LASER_EMBED_DIM:
self.output_projection = nn.Linear(
namespace.encoder_embed_dim, laser_embed_dim, bias=False
namespace.encoder_embed_dim, LASER_EMBED_DIM, bias=False
)
nn.init.normal_(
self.output_projection.weight,
Expand Down Expand Up @@ -191,6 +192,9 @@ def forward(
encoder_out = super().forward(src_tokens, src_lengths)

x = encoder_out["encoder_out"][0] # T x B x D

if self.output_projection: # project to LASER's embed dim
x = self.output_projection(x)

# project masked tokens only if performing MLM task
if isinstance(masked_tokens, torch.Tensor):
Expand All @@ -204,25 +208,26 @@ def forward(
return [x] # MLM criterion takes first element of list as logits
else:
# if not MLM task return sentence embedding

if self.output_projection:
x = self.output_projection(x)

padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)

if padding_mask.any() and self.sentemb_criterion == "maxpool":
x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)

if self.sentemb_criterion == "cls":
# determine location of 'cls' e.g. due to left-padding 'cls' may be different each sent
cls_indices = src_tokens.eq(self.dictionary.bos()).t() # T x B
sentemb = x[cls_indices, :]
# Build the sentence embedding by max-pooling over the encoder outputs
elif self.sentemb_criterion == "maxpool":
if padding_mask.any():
x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
sentemb = x.max(dim=0)[0]
# Build the sentence embedding by mean-pooling over the encoder outputs
elif self.sentemb_criterion == "meanpool":
if padding_mask.any():
x = x.float().masked_fill_(padding_mask, 0.).type_as(x)
sentemb = x.sum(dim=0) / (~padding_mask).sum(dim=0)
else:
raise Exception(
"Please provide a sentence embedding option from [cls|maxpool]"
"Please provide a sentence embedding option from [cls|maxpool|meanpool]"
)

return {"sentemb": sentemb} # B x D
Expand Down

0 comments on commit 50676b5

Please sign in to comment.