From 3727be9cf9428b074200405eb051d31a5a2cc58f Mon Sep 17 00:00:00 2001 From: NISHIMWE Lydia Date: Tue, 12 Sep 2023 18:00:05 +0200 Subject: [PATCH] implement character cnn embedding --- examples/laser/laser_src/character_cnn.py | 16 ++++++++++- examples/laser/laser_src/laser_transformer.py | 27 ++++++++++++++++--- .../models/transformer/transformer_encoder.py | 4 +-- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/examples/laser/laser_src/character_cnn.py b/examples/laser/laser_src/character_cnn.py index 0631dde1fc..00542fef1c 100644 --- a/examples/laser/laser_src/character_cnn.py +++ b/examples/laser/laser_src/character_cnn.py @@ -136,8 +136,9 @@ def __eq__(self, other) -> bool: class CharacterIndexer: - def __init__(self) -> None: + def __init__(self, dictionary=None) -> None: self._mapper = CharacterMapper() + self.dictionary = dictionary def tokens_to_indices(self, tokens: List[str]) -> List[List[int]]: return [self._mapper.convert_word_to_char_ids(token) for token in tokens] @@ -160,6 +161,18 @@ def as_padded_tensor(self, batch: List[List[str]], as_tensor=True, maxlen=None) else: return padded_batch + def word_ids_to_char_ids(self, batch: torch.Tensor, maxlen=None) -> torch.Tensor: + batch_of_words = [ self.dictionary.string(indices).split() for indices in batch ] + if maxlen is None: + maxlen = max(map(len, batch)) + batch_indices = [self.tokens_to_indices(words) for words in batch_of_words] + padded_batch = torch.LongTensor([ + pad_sequence_to_length( + indices, maxlen, + default_value=self._default_value_for_padding) + for indices in batch_indices + ]).to(batch.device) + return padded_batch class Highway(torch.nn.Module): """ @@ -244,6 +257,7 @@ def __init__(self, } } self.output_dim = output_dim + self.embedding_dim = output_dim self.requires_grad = requires_grad self._init_weights() diff --git a/examples/laser/laser_src/laser_transformer.py b/examples/laser/laser_src/laser_transformer.py index 717d13947a..30caef737c 100644 --- a/examples/laser/laser_src/laser_transformer.py +++ b/examples/laser/laser_src/laser_transformer.py @@ -25,11 +25,10 @@ base_architecture, ) from fairseq.modules import LayerNorm, TransformerDecoderLayer -from .character_cnn import CharacterCNN +from .character_cnn import CharacterCNN, CharacterIndexer logger = logging.getLogger(__name__) - @register_model("laser_transformer") class LaserTransformerModel(FairseqEncoderDecoderModel): """Train Transformer for LASER task @@ -90,7 +89,7 @@ def load_embed_tokens(dictionary, embed_dim): return Embedding(num_embeddings, embed_dim, padding_idx) if args.encoder_character_embeddings: - encoder_embed_tokens = CharacterCNN() + encoder_embed_tokens = CharacterCNN(args.encoder_embed_dim) else: encoder_embed_tokens = load_embed_tokens( task.source_dictionary, args.encoder_embed_dim @@ -153,11 +152,32 @@ def __init__(self, sentemb_criterion, *args, **kwargs): mean=0, std=namespace.encoder_embed_dim**-0.5, ) + # initialize character indexer + self.character_embeddings = namespace.encoder_character_embeddings + self.indexer = CharacterIndexer(dictionary) if self.character_embeddings else None def get_targets(self, sample, net_output): """Get targets from either the sample or the net's output.""" return sample["target"] + def forward_embedding( + self, src_tokens, token_embedding: Optional[torch.Tensor] = None + ): + if self.character_embeddings: + character_src_tokens = self.indexer.word_ids_to_char_ids(src_tokens) + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(character_src_tokens) if self.character_embeddings else self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + if self.quant_noise is not None: + x = self.quant_noise(x) + return x, embed + def forward( self, src_tokens, @@ -167,6 +187,7 @@ def forward( target_language_id=-1, dataset_name="", ): + encoder_out = super().forward(src_tokens, src_lengths) x = encoder_out["encoder_out"][0] # T x B x D diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py index bbd21a592e..18c992213a 100644 --- a/fairseq/models/transformer/transformer_encoder.py +++ b/fairseq/models/transformer/transformer_encoder.py @@ -28,7 +28,6 @@ from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ - # rewrite name for backward compatibility in `make_generation_fast_` def module_name_fordropout(module_name: str) -> str: if module_name == "TransformerEncoderBase": @@ -131,7 +130,8 @@ def forward_embedding( token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding if self.embed_positions is not None: - x = embed + self.embed_positions(src_tokens) + pos_embed = self.embed_positions(src_tokens) + x = embed + pos_embed if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x)