Skip to content

Commit

Permalink
🔍 mypy: Use a compatible syntax for multimethod
Browse files Browse the repository at this point in the history
  • Loading branch information
mikegerber committed Jan 9, 2024
1 parent 8166435 commit ad316ae
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
10 changes: 4 additions & 6 deletions src/dinglehopper/character_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@ def character_error_rate_n(
# XXX Should we really count newlines here?


@multimethod
def character_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
@character_error_rate_n.register
def _(reference: str, compared: str) -> Tuple[float, int]:
seq1 = list(grapheme_clusters(unicodedata.normalize("NFC", reference)))
seq2 = list(grapheme_clusters(unicodedata.normalize("NFC", compared)))
return character_error_rate_n(seq1, seq2)


@multimethod
def character_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
@character_error_rate_n.register
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return character_error_rate_n(
reference.grapheme_clusters, compared.grapheme_clusters
)
Expand Down
8 changes: 4 additions & 4 deletions src/dinglehopper/edit_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def distance(seq1: List[str], seq2: List[str]):
return Levenshtein.distance(seq1, seq2)


@multimethod
def distance(s1: str, s2: str):
@distance.register
def _(s1: str, s2: str):
"""Compute the Levenshtein edit distance between two Unicode strings
Note that this is different from levenshtein() as this function knows about Unicode
Expand All @@ -32,8 +32,8 @@ def distance(s1: str, s2: str):
return Levenshtein.distance(seq1, seq2)


@multimethod
def distance(s1: ExtractedText, s2: ExtractedText):
@distance.register
def _(s1: ExtractedText, s2: ExtractedText):
return Levenshtein.distance(s1.grapheme_clusters, s2.grapheme_clusters)


Expand Down
18 changes: 8 additions & 10 deletions src/dinglehopper/word_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def unwanted(c):
yield word


@multimethod
def words(s: ExtractedText):
@words.register
def _(s: ExtractedText):
return words(s.text)


Expand All @@ -70,8 +70,8 @@ def words_normalized(s: str):
return words(unicodedata.normalize("NFC", s))


@multimethod
def words_normalized(s: ExtractedText):
@words_normalized.register
def _(s: ExtractedText):
return words_normalized(s.text)


Expand All @@ -82,15 +82,13 @@ def word_error_rate_n(reference: str, compared: str) -> Tuple[float, int]:
return word_error_rate_n(reference_seq, compared_seq)


@multimethod
def word_error_rate_n(
reference: ExtractedText, compared: ExtractedText
) -> Tuple[float, int]:
@word_error_rate_n.register
def _(reference: ExtractedText, compared: ExtractedText) -> Tuple[float, int]:
return word_error_rate_n(reference.text, compared.text)


@multimethod
def word_error_rate_n(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
@word_error_rate_n.register
def _(reference: Iterable, compared: Iterable) -> Tuple[float, int]:
reference_seq = list(reference)
compared_seq = list(compared)

Expand Down

0 comments on commit ad316ae

Please sign in to comment.