diff --git a/llmebench/datasets/dataset_base.py b/llmebench/datasets/dataset_base.py index 1248ec58..0e88e0d6 100644 --- a/llmebench/datasets/dataset_base.py +++ b/llmebench/datasets/dataset_base.py @@ -10,6 +10,42 @@ class DatasetBase(ABC): + """ + Base class for datasets + + Implementations of this class need to implement at least three mandatory methods; + `metadata()`, `get_data_sample()` and `load_data()`. The purpose of objects of + this class is to encapsulate all the subtleties and information for a specific + dataset, and provide a consistent way for the framework to read the dataset. + + Attributes + ---------- + None + + Methods + ------- + metadata(): + Returns metadata for the dataset + + get_data_sample(): + Returns one sample of data. Useful to see the structure of loaded data + + load_data(data_path="", no_labels=False): + Loads data from the given path and returns a list of data samples + + prepare_fewshots(target_data=[], train_data=[], n_shots=1, deduplicate=True): + Returns a generator that provides few shot samples for every test sample + + Notes + ----- + - Consider overriding `_deduplicate_train_test` to replace the default "input_id" + based de-duplication between train/test + - If the data is not JSON serializable, `_stringify_sample`/`_destringify_sample` + must be re-implemented to provide serialization/deserialization of samples. This is + primarily used for some fewshot sampling methods. + + """ + def __init__(self, **kwargs): pass @@ -17,40 +53,93 @@ def __init__(self, **kwargs): @abstractmethod def metadata(): """ - Must return a dictionary with the following keys: - "citation": str + Returns the dataset's metadata + + Arguments + --------- + None + + Returns + ------- + metadata : dict + The returned dictionary _must_ have the following keys: + "citation" : str bib-formatted citation for the dataset - "language": str|list + "language" : str|list Can be one of: "multilingual" - ["ar", "fr", "en"] # List of supported langauges + ["ar", "fr", "en"] # List of supported languages "ar" # Single supported language Languages should be identified by their IETF language tags - "download_url": str (optional) - URL to data to automatically download if not present + The returned dictionary _can_ have the following additional keys: + "download_url" : str (optional) + URL to data (for automatic downloads) """ pass @abstractmethod def get_data_sample(self): + """ + Returns a single data sample. + + This function is useful to understand the structure of the underlying + data. All loaded samples _must_ match this sample. + + Arguments + --------- + None + + Returns + ------- + sample : dict + _Must_ contain at least two keys "input" and "label". + "input_id" can be specified to help with de-duplication + between train/dev/test data. Can include additional keys. + """ pass @abstractmethod def load_data(self, data_path, no_labels=False): """ - Returns a list of dictionaries, - with at least the following keys: - "input": - "label":