diff --git a/pycorrector/seq2seq/preprocess.py b/pycorrector/seq2seq/preprocess.py index d07611d8..4d063cfc 100644 --- a/pycorrector/seq2seq/preprocess.py +++ b/pycorrector/seq2seq/preprocess.py @@ -5,7 +5,6 @@ """ import os import sys -from codecs import open from xml.dom import minidom from sklearn.model_selection import train_test_split diff --git a/pycorrector/seq2seq/seq2seq_model.py b/pycorrector/seq2seq/seq2seq_model.py index 2501bfe9..db451b28 100644 --- a/pycorrector/seq2seq/seq2seq_model.py +++ b/pycorrector/seq2seq/seq2seq_model.py @@ -309,9 +309,6 @@ def train_model( if args: self.args.update_from_dict(args) - # if self.args.silent: - # show_running_loss = False - if self.args.evaluate_during_training and eval_data is None: raise ValueError( "evaluate_during_training is enabled but eval_data is not specified." @@ -344,12 +341,6 @@ def train_model( self.save_model(self.args.output_dir, model=self.model) - # model_to_save = self.model.module if hasattr(self.model, "module") else self.model - # model_to_save.save_pretrained(output_dir) - # self.encoder_tokenizer.save_pretrained(output_dir) - # self.decoder_tokenizer.save_pretrained(output_dir) - # torch.save(self.args, os.path.join(output_dir, "training_args.bin")) - if verbose: logger.info(" Training of {} model complete. Saved to {}.".format(self.args.model_name, output_dir)) diff --git a/pycorrector/seq2seq/seq2seq_utils.py b/pycorrector/seq2seq/seq2seq_utils.py index 56f85fff..84fad1ae 100644 --- a/pycorrector/seq2seq/seq2seq_utils.py +++ b/pycorrector/seq2seq/seq2seq_utils.py @@ -20,7 +20,7 @@ if transformers.__version__ < "4.2.0": shift_tokens_right = lambda input_ids, pad_token_id, decoder_start_token_id: _shift_tokens_right( - input_ids, pad_token_id) + input_ids, pad_token_id, decoder_start_token_id) else: shift_tokens_right = _shift_tokens_right diff --git a/pycorrector/seq2seq/tf/config.py b/pycorrector/seq2seq/tf/config.py index 796872db..305a1bcc 100644 --- a/pycorrector/seq2seq/tf/config.py +++ b/pycorrector/seq2seq/tf/config.py @@ -2,7 +2,7 @@ """ @author:XuMing(xuming624@qq.com) @description: -""" Use CGED corpus +""" import os pwd_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/pycorrector/seq2seq/tf/data_reader.py b/pycorrector/seq2seq/tf/data_reader.py index 75de5d8a..3b25dcbb 100644 --- a/pycorrector/seq2seq/tf/data_reader.py +++ b/pycorrector/seq2seq/tf/data_reader.py @@ -2,8 +2,7 @@ """ @author:XuMing(xuming624@qq.com) @description: -""" Corpus for model - +""" import sys from codecs import open from collections import Counter diff --git a/pycorrector/seq2seq/train.py b/pycorrector/seq2seq/train.py index 5826bb24..9a7cd7d5 100644 --- a/pycorrector/seq2seq/train.py +++ b/pycorrector/seq2seq/train.py @@ -29,6 +29,7 @@ from pycorrector.utils.logger import logger from pycorrector.seq2seq.seq2seq_model import Seq2SeqModel +os.environ["TOKENIZERS_PARALLELISM"] = "FALSE" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -262,7 +263,6 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs train_df = pd.DataFrame(train_data, columns=['input_text', 'target_text']) dev_df = pd.DataFrame(dev_data, columns=['input_text', 'target_text']) - os.environ["TOKENIZERS_PARALLELISM"] = "FALSE" model.train_model(train_df, eval_data=dev_df) else: logger.error('error arch: {}'.format(arch))