Skip to content

Commit

Permalink
make word module accept whitespace
Browse files Browse the repository at this point in the history
  • Loading branch information
omukazu committed Apr 13, 2024
1 parent a104b51 commit 010dd8a
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 95 deletions.
2 changes: 1 addition & 1 deletion configs/word_module.debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]"]
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]", " "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
4 changes: 2 additions & 2 deletions configs/word_module.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
configs/word_module.yamldefaults:
- base
- callbacks: [word_module_writer, early_stopping, lr_monitor, model_checkpoint, model_summary, progress_bar]
- datamodule: word
Expand Down Expand Up @@ -66,7 +66,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]"]
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]", " "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
21 changes: 5 additions & 16 deletions src/kwja/datamodule/datasets/word.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Optional, Union

from cohesion_tools.extractors import BridgingExtractor, CoreferenceExtractor, PasExtractor
from cohesion_tools.extractors.base import BaseExtractor
Expand All @@ -25,7 +25,6 @@
NE_TAGS,
POS_TAGS,
RESOURCE_PATH,
SPLIT_INTO_WORDS_MODEL_NAMES,
SUBPOS_TAGS,
WORD_FEATURES,
CohesionTask,
Expand Down Expand Up @@ -78,20 +77,14 @@ def __init__(
) -> None:
super().__init__(tokenizer, max_seq_length)
self.path = Path(path)
if tokenizer.name_or_path in SPLIT_INTO_WORDS_MODEL_NAMES:
self.tokenizer_input_format: Literal["words", "text"] = "words"
else:
self.tokenizer_input_format = "text"
super(BaseDataset, self).__init__(self.path, tokenizer, max_seq_length, document_split_stride)
# some tags are not annotated in editorial articles
self.skip_cohesion_ne_discourse = self.path.parts[-2] == "kyoto_ed"

# ---------- reading prediction ----------
reading_resource_path = RESOURCE_PATH / "reading_prediction"
self.reading2reading_id = get_reading2reading_id(reading_resource_path / "vocab.txt")
self.reading_aligner = ReadingAligner(
self.tokenizer, self.tokenizer_input_format, KanjiDic(str(reading_resource_path / "kanjidic"))
)
self.reading_aligner = ReadingAligner(self.tokenizer, KanjiDic(str(reading_resource_path / "kanjidic")))

