Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement configuration option for fewshot embedding model #241

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def config():
"class_labels": ["Positive", "Negative", "Neutral", "Mixed"],
"max_tries": 3,
},
"general_args": {
"fewshot": {"embedding_model_name": "distiluse-base-multilingual-cased-v1"}
},
}


Expand Down
35 changes: 25 additions & 10 deletions llmebench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ def __init__(
self.name = name

# Pipeline components
dataset_args = config.get("dataset_args", {})
if "data_dir" not in dataset_args:
dataset_args["data_dir"] = data_dir
self.data_dir = dataset_args["data_dir"]
self.dataset = config["dataset"](**dataset_args)
self.dataset_args = config.get("dataset_args", {})
if "data_dir" not in self.dataset_args:
self.dataset_args["data_dir"] = data_dir
self.data_dir = self.dataset_args["data_dir"]
self.dataset_cls = config["dataset"]

task_args = config.get("task_args", {})
self.task = config["task"](dataset=self.dataset, **task_args)
self.task_args = config.get("task_args", {})
self.task_cls = config["task"]

model_args = config.get("model_args", {})
self.model = config["model"](**model_args)
self.model_args = config.get("model_args", {})
self.model_cls = config["model"]

general_args = config.get("general_args", {})

Expand All @@ -62,20 +62,29 @@ def __init__(
if utils.is_fewshot_asset(config, prompt_fn):
self.zeroshot = False
self.deduplicate = True
self.fewshot_embedding_model_name = None
self.train_data_paths = utils.get_data_paths(config, "train")

assert len(self.data_paths) == len(
self.train_data_paths
), "A train split must be provided for every test split being run"
if "fewshot" in general_args:
self.deduplicate = general_args["fewshot"].get("deduplicate", True)
self.fewshot_embedding_model_name = general_args["fewshot"].get(
"embedding_model_name", None
)

self.limit = limit
self.n_shots = n_shots

def is_zeroshot(self):
return self.zeroshot

def initialize_pipeline(self):
self.dataset = self.dataset_cls(**self.dataset_args)
self.task = self.task_cls(dataset=self.dataset, **self.task_args)
self.model = self.model_cls(**self.model_args)

def run_pipeline(
self,
sample_key,
Expand Down Expand Up @@ -142,6 +151,8 @@ def run_pipeline(
return cache_payload, summarized_payload

def run_benchmark(self, dry_run=False):
self.initialize_pipeline()

base_name = self.name
base_cache_dir = self.cache_dir

Expand Down Expand Up @@ -179,7 +190,11 @@ def run_benchmark(self, dry_run=False):
train_data = self.dataset.load_data(train_data_path)

few_shots_data = self.dataset.prepare_fewshots(
data, train_data, self.n_shots, deduplicate=self.deduplicate
data,
train_data,
self.n_shots,
embedding_model_name=self.fewshot_embedding_model_name,
deduplicate=self.deduplicate,
)

true_labels = []
Expand Down
21 changes: 15 additions & 6 deletions llmebench/datasets/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,14 @@ def _destringify_sample(self, sample):
new_sample["input"] = json.loads(new_sample["input"])
return new_sample

def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
def prepare_fewshots(
self,
target_data,
train_data,
n_shots,
embedding_model_name=None,
deduplicate=True,
):
"""
Returns a generator for fewshot samples _per test instance_

Expand All @@ -246,6 +253,9 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
Train/Dev samples to pick few shot samples from
n_shots : int
Number of samples to pick for each test sample
embedding_model_name : str
The model to use for extracting embeddings to use for similarity computation.
Defaults to 'distiluse-base-multilingual-cased-v1'
deduplicate : bool, defaults to True
Whether the training samples should be de-duplicated (w.r.t test
samples).
Expand All @@ -256,7 +266,9 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
A generator that returns `n_shots` train samples for every
test sample
"""
""""""

if embedding_model_name is None:
embedding_model_name = "distiluse-base-multilingual-cased-v1"

# Stringify inputs for few shot
deserialization_required = False
Expand Down Expand Up @@ -291,10 +303,7 @@ def prepare_fewshots(self, target_data, train_data, n_shots, deduplicate=True):
)

# TODO: MaxMarginalRelevanceExampleSelector should be generalized
# TODO: Need to handle not str inputs
embedding_model = HuggingFaceEmbeddings(
model_name="distiluse-base-multilingual-cased-v1"
)
embedding_model = HuggingFaceEmbeddings(model_name=embedding_model_name)
example_selector = MaxMarginalRelevanceExampleSelector.from_examples(
train_data, embedding_model, FAISS, input_keys=["input"], k=n_shots
)
Expand Down
Loading