Skip to content

Commit

Permalink
Reorder preprocessing steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Nov 29, 2024
1 parent 9e5651f commit ca6c286
Showing 1 changed file with 52 additions and 56 deletions.
108 changes: 52 additions & 56 deletions transformer_ranker/datacleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class DatasetCleaner:
def prepare_dataset(
self, dataset: Union[str, Dataset, DatasetDict]
) -> tuple[Union[list[str], list[list[str]]], torch.Tensor, TaskCategory]:
"""Prepare texts and labels, and assign a task category.
"""Prepare texts and labels, and assign task category.
Downsample the dataset, find text and label columns, create label map,
preprocess labels, pre-tokenize, clean rows, merge columns.
Returns: (processed texts, label tensor, task category)
Downsample dataset, find text and label columns, create label map,
preprocess labels, pre-tokenize, clean rows, merge text pair columns.
Returns: (processed texts, label tensor, task category)
"""

# Verify dataset type
Expand All @@ -55,43 +55,45 @@ def prepare_dataset(

# Find or set the text field
text_column = self.text_column if self.text_column \
else self._find_column("Text column", dataset)
else self._find_column(dataset, "text column")

# Find or set the label field
label_column = self.label_column if self.label_column \
else self._find_column("Label column", dataset)
else self._find_column(dataset, "label column")

# Find or set the task_category
task_category = self.task_category if self.task_category \
else self._find_task_category(label_column, dataset)

# Set or create a label map for classification
label_map, dataset = (self.label_map, dataset) if self.label_map \
else self._create_label_map(label_column, dataset)
else self._find_task_category(dataset, label_column)

# Combine text pair columns with a separator token
if self.text_pair_column:
dataset = self._merge_text_pairs(text_column, self.text_pair_column, dataset)
dataset = self._merge_text_pairs(dataset, text_column, self.text_pair_column)
text_column = f"{text_column}+{self.text_pair_column}"

# Remove unused columns
dataset = dataset.select_columns([text_column, label_column])

# Downsample to a given ratio
if self.dataset_downsample:
dataset = self._downsample(self.dataset_downsample, dataset)
dataset = self._downsample(dataset, self.dataset_downsample)

# Clean noisy or empty rows
if self.cleanup_rows:
dataset = self._cleanup_rows(text_column, label_column, dataset)
dataset = self._cleanup_rows(dataset, text_column, label_column)

# Optional pre-tokenization
if self.tokenize and isinstance(dataset[text_column][0], str):
dataset = self._whitespace_tokenize(text_column, dataset)
dataset = self._whitespace_tokenize(dataset, text_column)

# Set or create a label map for classification
label_map = self.label_map
if task_category in (TaskCategory.TOKEN_CLASSIFICATION, TaskCategory.TEXT_CLASSIFICATION):
dataset, label_map = (dataset, label_map) if label_map \
else self._create_label_map(dataset, label_column)

# Handle BIO encoding for token classification
if self.remove_bio_encoding and task_category == TaskCategory.TOKEN_CLASSIFICATION:
dataset, label_map = self._remove_bio_encoding(dataset, label_column, label_map)
# Remove BIO encoding for token classification
if task_category == TaskCategory.TOKEN_CLASSIFICATION and self.remove_bio_encoding:
dataset, label_map = self._remove_bio_encoding(dataset, label_column, label_map)

# Prepare all texts
texts = dataset[text_column]
Expand All @@ -102,7 +104,7 @@ def prepare_dataset(
labels = [word_label for labels in dataset[label_column] for word_label in labels]
labels = torch.tensor(labels)

# Log some preprocessed dataset info
# Log dataset info
self._log_dataset_info(
text_column, label_column, label_map, task_category,
self.dataset_downsample, dataset_size=len(dataset)
Expand All @@ -111,14 +113,14 @@ def prepare_dataset(
return texts, labels, task_category

@staticmethod
def _find_column(column_role: str, dataset: Dataset) -> str:
def _find_column(dataset: Dataset, column_role: str) -> str:
"""Find text and label columns using common keywords."""
common_names: dict = {
'Text column': [
'text column': [
"text", "sentence", "token", "tweet", "document", "paragraph", "description",
"comment", "utterance", "question", "story", "context", "passage",
],
"Label column": [
"label column": [
"label", "ner_tag", "named_entities", "entities", "tag", "target", "category",
"class", "sentiment", "polarity", "emotion", "rating", "stance",
]
Expand All @@ -138,8 +140,8 @@ def _find_column(column_role: str, dataset: Dataset) -> str:
return found_column

@staticmethod
def _merge_text_pairs(text_column: str, text_pair_column: str, dataset: Dataset) -> Dataset:
"""Concatenate text pairs into a single string using separator token"""
def _merge_text_pairs(dataset: Dataset, text_column: str, text_pair_column: str) -> Dataset:
"""Concatenate text pairs into single column using sep token"""
if text_pair_column not in dataset.column_names:
raise ValueError(
f"Text pair column name '{text_pair_column}' can not be found in the dataset. "
Expand All @@ -158,13 +160,13 @@ def _merge_text_pairs(text_column: str, text_pair_column: str, dataset: Dataset)
return dataset

@staticmethod
def _find_task_category(label_column: str, dataset: Dataset) -> TaskCategory:
"""Determine task category based on the label column's data type."""
def _find_task_category(dataset: Dataset, label_column: str) -> TaskCategory:
"""Assign task category based on label type."""
label_to_task_category = {
int: TaskCategory.TEXT_CLASSIFICATION, # text classification labels can be integers
str: TaskCategory.TEXT_CLASSIFICATION, # or strings e.g. "positive"
list: TaskCategory.TOKEN_CLASSIFICATION, # token-level tasks have a list of labels
float: TaskCategory.TEXT_REGRESSION, # regression tasks have floats
int: TaskCategory.TEXT_CLASSIFICATION,
str: TaskCategory.TEXT_CLASSIFICATION,
list: TaskCategory.TOKEN_CLASSIFICATION,
float: TaskCategory.TEXT_REGRESSION,
}

