Skip to content

Commit

Permalink
Added test for empty data
Browse files Browse the repository at this point in the history
  • Loading branch information
twerkmeister committed Sep 19, 2023
1 parent df48628 commit 8fc686f
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/nlu/classifiers/test_diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,3 +1003,32 @@ async def test_dropping_of_last_partial_batch(
data_generator, _ = train_utils.create_data_generators(model_data, batch_size, 1)

assert len(data_generator) == expected_num_batches


@pytest.mark.timeout(120, func_only=True)
async def test_dropping_of_last_partial_batch_empty_data(
create_diet: Callable[..., DIETClassifier],
train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
):
"""test that diets data processing produces the right amount of batches.
We introduced a change to only keep the last incomplete batch if
1. it has more than 50% of examples of batch size
2. or it is the only batch in the epoch
"""

pipeline = [
{"component": WhitespaceTokenizer},
{"component": CountVectorsFeaturizer},
]
diet = create_diet(
{ENTITY_RECOGNITION: False, RANDOM_SEED: 1, EPOCHS: 1, RUN_EAGERLY: True}
)
training_data, loaded_pipeline = train_and_preprocess(
pipeline, training_data=TrainingData()
)

model_data = diet.preprocess_train_data(training_data)
data_generator, _ = train_utils.create_data_generators(model_data, 64, 1)

assert len(data_generator) == 0

0 comments on commit 8fc686f

Please sign in to comment.