Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement generic data loaders for csv/tsv/jsonl files #244

Merged
merged 1 commit into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions llmebench/datasets/CSV.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import csv

from pathlib import Path

from llmebench.datasets.dataset_base import DatasetBase


class CSVDataset(DatasetBase):
"""
Generic CSV dataset loader

This data loader provides a way to load local csv/tsv datasets from disk. Assets
using this loader *must* provide a `custom_test_split`, which can be a relative
path which will be resolved relative to `data_dir`, or an absolute path. Similarly,
`custom_train_split` must also be provided for few shot assets.

Attributes
----------
data_dir : str
Base path of data containing all datasets. Defaults to "data" in the current
working directory.
column_mapping : dict
Mapping defining which of the columns in the loaded csv are "input" and "label".
The supplied dict must contain mappings for "input" and "label", and may contain
other mappings (such as "input_id"). Column mappings can be `int`'s, which would
be used as indices, or `str`'s, which would be used to search for column indices
in a header row
has_header : bool, defaults to True
Whether the file has a header. If column_mapping specifies column names as `str`,
this must be True. Defaults to True.
delimiter : str, defaults to ','
Delimiter for the csvreader
encoding : str, defaults to 'utf=8'
Encoding to use when opening the file
"""

def __init__(
self, column_mapping, has_header=True, delimiter=",", encoding="utf-8", **kwargs
):
# Check for column_mapping
assert "input" in column_mapping
assert "label" in column_mapping
self.column_mapping = column_mapping

self.has_header = has_header
self.delimiter = delimiter
self.encoding = encoding

super(CSVDataset, self).__init__(**kwargs)

@staticmethod
def metadata():
return {"generic": True}

@staticmethod
def get_data_sample():
return {"input": "Test Input", "label": "0"}

def load_data(self, data_split, no_labels=False):
if not isinstance(data_split, Path):
data_split = Path(data_split)

if not data_split.is_absolute():
data_split = f":data_dir:{data_split}"

data_path = self.resolve_path(data_split)

data = []

with open(data_path, "r", encoding=self.encoding) as csvfile:
csv_reader = csv.reader(csvfile, delimiter=self.delimiter)

header = None
if self.has_header:
header = next(csv_reader)

column_index_mapping = {}
for sample_key, column_ref in self.column_mapping.items():
if isinstance(column_ref, int):
column_index_mapping[sample_key] = column_ref
elif isinstance(column_ref, str):
assert (
header is not None
), f"CSV Loader: file must have header if column_mapping uses `str` values"
column_idx = header.index(column_ref)
assert (
column_idx != -1
), f"CSV Loader: {column_ref} not found in data"
column_index_mapping[sample_key] = column_idx
else:
raise Exception(
f"CSV Loader: column_mapping must use `int` or `str` values"
)

for row in csv_reader:
processed_sample = {}
for sample_key, column_idx in column_index_mapping.items():
processed_sample[sample_key] = row[column_idx]
data.append(processed_sample)

return data

@classmethod
def download_dataset(cls, data_dir, download_url=None, default_url=None):
# Generic dataset loaders do not refer to a specific dataset to download
pass
70 changes: 70 additions & 0 deletions llmebench/datasets/JSONL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import json

from pathlib import Path

from llmebench.datasets.dataset_base import DatasetBase


class JSONLDataset(DatasetBase):
"""
Generic jsonl dataset loader

This data loader provides a way to load local jsonl datasets from disk. Each line
of the jsonl file must be a valid json object.
Assets using this loader *must* provide a `custom_test_split`, which can be a
relative path which will be resolved relative to `data_dir`, or an absolute path.
Similarly, `custom_train_split` must also be provided for few shot assets.

Attributes
----------
data_dir : str
Base path of data containing all datasets. Defaults to "data" in the current
working directory.
column_mapping : dict
Mapping defining which of the keys in the loaded json are "input" and "label".
The supplied dict must contain mappings for "input" and "label", and may contain
other mappings (such as "input_id").
"""

def __init__(self, column_mapping, **kwargs):
# Check for column_mapping
assert "input" in column_mapping
assert "label" in column_mapping
self.column_mapping = column_mapping

super(JSONLDataset, self).__init__(**kwargs)

@staticmethod
def metadata():
return {"generic": True}

@staticmethod
def get_data_sample():
return {"input": "Test Input", "label": "0"}

def load_data(self, data_split, no_labels=False):
if not isinstance(data_split, Path):
data_split = Path(data_split)

if not data_split.is_absolute():
data_split = f":data_dir:{data_split}"

data_path = self.resolve_path(data_split)

data = []

with open(data_path, "r") as jsonl_file:
for line in jsonl_file:
sample = json.loads(line)

processed_sample = {}
for sample_key, column_key in self.column_mapping.items():
processed_sample[sample_key] = sample[column_key]
data.append(processed_sample)

return data

@classmethod
def download_dataset(cls, data_dir, download_url=None, default_url=None):
# Generic dataset loaders do not refer to a specific dataset to download
pass
34 changes: 34 additions & 0 deletions llmebench/datasets/TSV.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from llmebench.datasets.CSV import CSVDataset


class TSVDataset(CSVDataset):
"""
Generic TSV dataset loader

This data loader provides a way to load local tsv datasets from disk. Assets
using this loader *must* provide a `custom_test_split`, which can be a relative
path which will be resolved relative to `data_dir`, or an absolute path. Similarly,
`custom_train_split` must also be provided for few shot assets.

Attributes
----------
data_dir : str
Base path of data containing all datasets. Defaults to "data" in the current
working directory.
column_mapping : dict
Mapping defining which of the columns in the loaded tsv are "input" and "label".
The supplied dict must contain mappings for "input" and "label", and may contain
other mappings (such as "input_id"). Column mappings can be `int`'s, which would
be used as indices, or `str`'s, which would be used to search for column indices
in a header row
has_header : bool, defaults to True
Whether the file has a header. If column_mapping specifies column names as `str`,
this must be True. Defaults to True.
encoding : str, defaults to 'utf=8'
Encoding to use when opening the file
"""

def __init__(self, **kwargs):
kwargs["delimiter"] = "\t"

super(TSVDataset, self).__init__(**kwargs)
3 changes: 3 additions & 0 deletions llmebench/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from .BanglaSentiment import BanglaSentimentDataset
from .BibleMaghrebiDiacritization import BibleMaghrebiDiacritizationDataset
from .COVID19Factuality import COVID19FactualityDataset
from .CSV import CSVDataset
from .CT22Attentionworthy import CT22AttentionworthyDataset
from .CT22Checkworthiness import CT22CheckworthinessDataset
from .CT22Claim import CT22ClaimDataset
from .CT22Harmful import CT22HarmfulDataset
from .CT23Subjectivity import CT23SubjectivityDataset
from .Emotion import EmotionDataset
from .HuggingFace import HuggingFaceDataset
from .JSONL import JSONLDataset
from .Location import LocationDataset
from .MGBWords import MGBWordsDataset
from .MLQA import MLQADataset
Expand All @@ -40,6 +42,7 @@
from .SemEval23T3Propaganda import SemEval23T3PropagandaDataset
from .Spam import SpamDataset
from .STSQ2Q import STSQ2QDataset
from .TSV import TSVDataset
from .TyDiQA import TyDiQADataset
from .UnifiedFCFactuality import UnifiedFCFactualityDataset
from .UnifiedFCStance import UnifiedFCStanceDataset
Expand Down
Loading