Skip to content

Commit

Permalink
Fix error to load data at the correct position when resuming from a c…
Browse files Browse the repository at this point in the history
…heckpoint
  • Loading branch information
vince62s authored and Thai Chau Truong committed Mar 28, 2024
1 parent c570639 commit 64a246b
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 8 deletions.
13 changes: 11 additions & 2 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info={},
data_type="text",
bucket_size=2048,
bucket_size_init=-1,
Expand All @@ -144,6 +145,7 @@ def __init__(
self.transforms = transforms
self.vocabs = vocabs
self.corpora_info = corpora_info
self.resume_corpora_info = resume_corpora_info
self.task = task
self.init_iterators = False
self.batch_size = batch_size
Expand Down Expand Up @@ -171,7 +173,8 @@ def __init__(

@classmethod
def from_opt(
cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0
cls, corpora, transforms, vocabs, opt, task, copy, device,
resume_corpora_info={}, stride=1, offset=0
):
"""Initilize `DynamicDatasetIter` with options parsed from `opt`."""
corpora_info = {}
Expand Down Expand Up @@ -206,6 +209,7 @@ def from_opt(
opt.batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info=resume_corpora_info,
data_type=opt.data_type,
bucket_size=bucket_size,
bucket_size_init=bucket_size_init,
Expand Down Expand Up @@ -388,6 +392,7 @@ def build_dynamic_dataset_iter(
vocabs,
copy=False,
task=CorpusTask.TRAIN,
resume_corpora_info={},
stride=1,
offset=0,
src=None,
Expand All @@ -412,7 +417,10 @@ def build_dynamic_dataset_iter(
advance to avoid the GPU waiting during the refilling of the bucket.
"""
transforms = make_transforms(opt, transforms_cls, vocabs)
corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align)
corpora = get_corpora(
opt, task, src=src, tgt=tgt, align=align,
resume_corpora_info=resume_corpora_info
)
if corpora is None:
assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
return None
Expand Down Expand Up @@ -442,6 +450,7 @@ def build_dynamic_dataset_iter(
vocabs,
opt,
task,
resume_corpora_info=resume_corpora_info,
copy=copy,
stride=stride,
offset=offset,
Expand Down
28 changes: 25 additions & 3 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None, resumed_line=0
):
"""Initialize src & tgt side file path."""
self.id = name
Expand All @@ -108,6 +108,12 @@ def __init__(
self.align = align
self.n_src_feats = n_src_feats
self.src_feats_defaults = src_feats_defaults
self.resumed_line = resumed_line
self.can_read_file = False

def activate_reading_mode(self, line_index):
if (line_index >= self.resumed_line):
self.can_read_file = True

def load(self, offset=0, stride=1):
"""
Expand Down Expand Up @@ -145,13 +151,19 @@ def make_ex(sline, tline, align):
for i, (sline, tline, align) in enumerate(
itertools.zip_longest(fs, ft, fa)
):
self.activate_reading_mode(line_index=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
yield make_ex(sline, tline, align)
else:
with exfile_open(self.src, mode="rb") as fs, exfile_open(
self.tgt, mode="rb"
) as ft, exfile_open(self.align, mode="rb") as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
self.activate_reading_mode(line_index=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
if tline is not None:
tline = tline.decode("utf-8")
Expand All @@ -169,19 +181,28 @@ def __str__(self):
)


def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
def get_corpora(
opts,
task=CorpusTask.TRAIN,
src=None, tgt=None, align=None,
resume_corpora_info={}
):
corpora_dict = {}
if task == CorpusTask.TRAIN:
for corpus_id, corpus_dict in opts.data.items():
if corpus_id != CorpusName.VALID:
if corpus_dict.get("path_txt", None) is None:
resume_line = 0
if (corpus_id in resume_corpora_info):
resume_line = resume_corpora_info[corpus_id]["cid_line_number"]
corpora_dict[corpus_id] = ParallelCorpus(
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
resumed_line=resume_line
)
else:
corpora_dict[corpus_id] = BlockwiseCorpus(
Expand Down Expand Up @@ -282,7 +303,8 @@ def __iter__(self):


def build_corpora_iters(
corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0
corpora, transforms, corpora_info,
skip_empty_level="warning", stride=1, offset=0,
):
"""Return `ParallelCorpusIterator` for all corpora defined in opts."""
corpora_iters = dict()
Expand Down
107 changes: 107 additions & 0 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import re
import subprocess
from collections import deque
from onmt.utils.logging import logger
from onmt.inputters.inputter import vocabs_to_dict
Expand All @@ -12,6 +13,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)

corpora_info_updater = CorpusInfoUpdater(opts=opt)
model_saver = ModelSaver(
opt.save_model,
model,
Expand All @@ -21,6 +23,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
opt.keep_checkpoint,
opt.save_format,
device_id,
corpora_info_updater
)
return model_saver

Expand Down Expand Up @@ -81,6 +84,97 @@ def fix_key(s):
return checkpoint


def load_corpora_info(opts, checkpoint):
message_resume_from_beginning = \
"The training will resume from the beginning of each corpus."
# Check if resume_from_corpora is True
if not opts.resume_from_corpora:
logger.info(
"No resume from corpora is specified. " + \
message_resume_from_beginning
)
return {}

# Check if the corpus list from the last training
# and in the new training are identical.
checkpoint_corpora = checkpoint.get("data", None)
if (checkpoint_corpora is None):
logger.info(
"Incoherent info: Some corpora in the last training " + \
"and in the new list do not match. " + \
message_resume_from_beginning
)
return {}

checkpoint_corpus_names = [name for name in checkpoint_corpora]
new_corpus_names = [name for name in opts.data]
if (set(checkpoint_corpus_names) != set(new_corpus_names)):
logger.info(
"Incoherent info: Some corpora in the last training " + \
"and in the new list do not match. " + \
message_resume_from_beginning
)
return {}

# For each corpus, check if the last line number to resume
# is smaller than or equal to the number of text lines.
message_incoherent_line_number = "Incoherent info: Some text line numbers " + \
"to resume do not exist or are greater than the total numbers of text lines. " + \
message_resume_from_beginning
corpora_info = {}
for c_name, corpus in checkpoint_corpora.items():
new_corpora_info = {}
if ("cid_line_number" not in corpus):
logger.info(message_incoherent_line_number)
return {}

new_corpora_info["cid_line_number"] = corpus["cid_line_number"]
number_of_text_lines = int(
subprocess.getoutput(
"wc -l " + \
opts.data[c_name]["path_src"] + \
" | awk '{print $1}'"
)
)
if (new_corpora_info["cid_line_number"] > number_of_text_lines-1):
logger.info(message_incoherent_line_number)
return {}

corpora_info[c_name] = new_corpora_info

logger.info(
"The training will resume from the saved text line in each corpus."
)
return corpora_info


class CorpusInfoUpdater(object):
def __init__(
self,
opts=None
):
self.opts = opts

def update_corpus_info_from_batches(self, batches):
# Update the last text line of each corpus
new_corpus_info = {}
for batch in batches:
for c_name, cid_line_number in zip(batch["cid"], batch["cid_line_number"]):
if (c_name not in new_corpus_info):
new_corpus_info[c_name] = cid_line_number
else:
new_corpus_info[c_name] = max(
new_corpus_info[c_name],
cid_line_number
)
for c_name, corpus in self.opts.data.items():
if (c_name in new_corpus_info):
corpus["cid_line_number"] = new_corpus_info[c_name]

def get_corpus_info_dict(self):
return {"data": self.opts.data}


class ModelSaverBase(object):
"""Base class for model saving operations
Expand All @@ -99,6 +193,7 @@ def __init__(
keep_checkpoint=-1,
save_format="pytorch",
device_id=0,
corpora_info_updater=None
):
self.base_path = base_path
self.model = model
Expand All @@ -109,6 +204,7 @@ def __init__(
self.keep_checkpoint = keep_checkpoint
self.save_format = save_format
self.device_id = device_id
self.corpora_info_updater = corpora_info_updater

if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
Expand Down Expand Up @@ -171,6 +267,15 @@ def _save(self, step, model):

raise NotImplementedError()

def update_corpora_info(self, batches):
if (self.corpora_info_updater is not None):
self.corpora_info_updater.update_corpus_info_from_batches(batches)

def get_corpora_info_to_save(self):
if (self.corpora_info_updater is not None):
return self.corpora_info_updater.get_corpus_info_dict()
return {}

def _rm_checkpoint(self, name):
"""Remove a checkpoint
Expand Down Expand Up @@ -267,6 +372,7 @@ def _save(self, step, model):
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}
checkpoint = {**checkpoint, **self.get_corpora_info_to_save()}
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
Expand Down Expand Up @@ -356,6 +462,7 @@ def _st_save(self, step, model):
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}
checkpoint = {**checkpoint, **self.get_corpora_info_to_save()}

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
Expand Down
7 changes: 7 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,13 @@ def _add_train_general_opts(parser):
help="If training from a checkpoint then this is the "
"path to the pretrained model's state_dict.",
)
group.add(
"--resume_from_corpora",
"-resume_from_corpora",
action="store_true",
help="If training from a checkpoint and this is set to True "
" then the data generator will resume from the last line of each corpora.",
)
group.add(
"--reset_optim",
"-reset_optim",
Expand Down
9 changes: 6 additions & 3 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.text_corpus import save_transformed_sample
from onmt.model_builder import build_model
from onmt.models.model_saver import load_checkpoint
from onmt.models.model_saver import load_checkpoint, load_corpora_info
from onmt.utils.optimizers import Optimizer
from onmt.utils.misc import set_random_seed
from onmt.trainer import build_trainer
Expand Down Expand Up @@ -80,6 +80,7 @@ def _init_train(opt):
if opt.train_from:
# Load checkpoint if we resume from a previous training.
checkpoint = load_checkpoint(ckpt_path=opt.train_from)
resume_corpora_info = load_corpora_info(opt, checkpoint)
vocabs = dict_to_vocabs(checkpoint["vocab"])
if (
hasattr(checkpoint["opt"], "_all_transform")
Expand All @@ -105,8 +106,9 @@ def _init_train(opt):
else:
checkpoint = None
vocabs = prepare_transforms_vocabs(opt, transforms_cls)
resume_corpora_info = {}

return checkpoint, vocabs, transforms_cls
return checkpoint, resume_corpora_info, vocabs, transforms_cls


def configure_process(opt, device_id):
Expand Down Expand Up @@ -159,7 +161,7 @@ def main(opt, device_id):

configure_process(opt, device_id)
init_logger(opt.log_file)
checkpoint, vocabs, transforms_cls = _init_train(opt)
checkpoint, resume_corpora_info, vocabs, transforms_cls = _init_train(opt)
model_opt = _get_model_opts(opt, checkpoint=checkpoint)

# Build model.
Expand Down Expand Up @@ -211,6 +213,7 @@ def main(opt, device_id):
transforms_cls,
vocabs,
task=CorpusTask.TRAIN,
resume_corpora_info=resume_corpora_info,
copy=opt.copy_attn,
stride=stride,
offset=offset,
Expand Down
2 changes: 2 additions & 0 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def train(
batches, normalization, total_stats, report_stats
)

self.model_saver.update_corpora_info(batches)

if self.average_decay > 0 and i % self.average_every == 0:
self._update_average(step)

Expand Down

0 comments on commit 64a246b

Please sign in to comment.