Skip to content

Commit

Permalink
Merge pull request #527 from smartmark-pro/feature-better-bart
Browse files Browse the repository at this point in the history
feat:添加HillZhang1999的两个开源纠错模型
  • Loading branch information
shibing624 authored Oct 28, 2024
2 parents d537032 + 1be289a commit 598b5cc
Show file tree
Hide file tree
Showing 11 changed files with 455 additions and 2 deletions.
61 changes: 60 additions & 1 deletion README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pycorrector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from pycorrector.proper_corrector import ProperCorrector
from pycorrector.seq2seq.conv_seq2seq_corrector import ConvSeq2SeqCorrector
from pycorrector.t5.t5_corrector import T5Corrector
from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector
from pycorrector.nasgec_bart.nasgec_bart_corrector import NaSGECBartCorrector
from pycorrector.utils import text_utils, tokenizer, io_utils, math_utils, evaluate_utils
from pycorrector.utils.evaluate_utils import eval_model_batch
from pycorrector.utils.get_file import get_file
Expand Down
Empty file.
70 changes: 70 additions & 0 deletions pycorrector/mucgec_bart/monkey_pack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from modelscope.pipelines import Pipeline
from typing import Any, Dict, List
from modelscope.utils.constant import Frameworks
from modelscope.utils.device import device_placement

# 批量推理问题
def _process_batch(self, input: List, batch_size,
**kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')

# batch data
output_list = []
for i in range(0, len(input), batch_size):
end = min(i + batch_size, len(input))
real_batch_size = end - i
preprocessed_list = [
self.preprocess(i, **preprocess_params) for i in input[i:end]
]

with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
batched_out = self._batch(preprocessed_list)
if self._auto_collate:
batched_out = self._collate_fn(batched_out)
batched_out = self.forward(batched_out,
**forward_params)
else:
batched_out = self._batch(preprocessed_list)
batched_out = self.forward(batched_out, **forward_params)
model_name = kwargs.get("model_name")
# print("model_name", model_name)
if model_name=="batch_correct":
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
out[k] = element[batch_idx]
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)
else:
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
if isinstance(element[0], torch.Tensor):
out[k] = type(element)(
e[batch_idx:batch_idx + 1]
for e in element)
else:
# Compatible with traditional pipelines
out[k] = element[batch_idx]
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)

return output_list


Pipeline._process_batch = _process_batch
91 changes: 91 additions & 0 deletions pycorrector/mucgec_bart/mucgec_bart_corrector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
import os
import time
from typing import List

import torch
from loguru import logger
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import sys
sys.path.append('../..')
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from pycorrector.mucgec_bart.monkey_pack import Pipeline
from pycorrector.utils.sentence_utils import long_sentence_split
import difflib


class MuCGECBartCorrector:
def __init__(self, model_name_or_path: str = "damo/nlp_bart_text-error-correction_chinese"):
t1 = time.time()
self.model = pipeline(Tasks.text_error_correction, model=model_name_or_path)
logger.debug("Device: {}".format(device))
logger.debug('Loaded mucgec bart correction model: %s, spend: %.3f s.' % (model_name_or_path, time.time() - t1))

def _predict(self, sentences, batch_size=32, max_length=128, silent=True):
raise NotImplementedError


def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True, ignore_function=None):
"""
批量句子纠错
:param sentences: list[str], sentence list
:param max_length: int, max length of each sentence
:param batch_size: int, bz
:param silent: bool, show log
:param ignore_function: function, 自定义一个函数可以指定跳过某类错误, 无需训练模型
:return: list of dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
"""
result = self.model(sentences, batch_size=batch_size, model_name="batch_correct")
start_idx = 0
n = len(sentences)
data = []
result = [r["output"] for r in result]
for i in range(n):
a, b = sentences[i], result[i]
if len(a)==0 or len(b)==0 or a=="\n":
start_idx += len(a)
return
s = difflib.SequenceMatcher(None, a, b)
errors = []
offset = 0
for tag, i1, i2, j1, j2 in s.get_opcodes():
if tag!="equal":
e = [a[i1:i2], b[j1+offset:j2+offset], i1]
if ignore_function and ignore_function(e):
# 因为不认为是错误, 所以改回原来的偏移值
b = b[:j1] + a[i1:i2] + b[j2:]
offset += i2-i1-j2+j1
continue

