Skip to content

Commit

Permalink
Merge pull request #5 from flairNLP/improving_data_preprocessing
Browse files Browse the repository at this point in the history
Improving the dataset preprocessing
  • Loading branch information
lukasgarbas authored Nov 30, 2024
2 parents 86676c5 + ca6c286 commit 1b416b3
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 319 deletions.
58 changes: 39 additions & 19 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
from datasets import Dataset
from datasets import Dataset, load_dataset
from transformer_ranker.datacleaner import DatasetCleaner


Expand All @@ -19,33 +19,27 @@
("text classification", "SetFit/rte", 0.05),
]


@pytest.mark.parametrize("task_type,dataset_name,downsampling_ratio", test_datasets)
def test_datacleaner(task_type, dataset_name, downsampling_ratio):

preprocessor = DatasetCleaner(dataset_downsample=downsampling_ratio)
dataset = preprocessor.prepare_dataset(dataset_name)

# Test dataset preprocessing
assert isinstance(dataset, Dataset), f"Dataset '{dataset_name}' is not a valid Dataset object"
dataset = load_dataset(dataset_name, trust_remote_code=True)
datacleaner = DatasetCleaner(dataset_downsample=downsampling_ratio)
texts, labels, task_category = datacleaner.prepare_dataset(dataset)

assert preprocessor.task_type == task_type, (
f"Task type mismatch: expected '{task_type}', got '{preprocessor.task_type}'"
assert task_category == task_type, (
f"Task type mismatch: expected '{task_type}', got '{task_category}'"
f"in dataset '{dataset_name}'"
)

# Make sure text and label columns were found
assert preprocessor.text_column is not None, f"Text column not found in dataset {dataset_name}"
assert preprocessor.label_column is not None, f"Label column not found in dataset {dataset_name}"

# Test texts in the text column
sentences = preprocessor.prepare_sentences(dataset)
assert isinstance(sentences, list) and len(sentences) > 0, (
assert isinstance(texts, list) and len(texts) > 0, (
"Sentences/tokens list is empty in dataset %s", dataset_name
)

# Ensure the sentences are in the correct format (str for text-classification, List[str] for token-level)
# Ensure the sentences are in the correct format
# (str for text-classification, List[str] for token-level)
if task_type == "text classification":
for sentence in sentences:
for sentence in texts:
assert isinstance(sentence, str), (
f"Incorrect sentence type in dataset '{dataset_name}', all expected to be str "
f"but some sentences have different type ({type(sentence)})."
Expand All @@ -55,7 +49,7 @@ def test_datacleaner(task_type, dataset_name, downsampling_ratio):
assert sentence != "", f"Empty sentence found in dataset {dataset_name}"

elif task_type == "token classification":
for sentence in sentences:
for sentence in texts:
# For token classification, make sure there is no empty lists of tokens
assert len(sentence) >= 0, f"Empty token list found in dataset {dataset_name}"

Expand All @@ -69,6 +63,32 @@ def test_datacleaner(task_type, dataset_name, downsampling_ratio):
raise KeyError(msg)

# Test the label column in each dataset
labels = preprocessor.prepare_labels(dataset)
assert isinstance(labels, torch.Tensor) and labels.size(0) > 0, "Labels tensor is empty"
assert (labels >= 0).all(), f"Negative label found in dataset {dataset_name}"


def test_simple_dataset():
original_dataset = Dataset.from_dict({
"text": ["", "This is a complete sentence.", "b", "c", "d", "e"],
"label": ["X", "Y", "Z", "X", "Y", "Z"],
"something_else": [0, 1, 2, 3, 4, 5]
})

preprocessor = DatasetCleaner(dataset_downsample=0.5, cleanup_rows=False)
texts, labels, task_category = preprocessor.prepare_dataset(original_dataset)

assert len(original_dataset) == 6
assert len(texts) == len(labels) == 3 # after downsampling

preprocessor = DatasetCleaner(cleanup_rows=False)
texts, labels, task_category = preprocessor.prepare_dataset(original_dataset)

assert original_dataset["label"] == ["X", "Y", "Z", "X", "Y", "Z"]
assert torch.equal(labels, torch.tensor([0, 1, 2, 0, 1, 2]))

preprocessor = DatasetCleaner(cleanup_rows=True)
texts, labels, task_category = preprocessor.prepare_dataset(original_dataset)

# One row should have been removed in the processed dataset
assert original_dataset["label"] == ["X", "Y", "Z", "X", "Y", "Z"]
assert torch.equal(labels, torch.tensor([1, 2, 0, 1, 2]))
Loading

0 comments on commit 1b416b3

Please sign in to comment.