Skip to content

Commit

Permalink
🔍 mypy: Use an almost strict mypy configuration, and fix any issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mikegerber committed Jan 10, 2024
1 parent ad316ae commit 483e809
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ repos:
hooks:
- additional_dependencies:
- types-setuptools
- types-lxml
- numpy # for numpy plugin
id: mypy

- repo: https://gitlab.com/vojko.pribudic/pre-commit-update
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,20 @@ markers = [


[tool.mypy]
plugins = ["numpy.typing.mypy_plugin"]

ignore_missing_imports = true


strict = true

disallow_subclassing_any = false
# ❗ error: Class cannot subclass "Processor" (has type "Any")
disallow_any_generics = false
disallow_untyped_defs = false
disallow_untyped_calls = false


[tool.ruff]
select = ["E", "F", "I"]
ignore = [
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ ruff
pytest-ruff

mypy
types-lxml
types-setuptools
pytest-mypy
3 changes: 1 addition & 2 deletions src/dinglehopper/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from typing import Optional

from rapidfuzz.distance import Levenshtein

from .edit_distance import grapheme_clusters
from uniseg.graphemecluster import grapheme_clusters


def align(t1, t2):
Expand Down
13 changes: 9 additions & 4 deletions src/dinglehopper/character_error_rate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import unicodedata
from typing import List, Tuple
from typing import List, Tuple, TypeVar

from multimethod import multimethod
from uniseg.graphemecluster import grapheme_clusters

from .edit_distance import distance
from .extracted_text import ExtractedText

T = TypeVar("T")


@multimethod
def character_error_rate_n(
Expand Down Expand Up @@ -34,21 +36,24 @@ def character_error_rate_n(
def _(reference: str, compared: str) -> Tuple[float, int]:
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
return character_error_rate_n(seq1, seq2)
cer, n = character_error_rate_n(seq1, seq2)
return cer, n


@character_error_rate_n.register
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return character_error_rate_n(
cer, n = character_error_rate_n(
reference.grapheme_clusters, compared.grapheme_clusters
)
return cer, n


def character_error_rate(reference, compared) -> float:
def character_error_rate(reference: T, compared: T) -> float:
"""
Compute character error rate.
:return: character error rate
"""
cer: float
cer, _ = character_error_rate_n(reference, compared)
return cer
11 changes: 6 additions & 5 deletions src/dinglehopper/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections import Counter
from typing import List

import click
from jinja2 import Environment, FileSystemLoader
Expand Down Expand Up @@ -76,7 +77,7 @@ def format_thing(t, css_classes=None, id_=None):
if o is not None:
o_pos += len(o)

found_differences = dict(Counter(elem for elem in found_differences))
counted_differences = dict(Counter(elem for elem in found_differences))

return (
"""
Expand All @@ -87,7 +88,7 @@ def format_thing(t, css_classes=None, id_=None):
""".format(
gtx, ocrx
),
found_differences,
counted_differences,
)


Expand All @@ -113,7 +114,7 @@ def process(
metrics: bool = True,
differences: bool = False,
textequiv_level: str = "region",
):
) -> None:
"""Check OCR result against GT.
The @click decorators change the signature of the decorated functions, so we keep
Expand All @@ -122,8 +123,8 @@ def process(

gt_text = extract(gt, textequiv_level=textequiv_level)
ocr_text = extract(ocr, textequiv_level=textequiv_level)
gt_words: list = list(words_normalized(gt_text))
ocr_words: list = list(words_normalized(ocr_text))
gt_words: List[str] = list(words_normalized(gt_text))
ocr_words: List[str] = list(words_normalized(ocr_text))

assert isinstance(gt_text, ExtractedText)
assert isinstance(ocr_text, ExtractedText)
Expand Down
5 changes: 3 additions & 2 deletions src/dinglehopper/cli_summarize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from typing import Dict

import click
from jinja2 import Environment, FileSystemLoader
Expand All @@ -13,8 +14,8 @@ def process(reports_folder, occurrences_threshold=1):
wer_list = []
cer_sum = 0
wer_sum = 0
diff_c = {}
diff_w = {}
diff_c: Dict[str, int] = {}
diff_w: Dict[str, int] = {}

for report in os.listdir(reports_folder):
if report.endswith(".json"):
Expand Down
6 changes: 3 additions & 3 deletions src/dinglehopper/edit_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@multimethod
def distance(seq1: List[str], seq2: List[str]):
def distance(seq1: List[str], seq2: List[str]) -> int:
"""Compute the Levenshtein edit distance between two lists of grapheme clusters.
This assumes that the grapheme clusters are already normalized.
Expand All @@ -20,7 +20,7 @@ def distance(seq1: List[str], seq2: List[str]):


@distance.register
def _(s1: str, s2: str):
def _(s1: str, s2: str) -> int:
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode
Expand All @@ -33,7 +33,7 @@ def _(s1: str, s2: str):


@distance.register
def _(s1: ExtractedText, s2: ExtractedText):
def _(s1: ExtractedText, s2: ExtractedText) -> int:
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)


Expand Down
16 changes: 10 additions & 6 deletions src/dinglehopper/extracted_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unicodedata
from contextlib import suppress
from itertools import repeat
from typing import List, Optional
from typing import Any, Dict, List, Optional

import attr
import numpy as np
Expand Down Expand Up @@ -173,10 +173,11 @@ def are_valid_grapheme_clusters(self, _, value):
normalization = attr.ib(converter=Normalization, default=Normalization.NFC_SBB)

@property
def text(self):
def text(self) -> str:
if self._text is not None:
return self._text
else:
assert self.joiner is not None and self.segments is not None
return self.joiner.join(s.text for s in self.segments)

@functools.cached_property
Expand All @@ -186,6 +187,7 @@ def _joiner_grapheme_cluster(self):
This property is cached.
"""

assert self.joiner is not None
if len(self.joiner) > 0:
joiner_grapheme_cluster = list(grapheme_clusters(self.joiner))
assert len(joiner_grapheme_cluster) == 1 # see joiner's check above
Expand All @@ -203,6 +205,7 @@ def grapheme_clusters(self):
else:
# TODO Test with text extracted at glyph level (joiner == "")
clusters = []
assert self.segments is not None
for seg in self.segments:
clusters += seg.grapheme_clusters + self._joiner_grapheme_cluster
clusters = clusters[:-1]
Expand All @@ -218,6 +221,7 @@ def segment_id_for_pos(self, pos):
else:
# Recurse
segment_id_for_pos = []
assert self.joiner is not None and self.segments is not None
for s in self.segments:
seg_ids = [s.segment_id_for_pos(i) for i in range(len(s.text))]
segment_id_for_pos.extend(seg_ids)
Expand Down Expand Up @@ -280,7 +284,7 @@ def invert_dict(d):
return {v: k for k, v in d.items()}


def get_textequiv_unicode(text_segment, nsmap) -> str:
def get_textequiv_unicode(text_segment: Any, nsmap: Dict[str, str]) -> str:
"""Get the TextEquiv/Unicode text of the given PAGE text element."""
segment_id = text_segment.attrib["id"]
textequivs = text_segment.findall("./page:TextEquiv", namespaces=nsmap)
Expand All @@ -304,7 +308,7 @@ def get_first_textequiv(textequivs, segment_id):
if np.any(~nan_mask):
if np.any(nan_mask):
log.warning("TextEquiv without index in %s.", segment_id)
index = np.nanargmin(indices)
index = int(np.nanargmin(indices))
else:
# try ordering by conf
confidences = np.array([get_attr(te, "conf") for te in textequivs], dtype=float)
Expand All @@ -313,15 +317,15 @@ def get_first_textequiv(textequivs, segment_id):
"No index attributes, use 'conf' attribute to sort TextEquiv in %s.",
segment_id,
)
index = np.nanargmax(confidences)
index = int(np.nanargmax(confidences))
else:
# fallback to first entry in case of neither index or conf present
log.warning("No index attributes, use first TextEquiv in %s.", segment_id)
index = 0
return textequivs[index]


def get_attr(te, attr_name) -> float:
def get_attr(te: Any, attr_name: str) -> float:
"""Extract the attribute for the given name.
Note: currently only handles numeric values!
Expand Down
21 changes: 14 additions & 7 deletions src/dinglehopper/ocr_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys
from typing import Iterator
from typing import Dict, Iterator, Optional

import chardet
from lxml import etree as ET
Expand All @@ -10,11 +10,11 @@
from .extracted_text import ExtractedText, normalize_sbb


def alto_namespace(tree: ET.ElementTree) -> str:
def alto_namespace(tree: ET._ElementTree) -> Optional[str]:
"""Return the ALTO namespace used in the given ElementTree.
This relies on the assumption that, in any given ALTO file, the root element has the
local name "alto". We do not check if the files uses any valid ALTO namespace.
local name "alto". We do not check if the file uses any valid ALTO namespace.
"""
root_name = ET.QName(tree.getroot().tag)
if root_name.localname == "alto":
Expand All @@ -23,8 +23,15 @@ def alto_namespace(tree: ET.ElementTree) -> str:
raise ValueError("Not an ALTO tree")


def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
nsmap = {"alto": alto_namespace(tree)}
def alto_nsmap(tree: ET._ElementTree) -> Dict[str, str]:
alto_ns = alto_namespace(tree)
if alto_ns is None:
raise ValueError("Could not determine ALTO namespace")
return {"alto": alto_ns}


def alto_extract_lines(tree: ET._ElementTree) -> Iterator[ExtractedText]:
nsmap = alto_nsmap(tree)
for line in tree.iterfind(".//alto:TextLine", namespaces=nsmap):
line_id = line.attrib.get("ID")
line_text = " ".join(
Expand All @@ -37,7 +44,7 @@ def alto_extract_lines(tree: ET.ElementTree) -> Iterator[ExtractedText]:
# FIXME hardcoded SBB normalization


def alto_extract(tree: ET.ElementTree) -> ExtractedText:
def alto_extract(tree: ET._ElementTree) -> ExtractedText:
"""Extract text from the given ALTO ElementTree."""
return ExtractedText(None, list(alto_extract_lines(tree)), "\n", None, None)

Expand Down Expand Up @@ -98,7 +105,7 @@ def extract_texts_from_reading_order_group(group, tree, nsmap, textequiv_level):
if ET.QName(group.tag).localname in ["OrderedGroup", "OrderedGroupIndexed"]:
ro_children = list(group)

ro_children = filter(lambda child: "index" in child.attrib.keys(), ro_children)
ro_children = [child for child in ro_children if "index" in child.attrib.keys()]
ro_children = sorted(ro_children, key=lambda child: int(child.attrib["index"]))
elif ET.QName(group.tag).localname in ["UnorderedGroup", "UnorderedGroupIndexed"]:
ro_children = list(group)
Expand Down
Loading

0 comments on commit 483e809

Please sign in to comment.