diff --git a/changelog.md b/changelog.md index eb813e7c4..3f3df11cc 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,7 @@ ### Fixed - Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor +- Support huggingface transformers that do not set `cls_token_id` and `sep_token_id` (we now also look for these tokens in the `special_tokens_map` and `vocab` mappings) ## v0.14.0 (2024-11-14) diff --git a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py index 71889bfdf..71b338702 100644 --- a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py +++ b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py @@ -180,6 +180,16 @@ def __init__( self.max_tokens_per_device = max_tokens_per_device self._mem_per_unit = None self.span_getter = span_getter + self.cls_token_id = self.tokenizer.cls_token_id + self.sep_token_id = self.tokenizer.sep_token_id + if self.cls_token_id is None: + [self.cls_token_id] = self.tokenizer.convert_tokens_to_ids( + [self.tokenizer.special_tokens_map["bos_token"]] + ) + if self.sep_token_id is None: + [self.sep_token_id] = self.tokenizer.convert_tokens_to_ids( + [self.tokenizer.special_tokens_map["eos_token"]] + ) if new_tokens: self.tokenizer.add_tokens(sorted(set(t[1] for t in new_tokens))) @@ -364,11 +374,9 @@ def collate(self, batch): sample_word_lengths, sample_word_tokens, ): - prompt_input_ids = [self.tokenizer.cls_token_id] + prompt_input_ids = [self.cls_token_id] if span_prompt_input_ids: - prompt_input_ids.extend( - [*span_prompt_input_ids, self.tokenizer.sep_token_id] - ) + prompt_input_ids.extend([*span_prompt_input_ids, self.sep_token_id]) windows_offsets = list( range(0, max(len(span_text_input_ids) - overlap, 1), stride) ) @@ -379,9 +387,7 @@ def collate(self, batch): offset : offset + self.window ] window_input_ids = ( - prompt_input_ids - + window_text_input_ids - + [self.tokenizer.sep_token_id] + prompt_input_ids + window_text_input_ids + [self.sep_token_id] ) left_overlap = overlap // 2 if offset > 0 else 0 right_overlap = ( @@ -523,7 +529,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput: # ) word_embeddings = torch.nn.functional.embedding_bag( input=batch["word_indices"], - weight=wordpiece_embeddings.view(-1, wordpiece_embeddings.size(2)), + weight=wordpiece_embeddings.reshape(-1, wordpiece_embeddings.size(2)), offsets=batch["word_offsets"], ) word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding