Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
omukazu committed Jun 26, 2024
1 parent be64207 commit 758dd42
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
65 changes: 32 additions & 33 deletions src/kwja/modules/components/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def __init__(

self.is_finished: List[bool] = [False] * len(batch_surfs)

def __call__(self, batch_prev_input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
def __call__(self, batch_prev_input_ids: torch.Tensor, batch_logits: torch.Tensor) -> torch.Tensor:
# Falseならば-inf
mask = self.get_mask(batch_prev_input_ids, logits)
masked_logits = mask_logits(logits, mask, mask_value=-float("inf"))
return masked_logits
mask = self.get_mask(batch_prev_input_ids, batch_logits)
batch_masked_logits = mask_logits(batch_logits, mask, mask_value=-float("inf"))
return batch_masked_logits

def get_mask(
self,
Expand All @@ -99,32 +99,22 @@ def get_mask(
mask[:, self.vocab[SURF_TOKEN]] = True
return mask

_, vocab_size = batch_logits.size()

batch_masks = []
for i, prev_input_ids in enumerate(batch_prev_input_ids.tolist()):
for i, (prev_input_ids, logits) in enumerate(zip(batch_prev_input_ids.tolist(), batch_logits)):
batch_idx = i // self.num_beams
if self.is_finished[batch_idx]:
batch_masks.append(torch.ones(vocab_size, device=batch_logits.device, dtype=torch.bool))
batch_masks.append(torch.zeros_like(logits, dtype=torch.bool))
continue

target_property: TargetProperty = self._get_target_property(prev_input_ids)
if target_property.surf is True:
surf_mask = torch.zeros(vocab_size, device=batch_logits.device, dtype=torch.bool)
self._set_surf_mask(surf_mask, prev_input_ids, batch_idx)
batch_masks.append(surf_mask)
batch_masks.append(self._get_surf_mask(prev_input_ids, logits, batch_idx))
elif target_property.reading is True:
reading_mask = torch.zeros(vocab_size, device=batch_logits.device, dtype=torch.bool)
self._set_reading_mask(reading_mask, prev_input_ids)
batch_masks.append(reading_mask)
batch_masks.append(self._get_reading_mask(prev_input_ids, logits))
elif target_property.lemma is True:
lemma_mask = torch.ones(vocab_size, device=batch_logits.device, dtype=torch.bool)
self._set_lemma_mask(lemma_mask, prev_input_ids)
batch_masks.append(lemma_mask)
batch_masks.append(self._get_lemma_mask(prev_input_ids, logits))
elif target_property.canon is True:
canon_mask = torch.ones(vocab_size, device=batch_logits.device, dtype=torch.bool)
self._set_canon_mask(canon_mask, prev_input_ids, batch_idx)
batch_masks.append(canon_mask)
batch_masks.append(self._get_canon_mask(prev_input_ids, logits, batch_idx))

return torch.stack(batch_masks)

Expand All @@ -145,19 +135,24 @@ def _get_target_property(self, prev_input_ids: List[int]) -> TargetProperty:
break
return target_property

def _set_surf_mask(self, mask: torch.Tensor, prev_input_ids: List[int], batch_idx: int) -> None:
def _get_surf_mask(self, prev_input_ids: List[int], logits: torch.Tensor, batch_idx: int) -> torch.Tensor:
surf_mask = torch.zeros_like(logits, dtype=torch.bool)
if ungenerated_surf := self._get_ungenerated_surf(prev_input_ids, self.batch_surfs[batch_idx]):
mask[self._get_permitted_token_ids(ungenerated_surf)] = True
surf_mask[self._get_permitted_token_ids(ungenerated_surf)] = True
else:
mask[self.vocab[READING_TOKEN]] = True
surf_mask[self.vocab[READING_TOKEN]] = True
return surf_mask

def _set_reading_mask(self, mask: torch.Tensor, prev_input_ids: List[int]) -> None:
def _get_reading_mask(self, prev_input_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
reading_mask = torch.zeros_like(logits, dtype=torch.bool)
if prev_input_ids[-1] == self.vocab[READING_TOKEN]:
mask[self.reading_candidate_token_ids] = True
reading_mask[self.reading_candidate_token_ids] = True
else:
mask[self.reading_candidate_token_ids + [self.vocab[LEMMA_TOKEN]]] = True
reading_mask[self.reading_candidate_token_ids + [self.vocab[LEMMA_TOKEN]]] = True
return reading_mask

def _set_lemma_mask(self, mask: torch.Tensor, prev_input_ids: List[int]) -> None:
def _get_lemma_mask(self, prev_input_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
lemma_mask = torch.ones_like(logits, dtype=torch.bool)
prohibited_token_ids = [
self.tokenizer.pad_token_id,
self.tokenizer.eos_token_id,
Expand All @@ -169,9 +164,11 @@ def _set_lemma_mask(self, mask: torch.Tensor, prev_input_ids: List[int]) -> None
]
if prev_input_ids[-1] == self.vocab[LEMMA_TOKEN]:
prohibited_token_ids.append(self.vocab[CANON_TOKEN])
mask[prohibited_token_ids] = False
lemma_mask[prohibited_token_ids] = False
return lemma_mask

def _set_canon_mask(self, mask: torch.Tensor, prev_input_ids: List[int], batch_idx: int) -> None:
def _get_canon_mask(self, prev_input_ids: List[int], logits: torch.Tensor, batch_idx: int) -> torch.Tensor:
canon_mask = torch.ones_like(logits, dtype=torch.bool)
prohibited_token_ids = [
self.tokenizer.pad_token_id,
self.vocab[READING_TOKEN],
Expand All @@ -183,11 +180,13 @@ def _set_canon_mask(self, mask: torch.Tensor, prev_input_ids: List[int], batch_i
prohibited_token_ids += [self.tokenizer.eos_token_id, self.vocab[SURF_TOKEN]]
else:
if prev_input_ids.count(self.vocab[READING_TOKEN]) < len(self.batch_surfs[batch_idx]):
prohibited_token_ids.append(self.tokenizer.eos_token_id)
prohibited_token_ids += [self.tokenizer.eos_token_id, self.vocab[NO_CANON_TOKEN]]
else:
prohibited_token_ids.append(self.vocab[SURF_TOKEN])
self.is_finished[batch_idx] = True
mask[prohibited_token_ids] = False
prohibited_token_ids += [self.vocab[SURF_TOKEN], self.vocab[NO_CANON_TOKEN]]
if logits.argmax().item() == self.tokenizer.eos_token_id:
self.is_finished[batch_idx] = True
canon_mask[prohibited_token_ids] = False
return canon_mask

def _get_ungenerated_surf(self, prev_input_ids: List[int], surfs: List[str]) -> str:
decoded: str = self.tokenizer.decode(prev_input_ids)
Expand Down
8 changes: 4 additions & 4 deletions tests/data/modules/permitted_tokens.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@
"t5": {
"input_tokens": ["<pad>", "<extra_id_0>", "計算", "<extra_id_1>", "▁けい", "さん", "<extra_id_2>", "計算", "<extra_id_3>", "計算"],
"permitted_tokens": [],
"prohibited_tokens": ["<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_5>", "<pad>", "</s>"]
"prohibited_tokens": ["<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<pad>", "</s>"]
},
"mt5": {
"input_tokens": ["<pad>", "<extra_id_0>", "計算", "<extra_id_1>", "けい", "さん", "<extra_id_2>", "計算", "<extra_id_3>", "計算"],
"permitted_tokens": [],
"prohibited_tokens": ["<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_5>", "<pad>", "</s>"]
"prohibited_tokens": ["<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<pad>", "</s>"]
},
"target_property": "canon"
},
Expand Down Expand Up @@ -172,12 +172,12 @@
"t5": {
"input_tokens": ["<pad>", "<extra_id_0>", "計算", "<extra_id_1>", "けい", "さん", "<extra_id_2>", "計算", "<extra_id_3>", "計算", "/", "けい", "さん"],
"permitted_tokens": [],
"prohibited_tokens": []
"prohibited_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<pad>"]
},
"mt5": {
"input_tokens": ["<pad>", "<extra_id_0>", "計算", "<extra_id_1>", "けい", "さん", "<extra_id_2>", "計算", "<extra_id_3>", "計算", "/", "けい", "さん"],
"permitted_tokens": [],
"prohibited_tokens": []
"prohibited_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<pad>"]
},
"target_property": "canon"
}
Expand Down

0 comments on commit 758dd42

Please sign in to comment.