Skip to content

Commit

Permalink
Implement graceful failure when an asset fails to run. (#206)
Browse files Browse the repository at this point in the history
Currently, a failing asset will cause the benchmark to stop executing. This commit relaxes this behavior by logging the error for the misbehaving asset and continues to run other assets. The commit also adds some tests for the benchmark runner.

* Implement graceful failure when an asset fails to run.

Currently, a failing asset will cause the benchmark to stop executing.
This commit relaxes this behavior by logging the error for the misbehaving
asset and continues to run other assets.

* Add tests for benchmark runner

* Add test for multiconfig assets
  • Loading branch information
fdalvi authored Sep 10, 2023
1 parent 469b1c7 commit 421839f
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 29 deletions.
60 changes: 32 additions & 28 deletions llmebench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,41 +338,45 @@ def main():
prompt_fn = asset["module"].prompt
post_process_fn = asset["module"].post_process

logging.info(f"Running benchmark: {name}")
task_benchmark = SingleTaskBenchmark(
config,
prompt_fn,
post_process_fn,
cache_dir=args.results_dir / name,
ignore_cache=args.ignore_cache,
limit=args.limit,
n_shots=args.n_shots,
)

if task_benchmark.is_zeroshot() and args.n_shots > 0:
logging.warning(
f"{name}: Skipping because asset is zero shot and --n_shots is non zero"
try:
logging.info(f"Running benchmark: {name}")
task_benchmark = SingleTaskBenchmark(
config,
prompt_fn,
post_process_fn,
cache_dir=args.results_dir / name,
ignore_cache=args.ignore_cache,
limit=args.limit,
n_shots=args.n_shots,
)
continue

if not task_benchmark.is_zeroshot() and args.n_shots == 0:
logging.warning(
f"{name}: Skipping because asset is few shot and --n_shots is zero"
)
continue
if task_benchmark.is_zeroshot() and args.n_shots > 0:
logging.warning(
f"{name}: Skipping because asset is zero shot and --n_shots is non zero"
)
continue

if not task_benchmark.is_zeroshot() and args.n_shots == 0:
logging.warning(
f"{name}: Skipping because asset is few shot and --n_shots is zero"
)
continue

task_results = task_benchmark.run_benchmark()
logging.info(f"{name}: {task_results['evaluation_scores']}")
task_results = task_benchmark.run_benchmark()
logging.info(f"{name}: {task_results['evaluation_scores']}")

task_result_path = task_benchmark.cache_dir / "results.json"
task_result_path = task_benchmark.cache_dir / "results.json"

with open(task_result_path, "w") as fp:
json.dump(task_results, fp, ensure_ascii=False)
with open(task_result_path, "w") as fp:
json.dump(task_results, fp, ensure_ascii=False)

if not task_benchmark.is_zeroshot():
name = f"{name}_{task_benchmark.n_shots}"
if not task_benchmark.is_zeroshot():
name = f"{name}_{task_benchmark.n_shots}"

all_results[name] = task_results
all_results[name] = task_results
except Exception as e:
logging.error(f"{name} failed to run")
traceback.print_exc()

with open(all_results_path, "w") as fp:
json.dump(all_results, fp, ensure_ascii=False)
199 changes: 198 additions & 1 deletion tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import sys
import types

import unittest
Expand All @@ -6,13 +8,72 @@

from unittest.mock import MagicMock, patch

import llmebench

from llmebench import Benchmark
from llmebench.datasets.dataset_base import DatasetBase
from llmebench.models.model_base import ModelBase
from llmebench.tasks.task_base import TaskBase


class MockDataset(DatasetBase):
def metadata():
return {}

def get_data_sample(self):
return {"input": "input", "label": "label"}

def load_data(self, data_path):
return [self.get_data_sample() for _ in range(100)]


class MockModel(ModelBase):
def prompt(self, processed_input):
return processed_input

def summarize_response(self, response):
return response


class MockTask(TaskBase):
def evaluate(self, true_labels, predicted_labels):
return {"Accuracy": 1}


class MockAsset(object):
@staticmethod
def config():
return {}
return {
"dataset": MockDataset,
"dataset_args": {},
"task": MockTask,
"task_args": {},
"model": MockModel,
"model_args": {},
"general_args": {"data_path": "fake/path/to/data"},
}

@staticmethod
def prompt(input_sample):
return {"prompt": input_sample}

@staticmethod
def post_process(response):
return response


class MockFailingAsset(MockAsset):
def prompt(input_sample):
raise Exception("Fail!")


class MockMultiConfigAsset(MockAsset):
@staticmethod
def config():
return [
{"name": "Subasset 1", "config": MockAsset.config()},
{"name": "Subasset 2", "config": MockAsset.config()},
]


@patch("llmebench.utils.import_source_file", MagicMock(return_value=MockAsset))
Expand Down Expand Up @@ -103,3 +164,139 @@ def test_partial_path(self):
self.assertEqual(len(assets), 4)
for asset in assets:
self.assertIn("unique_prefix/", asset["name"])


class TestBenchmarkRunner(unittest.TestCase):
def setUp(self):
self.benchmark_dir = TemporaryDirectory()
self.results_dir = TemporaryDirectory()

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_no_asset_run(self, asset_finder_mock):
"Run benchmark with no assets"
asset_finder_mock.return_value = []

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 0)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_single_asset_run(self, asset_finder_mock):
"Run benchmark with one asset"
asset_finder_mock.return_value = [
{
"name": "MockAsset 1",
"config": MockAsset.config(),
"module": MockAsset,
}
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 1)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_single_failing_asset_run(self, asset_finder_mock):
"Run benchmark with one failing asset"
asset_finder_mock.return_value = [
{
"name": "MockFailingAsset",
"config": MockFailingAsset.config(),
"module": MockFailingAsset,
}
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 0)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_multiple_assets(self, asset_finder_mock):
"Run benchmark with multiple assets"
asset_finder_mock.return_value = [
{
"name": "MockAsset 1",
"config": MockAsset.config(),
"module": MockAsset,
},
{
"name": "MockAsset 2",
"config": MockAsset.config(),
"module": MockAsset,
},
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 2)

@patch("llmebench.benchmark.Benchmark.find_assets")
def test_multiple_assets_with_failure(self, asset_finder_mock):
"Run benchmark with multiple assets and failing assets"
asset_finder_mock.return_value = [
{
"name": "MockAsset 1",
"config": MockAsset.config(),
"module": MockAsset,
},
{
"name": "MockFailingAsset 1",
"config": MockFailingAsset.config(),
"module": MockFailingAsset,
},
{
"name": "MockAsset 2",
"config": MockAsset.config(),
"module": MockAsset,
},
{
"name": "MockFailingAsset 2",
"config": MockFailingAsset.config(),
"module": MockFailingAsset,
},
]

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
self.assertEqual(len(results), 2)

@patch("llmebench.utils.import_source_file")
def test_multi_config_asset(self, asset_importer_mock):
"Run benchmark with multiconfig asset"

# Create dummy asset file
(Path(self.benchmark_dir.name) / "sample.py").touch(exist_ok=True)

asset_importer_mock.return_value = MockMultiConfigAsset

testargs = ["llmebench", self.benchmark_dir.name, self.results_dir.name]
with patch.object(sys, "argv", testargs):
llmebench.benchmark.main()

with open(Path(self.results_dir.name) / "all_results.json") as fp:
results = json.load(fp)
config = MockMultiConfigAsset.config()
self.assertEqual(len(results), len(config))

for subconfig in config:
self.assertIn(f"sample/{subconfig['name']}", results)

0 comments on commit 421839f

Please sign in to comment.