diff --git a/tests/nlu/classifiers/test_diet_classifier.py b/tests/nlu/classifiers/test_diet_classifier.py index 2319a29eb23f..06b5d5481f20 100644 --- a/tests/nlu/classifiers/test_diet_classifier.py +++ b/tests/nlu/classifiers/test_diet_classifier.py @@ -972,7 +972,19 @@ async def test_no_bilou_when_entity_recognition_off( @pytest.mark.timeout(120, func_only=True) @pytest.mark.parametrize( "batch_size, expected_num_batches", - [(8, 6), (15, 3), (16, 3), (18, 3), (20, 2), (32, 1), (64, 1), (128, 1), (256, 1)], + # the training dataset has 48 NLU examples + [ + (1, 48), + (8, 6), + (15, 3), + (16, 3), + (18, 3), + (20, 2), + (32, 1), + (64, 1), + (128, 1), + (256, 1), + ], ) async def test_dropping_of_last_partial_batch( batch_size: int,