Skip to content

Commit

Permalink
Merge pull request #211 from ku-nlp/add-whitespace-token
Browse files Browse the repository at this point in the history
Add whitespace token
  • Loading branch information
omukazu authored Jul 8, 2024
2 parents 049c232 + e9a45d0 commit b7f5474
Show file tree
Hide file tree
Showing 38 changed files with 808 additions and 989 deletions.
1 change: 1 addition & 0 deletions configs/char_module.debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: [" "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
1 change: 1 addition & 0 deletions configs/char_module.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: [" "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
1 change: 1 addition & 0 deletions configs/datamodule/base/char.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ denormalize_probability: ${denormalize_probability}
tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: ${encoder.pretrained_model_name_or_path}
additional_special_tokens: ${special_tokens}
do_word_tokenize: false
_convert_: all
4 changes: 2 additions & 2 deletions configs/seq2seq_module.debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ do_predict_after_train: false
checkpoint_path: ""

# For decoding settings
use_forced_decoding: true
use_surf_forced_decoding: true
decoding:
max_length: ${max_tgt_length}
num_beams: 3
num_beams: 2

# set monitor and mode for early_stopping and model_checkpoint
monitor: valid/seq2seq_loss
Expand Down
4 changes: 2 additions & 2 deletions configs/seq2seq_module.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ do_predict_after_train: false
checkpoint_path: ""

# For decoding settings
use_forced_decoding: true
use_surf_forced_decoding: true
decoding:
max_length: ${max_tgt_length}
num_beams: 3
num_beams: 2

# set monitor and mode for early_stopping and model_checkpoint
monitor: valid/seq2seq_loss
Expand Down
2 changes: 1 addition & 1 deletion configs/typo_module.debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: ["<k>", "<d>", "<_>", "<dummy>"]
special_tokens: ["<k>", "<d>", "<_>", "<dummy>", " "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
2 changes: 1 addition & 1 deletion configs/typo_module.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: ["<k>", "<d>", "<_>", "<dummy>"]
special_tokens: ["<k>", "<d>", "<_>", "<dummy>", " "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
2 changes: 1 addition & 1 deletion configs/word_module.debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]"]
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]", " "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
2 changes: 1 addition & 1 deletion configs/word_module.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ compile: ${oc.env:COMPILE,false}
ignore_hparams_on_save: false

# constants
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]"]
special_tokens: ["[著者]", "[読者]", "[不特定:人]", "[不特定:物]", "[NULL]", "[NA]", "[ROOT]", " "]
hparams_to_ignore_on_save:
- project
- work_dir
Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ echo "Juman++ & KNP (NER + Dependency parsing + PAS analysis)"
grep "# S-ID:" "$WORK_DIR/benchmark.knp_ne_anaphora.knp" | grep -cv "ERROR:" >> "$WORK_DIR/count.txt"

echo "KWJA (typo_module)"
poetry run python ./scripts/analyze.py module=typo datamodule.predict.raw_input_file="$INPUT" checkpoint_path="$TYPO_MODULE" devices="$DEVICE" max_batches_per_device="$TYPO_BATCH_SIZE" > "$WORK_DIR/benchmark.kwja.txt" 2>> "$WORK_DIR/benchmark.stderr"
poetry run python ./scripts/analyze.py module=typo checkpoint_path="$TYPO_MODULE" devices="$DEVICE" max_batches_per_device="$TYPO_BATCH_SIZE" +datamodule.predict.raw_input_file="$INPUT" > "$WORK_DIR/benchmark.kwja.txt" 2>> "$WORK_DIR/benchmark.stderr"

echo "KWJA (char_module)"
poetry run python ./scripts/analyze.py module=char datamodule.predict.raw_input_file="$INPUT" checkpoint_path="$CHAR_MODULE" devices="$DEVICE" max_batches_per_device="$CHAR_BATCH_SIZE" > "$WORK_DIR/benchmark.kwja.juman" 2>> "$WORK_DIR/benchmark.stderr"
poetry run python ./scripts/analyze.py module=char checkpoint_path="$CHAR_MODULE" devices="$DEVICE" max_batches_per_device="$CHAR_BATCH_SIZE" +datamodule.predict.raw_input_file="$INPUT" > "$WORK_DIR/benchmark.kwja.juman" 2>> "$WORK_DIR/benchmark.stderr"
grep "# S-ID:" "$WORK_DIR/benchmark.kwja.juman" | cut -f -3 -d "-" | uniq | wc -l >> "$WORK_DIR/count.txt"

echo "KWJA (word_module)"
Expand Down
44 changes: 44 additions & 0 deletions scripts/preprocessors/preprocess_reading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from argparse import ArgumentParser
from collections import Counter
from pathlib import Path
from typing import Dict

from rhoknp import Document
from transformers import AutoTokenizer

from kwja.utils.kanjidic import KanjiDic
from kwja.utils.reading_prediction import ReadingAligner

logger = logging.getLogger(__name__)


def main():
parser = ArgumentParser()
parser.add_argument("-m", "--model-name-or-path", type=str, help="model_name_or_path")
parser.add_argument("-k", "--kanji-dic", type=str, help="path to kanji dic file")
parser.add_argument("-i", "--in-dir", type=Path, help="path to input directory")
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
kanji_dic = KanjiDic(args.kanji_dic)
reading_aligner = ReadingAligner(tokenizer, kanji_dic)

reading_counter: Dict[str, int] = Counter()
for path in args.in_dir.glob("**/*.knp"):
logger.info(f"processing {path}")
with path.open() as f:
document = Document.from_knp(f.read())
try:
for reading in reading_aligner.align(document.morphemes):
reading_counter[reading] += 1
except ValueError:
logger.warning(f"skip {document.doc_id} for an error")
for subreading, count in sorted(
sorted(reading_counter.items(), key=lambda pair: pair[0]), key=lambda pair: pair[1], reverse=True
):
print(f"{subreading}\t{count}")


if __name__ == "__main__":
main()
Loading

0 comments on commit b7f5474

Please sign in to comment.