Skip to content

Commit

Permalink
update train dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 10, 2022
1 parent 077e33f commit 9cff54d
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 15 deletions.
1 change: 0 additions & 1 deletion pycorrector/seq2seq/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""
import os
import sys
from codecs import open
from xml.dom import minidom

from sklearn.model_selection import train_test_split
Expand Down
9 changes: 0 additions & 9 deletions pycorrector/seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,6 @@ def train_model(
if args:
self.args.update_from_dict(args)

# if self.args.silent:
# show_running_loss = False

if self.args.evaluate_during_training and eval_data is None:
raise ValueError(
"evaluate_during_training is enabled but eval_data is not specified."
Expand Down Expand Up @@ -344,12 +341,6 @@ def train_model(

self.save_model(self.args.output_dir, model=self.model)

# model_to_save = self.model.module if hasattr(self.model, "module") else self.model
# model_to_save.save_pretrained(output_dir)
# self.encoder_tokenizer.save_pretrained(output_dir)
# self.decoder_tokenizer.save_pretrained(output_dir)
# torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

if verbose:
logger.info(" Training of {} model complete. Saved to {}.".format(self.args.model_name, output_dir))

Expand Down
2 changes: 1 addition & 1 deletion pycorrector/seq2seq/seq2seq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

if transformers.__version__ < "4.2.0":
shift_tokens_right = lambda input_ids, pad_token_id, decoder_start_token_id: _shift_tokens_right(
input_ids, pad_token_id)
input_ids, pad_token_id, decoder_start_token_id)
else:
shift_tokens_right = _shift_tokens_right

Expand Down
2 changes: 1 addition & 1 deletion pycorrector/seq2seq/tf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
@author:XuMing([email protected])
@description:
""" Use CGED corpus
"""
import os

pwd_path = os.path.abspath(os.path.dirname(__file__))
Expand Down
3 changes: 1 addition & 2 deletions pycorrector/seq2seq/tf/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
"""
@author:XuMing([email protected])
@description:
""" Corpus for model

"""
import sys
from codecs import open
from collections import Counter
Expand Down
2 changes: 1 addition & 1 deletion pycorrector/seq2seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pycorrector.utils.logger import logger
from pycorrector.seq2seq.seq2seq_model import Seq2SeqModel

os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Expand Down Expand Up @@ -262,7 +263,6 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs

train_df = pd.DataFrame(train_data, columns=['input_text', 'target_text'])
dev_df = pd.DataFrame(dev_data, columns=['input_text', 'target_text'])
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
model.train_model(train_df, eval_data=dev_df)
else:
logger.error('error arch: {}'.format(arch))
Expand Down

0 comments on commit 9cff54d

Please sign in to comment.