Skip to content

Commit

Permalink
update train.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed May 10, 2022
1 parent 9cff54d commit 4f1fb5b
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions pycorrector/seq2seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def train_convseq2seq_model(model, train_data, device, loss_fn, optimizer, model

def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs,
src_vocab_path, trg_vocab_path, model_dir, max_length, use_segment, model_name_or_path):
print("device: {}".format(device))
logger.info("device: {}".format(device))
arch = arch.lower()
os.makedirs(model_dir, exist_ok=True)
if arch in ['seq2seq', 'convseq2seq']:
Expand All @@ -195,9 +195,8 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs
id_2_trgs = {v: k for k, v in trg_2_ids.items()}
train_src, train_trg = one_hot(source_texts, target_texts, src_2_ids, trg_2_ids, sort_by_len=True)

k = 0
print('src:', ' '.join([id_2_srcs[i] for i in train_src[k]]))
print('trg:', ' '.join([id_2_trgs[i] for i in train_trg[k]]))
logger.debug(f'src: {[id_2_srcs[i] for i in train_src[0]]}')
logger.debug(f'trg: {[id_2_trgs[i] for i in train_trg[0]]}')

train_data = gen_examples(train_src, train_trg, batch_size, max_length)

Expand All @@ -209,7 +208,7 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs
enc_hidden_size=hidden_size,
dec_hidden_size=hidden_size,
dropout=dropout).to(device)
print(model)
logger.info(model)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())

Expand All @@ -226,7 +225,7 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs
trg_pad_idx=trg_pad_idx,
device=device,
max_length=max_length).to(device)
print(model)
logger.info(model)
loss_fn = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
optimizer = torch.optim.Adam(model.parameters())

Expand Down Expand Up @@ -257,13 +256,23 @@ def train(arch, train_path, batch_size, embed_size, hidden_size, dropout, epochs
# encoder_name="bert-base-chinese"
model = Seq2SeqModel("bert", model_name_or_path, model_name_or_path, args=model_args, use_cuda=use_cuda)

print('start train bertseq2seq ...')
logger.info('start train bertseq2seq ...')
data = load_bert_data(train_path, use_segment)
logger.info(f'load data done, data size: {len(data)}')
logger.debug(f'data samples: {data[:10]}')
train_data, dev_data = train_test_split(data, test_size=0.1, shuffle=True)

train_df = pd.DataFrame(train_data, columns=['input_text', 'target_text'])
dev_df = pd.DataFrame(dev_data, columns=['input_text', 'target_text'])
model.train_model(train_df, eval_data=dev_df)

def count_matches(labels, preds):
logger.debug(f"labels: {labels[:10]}")
logger.debug(f"preds: {preds[:10]}")
match = sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])
logger.debug(f"match: {match}")
return match

model.train_model(train_df, eval_data=dev_df, matches=count_matches)
else:
logger.error('error arch: {}'.format(arch))
raise ValueError("Model arch choose error. Must use one of seq2seq model.")
Expand Down

0 comments on commit 4f1fb5b

Please sign in to comment.