Skip to content

Commit

Permalink
update pre-commit hooks and reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu-g committed Feb 1, 2024
1 parent a95d045 commit bda5217
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 25 deletions.
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@ repos:
- id: check-yaml
- id: check-toml
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.11.0
rev: 24.1.1
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies: [Flake8-pyproject]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies:
- rhoknp==1.6.0
- hydra-core==1.3.2
- torch==2.1.1
- torchmetrics==1.2.0
- transformers==4.34.1
- torch==2.2.0
- torchmetrics==1.3.0
- transformers==4.36.2
- tokenizers
- wandb
- typer
Expand Down
8 changes: 5 additions & 3 deletions src/kwja/metrics/char.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def _convert_doc_id2sentences_into_documents(
) -> List[Document]:
# Build documents that do not have clauses, phrases, or base phrases, but morphemes only
return [
Document.from_jumanpp("".join(s.to_jumanpp() for s in ss))
if from_sentences is False
else Document.from_sentences(ss)
(
Document.from_jumanpp("".join(s.to_jumanpp() for s in ss))
if from_sentences is False
else Document.from_sentences(ss)
)
for ss in doc_id2sentences.values()
]

Expand Down
28 changes: 15 additions & 13 deletions src/kwja/modules/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,21 @@ def predict_step(self, batch: Any) -> Dict[str, Any]:
generations = self.encoder_decoder.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
logits_processor=LogitsProcessorList(
[
ForcedLogitsProcessor(
surfs=batch["surfs"],
num_beams=self.hparams.decoding.num_beams,
tokenizer=self.tokenizer,
reading_candidates=self.reading_candidates,
char2tokens=self.char2tokens,
),
]
)
if self.use_forced_decoding
else None,
logits_processor=(
LogitsProcessorList(
[
ForcedLogitsProcessor(
surfs=batch["surfs"],
num_beams=self.hparams.decoding.num_beams,
tokenizer=self.tokenizer,
reading_candidates=self.reading_candidates,
char2tokens=self.char2tokens,
),
]
)
if self.use_forced_decoding
else None
),
**self.hparams.decoding,
)
if isinstance(generations, torch.Tensor):
Expand Down
10 changes: 9 additions & 1 deletion src/kwja/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,15 @@ class CohesionTask(Enum):
"ヤ列基本推量形": 11,
"ヤ列基本省略推量形": 12,
},
"助動詞そうだ型": {"*": 0, "語幹": 1, "基本形": 2, "ダ列タ系連用テ形": 3, "デアル列基本形": 4, "デス列基本形": 5, "デス列音便基本形": 6},
"助動詞そうだ型": {
"*": 0,
"語幹": 1,
"基本形": 2,
"ダ列タ系連用テ形": 3,
"デアル列基本形": 4,
"デス列基本形": 5,
"デス列音便基本形": 6,
},
"助動詞く型": {"*": 0, "語幹": 1, "基本形": 2, "基本連用形": 3, "文語連体形": 4, "文語未然形": 5},
"動詞性接尾辞ます型": {
"*": 0,
Expand Down
5 changes: 4 additions & 1 deletion tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def test_sanity():
for knp_text in chunk_by_document(io.StringIO(ret.stdout)):
documents.append(Document.from_knp(knp_text))
assert len(documents) == 2
assert documents[0].text == "KWJAは日本語の統合解析ツールです。汎用言語モデルを利用し、様々な言語解析を統一的な方法で解いています。"
assert (
documents[0].text
== "KWJAは日本語の統合解析ツールです。汎用言語モデルを利用し、様々な言語解析を統一的な方法で解いています。"
)
assert documents[1].text == (
"計算機による言語理解を実現するためには、計算機に常識・世界知識を与える必要があります。10年前にはこれは非常に難しい問題でしたが、"
+ "近年の計算機パワー、計算機ネットワークの飛躍的進展によって計算機が超大規模テキストを取り扱えるようになり、そこから常識を"
Expand Down

0 comments on commit bda5217

Please sign in to comment.