Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation for base classes and models #213

Merged
merged 7 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 168 additions & 21 deletions llmebench/datasets/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,136 @@


class DatasetBase(ABC):
"""
Base class for datasets

Implementations of this class need to implement at least three mandatory methods;
`metadata()`, `get_data_sample()` and `load_data()`. The purpose of objects of
this class is to encapsulate all the subtleties and information for a specific
dataset, and provide a consistent way for the framework to read the dataset.

Attributes
----------
None

Methods
-------
metadata():
Returns metadata for the dataset

get_data_sample():
Returns one sample of data. Useful to see the structure of loaded data

load_data(data_path="", no_labels=False):
Loads data from the given path and returns a list of data samples

prepare_fewshots(target_data=[], train_data=[], n_shots=1, deduplicate=True):
Returns a generator that provides few shot samples for every test sample

Notes
-----
- Consider overriding `_deduplicate_train_test` to replace the default "input_id"
based de-duplication between train/test
- If the data is not JSON serializable, `_stringify_sample`/`_destringify_sample`
must be re-implemented to provide serialization/deserialization of samples. This is
primarily used for some fewshot sampling methods.

"""

def __init__(self, **kwargs):
pass

@staticmethod
@abstractmethod
def metadata():
"""
Must return a dictionary with the following keys:
"citation": str
Returns the dataset's metadata

Arguments
---------
None

Returns
-------
metadata : dict
The returned dictionary _must_ have the following keys:
"citation" : str
bib-formatted citation for the dataset
"language": str|list
"language" : str|list
Can be one of:
"multilingual"
["ar", "fr", "en"] # List of supported langauges
["ar", "fr", "en"] # List of supported languages
"ar" # Single supported language
Languages should be identified by their IETF language tags
"download_url": str (optional)
URL to data to automatically download if not present
The returned dictionary _can_ have the following additional keys:
"download_url" : str (optional)
URL to data (for automatic downloads)
"""
pass

@abstractmethod
def get_data_sample(self):
"""
Returns a single data sample.

This function is useful to understand the structure of the underlying
data. All loaded samples _must_ match this sample.

Arguments
---------
None

Returns
-------
sample : dict
_Must_ contain at least two keys "input" and "label".
"input_id" can be specified to help with de-duplication
between train/dev/test data. Can include additional keys.
"""
pass

@abstractmethod
def load_data(self, data_path, no_labels=False):
"""
Returns a list of dictionaries,
with at least the following keys:
"input": <input-instance>
"label": <label>
The dictionaries can contain other keys as well
which will be saved in the cache
Load data from data_path.

Arguments
---------
data_path : str|list|dict
Path to dataset. Can be a list or dict as well.
no_labels : bool
Specifies if the data_path has a split with no labels

Returns
-------
data : list
List of dictionaries, where each dictionary is structured like
`get_data_sample()`'s output
"""
pass

def deduplicate_train_test(self, train_data, test_data):
def _deduplicate_train_test(self, train_data, test_data):
"""
Filter train data to avoid overlap with test data

The default implementation de-duplicates based on an "input_id"
element in the sample dictionary.

Arguments
---------
train_data : list
Loaded train data
test_data : list
Loaded test data

Returns
-------
filtered_train_data : list
Train data with overlapping test samples removed
"""
if len(test_data) > 0 and "input_id" not in test_data[0]:
logging.warning(
"`input_id` not found in data, no deduplication will be run"
"`input_id` not found in data, no de-duplication will be run"
)
# TODO: Add fallback to input, label deep comparison
return train_data
Expand All @@ -65,18 +154,76 @@ def deduplicate_train_test(self, train_data, test_data):

return final_train_data

def stringify_sample(self, sample):
def _stringify_sample(self, sample):
"""
Serialize data sample into a string.

Primarily used for some fewshot samplers that work only on strings.
By default uses JSON serialization; If the data is not JSON serializable,
this function must be re-implemented in the implementing class.

Arguments
---------
sample : dict
Input sample, with the same structure as that returned by
`get_data_sample()`

Returns
-------
new_sample : dict
Same as the input sample, except the value associated with the key
"input" must be a string
"""
new_sample = sample.copy()
new_sample["input"] = json.dumps(new_sample["input"], ensure_ascii=False)
return new_sample

def destringify_sample(self, sample):
def _destringify_sample(self, sample):
"""
Deserialize data sample from a string.

Primarily used for some fewshot samplers that work only on strings.
By default uses JSON deserialization; If the data is not JSON deserializable,
this function must be re-implemented in the implementing class.

Arguments
---------
sample : dict
Output of `_stringify_sample()`

Returns
-------
new_sample : dict
Sample with the same structure as that returned by
`get_data_sample()`
"""
new_sample = sample.copy()
new_sample["input"] = json.loads(new_sample["input"])
return new_sample

