Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement graceful failure when an asset fails to run. #206

Merged
merged 5 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 32 additions & 28 deletions llmebench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,41 +338,45 @@ def main():
prompt_fn = asset["module"].prompt
post_process_fn = asset["module"].post_process

logging.info(f"Running benchmark: {name}")
task_benchmark = SingleTaskBenchmark(
config,
prompt_fn,
post_process_fn,
cache_dir=args.results_dir / name,
ignore_cache=args.ignore_cache,
limit=args.limit,
n_shots=args.n_shots,
)

if task_benchmark.is_zeroshot() and args.n_shots > 0:
logging.warning(
f"{name}: Skipping because asset is zero shot and --n_shots is non zero"
try:
logging.info(f"Running benchmark: {name}")
task_benchmark = SingleTaskBenchmark(
config,
prompt_fn,
post_process_fn,
cache_dir=args.results_dir / name,
ignore_cache=args.ignore_cache,
limit=args.limit,
n_shots=args.n_shots,
)
continue

if not task_benchmark.is_zeroshot() and args.n_shots == 0:
logging.warning(
f"{name}: Skipping because asset is few shot and --n_shots is zero"
)
continue
if task_benchmark.is_zeroshot() and args.n_shots > 0:
logging.warning(
f"{name}: Skipping because asset is zero shot and --n_shots is non zero"
)
continue

if not task_benchmark.is_zeroshot() and args.n_shots == 0:
logging.warning(
f"{name}: Skipping because asset is few shot and --n_shots is zero"
)
continue

task_results = task_benchmark.run_benchmark()
logging.info(f"{name}: {task_results['evaluation_scores']}")
task_results = task_benchmark.run_benchmark()
logging.info(f"{name}: {task_results['evaluation_scores']}")

task_result_path = task_benchmark.cache_dir / "results.json"
task_result_path = task_benchmark.cache_dir / "results.json"

with open(task_result_path, "w") as fp:
json.dump(task_results, fp, ensure_ascii=False)
with open(task_result_path, "w") as fp:
json.dump(task_results, fp, ensure_ascii=False)

if not task_benchmark.is_zeroshot():
name = f"{name}_{task_benchmark.n_shots}"
if not task_benchmark.is_zeroshot():
name = f"{name}_{task_benchmark.n_shots}"

all_results[name] = task_results
all_results[name] = task_results
except Exception as e:
logging.error(f"{name} failed to run")
traceback.print_exc()

with open(all_results_path, "w") as fp:
json.dump(all_results, fp, ensure_ascii=False)
199 changes: 198 additions & 1 deletion tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import sys
import types

import unittest
Expand All @@ -6,13 +8,72 @@

from unittest.mock import MagicMock, patch

import llmebench

from llmebench import Benchmark
from llmebench.datasets.dataset_base import DatasetBase
from llmebench.models.model_base import ModelBase
from llmebench.tasks.task_base import TaskBase


class MockDataset(DatasetBase):
def metadata():
return {}

def get_data_sample(self):
return {"input": "input", "label": "label"}

def load_data(self, data_path):
return [self.get_data_sample() for _ in range(100)]


class MockModel(ModelBase):
def prompt(self, processed_input):
return processed_input

def summarize_response(self, response):
return response


class MockTask(TaskBase):
def evaluate(self, true_labels, predicted_labels):
return {"Accuracy": 1}


class MockAsset(object):
@staticmethod
def config():
return {}
return {
"dataset": MockDataset,
"dataset_args": {},
"task": MockTask,
"task_args": {},
"model": MockModel,
"model_args": {},
"general_args": {"data_path": "fake/path/to/data"},
}

@staticmethod
def prompt(input_sample):
return {"prompt": input_sample}

@staticmethod
def post_process(response):
return response


class MockFailingAsset(MockAsset):
def prompt(input_sample):
raise Exception("Fail!")


class MockMultiConfigAsset(MockAsset):
@staticmethod
def config():
return [
{"name": "Subasset 1", "config": MockAsset.config()},
{"name": "Subasset 2", "config": MockAsset.config()},
]


@patch("llmebench.utils.import_source_file", MagicMock(return_value=MockAsset))
Expand Down Expand Up @@ -103,3 +164,139 @@ def test_partial_path(self):
self.assertEqual(len(assets), 4)
for asset in assets:
self.assertIn("unique_prefix/", asset["name"])


class TestBenchmarkRunner(unittest.TestCase):
def setUp(self):
self.benchmark_dir = TemporaryDirectory()
self.results_dir = TemporaryDirectory()

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_no_asset_run(self, asset_finder_mock):
"Run benchmark with no assets"
asset_finder_mock.return_value = []

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 0)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_single_asset_run(self, asset_finder_mock):
"Run benchmark with one asset"
asset_finder_mock.return_value = [
{
"name": "MockAsset 1",
"config": MockAsset.config(),
"module": MockAsset,
}
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 1)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_single_failing_asset_run(self, asset_finder_mock):
"Run benchmark with one failing asset"
asset_finder_mock.return_value = [
{
"name": "MockFailingAsset",
"config": MockFailingAsset.config(),
"module": MockFailingAsset,
}
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 0)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_multiple_assets(self, asset_finder_mock):
"Run benchmark with multiple assets"
asset_finder_mock.return_value = [
{
"name": "MockAsset 1",
"config": MockAsset.config(),
"module": MockAsset,
},
{
"name": "MockAsset 2",
"config": MockAsset.config(),
"module": MockAsset,
},
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 2)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_multiple_assets_with_failure(self, asset_finder_mock):
"Run benchmark with multiple assets and failing assets"
asset_finder_mock.return_value = [
{
"name": "MockAsset 1",
"config": MockAsset.config(),
"module": MockAsset,
},
{
"name": "MockFailingAsset 1",
"config": MockFailingAsset.config(),
"module": MockFailingAsset,
},
{
"name": "MockAsset 2",
"config": MockAsset.config(),
"module": MockAsset,
},
{
"name": "MockFailingAsset 2",
"config": MockFailingAsset.config(),
"module": MockFailingAsset,
},
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 2)

@patch("llmebench.utils.import_source_file")
def test_multi_config_asset(self, asset_importer_mock):
"Run benchmark with multiconfig asset"

# Create dummy asset file
(Path(self.benchmark_dir.name) / "sample.py").touch(exist_ok=True)

asset_importer_mock.return_value = MockMultiConfigAsset

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
config = MockMultiConfigAsset.config()
self.assertEqual(len(results), len(config))

for subconfig in config:
self.assertIn(f"sample/{subconfig['name']}", results)
Loading