Skip to content

Commit

Permalink
Better last batch logic
Browse files Browse the repository at this point in the history
  • Loading branch information
twerkmeister committed Sep 19, 2023
1 parent 1634e56 commit df48628
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
6 changes: 5 additions & 1 deletion rasa/utils/tensorflow/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,11 @@ def __len__(self) -> int:
# data was rebalanced, so need to recalculate number of examples
num_examples = self.model_data.number_of_examples(self._data)
batch_size = self._current_batch_size
return num_examples // batch_size + int(num_examples % batch_size > 0)
# keep last batch only if it has more than half a batch size of examples
last_batch_half_full = num_examples % batch_size > batch_size // 2
num_batches = num_examples // batch_size + int(last_batch_half_full)
# Return at least 1 if there is an example
return max(num_batches, int(num_examples > 0))

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""Gets batch at position `index`.
Expand Down
37 changes: 37 additions & 0 deletions tests/nlu/classifiers/test_diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PREDICTED_CONFIDENCE_KEY,
INTENT_NAME_KEY,
)
from rasa.utils import train_utils
from rasa.utils.tensorflow.constants import (
LOSS_TYPE,
RANDOM_SEED,
Expand Down Expand Up @@ -966,3 +967,39 @@ async def test_no_bilou_when_entity_recognition_off(
diet.train(training_data=training_data)

assert all(msg.get(BILOU_ENTITIES) is None for msg in training_data.nlu_examples)


@pytest.mark.timeout(120, func_only=True)
@pytest.mark.parametrize(
"batch_size, expected_num_batches",
[(8, 6), (15, 3), (16, 3), (18, 3), (20, 2), (32, 1), (64, 1), (128, 1), (256, 1)],
)
async def test_dropping_of_last_partial_batch(
batch_size: int,
expected_num_batches: int,
create_diet: Callable[..., DIETClassifier],
train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]],
):
"""test that diets data processing produces the right amount of batches.
We introduced a change to only keep the last incomplete batch if
1. it has more than 50% of examples of batch size
2. or it is the only batch in the epoch
"""

pipeline = [
{"component": WhitespaceTokenizer},
{"component": CountVectorsFeaturizer},
]
diet = create_diet(
{ENTITY_RECOGNITION: False, RANDOM_SEED: 1, EPOCHS: 1, RUN_EAGERLY: True}
)
# This data set has 48 NLU examples
training_data, loaded_pipeline = train_and_preprocess(
pipeline, training_data="data/test/demo-rasa-no-ents.yml"
)

model_data = diet.preprocess_train_data(training_data)
data_generator, _ = train_utils.create_data_generators(model_data, batch_size, 1)

assert len(data_generator) == expected_num_batches

0 comments on commit df48628

Please sign in to comment.