Skip to content

Commit

Permalink
revert batching mode (#2479)
Browse files Browse the repository at this point in the history
* revert batching mode
  • Loading branch information
vince62s authored Sep 26, 2023
1 parent e6a4412 commit d61467e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
34 changes: 18 additions & 16 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,27 +279,30 @@ def batch_size_fn(nbsents, maxlen):
else:
raise ValueError(f"Invalid argument batch_type={batch_type}")

minibatch, maxlen, size_so_far, seen = [], 0, 0, []
def max_src_tgt(ex):
"""return the max tokens btw src and tgt in the sequence."""
if ex["tgt"]:
return max(len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"]))
return len(ex["src"]["src_ids"])

minibatch, maxlen, size_so_far, seen = [], 0, 0, set()
for ex in data:
if (ex["src"]["src"] not in seen) or (self.task != CorpusTask.TRAIN):
seen.append(ex["src"]["src"])
src = ex["src"]["src"]
if src not in seen or (self.task != CorpusTask.TRAIN):
seen.add(src)
minibatch.append(ex)
nbsents = len(minibatch)
maxlen = max(text_sort_key(ex), maxlen)
maxlen = max(max_src_tgt(ex), maxlen)
size_so_far = batch_size_fn(nbsents, maxlen)
if size_so_far >= batch_size:
overflowed = 0
if size_so_far > batch_size:
overflowed += 1
overflowed = 1 if size_so_far > batch_size else 0
if batch_size_multiple > 1:
overflowed += (
len(minibatch) - overflowed
) % batch_size_multiple
overflowed += (nbsents - overflowed) % batch_size_multiple
if overflowed == 0:
yield minibatch
minibatch, maxlen, size_so_far, seen = [], 0, 0, []
minibatch, maxlen, size_so_far, seen = [], 0, 0, set()
else:
if overflowed == len(minibatch):
if overflowed == nbsents:
logger.warning(
"The batch will be filled until we reach"
" %d, its size may exceed %d tokens"
Expand All @@ -308,10 +311,9 @@ def batch_size_fn(nbsents, maxlen):
else:
yield minibatch[:-overflowed]
minibatch = minibatch[-overflowed:]
maxlen, size_so_far, seen = 0, 0, []
for i, ex in enumerate(minibatch):
maxlen = max(text_sort_key(ex), maxlen)
size_so_far = batch_size_fn(i + 1, maxlen)
maxlen = max([max_src_tgt(ex) for ex in minibatch])
size_so_far = batch_size_fn(len(minibatch), maxlen)
seen = set()

if minibatch:
yield minibatch
Expand Down
2 changes: 1 addition & 1 deletion onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def append_features_to_text(text, features):
def text_sort_key(ex):
"""Sort using the number of tokens in the sequence."""
if ex["tgt"]:
return max(len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"]))
return len(ex["src"]["src_ids"]), len(ex["tgt"]["tgt_ids"])
return len(ex["src"]["src_ids"])


Expand Down

0 comments on commit d61467e

Please sign in to comment.