diff --git a/tests/nlu/classifiers/test_diet_classifier.py b/tests/nlu/classifiers/test_diet_classifier.py index 49c3691b467c..2319a29eb23f 100644 --- a/tests/nlu/classifiers/test_diet_classifier.py +++ b/tests/nlu/classifiers/test_diet_classifier.py @@ -1003,3 +1003,32 @@ async def test_dropping_of_last_partial_batch( data_generator, _ = train_utils.create_data_generators(model_data, batch_size, 1) assert len(data_generator) == expected_num_batches + + +@pytest.mark.timeout(120, func_only=True) +async def test_dropping_of_last_partial_batch_empty_data( + create_diet: Callable[..., DIETClassifier], + train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]], +): + """test that diets data processing produces the right amount of batches. + + We introduced a change to only keep the last incomplete batch if + 1. it has more than 50% of examples of batch size + 2. or it is the only batch in the epoch + """ + + pipeline = [ + {"component": WhitespaceTokenizer}, + {"component": CountVectorsFeaturizer}, + ] + diet = create_diet( + {ENTITY_RECOGNITION: False, RANDOM_SEED: 1, EPOCHS: 1, RUN_EAGERLY: True} + ) + training_data, loaded_pipeline = train_and_preprocess( + pipeline, training_data=TrainingData() + ) + + model_data = diet.preprocess_train_data(training_data) + data_generator, _ = train_utils.create_data_generators(model_data, 64, 1) + + assert len(data_generator) == 0