label_types = list(set(type(label) for label in dataset[label_column]))
Expand All @@ -184,31 +186,34 @@ def _find_task_category(label_column: str, dataset: Dataset) -> TaskCategory:
f"Supported label types for are {list(label_to_task_category.keys())}."
)

@staticmethod
def _downsample(dataset: Dataset, ratio: float) -> Dataset:
"""Reduce dataset size to given ratio."""
return dataset.shuffle(seed=42).select(range(int(len(dataset) * ratio)))

@staticmethod
def _cleanup_rows(
text_column: str, label_column: str, dataset: Dataset
dataset: Dataset, text_column: str, label_column: str
) -> Dataset:
"""Filter out entries with empty or noisy texts and labels."""
def is_valid_entry(dataset_row) -> bool:
text, label = dataset_row[text_column], dataset_row[label_column]

# Remove empty entries
if not text or label is None:
return False

if not isinstance(text, list):
text = [text]

# Remove sentences with characters unsupported by most tokenizers
bad_characters = ["\uFE0F"] # emoji variation symbol '\uFE0F'
if any(char in t for t in text for char in bad_characters):
return False

if not isinstance(label, list):
label = [label]

# Remove negative labels from classification datasets
if any(isinstance(word_label, int) and word_label < 0 for word_label in label):
# Remove "-1" labels sometimes used for unlabeled text
if any(word_label == -1 for word_label in label):
return False

return True
Expand All @@ -217,8 +222,8 @@ def is_valid_entry(dataset_row) -> bool:
return dataset

@staticmethod
def _map_string_labels_to_integers(label_column, dataset) -> tuple[Dataset, dict[str, int]]:
"""Converts string labels to integers and store the label map"""
def _map_string_labels_to_integers(dataset, label_column) -> tuple[Dataset, dict[str, int]]:
"""Convert string labels to integers and retain label map."""
label_names = sorted(set(dataset[label_column]))
label_map = {label_name: idx for idx, label_name in enumerate(label_names)}

Expand All @@ -235,20 +240,17 @@ def label_to_id(label):
return dataset, label_map

@staticmethod
def _create_label_map(label_column: str, dataset: Dataset) -> tuple[dict[str, int], Dataset]:
"""Find feature names to create a label map in a hf dataset.
Convert label column to integers if needed."""
def _create_label_map(dataset: Dataset, label_column: str) -> tuple[Dataset, dict[str, int]]:
"""Find feature names to create label map, convert labels to integers if needed."""
label_names = getattr(
getattr(dataset.features[label_column], "feature", None), "names", None
) or getattr(dataset.features[label_column], "names", None)

if not label_names:
label_names = sorted(
{
label_names = sorted(set(
label for sublist in dataset[label_column]
for label in (sublist if isinstance(sublist, list) else [sublist])
}
)
))
label_names = [str(label) for label in label_names]

label_map = {label: idx for idx, label in enumerate(label_names)}
Expand All @@ -264,23 +266,17 @@ def _create_label_map(label_column: str, dataset: Dataset) -> tuple[dict[str, in
},
desc="Converting string labels to integers"
)
return label_map, dataset

@staticmethod
def _downsample(ratio: float, dataset: Dataset) -> Dataset:
"""Reduce the dataset to a chosen ratio."""
return dataset.shuffle(seed=42).select(range(int(len(dataset) * ratio)))
return dataset, label_map

@staticmethod
def _remove_bio_encoding(
dataset: Dataset, label_column: str, label_map: dict[str, int]
) -> tuple[Dataset, dict[str, int]]:
"""Remove BIO prefixes for NER labels and create a new label map."""
"""Remove BIO prefixes for ner labels and create new label map."""
labels = [label.split("-")[-1] for label in label_map.keys()]
unique_labels = list(dict.fromkeys(labels))
new_label_map = {label: idx for idx, label in enumerate(unique_labels)}

# Check if label map was changed
if label_map == new_label_map:
logger.warning(
"Could not remove BIO encoding. Pass your own label map "
Expand All @@ -301,7 +297,7 @@ def _remove_bio_encoding(
return dataset, new_label_map

@staticmethod
def _whitespace_tokenize(text_column: str, dataset: Dataset) -> Dataset:
def _whitespace_tokenize(dataset: Dataset, text_column: str) -> Dataset:
"""Tokenize using Whitespace"""
tokenizer = Whitespace()

Expand All @@ -318,18 +314,18 @@ def _log_dataset_info(
text_column, label_column, label_map, task_category, downsample_ratio, dataset_size
) -> None:
"""Log information about preprocessed dataset"""
# Basic dataset configuration
# Some details about dataset
logger.info(
f"Dataset Info - Text Column: {text_column}, Label Column: {label_column}, "
f"Task Category: {task_category}, Dataset Size: {dataset_size} texts"
)

# Show the down-sampled size
# Show dataset size
if downsample_ratio and downsample_ratio < 1.0:
logger.info(
f"Dataset has been downsampled to {int(downsample_ratio * 100)}% of original size."
)

# Log the label map
# And the label map
if task_category != TaskCategory.TEXT_REGRESSION:
logger.info(f"Label Map: {label_map}")

0 comments on commit ca6c286

Please sign in to comment.