Skip to content

Commit

Permalink
added a call method to seperately handle the tokenization before enco…
Browse files Browse the repository at this point in the history
…dding
  • Loading branch information
CaptainVee committed Sep 18, 2023
1 parent af224c6 commit 2ac3362
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
15 changes: 8 additions & 7 deletions laser_encoders/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def initialize_encoder(
model_dir: str = None,
spm: bool = True,
laser: str = None,
tokenize: bool = None,
tokenize: bool = False,
):
downloader = LaserModelDownloader(model_dir)
if laser is not None:
Expand Down Expand Up @@ -147,17 +147,18 @@ def initialize_encoder(

model_dir = downloader.model_dir
model_path = os.path.join(model_dir, f"{file_path}.pt")
spm_path = os.path.join(model_dir, f"{file_path}.cvocab")
spm_vocab = os.path.join(model_dir, f"{file_path}.cvocab")
spm_model = None
if not os.path.exists(spm_vocab):
# if there is no cvocab for the laser3 lang use laser2 cvocab
spm_vocab = os.path.join(model_dir, "laser2.cvocab")
if tokenize:
spm_model = os.path.join(model_dir, f"{file_path}.spm")
if not os.path.exists(spm_model):
spm_model = os.path.join(model_dir, "laser2.spm")

if not os.path.exists(spm_path):
# if there is no cvocab for the laser3 lang use laser2 cvocab
spm_path = os.path.join(model_dir, "laser2.cvocab")
spm_model = os.path.join(model_dir, "laser2.spm")
return SentenceEncoder(
model_path=model_path, spm_vocab=spm_path, spm_model=spm_model
model_path=model_path, spm_vocab=spm_vocab, spm_model=spm_model
)


Expand Down
12 changes: 8 additions & 4 deletions laser_encoders/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
if verbose:
logger.info(f"loading encoder: {model_path}")
self.spm_model = spm_model
if self.spm_model:
self.tokenizer = LaserTokenizer(spm_model=Path(self.spm_model))

self.use_cuda = torch.cuda.is_available() and not cpu
self.max_sentences = max_sentences
self.max_tokens = max_tokens
Expand Down Expand Up @@ -88,6 +91,11 @@ def __init__(
self.encoder.eval()
self.sort_kind = sort_kind

def __call__(self, sentences):
if self.spm_model:
sentences = self.tokenizer(sentences)
return self.encode_sentences(sentences)

def _process_batch(self, batch):
tokens = batch.tokens
lengths = batch.lengths
Expand Down Expand Up @@ -153,10 +161,6 @@ def batch(tokens, lengths, indices):
yield batch(batch_tokens, batch_lengths, batch_indices)

def encode_sentences(self, sentences):
if self.spm_model:
tokenizer = LaserTokenizer(spm_model=Path(self.spm_model))
sentences = tokenizer(sentences)

indices = []
results = []
for batch, batch_indices in self._make_batches(sentences):
Expand Down

0 comments on commit 2ac3362

Please sign in to comment.