From 1f11d76e1b161db4cb8dc75fbddcf178b62b0358 Mon Sep 17 00:00:00 2001 From: nobu-g Date: Thu, 1 Feb 2024 19:59:58 +0900 Subject: [PATCH] eliminate the escape of hydra variable interpolation: ${} --- configs/datamodule/predict/char_inference.yaml | 1 + .../datamodule/predict/seq2seq_inference.yaml | 1 + src/kwja/cli/cli.py | 10 ++-------- src/kwja/datamodule/datasets/char_inference.py | 14 ++++++++++++-- src/kwja/datamodule/datasets/typo_inference.py | 15 ++++++++++++--- src/kwja/datamodule/datasets/utils.py | 16 ++++++++++++++-- tests/cli/test_cli.py | 13 +++++++++++++ 7 files changed, 55 insertions(+), 15 deletions(-) diff --git a/configs/datamodule/predict/char_inference.yaml b/configs/datamodule/predict/char_inference.yaml index c22ebe5c..f165f36e 100644 --- a/configs/datamodule/predict/char_inference.yaml +++ b/configs/datamodule/predict/char_inference.yaml @@ -5,3 +5,4 @@ defaults: _target_: kwja.datamodule.datasets.CharInferenceDataset texts: [] doc_id_prefix: null +raw_text_file: null diff --git a/configs/datamodule/predict/seq2seq_inference.yaml b/configs/datamodule/predict/seq2seq_inference.yaml index 6ff6d232..1c8c8ae9 100644 --- a/configs/datamodule/predict/seq2seq_inference.yaml +++ b/configs/datamodule/predict/seq2seq_inference.yaml @@ -5,3 +5,4 @@ defaults: _target_: kwja.datamodule.datasets.Seq2SeqInferenceDataset texts: [] doc_id_prefix: null +raw_text_file: null diff --git a/src/kwja/cli/cli.py b/src/kwja/cli/cli.py index e77f4098..56dc00da 100644 --- a/src/kwja/cli/cli.py +++ b/src/kwja/cli/cli.py @@ -1,6 +1,5 @@ import logging import os -import re import sys from abc import ABC from enum import Enum @@ -29,7 +28,6 @@ filter_logs(environment="production") os.environ["TOKENIZERS_PARALLELISM"] = "false" -OMEGACONF_VARIABLE_INTERPOLATION = re.compile(r"\$(?P\{.*?})") logging.basicConfig(format="") logger = logging.getLogger("kwja_cli") @@ -102,8 +100,7 @@ def _load_module(self) -> pl.LightningModule: def _load_datamodule(self, input_file: Path) -> DataModule: assert self.module is not None - with input_file.open() as f: - self.module.hparams.datamodule.predict.texts = list(_chunk_by_document(f, self.input_format)) + self.module.hparams.datamodule.predict.raw_text_file = input_file datamodule = DataModule(cfg=self.module.hparams.datamodule) datamodule.setup(stage=TrainerFn.PREDICTING) return datamodule @@ -128,8 +125,7 @@ def _load_module(self) -> pl.LightningModule: def _load_datamodule(self, input_file: Path) -> DataModule: assert self.module is not None - with input_file.open() as f: - self.module.hparams.datamodule.predict.texts = list(_chunk_by_document(f, self.input_format)) + self.module.hparams.datamodule.predict.raw_text_file = input_file datamodule = DataModule(cfg=self.module.hparams.datamodule) datamodule.setup(stage=TrainerFn.PREDICTING) return datamodule @@ -248,8 +244,6 @@ def normalize_text(text: str) -> str: normalized = normalize("NFKC", text) # escape several symbols and delete control characters normalized = normalized.translate(TRANSLATION_TABLE) - # prevent hydra.utils.instantiate from interpolating the string "${...}" - normalized = OMEGACONF_VARIABLE_INTERPOLATION.sub(r"$␣\g", normalized) # if normalized != text: # typer.echo(f"apply normalization ({text} -> {normalized})", err=True) return normalized diff --git a/src/kwja/datamodule/datasets/char_inference.py b/src/kwja/datamodule/datasets/char_inference.py index 1e7a065c..23c0f6dc 100644 --- a/src/kwja/datamodule/datasets/char_inference.py +++ b/src/kwja/datamodule/datasets/char_inference.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path from typing import Dict, List, Optional from omegaconf import ListConfig @@ -8,7 +9,7 @@ from kwja.datamodule.datasets.base import BaseDataset, FullAnnotatedDocumentLoaderMixin from kwja.datamodule.datasets.char import CharModuleFeatures -from kwja.datamodule.datasets.utils import add_doc_ids, add_sent_ids, create_documents_from_raw_texts +from kwja.datamodule.datasets.utils import add_doc_ids, add_sent_ids, chunk_by_document, create_documents_from_raw_texts from kwja.datamodule.examples import CharInferenceExample from kwja.utils.logging_util import track @@ -22,10 +23,19 @@ def __init__( tokenizer: PreTrainedTokenizerBase, max_seq_length: int, doc_id_prefix: Optional[str] = None, + raw_text_file: Optional[Path] = None, **_, ) -> None: super().__init__(tokenizer, max_seq_length) - documents = create_documents_from_raw_texts(texts) + documents: List[Document] + if len(texts) > 0: + documents = create_documents_from_raw_texts(texts) + elif raw_text_file is not None: + with raw_text_file.open() as f: + documents = create_documents_from_raw_texts(chunk_by_document(f)) + else: + documents = [] + add_doc_ids(documents, doc_id_prefix) documents = self._add_tentative_sentence_boundary(documents) super(BaseDataset, self).__init__(documents, tokenizer, max_seq_length, -1) # document_split_stride must be -1 diff --git a/src/kwja/datamodule/datasets/typo_inference.py b/src/kwja/datamodule/datasets/typo_inference.py index 6e7c081a..11f86679 100644 --- a/src/kwja/datamodule/datasets/typo_inference.py +++ b/src/kwja/datamodule/datasets/typo_inference.py @@ -1,5 +1,6 @@ from collections import defaultdict -from typing import Dict, List, Tuple +from pathlib import Path +from typing import Dict, List, Optional, Tuple from omegaconf import ListConfig from rhoknp import Document @@ -8,7 +9,7 @@ from kwja.datamodule.datasets.base import BaseDataset from kwja.datamodule.datasets.typo import TypoModuleFeatures -from kwja.datamodule.datasets.utils import create_documents_from_raw_texts +from kwja.datamodule.datasets.utils import chunk_by_document, create_documents_from_raw_texts from kwja.datamodule.examples import TypoInferenceExample from kwja.utils.constants import DUMMY_TOKEN from kwja.utils.logging_util import track @@ -20,9 +21,17 @@ def __init__( texts: ListConfig, tokenizer: PreTrainedTokenizerBase, max_seq_length: int, + raw_text_file: Optional[Path] = None, ) -> None: super().__init__(tokenizer, max_seq_length) - documents: List[Document] = create_documents_from_raw_texts(texts) + documents: List[Document] + if len(texts) > 0: + documents = create_documents_from_raw_texts(texts) + elif raw_text_file is not None: + with raw_text_file.open() as f: + documents = create_documents_from_raw_texts(chunk_by_document(f)) + else: + documents = [] self.examples: List[TypoInferenceExample] = [] self.stash: Dict[int, List[Tuple[str, str]]] = defaultdict(list) example_id = 0 diff --git a/src/kwja/datamodule/datasets/utils.py b/src/kwja/datamodule/datasets/utils.py index bd2be401..03568d62 100644 --- a/src/kwja/datamodule/datasets/utils.py +++ b/src/kwja/datamodule/datasets/utils.py @@ -1,10 +1,10 @@ from datetime import datetime -from typing import List, Optional, Sequence +from typing import Iterable, Iterator, List, Optional, Sequence, TextIO from rhoknp import Document -def create_documents_from_raw_texts(texts: Sequence[str]) -> List[Document]: +def create_documents_from_raw_texts(texts: Iterable[str]) -> List[Document]: documents: List[Document] = [] for text in texts: raw_text = "" @@ -37,3 +37,15 @@ def add_sent_ids(documents: Sequence[Document]) -> None: for idx, sentence in enumerate(document.sentences): if sentence.sent_id == "": sentence.sent_id = f"{document.doc_id}-{idx:0{sent_id_width}}" + + +def chunk_by_document(f: TextIO) -> Iterator[str]: + buff: str = "" + for line in f: + if line.strip() == "EOD": + yield buff.rstrip() + buff = "" + else: + buff += line + if buff.rstrip() != "": + yield buff.rstrip() diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index d55c0b60..cfa24fb8 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -108,6 +108,19 @@ def test_normalization_and_typo_module(text: str, output: str): assert ret.stdout == output +@pytest.mark.parametrize( + ("text", "output"), + [ + ("今日は${day}日です", "今日は${day}日です"), + ], +) +def test_normalization_and_char_module(text: str, output: str): + ret = runner.invoke(app, args=["--model-size", "tiny", "--tasks", "char", "--text", text]) + assert ret.exception is None + restored_output = "".join([line for line in ret.stdout.splitlines() if not line.startswith("#")]).replace(" ", "") + assert restored_output == output.replace(" ", "") + + def test_file_input(): with tempfile.NamedTemporaryFile("wt") as f: f.write(