From 50676b5202730f20d669fe35f9919d8ac0260466 Mon Sep 17 00:00:00 2001 From: NISHIMWE Lydia Date: Fri, 24 Nov 2023 17:22:58 +0100 Subject: [PATCH] add mean pooling to laser transformer --- examples/laser/laser_src/laser_transformer.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/examples/laser/laser_src/laser_transformer.py b/examples/laser/laser_src/laser_transformer.py index 30caef737c..64b8b83dc2 100644 --- a/examples/laser/laser_src/laser_transformer.py +++ b/examples/laser/laser_src/laser_transformer.py @@ -29,6 +29,8 @@ logger = logging.getLogger(__name__) +LASER_EMBED_DIM = 1024 + @register_model("laser_transformer") class LaserTransformerModel(FairseqEncoderDecoderModel): """Train Transformer for LASER task @@ -66,7 +68,7 @@ def add_args(parser): ) parser.add_argument( "--sentemb-criterion", - choices=["maxpool", "cls"], + choices=["maxpool", "meanpool", "cls"], help="How to build sentence embeddings?", ) parser.add_argument( @@ -126,7 +128,6 @@ def __init__(self, sentemb_criterion, *args, **kwargs): tasks = [ task.split(":")[0] for task in namespace.student_teacher_config.split(",") ] - laser_embed_dim = 1024 # if we have a masking task, then add a linear layer projecting from embed_dim to vocab_size if "mask" in tasks: self.project_vocabulary = nn.Linear( @@ -143,9 +144,9 @@ def __init__(self, sentemb_criterion, *args, **kwargs): self.activation_fn = utils.get_activation_fn("tanh") self.layer_norm = LayerNorm(namespace.encoder_embed_dim) # if the embed_dim is different, then add a linear layer projecting from embed_dim to the Laser one (1024) - elif namespace.encoder_embed_dim != laser_embed_dim: + elif namespace.encoder_embed_dim != LASER_EMBED_DIM: self.output_projection = nn.Linear( - namespace.encoder_embed_dim, laser_embed_dim, bias=False + namespace.encoder_embed_dim, LASER_EMBED_DIM, bias=False ) nn.init.normal_( self.output_projection.weight, @@ -191,6 +192,9 @@ def forward( encoder_out = super().forward(src_tokens, src_lengths) x = encoder_out["encoder_out"][0] # T x B x D + + if self.output_projection: # project to LASER's embed dim + x = self.output_projection(x) # project masked tokens only if performing MLM task if isinstance(masked_tokens, torch.Tensor): @@ -204,14 +208,8 @@ def forward( return [x] # MLM criterion takes first element of list as logits else: # if not MLM task return sentence embedding - - if self.output_projection: - x = self.output_projection(x) - - padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) - if padding_mask.any() and self.sentemb_criterion == "maxpool": - x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) + padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1) if self.sentemb_criterion == "cls": # determine location of 'cls' e.g. due to left-padding 'cls' may be different each sent @@ -219,10 +217,17 @@ def forward( sentemb = x[cls_indices, :] # Build the sentence embedding by max-pooling over the encoder outputs elif self.sentemb_criterion == "maxpool": + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x) sentemb = x.max(dim=0)[0] + # Build the sentence embedding by mean-pooling over the encoder outputs + elif self.sentemb_criterion == "meanpool": + if padding_mask.any(): + x = x.float().masked_fill_(padding_mask, 0.).type_as(x) + sentemb = x.sum(dim=0) / (~padding_mask).sum(dim=0) else: raise Exception( - "Please provide a sentence embedding option from [cls|maxpool]" + "Please provide a sentence embedding option from [cls|maxpool|meanpool]" ) return {"sentemb": sentemb} # B x D