Skip to content

Commit

Permalink
Better last batch logic (#12827)
Browse files Browse the repository at this point in the history
* Better last batch logic
  • Loading branch information
twerkmeister committed Sep 21, 2023
1 parent 6bb02b7 commit 1fdc9a1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog/12827.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved handling of last batch during DIET and TED training. The last batch is discarded if it contains less than half a batch size of data.
7 changes: 6 additions & 1 deletion rasa/utils/tensorflow/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import List, Union, Text, Optional, Any, Tuple, Dict, cast

import logging
Expand Down Expand Up @@ -380,7 +381,11 @@ def __len__(self) -> int:
# data was rebalanced, so need to recalculate number of examples
num_examples = self.model_data.number_of_examples(self._data)
batch_size = self._current_batch_size
return num_examples // batch_size + int(num_examples % batch_size > 0)
# keep last batch only if it has at least half a batch size of examples
last_batch_half_full = num_examples % batch_size >= math.ceil(batch_size / 2)
num_batches = num_examples // batch_size + int(last_batch_half_full)
# Return at least 1 if there is an example
return max(num_batches, int(num_examples > 0))

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Gets batch at position `index`.
Expand Down
74 changes: 74 additions & 0 deletions tests/nlu/classifiers/test_diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PREDICTED_CONFIDENCE_KEY,
INTENT_NAME_KEY,
)
from rasa.utils import train_utils
from rasa.utils.tensorflow.constants import (
LOSS_TYPE,
RANDOM_SEED,
Expand Down Expand Up @@ -896,3 +897,76 @@ async def test_sparse_feature_sizes_decreased_incremental_training(
train_load_and_process_diet(
finetune_classifier, pipeline=pipeline, training_data=iter2_path
)


@pytest.mark.timeout(120, func_only=True)
@pytest.mark.parametrize(
"batch_size, expected_num_batches",
# the training dataset has 48 NLU examples
[
(1, 48),
(8, 6),
(15, 3),
(16, 3),
(18, 3),
(20, 2),
(32, 2),
(64, 1),
(128, 1),
(256, 1),
],
)
async def test_dropping_of_last_partial_batch(
batch_size: int,
expected_num_batches: int,
create_diet: Callable[..., DIETClassifier],
train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
):
"""test that diets data processing produces the right amount of batches.
We introduced a change to only keep the last incomplete batch if
1. it has more than 50% of examples of batch size
2. or it is the only batch in the epoch
"""

pipeline = [
{"component": WhitespaceTokenizer},
{"component": CountVectorsFeaturizer},
]
diet = create_diet({ENTITY_RECOGNITION: False, RANDOM_SEED: 1, EPOCHS: 1})
# This data set has 48 NLU examples
training_data, loaded_pipeline = train_and_preprocess(
pipeline, training_data="data/test/demo-rasa-no-ents.yml"
)

model_data = diet.preprocess_train_data(training_data)
data_generator, _ = train_utils.create_data_generators(model_data, batch_size, 1)

assert len(data_generator) == expected_num_batches


@pytest.mark.timeout(120, func_only=True)
async def test_dropping_of_last_partial_batch_empty_data(
create_diet: Callable[..., DIETClassifier],
train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
):
"""test that diets data processing produces the right amount of batches.
We introduced a change to only keep the last incomplete batch if
1. it has more than 50% of examples of batch size
2. or it is the only batch in the epoch
"""

pipeline = [
{"component": WhitespaceTokenizer},
{"component": CountVectorsFeaturizer},
]
diet = create_diet({ENTITY_RECOGNITION: False, RANDOM_SEED: 1, EPOCHS: 1})
training_data, loaded_pipeline = train_and_preprocess(
pipeline, training_data=TrainingData()
)

model_data = diet.preprocess_train_data(training_data)
data_generator, _ = train_utils.create_data_generators(model_data, 64, 1)

assert len(data_generator) == 0

0 comments on commit 1fdc9a1

Please sign in to comment.