def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
"""Returns a generator for fewshot samples _per test instance_"""
"""
Returns a generator for fewshot samples _per test instance_

Arguments
---------
target_data : list
Test samples
train_data : list
Train/Dev samples to pick few shot samples from
n_shots : int
Number of samples to pick for each test sample
deduplicate : bool, defaults to True
Whether the training samples should be de-duplicated (w.r.t test
samples).

Returns
-------
fewshot_data : generator
A generator that returns `n_shots` train samples for every
test sample
"""
""""""

# Stringify inputs for few shot
deserialization_required = False
Expand All @@ -85,7 +232,7 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
"`input` is not a string, JSON stringifying for few shot preparation"
)
deserialization_required = True
train_data = [self.stringify_sample(sample) for sample in train_data]
train_data = [self._stringify_sample(sample) for sample in train_data]

# Remove empty inputs
original_sample_count = len(train_data)
Expand All @@ -103,7 +250,7 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
# We discovered some datasets had overlap between train and test
if deduplicate:
original_sample_count = len(train_data)
train_data = self.deduplicate_train_test(train_data, target_data)
train_data = self._deduplicate_train_test(train_data, target_data)
filtered_sample_count = len(train_data)
if filtered_sample_count < original_sample_count:
logging.warning(
Expand All @@ -122,7 +269,7 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
# For each input sample, get few shot examples
for idx, input_sample in enumerate(target_data):
if deserialization_required:
input_sample = self.stringify_sample(input_sample)
input_sample = self._stringify_sample(input_sample)
if len(input_sample["input"].strip()) > 0:
examples = example_selector.select_examples(input_sample)
else:
Expand All @@ -134,6 +281,6 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):

if deserialization_required:
# Deserialize example
examples = [self.destringify_sample(sample) for sample in examples]
examples = [self._destringify_sample(sample) for sample in examples]

yield examples
50 changes: 40 additions & 10 deletions llmebench/models/HuggingFaceInferenceAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@


class HuggingFaceModelLoadingError(Exception):
"""Exception class to capture loading errors"""

def __init__(self, failure_message):
self.failure_message = failure_message

Expand All @@ -38,15 +40,21 @@ def __str__(self):


class HuggingFaceInferenceAPIModel(ModelBase):
"""An interface to HuggingFace Inference API

Args:
task_type: one of Summarization, Sentence_Similarity, Text_Generation, Text2Text_Generation, Translation,
Feature_Extraction, Fill_Mask, Question_Answering, Table_Question_Answering, Text_Classification,
Token_Classification, Named_Entity_Recognition, Zero_Shot_Classification, Conversational as found on
HuggingFace model's page
inference_api_url: the URL to the particular model, as found in the Deploy > Inference API menu in the model's page
api_token: HuggingFace API access key (can also be read from enviroment variable HUGGINGFACE_API_TOKEN)
"""
An interface to HuggingFace Inference API

Arguments
---------
task_type : HuggingFaceTaskTypes
One of Summarization, Sentence_Similarity, Text_Generation, Text2Text_Generation, Translation,
Feature_Extraction, Fill_Mask, Question_Answering, Table_Question_Answering, Text_Classification,
Token_Classification, Named_Entity_Recognition, Zero_Shot_Classification, Conversational as found on
HuggingFace model's page
inference_api_url : str
The URL to the particular model, as found in the Deploy > Inference API menu in the model's page
api_token : str
HuggingFace API access key. If not provided, will be inferred from the environment variable
`HUGGINGFACE_API_TOKEN`
"""

def __init__(self, task_type, inference_api_url, api_token=None, **kwargs):
Expand All @@ -64,6 +72,27 @@ def __init__(self, task_type, inference_api_url, api_token=None, **kwargs):
)

def prompt(self, processed_input):
"""
HuggingFace Inference API Implementation

Arguments
---------
processed_input : dictionary
Must be a dictionary with one key "inputs", the value of which will
depend on the task type. See https://huggingface.co/docs/api-inference/detailed_parameters
for detailed parameters.

Returns
-------
response : dict
Response from the HuggingFace Inference API

Raises
------
HuggingFaceModelLoadingError : Exception
This method raises this exception if the model is not yet loaded on
HuggingFace. Retrying after a few seconds is the usual remedy.
"""
headers = {"Authorization": f"Bearer {self.api_token}"}
data = json.dumps(processed_input)
response = requests.request(
Expand All @@ -78,7 +107,8 @@ def prompt(self, processed_input):

def summarize_response(self, response):
"""
This method will attempt to interpret the output based on the task type. Otherwise, it returns the response object as is.
This method will attempt to interpret the output based on the task type.
Otherwise, it returns the response object as is.
"""
output_types = {
HuggingFaceTaskTypes.Summarization: str,
Expand Down
Loading