From 31f7993e036c6b27b464d0c1409a457d4b0c89b3 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Fri, 13 Dec 2024 21:59:21 +0800 Subject: [PATCH] fix https://github.com/shibing624/pycorrector/issues/534 --- pycorrector/proper_corrector.py | 40 ++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/pycorrector/proper_corrector.py b/pycorrector/proper_corrector.py index fb04f250..0d597464 100644 --- a/pycorrector/proper_corrector.py +++ b/pycorrector/proper_corrector.py @@ -6,7 +6,6 @@ import os from codecs import open from typing import List - import pypinyin from loguru import logger @@ -14,13 +13,13 @@ from pycorrector.utils.ngram_util import NgramUtil from pycorrector.utils.text_utils import is_chinese_char from pycorrector.utils.tokenizer import segment, split_text_into_sentences_by_symbol +from collections import defaultdict pwd_path = os.path.abspath(os.path.dirname(__file__)) - # 五笔笔画字典 stroke_path = os.path.join(pwd_path, 'data/stroke.txt') -# 专名词典,包括成语、俗语、专业领域词等 format: 词语 -proper_name_path = os.path.join(pwd_path, 'data/proper_name.txt') +# 专名词典,包括成语、俗语、专业领域词等 format: 词语, 可以自定义 +default_proper_name_path = os.path.join(pwd_path, 'data/proper_name.txt') def load_set_file(path): @@ -60,10 +59,35 @@ def load_dict_file(path): return result +class TrieNode: + def __init__(self): + self.children = defaultdict(TrieNode) + self.is_end_of_word = False + + +class Trie: + def __init__(self): + self.root = TrieNode() + + def insert(self, word): + node = self.root + for char in word: + node = node.children[char] + node.is_end_of_word = True + + def search(self, word): + node = self.root + for char in word: + if char not in node.children: + return False + node = node.children[char] + return node.is_end_of_word + + class ProperCorrector: def __init__( self, - proper_name_path=proper_name_path, + proper_name_path=default_proper_name_path, stroke_path=stroke_path, ): self.name = 'ProperCorrector' @@ -71,6 +95,9 @@ def __init__( self.proper_names = load_set_file(proper_name_path) # stroke, 笔划字典 format: 字:笔划,如:万,笔划是横(h),折(z),撇(p),组合起来是:hzp self.stroke_dict = load_dict_file(stroke_path) + self.trie = Trie() + for name in self.proper_names: + self.trie.insert(name) def get_stroke(self, char): """ @@ -95,6 +122,7 @@ def is_near_stroke_char(self, char1, char2, stroke_threshold=0.8): def get_char_stroke_similarity_score(self, char1, char2): """ 获取字符的字形相似度 + Args: char1: char2: @@ -253,6 +281,8 @@ def correct( # 词长度过滤 ngrams = [i for i in ngrams if min_word_length <= len(i) <= max_word_length] for cur_item in ngrams: + if self.trie.search(cur_item): + continue for name in self.proper_names: if self.get_word_similarity_score(cur_item, name) > sim_threshold: if cur_item != name: