diff --git a/examples/seq2seq_demo.py b/examples/seq2seq_demo.py index 2f024c17..78c3a240 100644 --- a/examples/seq2seq_demo.py +++ b/examples/seq2seq_demo.py @@ -34,7 +34,7 @@ def main(): ) parser.add_argument("--model_dir", default="output/bertseq2seq/", type=str, help="Dir for model save.") parser.add_argument("--arch", - default="bertseq2seq", type=str, + default="convseq2seq", type=str, help="The name of the task to train selected in the list: " + ", ".join( ['seq2seq', 'convseq2seq', 'bertseq2seq']), ) @@ -52,7 +52,7 @@ def main(): parser.add_argument("--embed_size", default=128, type=int, help="Embedding size.") parser.add_argument("--hidden_size", default=128, type=int, help="Hidden size.") parser.add_argument("--dropout", default=0.25, type=float, help="Dropout rate.") - parser.add_argument("--epochs", default=10, type=int, help="Epoch num.") + parser.add_argument("--epochs", default=200, type=int, help="Epoch num.") args = parser.parse_args() print(args) diff --git a/pycorrector/seq2seq/README.md b/pycorrector/seq2seq/README.md index 3143c6c1..4d51b664 100644 --- a/pycorrector/seq2seq/README.md +++ b/pycorrector/seq2seq/README.md @@ -22,7 +22,7 @@ tensorboardX ## Demo -- bertseq2seq demo +- convseq2seq demo 示例[seq2seq_demo.py](../../examples/seq2seq_demo.py) ``` @@ -93,3 +93,7 @@ predict: 王天华开心地一直说话。 ``` python preprocess.py ``` + +### release models + +基于SIGHAN2015数据集训练的seq2seq和convseq2seq模型,已经release到github,通过[github models]()获取。 diff --git a/pycorrector/seq2seq/seq2seq_model.py b/pycorrector/seq2seq/seq2seq_model.py index db451b28..08189b95 100644 --- a/pycorrector/seq2seq/seq2seq_model.py +++ b/pycorrector/seq2seq/seq2seq_model.py @@ -520,7 +520,7 @@ def train( if args.model_name and os.path.exists(args.model_name): try: - # set global_step to gobal_step of last saved checkpoint from model path + # set global_step to global_step of last saved checkpoint from model path checkpoint_suffix = args.model_name.split("/")[-1].split("-") if len(checkpoint_suffix) > 2: checkpoint_suffix = checkpoint_suffix[1] diff --git a/pycorrector/version.py b/pycorrector/version.py index 61998de9..f36fb9ee 100644 --- a/pycorrector/version.py +++ b/pycorrector/version.py @@ -3,4 +3,4 @@ @author:XuMing(xuming624@qq.com) @description: version """ -__version__ = '0.4.4' +__version__ = '0.4.5'