From 4f1fb5b2ff939b4d5ff3d876c98783c07d8b8a92 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Tue, 10 May 2022 15:00:18 +0800 Subject: [PATCH] update train. --- pycorrector/seq2seq/train.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/pycorrector/seq2seq/train.py b/pycorrector/seq2seq/train.py index 9a7cd7d5..0439956b 100644 --- a/pycorrector/seq2seq/train.py +++ b/pycorrector/seq2seq/train.py @@ -178,7 +178,7 @@ def train_convseq2seq_model(model, train_data, device, loss_fn, optimizer, model def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs, src_vocab_path, trg_vocab_path, model_dir, max_length, use_segment, model_name_or_path): - print("device: {}".format(device)) + logger.info("device: {}".format(device)) arch = arch.lower() os.makedirs(model_dir, exist_ok=True) if arch in ['seq2seq', 'convseq2seq']: @@ -195,9 +195,8 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs id_2_trgs = {v: k for k, v in trg_2_ids.items()} train_src, train_trg = one_hot(source_texts, target_texts, src_2_ids, trg_2_ids, sort_by_len=True) - k = 0 - print('src:', ' '.join([id_2_srcs[i] for i in train_src[k]])) - print('trg:', ' '.join([id_2_trgs[i] for i in train_trg[k]])) + logger.debug(f'src: {[id_2_srcs[i] for i in train_src[0]]}') + logger.debug(f'trg: {[id_2_trgs[i] for i in train_trg[0]]}') train_data = gen_examples(train_src, train_trg, batch_size, max_length) @@ -209,7 +208,7 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs enc_hidden_size=hidden_size, dec_hidden_size=hidden_size, dropout=dropout).to(device) - print(model) + logger.info(model) loss_fn = LanguageModelCriterion().to(device) optimizer = torch.optim.Adam(model.parameters()) @@ -226,7 +225,7 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs trg_pad_idx=trg_pad_idx, device=device, max_length=max_length).to(device) - print(model) + logger.info(model) loss_fn = nn.CrossEntropyLoss(ignore_index=trg_pad_idx) optimizer = torch.optim.Adam(model.parameters()) @@ -257,13 +256,23 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs # encoder_name="bert-base-chinese" model = Seq2SeqModel("bert", model_name_or_path, model_name_or_path, args=model_args, use_cuda=use_cuda) - print('start train bertseq2seq ...') + logger.info('start train bertseq2seq ...') data = load_bert_data(train_path, use_segment) + logger.info(f'load data done, data size: {len(data)}') + logger.debug(f'data samples: {data[:10]}') train_data, dev_data = train_test_split(data, test_size=0.1, shuffle=True) train_df = pd.DataFrame(train_data, columns=['input_text', 'target_text']) dev_df = pd.DataFrame(dev_data, columns=['input_text', 'target_text']) - model.train_model(train_df, eval_data=dev_df) + + def count_matches(labels, preds): + logger.debug(f"labels: {labels[:10]}") + logger.debug(f"preds: {preds[:10]}") + match = sum([1 if label == pred else 0 for label, pred in zip(labels, preds)]) + logger.debug(f"match: {match}") + return match + + model.train_model(train_df, eval_data=dev_df, matches=count_matches) else: logger.error('error arch: {}'.format(arch)) raise ValueError("Model arch choose error. Must use one of seq2seq model.")