Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support hf transformers with cls_token_id and sep_token_id set to None #346

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Fixed

- Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor
- Support huggingface transformers that do not set `cls_token_id` and `sep_token_id` (we now also look for these tokens in the `special_tokens_map` and `vocab` mappings)

## v0.14.0 (2024-11-14)

Expand Down
22 changes: 14 additions & 8 deletions edsnlp/pipes/trainable/embeddings/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ def __init__(
self.max_tokens_per_device = max_tokens_per_device
self._mem_per_unit = None
self.span_getter = span_getter
self.cls_token_id = self.tokenizer.cls_token_id
self.sep_token_id = self.tokenizer.sep_token_id
if self.cls_token_id is None:
[self.cls_token_id] = self.tokenizer.convert_tokens_to_ids(
[self.tokenizer.special_tokens_map["bos_token"]]
)
if self.sep_token_id is None:
[self.sep_token_id] = self.tokenizer.convert_tokens_to_ids(
[self.tokenizer.special_tokens_map["eos_token"]]
)

if new_tokens:
self.tokenizer.add_tokens(sorted(set(t[1] for t in new_tokens)))
Expand Down Expand Up @@ -364,11 +374,9 @@ def collate(self, batch):
sample_word_lengths,
sample_word_tokens,
):
prompt_input_ids = [self.tokenizer.cls_token_id]
prompt_input_ids = [self.cls_token_id]
if span_prompt_input_ids:
prompt_input_ids.extend(
[*span_prompt_input_ids, self.tokenizer.sep_token_id]
)
prompt_input_ids.extend([*span_prompt_input_ids, self.sep_token_id])
windows_offsets = list(
range(0, max(len(span_text_input_ids) - overlap, 1), stride)
)
Expand All @@ -379,9 +387,7 @@ def collate(self, batch):
offset : offset + self.window
]
window_input_ids = (
prompt_input_ids
+ window_text_input_ids
+ [self.tokenizer.sep_token_id]
prompt_input_ids + window_text_input_ids + [self.sep_token_id]
)
left_overlap = overlap // 2 if offset > 0 else 0
right_overlap = (
Expand Down Expand Up @@ -523,7 +529,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput:
# )
word_embeddings = torch.nn.functional.embedding_bag(
input=batch["word_indices"],
weight=wordpiece_embeddings.view(-1, wordpiece_embeddings.size(2)),
weight=wordpiece_embeddings.reshape(-1, wordpiece_embeddings.size(2)),
offsets=batch["word_offsets"],
)
word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding
Expand Down
Loading