errors.append(tuple(e))
data.append({"source": a, "target": b, "errors": errors})
return data


def correct(self, sentence: str, **kwargs):
"""长句改为短句, 可直接调用长文本"""
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128), period=kwargs.pop("period", None), comma=kwargs.pop("comma", None))
batch_results = self.correct_batch(sentences, **kwargs)
source, target, errors = "", "", []
for sr in batch_results:
ll = len(source)
source += sr["source"]
target += sr["target"]
for e in sr["errors"]:
# 改写位置
e = list(e)
e.append(e[-1])
e[2] += ll
errors.append(tuple(e))
return {"source": source, "target": target, "errors": errors, "sentences": batch_results}






Empty file.
96 changes: 96 additions & 0 deletions pycorrector/nasgec_bart/nasgec_bart_corrector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
import os
import time
from typing import List

import torch
from loguru import logger
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import sys
sys.path.append('../..')
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline
from pycorrector.utils.sentence_utils import long_sentence_split
import difflib


class NaSGECBartCorrector:
def __init__(self, model_name_or_path: str = "HillZhang/real_learner_bart_CGEC"):
# https://github.com/HillZhang1999/NaSGEC
t1 = time.time()
self.tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
self.model = BartForConditionalGeneration.from_pretrained(model_name_or_path)
logger.debug("Device: {}".format(device))
logger.debug('Loaded nasgec bart correction model: %s, spend: %.3f s.' % (model_name_or_path, time.time() - t1))

def _predict(self, sentences, batch_size=32, max_length=128, silent=True):
raise NotImplementedError


def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True, ignore_function=None):
"""
批量句子纠错
:param sentences: list[str], sentence list
:param max_length: int, max length of each sentence
:param batch_size: int, bz
:param silent: bool, show log
:param ignore_function: function, 自定义一个函数可以指定跳过某类错误, 无需训练模型
:return: list of dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]}
"""
encoded_input = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
if "token_type_ids" in encoded_input:
del encoded_input["token_type_ids"]
output = self.model.generate(**encoded_input)
result = self.tokenizer.batch_decode(output, skip_special_tokens=True)
start_idx = 0
n = len(sentences)
data = []
result = [r.replace(" ", "") for r in result]
print(result)
for i in range(n):
a, b = sentences[i], result[i]
if len(a)==0 or len(b)==0 or a=="\n":
start_idx += len(a)
return
s = difflib.SequenceMatcher(None, a, b)
errors = []
offset = 0
for tag, i1, i2, j1, j2 in s.get_opcodes():
if tag!="equal":
e = [a[i1:i2], b[j1+offset:j2+offset], i1]
if ignore_function and ignore_function(e):
# 因为不认为是错误, 所以改回原来的偏移值
b = b[:j1] + a[i1:i2] + b[j2:]
offset += i2-i1-j2+j1
continue

errors.append(tuple(e))
data.append({"source": a, "target": b, "errors": errors})
return data


def correct(self, sentence: str, **kwargs):
"""长句改为短句, 可直接调用长文本"""
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128), period=kwargs.pop("period", None), comma=kwargs.pop("comma", None))
batch_results = self.correct_batch(sentences, **kwargs)
source, target, errors = "", "", []
for sr in batch_results:
ll = len(source)
source += sr["source"]
target += sr["target"]
for e in sr["errors"]:
# 改写位置
e = list(e)
e.append(e[-1])
e[2] += ll
errors.append(tuple(e))
return {"source": source, "target": target, "errors": errors, "sentences": batch_results}






70 changes: 70 additions & 0 deletions pycorrector/utils/sentence_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import re

default_period = set(["。", "……", "!", "?", "?", "\n",])
default_comma = set([",", ","])


def is_not_chinese_error(e):
"""不是全中文的情况, 忽略这类错误"""
text = e[0]
if len(text)==0:
return True
for char in text:
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]')
if not chinese_char_pattern.match(char):
return True
return False


def long_sentence_split(text, max_length=128, period=None, comma=None):
"""
先按照 period切分再按照 comma切分, 最后减少句子数量再合并
"""
if period is None:
period = default_period
if comma is None:
comma = default_comma

