-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #527 from smartmark-pro/feature-better-bart
feat:添加HillZhang1999的两个开源纠错模型
- Loading branch information
Showing
11 changed files
with
455 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
from modelscope.pipelines import Pipeline | ||
from typing import Any, Dict, List | ||
from modelscope.utils.constant import Frameworks | ||
from modelscope.utils.device import device_placement | ||
|
||
# 批量推理问题 | ||
def _process_batch(self, input: List, batch_size, | ||
**kwargs) -> Dict[str, Any]: | ||
preprocess_params = kwargs.get('preprocess_params') | ||
forward_params = kwargs.get('forward_params') | ||
postprocess_params = kwargs.get('postprocess_params') | ||
|
||
# batch data | ||
output_list = [] | ||
for i in range(0, len(input), batch_size): | ||
end = min(i + batch_size, len(input)) | ||
real_batch_size = end - i | ||
preprocessed_list = [ | ||
self.preprocess(i, **preprocess_params) for i in input[i:end] | ||
] | ||
|
||
with device_placement(self.framework, self.device_name): | ||
if self.framework == Frameworks.torch: | ||
with torch.no_grad(): | ||
batched_out = self._batch(preprocessed_list) | ||
if self._auto_collate: | ||
batched_out = self._collate_fn(batched_out) | ||
batched_out = self.forward(batched_out, | ||
**forward_params) | ||
else: | ||
batched_out = self._batch(preprocessed_list) | ||
batched_out = self.forward(batched_out, **forward_params) | ||
model_name = kwargs.get("model_name") | ||
# print("model_name", model_name) | ||
if model_name=="batch_correct": | ||
for batch_idx in range(real_batch_size): | ||
out = {} | ||
for k, element in batched_out.items(): | ||
if element is not None: | ||
if isinstance(element, (tuple, list)): | ||
out[k] = element[batch_idx] | ||
else: | ||
out[k] = element[batch_idx:batch_idx + 1] | ||
out = self.postprocess(out, **postprocess_params) | ||
self._check_output(out) | ||
output_list.append(out) | ||
else: | ||
for batch_idx in range(real_batch_size): | ||
out = {} | ||
for k, element in batched_out.items(): | ||
if element is not None: | ||
if isinstance(element, (tuple, list)): | ||
if isinstance(element[0], torch.Tensor): | ||
out[k] = type(element)( | ||
e[batch_idx:batch_idx + 1] | ||
for e in element) | ||
else: | ||
# Compatible with traditional pipelines | ||
out[k] = element[batch_idx] | ||
else: | ||
out[k] = element[batch_idx:batch_idx + 1] | ||
out = self.postprocess(out, **postprocess_params) | ||
self._check_output(out) | ||
output_list.append(out) | ||
|
||
return output_list | ||
|
||
|
||
Pipeline._process_batch = _process_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
import time | ||
from typing import List | ||
|
||
import torch | ||
from loguru import logger | ||
from tqdm import tqdm | ||
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | ||
import sys | ||
sys.path.append('../..') | ||
from modelscope.pipelines import pipeline | ||
from modelscope.utils.constant import Tasks | ||
from pycorrector.mucgec_bart.monkey_pack import Pipeline | ||
from pycorrector.utils.sentence_utils import long_sentence_split | ||
import difflib | ||
|
||
|
||
class MuCGECBartCorrector: | ||
def __init__(self, model_name_or_path: str = "damo/nlp_bart_text-error-correction_chinese"): | ||
t1 = time.time() | ||
self.model = pipeline(Tasks.text_error_correction, model=model_name_or_path) | ||
logger.debug("Device: {}".format(device)) | ||
logger.debug('Loaded mucgec bart correction model: %s, spend: %.3f s.' % (model_name_or_path, time.time() - t1)) | ||
|
||
def _predict(self, sentences, batch_size=32, max_length=128, silent=True): | ||
raise NotImplementedError | ||
|
||
|
||
def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True, ignore_function=None): | ||
""" | ||
批量句子纠错 | ||
:param sentences: list[str], sentence list | ||
:param max_length: int, max length of each sentence | ||
:param batch_size: int, bz | ||
:param silent: bool, show log | ||
:param ignore_function: function, 自定义一个函数可以指定跳过某类错误, 无需训练模型 | ||
:return: list of dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]} | ||
""" | ||
result = self.model(sentences, batch_size=batch_size, model_name="batch_correct") | ||
start_idx = 0 | ||
n = len(sentences) | ||
data = [] | ||
result = [r["output"] for r in result] | ||
for i in range(n): | ||
a, b = sentences[i], result[i] | ||
if len(a)==0 or len(b)==0 or a=="\n": | ||
start_idx += len(a) | ||
return | ||
s = difflib.SequenceMatcher(None, a, b) | ||
errors = [] | ||
offset = 0 | ||
for tag, i1, i2, j1, j2 in s.get_opcodes(): | ||
if tag!="equal": | ||
e = [a[i1:i2], b[j1+offset:j2+offset], i1] | ||
if ignore_function and ignore_function(e): | ||
# 因为不认为是错误, 所以改回原来的偏移值 | ||
b = b[:j1] + a[i1:i2] + b[j2:] | ||
offset += i2-i1-j2+j1 | ||
continue | ||
|
||
errors.append(tuple(e)) | ||
data.append({"source": a, "target": b, "errors": errors}) | ||
return data | ||
|
||
|
||
def correct(self, sentence: str, **kwargs): | ||
"""长句改为短句, 可直接调用长文本""" | ||
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128), period=kwargs.pop("period", None), comma=kwargs.pop("comma", None)) | ||
batch_results = self.correct_batch(sentences, **kwargs) | ||
source, target, errors = "", "", [] | ||
for sr in batch_results: | ||
ll = len(source) | ||
source += sr["source"] | ||
target += sr["target"] | ||
for e in sr["errors"]: | ||
# 改写位置 | ||
e = list(e) | ||
e.append(e[-1]) | ||
e[2] += ll | ||
errors.append(tuple(e)) | ||
return {"source": source, "target": target, "errors": errors, "sentences": batch_results} | ||
|
||
|
||
|
||
|
||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
import time | ||
from typing import List | ||
|
||
import torch | ||
from loguru import logger | ||
from tqdm import tqdm | ||
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | ||
import sys | ||
sys.path.append('../..') | ||
from transformers import BertTokenizer, BartForConditionalGeneration, Text2TextGenerationPipeline | ||
from pycorrector.utils.sentence_utils import long_sentence_split | ||
import difflib | ||
|
||
|
||
class NaSGECBartCorrector: | ||
def __init__(self, model_name_or_path: str = "HillZhang/real_learner_bart_CGEC"): | ||
# https://github.com/HillZhang1999/NaSGEC | ||
t1 = time.time() | ||
self.tokenizer = BertTokenizer.from_pretrained(model_name_or_path) | ||
self.model = BartForConditionalGeneration.from_pretrained(model_name_or_path) | ||
logger.debug("Device: {}".format(device)) | ||
logger.debug('Loaded nasgec bart correction model: %s, spend: %.3f s.' % (model_name_or_path, time.time() - t1)) | ||
|
||
def _predict(self, sentences, batch_size=32, max_length=128, silent=True): | ||
raise NotImplementedError | ||
|
||
|
||
def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True, ignore_function=None): | ||
""" | ||
批量句子纠错 | ||
:param sentences: list[str], sentence list | ||
:param max_length: int, max length of each sentence | ||
:param batch_size: int, bz | ||
:param silent: bool, show log | ||
:param ignore_function: function, 自定义一个函数可以指定跳过某类错误, 无需训练模型 | ||
:return: list of dict, {'source': 'src', 'target': 'trg', 'errors': [(error_word, correct_word, position), ...]} | ||
""" | ||
encoded_input = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True) | ||
if "token_type_ids" in encoded_input: | ||
del encoded_input["token_type_ids"] | ||
output = self.model.generate(**encoded_input) | ||
result = self.tokenizer.batch_decode(output, skip_special_tokens=True) | ||
start_idx = 0 | ||
n = len(sentences) | ||
data = [] | ||
result = [r.replace(" ", "") for r in result] | ||
print(result) | ||
for i in range(n): | ||
a, b = sentences[i], result[i] | ||
if len(a)==0 or len(b)==0 or a=="\n": | ||
start_idx += len(a) | ||
return | ||
s = difflib.SequenceMatcher(None, a, b) | ||
errors = [] | ||
offset = 0 | ||
for tag, i1, i2, j1, j2 in s.get_opcodes(): | ||
if tag!="equal": | ||
e = [a[i1:i2], b[j1+offset:j2+offset], i1] | ||
if ignore_function and ignore_function(e): | ||
# 因为不认为是错误, 所以改回原来的偏移值 | ||
b = b[:j1] + a[i1:i2] + b[j2:] | ||
offset += i2-i1-j2+j1 | ||
continue | ||
|
||
errors.append(tuple(e)) | ||
data.append({"source": a, "target": b, "errors": errors}) | ||
return data | ||
|
||
|
||
def correct(self, sentence: str, **kwargs): | ||
"""长句改为短句, 可直接调用长文本""" | ||
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128), period=kwargs.pop("period", None), comma=kwargs.pop("comma", None)) | ||
batch_results = self.correct_batch(sentences, **kwargs) | ||
source, target, errors = "", "", [] | ||
for sr in batch_results: | ||
ll = len(source) | ||
source += sr["source"] | ||
target += sr["target"] | ||
for e in sr["errors"]: | ||
# 改写位置 | ||
e = list(e) | ||
e.append(e[-1]) | ||
e[2] += ll | ||
errors.append(tuple(e)) | ||
return {"source": source, "target": target, "errors": errors, "sentences": batch_results} | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import re | ||
|
||
default_period = set(["。", "……", "!", "?", "?", "\n",]) | ||
default_comma = set([",", ","]) | ||
|
||
|
||
def is_not_chinese_error(e): | ||
"""不是全中文的情况, 忽略这类错误""" | ||
text = e[0] | ||
if len(text)==0: | ||
return True | ||
for char in text: | ||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]') | ||
if not chinese_char_pattern.match(char): | ||
return True | ||
return False | ||
|
||
|
||
def long_sentence_split(text, max_length=128, period=None, comma=None): | ||
""" | ||
先按照 period切分再按照 comma切分, 最后减少句子数量再合并 | ||
""" | ||
if period is None: | ||
period = default_period | ||
if comma is None: | ||
comma = default_comma | ||
|
||
def same_split(text, max_length=128): | ||
""" | ||
等长切分 | ||
""" | ||
sentences = [] | ||
for i in range(0, len(text), max_length): | ||
sentences.append(text[i:i + max_length]) | ||
return sentences | ||
|
||
def get_sentences_by_punc(text, punc, max_length): | ||
n, last = len(text), 0 | ||
sentences = [] | ||
if n <= max_length: | ||
sentences.append(text) | ||
else: | ||
for i in range(n): | ||
if text[i] in punc: | ||
sentences.extend(same_split(text[last:i + 1], max_length=max_length)) | ||
last = i + 1 | ||
if last < n: | ||
sentences.extend(same_split(text[last:], max_length=max_length)) | ||
return sentences | ||
|
||
sentences = [] | ||
|
||
n, last = len(text), 0 | ||
for i in range(n): | ||
if text[i] in period: | ||
sentences.extend(get_sentences_by_punc(text[last:i + 1], comma, max_length=max_length)) | ||
last = i + 1 | ||
if last < n: | ||
sentences.extend(get_sentences_by_punc(text[last:], comma, max_length=max_length)) | ||
|
||
new = [] | ||
cur = "" | ||
for s in sentences: | ||
if len(cur)+len(s)>max_length: | ||
new.append(cur) | ||
cur = "" | ||
cur += s | ||
if len(cur)>0: | ||
new.append(cur) | ||
return new |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,7 @@ numpy | |
pandas | ||
six | ||
loguru | ||
pyahocorasick | ||
pyahocorasick | ||
difflib | ||
modelscope==1.16.0 | ||
fairseq==0.12.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import sys | ||
import unittest | ||
|
||
sys.path.append('..') | ||
from pycorrector import MuCGECBartCorrector | ||
from pycorrector.utils.sentence_utils import is_not_chinese_error | ||
|
||
|
||
m = MuCGECBartCorrector() | ||
|
||
|
||
class MyTestCase(unittest.TestCase): | ||
def test1(self): | ||
sents = ["北京是中国的都。", "他说:”我最爱的运动是打蓝球“", "我每天大约喝5次水左右。", "今天,我非常开开心。"] | ||
res = m.correct_batch(sents) | ||
|
||
self.assertEqual(res[0]['target'], '北京是中国的首都。') | ||
self.assertEqual(res[1]['target'], '他说:“我最爱的运动是打篮球”') | ||
self.assertEqual(res[2]['target'], '我每天大约喝5杯水左右。') | ||
self.assertEqual(res[3]['target'], '今天,我非常开心。') | ||
|
||
|
||
def test2(self): | ||
long_text = "在一个充满生活热闹和忙碌的城市中,有一个年轻人名叫李华。他生活在北京,这座充满着现代化建筑和繁忙街道的都市。每天,他都要穿行在拥挤的人群中,追逐着自己的梦想和生活节奏。\n\n李华从小就听祖辈讲述关于福气和努力的故事。他相信,“这洋的话,下一年的福气来到自己身上”。因此,尽管每天都很忙碌,他总是尽力保持乐观和积极。\n\n某天早晨,李华骑着自行车准备去上班。北京的交通总是非常繁忙,尤其是在早高峰时段。他经过一个交通路口,看到至少两个交警正在维持交通秩序。这些交警穿着整齐的制服,手势有序而又果断,让整个路口的车辆有条不紊地行驶着。这让李华想起了他父亲曾经告诫过他的话:“在拥挤的时间里,为了让人们遵守交通规则,至少要派两个警察或者交通管理者。”\n\n李华心中感慨万千,他想要在自己的生活中也如此积极地影响他人。他虽然只是一名普通的白领,却希望能够通过自己的努力和行动,为这座城市的安全与和谐贡献一份力量。\n\n随着时间的推移,中国的经济不断发展,北京的建设也日益繁荣。李华所在的公司也因为他的努力和创新精神而蓬勃发展。他喜欢打篮球,每周都会和朋友们一起去运动场,放松身心。他也十分重视健康,每天都保持适量的饮水量,大约喝五次左右。\n\n今天,李华觉得格外开心。他意识到,自己虽然只是一个普通人,却通过日复一日的努力,终于在生活中找到了属于自己的那份福气。他明白了祖辈们口中的那句话的含义——“这洋的话,下一年的福气来到自己身上”,并且深信不疑。\n\n在这个充满希望和机遇的时代里,李华将继续努力工作,为自己的梦想而奋斗,也希望能够在这座城市中留下自己的一份足迹,为他人带来更多的希望和正能量。\n\n这就是李华的故事,一个在现代城市中追寻梦想和福气的普通青年。" | ||
result = m.correct(long_text, ignore_function=is_not_chinese_error) | ||
for e in result["errors"]: | ||
self.assertEqual(result["source"][e[2]], e[0]) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.