diff --git a/zshot/linker/linker_regen/linker_regen.py b/zshot/linker/linker_regen/linker_regen.py index 2538b28..80a729a 100644 --- a/zshot/linker/linker_regen/linker_regen.py +++ b/zshot/linker/linker_regen/linker_regen.py @@ -64,7 +64,11 @@ def load_tokenizer(self): def restrict_decode_vocab(self, _, prefix_beam): """ Restrict the possibilities of the Beam search to force the text generation """ - return self.trie.postfix(prefix_beam.tolist()) + tokens = self.trie.postfix(prefix_beam.tolist()) + if not tokens: + return [self.tokenizer.eos_token_id] + + return tokens def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: """