From df48628069dec036bc2287b896ab5f8e57c2e3a8 Mon Sep 17 00:00:00 2001 From: Thomas Werkmeister Date: Tue, 19 Sep 2023 14:29:41 +0200 Subject: [PATCH] Better last batch logic --- rasa/utils/tensorflow/data_generator.py | 6 ++- tests/nlu/classifiers/test_diet_classifier.py | 37 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/rasa/utils/tensorflow/data_generator.py b/rasa/utils/tensorflow/data_generator.py index 9157ea7252ca..058e6d96cf24 100644 --- a/rasa/utils/tensorflow/data_generator.py +++ b/rasa/utils/tensorflow/data_generator.py @@ -380,7 +380,11 @@ def __len__(self) -> int: # data was rebalanced, so need to recalculate number of examples num_examples = self.model_data.number_of_examples(self._data) batch_size = self._current_batch_size - return num_examples // batch_size + int(num_examples % batch_size > 0) + # keep last batch only if it has more than half a batch size of examples + last_batch_half_full = num_examples % batch_size > batch_size // 2 + num_batches = num_examples // batch_size + int(last_batch_half_full) + # Return at least 1 if there is an example + return max(num_batches, int(num_examples > 0)) def __getitem__(self, index: int) -> Tuple[Any, Any]: """Gets batch at position `index`. diff --git a/tests/nlu/classifiers/test_diet_classifier.py b/tests/nlu/classifiers/test_diet_classifier.py index 52ae4336fb14..49c3691b467c 100644 --- a/tests/nlu/classifiers/test_diet_classifier.py +++ b/tests/nlu/classifiers/test_diet_classifier.py @@ -23,6 +23,7 @@ PREDICTED_CONFIDENCE_KEY, INTENT_NAME_KEY, ) +from rasa.utils import train_utils from rasa.utils.tensorflow.constants import ( LOSS_TYPE, RANDOM_SEED, @@ -966,3 +967,39 @@ async def test_no_bilou_when_entity_recognition_off( diet.train(training_data=training_data) assert all(msg.get(BILOU_ENTITIES) is None for msg in training_data.nlu_examples) + + +@pytest.mark.timeout(120, func_only=True) +@pytest.mark.parametrize( + "batch_size, expected_num_batches", + [(8, 6), (15, 3), (16, 3), (18, 3), (20, 2), (32, 1), (64, 1), (128, 1), (256, 1)], +) +async def test_dropping_of_last_partial_batch( + batch_size: int, + expected_num_batches: int, + create_diet: Callable[..., DIETClassifier], + train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]], +): + """test that diets data processing produces the right amount of batches. + + We introduced a change to only keep the last incomplete batch if + 1. it has more than 50% of examples of batch size + 2. or it is the only batch in the epoch + """ + + pipeline = [ + {"component": WhitespaceTokenizer}, + {"component": CountVectorsFeaturizer}, + ] + diet = create_diet( + {ENTITY_RECOGNITION: False, RANDOM_SEED: 1, EPOCHS: 1, RUN_EAGERLY: True} + ) + # This data set has 48 NLU examples + training_data, loaded_pipeline = train_and_preprocess( + pipeline, training_data="data/test/demo-rasa-no-ents.yml" + ) + + model_data = diet.preprocess_train_data(training_data) + data_generator, _ = train_utils.create_data_generators(model_data, batch_size, 1) + + assert len(data_generator) == expected_num_batches