def same_split(text, max_length=128):
"""
等长切分
"""
sentences = []
for i in range(0, len(text), max_length):
sentences.append(text[i:i + max_length])
return sentences

def get_sentences_by_punc(text, punc, max_length):
n, last = len(text), 0
sentences = []
if n <= max_length:
sentences.append(text)
else:
for i in range(n):
if text[i] in punc:
sentences.extend(same_split(text[last:i + 1], max_length=max_length))
last = i + 1
if last < n:
sentences.extend(same_split(text[last:], max_length=max_length))
return sentences

sentences = []

n, last = len(text), 0
for i in range(n):
if text[i] in period:
sentences.extend(get_sentences_by_punc(text[last:i + 1], comma, max_length=max_length))
last = i + 1
if last < n:
sentences.extend(get_sentences_by_punc(text[last:], comma, max_length=max_length))

new = []
cur = ""
for s in sentences:
if len(cur)+len(s)>max_length:
new.append(cur)
cur = ""
cur += s
if len(cur)>0:
new.append(cur)
return new
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ numpy
pandas
six
loguru
pyahocorasick
pyahocorasick
difflib
modelscope==1.16.0
fairseq==0.12.2
31 changes: 31 additions & 0 deletions tests/test_mucgec_bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys
import unittest

sys.path.append('..')
from pycorrector import MuCGECBartCorrector
from pycorrector.utils.sentence_utils import is_not_chinese_error


m = MuCGECBartCorrector()


class MyTestCase(unittest.TestCase):
def test1(self):
sents = ["北京是中国的都。", "他说:”我最爱的运动是打蓝球“", "我每天大约喝5次水左右。", "今天,我非常开开心。"]
res = m.correct_batch(sents)

self.assertEqual(res[0]['target'], '北京是中国的首都。')
self.assertEqual(res[1]['target'], '他说:“我最爱的运动是打篮球”')
self.assertEqual(res[2]['target'], '我每天大约喝5杯水左右。')
self.assertEqual(res[3]['target'], '今天,我非常开心。')


def test2(self):
long_text = "在一个充满生活热闹和忙碌的城市中,有一个年轻人名叫李华。他生活在北京,这座充满着现代化建筑和繁忙街道的都市。每天,他都要穿行在拥挤的人群中,追逐着自己的梦想和生活节奏。\n\n李华从小就听祖辈讲述关于福气和努力的故事。他相信,“这洋的话,下一年的福气来到自己身上”。因此,尽管每天都很忙碌,他总是尽力保持乐观和积极。\n\n某天早晨,李华骑着自行车准备去上班。北京的交通总是非常繁忙,尤其是在早高峰时段。他经过一个交通路口,看到至少两个交警正在维持交通秩序。这些交警穿着整齐的制服,手势有序而又果断,让整个路口的车辆有条不紊地行驶着。这让李华想起了他父亲曾经告诫过他的话:“在拥挤的时间里,为了让人们遵守交通规则,至少要派两个警察或者交通管理者。”\n\n李华心中感慨万千,他想要在自己的生活中也如此积极地影响他人。他虽然只是一名普通的白领,却希望能够通过自己的努力和行动,为这座城市的安全与和谐贡献一份力量。\n\n随着时间的推移,中国的经济不断发展,北京的建设也日益繁荣。李华所在的公司也因为他的努力和创新精神而蓬勃发展。他喜欢打篮球,每周都会和朋友们一起去运动场,放松身心。他也十分重视健康,每天都保持适量的饮水量,大约喝五次左右。\n\n今天,李华觉得格外开心。他意识到,自己虽然只是一个普通人,却通过日复一日的努力,终于在生活中找到了属于自己的那份福气。他明白了祖辈们口中的那句话的含义——“这洋的话,下一年的福气来到自己身上”,并且深信不疑。\n\n在这个充满希望和机遇的时代里,李华将继续努力工作,为自己的梦想而奋斗,也希望能够在这座城市中留下自己的一份足迹,为他人带来更多的希望和正能量。\n\n这就是李华的故事,一个在现代城市中追寻梦想和福气的普通青年。"
result = m.correct(long_text, ignore_function=is_not_chinese_error)
for e in result["errors"]:
self.assertEqual(result["source"][e[2]], e[0])


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 598b5cc

Please sign in to comment.