diff --git a/src/kwja/modules/components/logits_processor.py b/src/kwja/modules/components/logits_processor.py index a6aa54fb..5abb9024 100644 --- a/src/kwja/modules/components/logits_processor.py +++ b/src/kwja/modules/components/logits_processor.py @@ -82,11 +82,11 @@ def __init__( self.is_finished: List[bool] = [False] * len(batch_surfs) - def __call__(self, batch_prev_input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: + def __call__(self, batch_prev_input_ids: torch.Tensor, batch_logits: torch.Tensor) -> torch.Tensor: # Falseならば-inf - mask = self.get_mask(batch_prev_input_ids, logits) - masked_logits = mask_logits(logits, mask, mask_value=-float("inf")) - return masked_logits + mask = self.get_mask(batch_prev_input_ids, batch_logits) + batch_masked_logits = mask_logits(batch_logits, mask, mask_value=-float("inf")) + return batch_masked_logits def get_mask( self, @@ -99,32 +99,22 @@ def get_mask( mask[:, self.vocab[SURF_TOKEN]] = True return mask - _, vocab_size = batch_logits.size() - batch_masks = [] - for i, prev_input_ids in enumerate(batch_prev_input_ids.tolist()): + for i, (prev_input_ids, logits) in enumerate(zip(batch_prev_input_ids.tolist(), batch_logits)): batch_idx = i // self.num_beams if self.is_finished[batch_idx]: - batch_masks.append(torch.ones(vocab_size, device=batch_logits.device, dtype=torch.bool)) + batch_masks.append(torch.zeros_like(logits, dtype=torch.bool)) continue target_property: TargetProperty = self._get_target_property(prev_input_ids) if target_property.surf is True: - surf_mask = torch.zeros(vocab_size, device=batch_logits.device, dtype=torch.bool) - self._set_surf_mask(surf_mask, prev_input_ids, batch_idx) - batch_masks.append(surf_mask) + batch_masks.append(self._get_surf_mask(prev_input_ids, logits, batch_idx)) elif target_property.reading is True: - reading_mask = torch.zeros(vocab_size, device=batch_logits.device, dtype=torch.bool) - self._set_reading_mask(reading_mask, prev_input_ids) - batch_masks.append(reading_mask) + batch_masks.append(self._get_reading_mask(prev_input_ids, logits)) elif target_property.lemma is True: - lemma_mask = torch.ones(vocab_size, device=batch_logits.device, dtype=torch.bool) - self._set_lemma_mask(lemma_mask, prev_input_ids) - batch_masks.append(lemma_mask) + batch_masks.append(self._get_lemma_mask(prev_input_ids, logits)) elif target_property.canon is True: - canon_mask = torch.ones(vocab_size, device=batch_logits.device, dtype=torch.bool) - self._set_canon_mask(canon_mask, prev_input_ids, batch_idx) - batch_masks.append(canon_mask) + batch_masks.append(self._get_canon_mask(prev_input_ids, logits, batch_idx)) return torch.stack(batch_masks) @@ -145,19 +135,24 @@ def _get_target_property(self, prev_input_ids: List[int]) -> TargetProperty: break return target_property - def _set_surf_mask(self, mask: torch.Tensor, prev_input_ids: List[int], batch_idx: int) -> None: + def _get_surf_mask(self, prev_input_ids: List[int], logits: torch.Tensor, batch_idx: int) -> torch.Tensor: + surf_mask = torch.zeros_like(logits, dtype=torch.bool) if ungenerated_surf := self._get_ungenerated_surf(prev_input_ids, self.batch_surfs[batch_idx]): - mask[self._get_permitted_token_ids(ungenerated_surf)] = True + surf_mask[self._get_permitted_token_ids(ungenerated_surf)] = True else: - mask[self.vocab[READING_TOKEN]] = True + surf_mask[self.vocab[READING_TOKEN]] = True + return surf_mask - def _set_reading_mask(self, mask: torch.Tensor, prev_input_ids: List[int]) -> None: + def _get_reading_mask(self, prev_input_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + reading_mask = torch.zeros_like(logits, dtype=torch.bool) if prev_input_ids[-1] == self.vocab[READING_TOKEN]: - mask[self.reading_candidate_token_ids] = True + reading_mask[self.reading_candidate_token_ids] = True else: - mask[self.reading_candidate_token_ids + [self.vocab[LEMMA_TOKEN]]] = True + reading_mask[self.reading_candidate_token_ids + [self.vocab[LEMMA_TOKEN]]] = True + return reading_mask - def _set_lemma_mask(self, mask: torch.Tensor, prev_input_ids: List[int]) -> None: + def _get_lemma_mask(self, prev_input_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + lemma_mask = torch.ones_like(logits, dtype=torch.bool) prohibited_token_ids = [ self.tokenizer.pad_token_id, self.tokenizer.eos_token_id, @@ -169,9 +164,11 @@ def _set_lemma_mask(self, mask: torch.Tensor, prev_input_ids: List[int]) -> None ] if prev_input_ids[-1] == self.vocab[LEMMA_TOKEN]: prohibited_token_ids.append(self.vocab[CANON_TOKEN]) - mask[prohibited_token_ids] = False + lemma_mask[prohibited_token_ids] = False + return lemma_mask - def _set_canon_mask(self, mask: torch.Tensor, prev_input_ids: List[int], batch_idx: int) -> None: + def _get_canon_mask(self, prev_input_ids: List[int], logits: torch.Tensor, batch_idx: int) -> torch.Tensor: + canon_mask = torch.ones_like(logits, dtype=torch.bool) prohibited_token_ids = [ self.tokenizer.pad_token_id, self.vocab[READING_TOKEN], @@ -183,11 +180,13 @@ def _set_canon_mask(self, mask: torch.Tensor, prev_input_ids: List[int], batch_i prohibited_token_ids += [self.tokenizer.eos_token_id, self.vocab[SURF_TOKEN]] else: if prev_input_ids.count(self.vocab[READING_TOKEN]) < len(self.batch_surfs[batch_idx]): - prohibited_token_ids.append(self.tokenizer.eos_token_id) + prohibited_token_ids += [self.tokenizer.eos_token_id, self.vocab[NO_CANON_TOKEN]] else: - prohibited_token_ids.append(self.vocab[SURF_TOKEN]) - self.is_finished[batch_idx] = True - mask[prohibited_token_ids] = False + prohibited_token_ids += [self.vocab[SURF_TOKEN], self.vocab[NO_CANON_TOKEN]] + if logits.argmax().item() == self.tokenizer.eos_token_id: + self.is_finished[batch_idx] = True + canon_mask[prohibited_token_ids] = False + return canon_mask def _get_ungenerated_surf(self, prev_input_ids: List[int], surfs: List[str]) -> str: decoded: str = self.tokenizer.decode(prev_input_ids) diff --git a/tests/data/modules/permitted_tokens.json b/tests/data/modules/permitted_tokens.json index 7bf20cd7..838e4c48 100644 --- a/tests/data/modules/permitted_tokens.json +++ b/tests/data/modules/permitted_tokens.json @@ -120,12 +120,12 @@ "t5": { "input_tokens": ["", "", "計算", "", "▁けい", "さん", "", "計算", "", "計算"], "permitted_tokens": [], - "prohibited_tokens": ["", "", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", "", "計算"], "permitted_tokens": [], - "prohibited_tokens": ["", "", "", "", "", ""] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, "target_property": "canon" }, @@ -172,12 +172,12 @@ "t5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", "", "計算", "/", "けい", "さん"], "permitted_tokens": [], - "prohibited_tokens": [] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, "mt5": { "input_tokens": ["", "", "計算", "", "けい", "さん", "", "計算", "", "計算", "/", "けい", "さん"], "permitted_tokens": [], - "prohibited_tokens": [] + "prohibited_tokens": ["", "", "", "", "", "", ""] }, "target_property": "canon" }