# ---------- cohesion analysis ----------
self.cohesion_tasks: List[CohesionTask] = [task for task in CohesionTask if task.value in cohesion_tasks]
Expand All @@ -116,7 +109,7 @@ def __init__(
self.restrict_cohesion_target: bool = restrict_cohesion_target

# ---------- dependency parsing & cohesion analysis ----------
self.special_tokens: List[str] = list(special_tokens)
self.special_tokens: List[str] = [st for st in special_tokens if st != " "]
self.special_encoding: Encoding = self.tokenizer(
self.special_tokens,
add_special_tokens=False,
Expand All @@ -134,22 +127,18 @@ def __init__(

def _get_tokenized_len(self, document_or_sentence: Union[Document, Sentence]) -> int:
tokenizer_input: Union[List[str], str] = [m.text for m in document_or_sentence.morphemes]
if self.tokenizer_input_format == "text":
tokenizer_input = " ".join(tokenizer_input)
return len(self.tokenizer.tokenize(tokenizer_input, is_split_into_words=self.tokenizer_input_format == "words"))
return len(self.tokenizer.tokenize(tokenizer_input, is_split_into_words=True))

def _load_examples(self, doc_id2document: Dict[str, Document]) -> List[WordExample]:
examples = []
example_id = 0
for document in track(doc_id2document.values(), description="Loading examples"):
tokenizer_input: Union[List[str], str] = [m.text for m in document.morphemes]
if self.tokenizer_input_format == "text":
tokenizer_input = " ".join(tokenizer_input)
encoding: Encoding = self.tokenizer(
tokenizer_input,
padding=PaddingStrategy.DO_NOT_PAD,
truncation=False,
is_split_into_words=self.tokenizer_input_format == "words",
is_split_into_words=True,
).encodings[0]
if len(encoding.ids) > self.max_seq_length - len(self.special_tokens):
continue
Expand Down
19 changes: 5 additions & 14 deletions src/kwja/datamodule/datasets/word_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from typing import Dict, List, Optional, Union

from cohesion_tools.extractors import BridgingExtractor, CoreferenceExtractor, PasExtractor
from cohesion_tools.extractors.base import BaseExtractor
Expand All @@ -15,7 +15,7 @@
from kwja.datamodule.datasets.base import BaseDataset, FullAnnotatedDocumentLoaderMixin
from kwja.datamodule.datasets.word import WordModuleFeatures
from kwja.datamodule.examples import SpecialTokenIndexer, WordInferenceExample
from kwja.utils.constants import SPLIT_INTO_WORDS_MODEL_NAMES, CohesionTask
from kwja.utils.constants import CohesionTask
from kwja.utils.logging_util import track
from kwja.utils.sub_document import extract_target_sentences

Expand Down Expand Up @@ -46,11 +46,6 @@ def __init__(
# do_predict_after_train
documents = []

if tokenizer.name_or_path in SPLIT_INTO_WORDS_MODEL_NAMES:
self.tokenizer_input_format: Literal["words", "text"] = "words"
else:
self.tokenizer_input_format = "text"

super(BaseDataset, self).__init__(documents, tokenizer, max_seq_length, document_split_stride)
# ---------- cohesion analysis ----------
self.cohesion_tasks: List[CohesionTask] = [task for task in CohesionTask if task.value in cohesion_tasks]
Expand All @@ -75,7 +70,7 @@ def __init__(
self.restrict_cohesion_target: bool = restrict_cohesion_target

# ---------- dependency parsing & cohesion analysis ----------
self.special_tokens: List[str] = list(special_tokens)
self.special_tokens: List[str] = [st for st in special_tokens if st != " "]
self.special_encoding: Encoding = self.tokenizer(
self.special_tokens,
add_special_tokens=False,
Expand All @@ -88,22 +83,18 @@ def __init__(

def _get_tokenized_len(self, document_or_sentence: Union[Document, Sentence]) -> int:
tokenizer_input: Union[List[str], str] = [m.text for m in document_or_sentence.morphemes]
if self.tokenizer_input_format == "text":
tokenizer_input = " ".join(tokenizer_input)
return len(self.tokenizer.tokenize(tokenizer_input, is_split_into_words=self.tokenizer_input_format == "words"))
return len(self.tokenizer.tokenize(tokenizer_input, is_split_into_words=True))

def _load_examples(self, doc_id2document: Dict[str, Document]) -> List[WordInferenceExample]:
examples = []
example_id = 0
for document in track(doc_id2document.values(), description="Loading examples"):
tokenizer_input: Union[List[str], str] = [m.text for m in document.morphemes]
if self.tokenizer_input_format == "text":
tokenizer_input = " ".join(tokenizer_input)
encoding: Encoding = self.tokenizer(
tokenizer_input,
padding=PaddingStrategy.DO_NOT_PAD,
truncation=False,
is_split_into_words=self.tokenizer_input_format == "words",
is_split_into_words=True,
).encodings[0]
if len(encoding.ids) > self.max_seq_length - len(self.special_tokens):
continue
Expand Down
8 changes: 0 additions & 8 deletions src/kwja/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@
RESOURCE_PATH = resource_files(kwja) / "resource"


# ---------- word (inference) dataset ----------
SPLIT_INTO_WORDS_MODEL_NAMES = [
"nlp-waseda/roberta-base-japanese",
"nlp-waseda/roberta-large-japanese",
"nlp-waseda/roberta-large-japanese-seq512",
]


# ---------- typo module ----------
TYPO_CORR_OP_TAG2TOKEN = {
"K": "<k>",
Expand Down
60 changes: 6 additions & 54 deletions src/kwja/utils/reading_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from rhoknp import Document, Morpheme
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from rhoknp import Morpheme
from transformers import PreTrainedTokenizerBase

from kwja.utils.constants import (
CHOON_SET,
Expand Down Expand Up @@ -42,21 +42,14 @@ def get_reading2reading_id(path: Path) -> Dict[str, int]:
class ReadingAligner:
kana_re = re.compile("^[\u3041-\u30FF]+$")

def __init__(
self, tokenizer: PreTrainedTokenizerBase, tokenizer_input_format: Literal["words", "text"], kanji_dic: KanjiDic
) -> None:
def __init__(self, tokenizer: PreTrainedTokenizerBase, kanji_dic: KanjiDic) -> None:
self.tokenizer = tokenizer
self.tokenizer_input_format = tokenizer_input_format
self.kanji_dic = kanji_dic

def align(self, morphemes: List[Morpheme]) -> List[str]:
# assumption: morphemes are never combined
tokenizer_input: Union[List[str], str] = [m.text for m in morphemes]
if self.tokenizer_input_format == "text":
tokenizer_input = " ".join(tokenizer_input)
encoding = self.tokenizer(
tokenizer_input, add_special_tokens=False, is_split_into_words=self.tokenizer_input_format == "words"
).encodings[0]
encoding = self.tokenizer(tokenizer_input, add_special_tokens=False, is_split_into_words=True).encodings[0]
word_id2subwords = defaultdict(list)
for token_id, word_id in enumerate(encoding.word_ids):
word_id2subwords[word_id].append(self.tokenizer.decode(encoding.ids[token_id]))
Expand Down Expand Up @@ -293,46 +286,5 @@ def get_word_level_readings(readings: List[str], tokens: List[str], subword_map:
if item:
ret.append(item)
elif any(flags):
ret.append("")
ret.append("_")
return ret


def main():
from argparse import ArgumentParser
from collections import Counter
from pathlib import Path

from kwja.utils.constants import SPLIT_INTO_WORDS_MODEL_NAMES

parser = ArgumentParser()
parser.add_argument("-m", "--model-name-or-path", type=str, help="model_name_or_path")
parser.add_argument("-k", "--kanjidic", type=str, help="path to file")
parser.add_argument("-i", "--input", type=str, help="path to input dir")
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
if args.model_name_or_path in SPLIT_INTO_WORDS_MODEL_NAMES:
tokenizer_input_format: Literal["words", "text"] = "words"
else:
tokenizer_input_format = "text"
kanjidic = KanjiDic(args.kanjidic)
reading_aligner = ReadingAligner(tokenizer, tokenizer_input_format, kanjidic)

reading_counter: Dict[str, int] = Counter()
for path in Path(args.input).glob("**/*.knp"):
logger.info(f"processing {path}")
with path.open() as f:
document = Document.from_knp(f.read())
try:
for reading in reading_aligner.align(document.morphemes):
reading_counter[reading] += 1
except ValueError:
logger.warning(f"skip {document.doc_id} for an error")
for subreading, count in sorted(
sorted(reading_counter.items(), key=lambda pair: pair[0]), key=lambda pair: pair[1], reverse=True
):
print(f"{subreading}\t{count}")


if __name__ == "__main__":
main()

0 comments on commit 010dd8a

Please sign in to comment.