Skip to content

Commit

Permalink
eliminate the escape of hydra variable interpolation: ${}
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu-g committed Feb 1, 2024
1 parent bdb34b7 commit 1f11d76
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 15 deletions.
1 change: 1 addition & 0 deletions configs/datamodule/predict/char_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
_target_: kwja.datamodule.datasets.CharInferenceDataset
texts: []
doc_id_prefix: null
raw_text_file: null
1 change: 1 addition & 0 deletions configs/datamodule/predict/seq2seq_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ defaults:
_target_: kwja.datamodule.datasets.Seq2SeqInferenceDataset
texts: []
doc_id_prefix: null
raw_text_file: null
10 changes: 2 additions & 8 deletions src/kwja/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import re
import sys
from abc import ABC
from enum import Enum
Expand Down Expand Up @@ -29,7 +28,6 @@

filter_logs(environment="production")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
OMEGACONF_VARIABLE_INTERPOLATION = re.compile(r"\$(?P<variable>\{.*?})")
logging.basicConfig(format="")

logger = logging.getLogger("kwja_cli")
Expand Down Expand Up @@ -102,8 +100,7 @@ def _load_module(self) -> pl.LightningModule:

def _load_datamodule(self, input_file: Path) -> DataModule:
assert self.module is not None
with input_file.open() as f:
self.module.hparams.datamodule.predict.texts = list(_chunk_by_document(f, self.input_format))
self.module.hparams.datamodule.predict.raw_text_file = input_file
datamodule = DataModule(cfg=self.module.hparams.datamodule)
datamodule.setup(stage=TrainerFn.PREDICTING)
return datamodule
Expand All @@ -128,8 +125,7 @@ def _load_module(self) -> pl.LightningModule:

def _load_datamodule(self, input_file: Path) -> DataModule:
assert self.module is not None
with input_file.open() as f:
self.module.hparams.datamodule.predict.texts = list(_chunk_by_document(f, self.input_format))
self.module.hparams.datamodule.predict.raw_text_file = input_file
datamodule = DataModule(cfg=self.module.hparams.datamodule)
datamodule.setup(stage=TrainerFn.PREDICTING)
return datamodule
Expand Down Expand Up @@ -248,8 +244,6 @@ def normalize_text(text: str) -> str:
normalized = normalize("NFKC", text)
# escape several symbols and delete control characters
normalized = normalized.translate(TRANSLATION_TABLE)
# prevent hydra.utils.instantiate from interpolating the string "${...}"
normalized = OMEGACONF_VARIABLE_INTERPOLATION.sub(r"$␣\g<variable>", normalized)
# if normalized != text:
# typer.echo(f"apply normalization ({text} -> {normalized})", err=True)
return normalized
Expand Down
14 changes: 12 additions & 2 deletions src/kwja/datamodule/datasets/char_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from typing import Dict, List, Optional

from omegaconf import ListConfig
Expand All @@ -8,7 +9,7 @@

from kwja.datamodule.datasets.base import BaseDataset, FullAnnotatedDocumentLoaderMixin
from kwja.datamodule.datasets.char import CharModuleFeatures
from kwja.datamodule.datasets.utils import add_doc_ids, add_sent_ids, create_documents_from_raw_texts
from kwja.datamodule.datasets.utils import add_doc_ids, add_sent_ids, chunk_by_document, create_documents_from_raw_texts
from kwja.datamodule.examples import CharInferenceExample
from kwja.utils.logging_util import track

Expand All @@ -22,10 +23,19 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
max_seq_length: int,
doc_id_prefix: Optional[str] = None,
raw_text_file: Optional[Path] = None,
**_,
) -> None:
super().__init__(tokenizer, max_seq_length)
documents = create_documents_from_raw_texts(texts)
documents: List[Document]
if len(texts) > 0:
documents = create_documents_from_raw_texts(texts)
elif raw_text_file is not None:
with raw_text_file.open() as f:
documents = create_documents_from_raw_texts(chunk_by_document(f))
else:
documents = []

add_doc_ids(documents, doc_id_prefix)
documents = self._add_tentative_sentence_boundary(documents)
super(BaseDataset, self).__init__(documents, tokenizer, max_seq_length, -1) # document_split_stride must be -1
Expand Down
15 changes: 12 additions & 3 deletions src/kwja/datamodule/datasets/typo_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from typing import Dict, List, Tuple
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from omegaconf import ListConfig
from rhoknp import Document
Expand All @@ -8,7 +9,7 @@

from kwja.datamodule.datasets.base import BaseDataset
from kwja.datamodule.datasets.typo import TypoModuleFeatures
from kwja.datamodule.datasets.utils import create_documents_from_raw_texts
from kwja.datamodule.datasets.utils import chunk_by_document, create_documents_from_raw_texts
from kwja.datamodule.examples import TypoInferenceExample
from kwja.utils.constants import DUMMY_TOKEN
from kwja.utils.logging_util import track
Expand All @@ -20,9 +21,17 @@ def __init__(
texts: ListConfig,
tokenizer: PreTrainedTokenizerBase,
max_seq_length: int,
raw_text_file: Optional[Path] = None,
) -> None:
super().__init__(tokenizer, max_seq_length)
documents: List[Document] = create_documents_from_raw_texts(texts)
documents: List[Document]
if len(texts) > 0:
documents = create_documents_from_raw_texts(texts)
elif raw_text_file is not None:
with raw_text_file.open() as f:
documents = create_documents_from_raw_texts(chunk_by_document(f))
else:
documents = []
self.examples: List[TypoInferenceExample] = []
self.stash: Dict[int, List[Tuple[str, str]]] = defaultdict(list)
example_id = 0
Expand Down
16 changes: 14 additions & 2 deletions src/kwja/datamodule/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import List, Optional, Sequence
from typing import Iterable, Iterator, List, Optional, Sequence, TextIO

from rhoknp import Document


def create_documents_from_raw_texts(texts: Sequence[str]) -> List[Document]:
def create_documents_from_raw_texts(texts: Iterable[str]) -> List[Document]:
documents: List[Document] = []
for text in texts:
raw_text = ""
Expand Down Expand Up @@ -37,3 +37,15 @@ def add_sent_ids(documents: Sequence[Document]) -> None:
for idx, sentence in enumerate(document.sentences):
if sentence.sent_id == "":
sentence.sent_id = f"{document.doc_id}-{idx:0{sent_id_width}}"


def chunk_by_document(f: TextIO) -> Iterator[str]:
buff: str = ""
for line in f:
if line.strip() == "EOD":
yield buff.rstrip()
buff = ""
else:
buff += line
if buff.rstrip() != "":
yield buff.rstrip()
13 changes: 13 additions & 0 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ def test_normalization_and_typo_module(text: str, output: str):
assert ret.stdout == output


@pytest.mark.parametrize(
("text", "output"),
[
("今日は${day}日です", "今日は${day}日です"),
],
)
def test_normalization_and_char_module(text: str, output: str):
ret = runner.invoke(app, args=["--model-size", "tiny", "--tasks", "char", "--text", text])
assert ret.exception is None
restored_output = "".join([line for line in ret.stdout.splitlines() if not line.startswith("#")]).replace(" ", "")
assert restored_output == output.replace(" ", "")


def test_file_input():
with tempfile.NamedTemporaryFile("wt") as f:
f.write(
Expand Down

0 comments on commit 1f11d76

Please sign in to comment.