Skip to content

Commit

Permalink
implement character cnn embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
NISHIMWE Lydia committed Sep 12, 2023
1 parent 361a1ab commit 3727be9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
16 changes: 15 additions & 1 deletion examples/laser/laser_src/character_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def __eq__(self, other) -> bool:


class CharacterIndexer:
def __init__(self) -> None:
def __init__(self, dictionary=None) -> None:
self._mapper = CharacterMapper()
self.dictionary = dictionary

def tokens_to_indices(self, tokens: List[str]) -> List[List[int]]:
return [self._mapper.convert_word_to_char_ids(token) for token in tokens]
Expand All @@ -160,6 +161,18 @@ def as_padded_tensor(self, batch: List[List[str]], as_tensor=True, maxlen=None)
else:
return padded_batch

def word_ids_to_char_ids(self, batch: torch.Tensor, maxlen=None) -> torch.Tensor:
batch_of_words = [ self.dictionary.string(indices).split() for indices in batch ]
if maxlen is None:
maxlen = max(map(len, batch))
batch_indices = [self.tokens_to_indices(words) for words in batch_of_words]
padded_batch = torch.LongTensor([
pad_sequence_to_length(
indices, maxlen,
default_value=self._default_value_for_padding)
for indices in batch_indices
]).to(batch.device)
return padded_batch

class Highway(torch.nn.Module):
"""
Expand Down Expand Up @@ -244,6 +257,7 @@ def __init__(self,
}
}
self.output_dim = output_dim
self.embedding_dim = output_dim
self.requires_grad = requires_grad

self._init_weights()
Expand Down
27 changes: 24 additions & 3 deletions examples/laser/laser_src/laser_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
base_architecture,
)
from fairseq.modules import LayerNorm, TransformerDecoderLayer
from .character_cnn import CharacterCNN
from .character_cnn import CharacterCNN, CharacterIndexer

logger = logging.getLogger(__name__)


@register_model("laser_transformer")
class LaserTransformerModel(FairseqEncoderDecoderModel):
"""Train Transformer for LASER task
Expand Down Expand Up @@ -90,7 +89,7 @@ def load_embed_tokens(dictionary, embed_dim):
return Embedding(num_embeddings, embed_dim, padding_idx)

if args.encoder_character_embeddings:
encoder_embed_tokens = CharacterCNN()
encoder_embed_tokens = CharacterCNN(args.encoder_embed_dim)
else:
encoder_embed_tokens = load_embed_tokens(
task.source_dictionary, args.encoder_embed_dim
Expand Down Expand Up @@ -153,11 +152,32 @@ def __init__(self, sentemb_criterion, *args, **kwargs):
mean=0,
std=namespace.encoder_embed_dim**-0.5,
)
# initialize character indexer
self.character_embeddings = namespace.encoder_character_embeddings
self.indexer = CharacterIndexer(dictionary) if self.character_embeddings else None

def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample["target"]

def forward_embedding(
self, src_tokens, token_embedding: Optional[torch.Tensor] = None
):
if self.character_embeddings:
character_src_tokens = self.indexer.word_ids_to_char_ids(src_tokens)
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(character_src_tokens) if self.character_embeddings else self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
return x, embed

def forward(
self,
src_tokens,
Expand All @@ -167,6 +187,7 @@ def forward(
target_language_id=-1,
dataset_name="",
):

encoder_out = super().forward(src_tokens, src_lengths)

x = encoder_out["encoder_out"][0] # T x B x D
Expand Down
4 changes: 2 additions & 2 deletions fairseq/models/transformer/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_


# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
if module_name == "TransformerEncoderBase":
Expand Down Expand Up @@ -131,7 +130,8 @@ def forward_embedding(
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
pos_embed = self.embed_positions(src_tokens)
x = embed + pos_embed
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
Expand Down

0 comments on commit 3727be9

Please sign in to comment.