Skip to content

Commit

Permalink
Benchmark Runner
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Dec 20, 2024
1 parent d90e05f commit 1bf033c
Show file tree
Hide file tree
Showing 6 changed files with 327 additions and 11 deletions.
69 changes: 60 additions & 9 deletions construe/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .version import get_version
from .exceptions import DeviceError

from .models.path import get_model_home

from .datasets.path import get_data_home
from .datasets.loaders import cleanup_all_datasets
Expand All @@ -23,22 +24,39 @@
download_nsfw,
)


from .whisper import Whisper
from .basic import BasicBenchmark
from .whisper import WhisperBenchmark
from .moondream import MoonDreamBenchmark

from .benchmark import BenchmarkRunner


CONTEXT_SETTINGS = {
"help_option_names": ["-h", "--help"],
}

DATASETS = [
"all", "dialects", "lowlight", "reddit", "movies", "essays", "aegis", "nsfw",
"all",
"dialects",
"lowlight",
"reddit",
"movies",
"essays",
"aegis",
"nsfw",
]


@click.group(context_settings=CONTEXT_SETTINGS)
@click.version_option(get_version(), message="%(prog)s v%(version)s")
@click.option(
"-o",
"--out",
default="construe.json",
type=str,
help="specify the path to write the benchmark results to",
)
@click.option(
"-d",
"--device",
Expand All @@ -54,21 +72,51 @@
envvar=["CONSTRUE_ENV", "ENV"],
help="name of the experimental environment for comparison (default is hostname)",
)
@click.option(
"-c",
"--count",
default=1,
type=int,
help="specify the number of times to run each benchmark",
)
@click.option(
"-D",
"--datadir",
default=None,
envvar="CONSTRUE_DATA",
help="specify the location to download datasets to",
)
@click.option(
"-M",
"--modeldir",
default=None,
envvar="CONSTRUE_MODELS",
help="specify the location to download models to",
)
@click.option(
"-S",
"--sample/--no-sample",
default=True,
help="use sample dataset instead of full dataset for benchmark",
)
@click.option(
"-C",
"--cleanup/--no-cleanup",
default=True,
help="cleanup all downloaded datasets after the benchmark is run",
)
@click.pass_context
def main(ctx, env=None, device=None, datadir=None, cleanup=True):
def main(
ctx,
out=None,
env=None,
device=None,
count=1,
datadir=None,
modeldir=None,
sample=True,
cleanup=True,
):
"""
A utility for executing inferencing benchmarks.
"""
Expand All @@ -78,15 +126,19 @@ def main(ctx, env=None, device=None, datadir=None, cleanup=True):
except RuntimeError as e:
raise DeviceError(e)

click.echo(f"using torch.device(\"{device}\")")
click.echo(f'using torch.device("{device}")')

if env is None:
env = platform.node()

ctx.ensure_object(dict)
ctx.obj["out"] = out
ctx.obj["device"] = device
ctx.obj["env"] = env
ctx.obj["n_runs"] = count
ctx.obj["data_home"] = get_data_home(datadir)
ctx.obj["model_home"] = get_model_home(modeldir)
ctx.obj["use_sample"] = sample
ctx.obj["cleanup"] = cleanup


Expand Down Expand Up @@ -144,11 +196,10 @@ def whisper(ctx, **kwargs):
"""
Executes audio-to-text inferencing benchmarks.
"""
kwargs["env"] = ctx.obj["env"]
benchmark = WhisperBenchmark(**kwargs)
benchmark.before()
benchmark.run()
benchmark.after()
out = ctx.obj.pop("out")
runner = BenchmarkRunner(benchmarks=[Whisper], **ctx.obj)
runner.run()
runner.save(out)


@main.command()
Expand Down
6 changes: 6 additions & 0 deletions construe/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Manages benchmark execution in the construe library.
"""

from .base import Benchmark
from .runner import BenchmarkRunner
91 changes: 91 additions & 0 deletions construe/benchmark/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Defines the base class for all Benchmarks.
"""

import abc

from ..models import get_model_home
from ..datasets import get_data_home

from typing import Any, Generator, Dict, Union


class Benchmark(abc.ABC):
"""
All benchmarks must subclass this class to ensure all properties and methods are
correctly set for generic benchmarks to be run correctly.
"""

def __init__(self, **kwargs):
self._data_home = get_data_home(kwargs.pop("data_home", None))
self._model_home = get_model_home(kwargs.pop("model_home", None))
self._use_sample = kwargs.pop("use_sample", True)
self._options = kwargs

@property
def data_home(self) -> str:
if hasattr(self, "_data_home"):
return self._data_home
return get_data_home()

@property
def model_home(self) -> str:
if hasattr(self, "_model_home"):
return self._model_home
return get_model_home()

@property
def use_sample(self) -> bool:
return getattr(self, "_use_sample", True)

@property
def metadata(self) -> Dict:
return getattr(self, "_metadata", None)

@property
def options(self) -> Union[Dict, None]:
return getattr(self, "_options", {})

@property
@abc.abstractmethod
def description(self):
pass

@abc.abstractmethod
def before(self):
"""
This method is called before the benchmark runs and should cause it to
setup any datasets and models needed for the benchmark to run.
"""
pass

@abc.abstractmethod
def after(self, cleanup: bool = True):
"""
This method is called after the benchamrk is run; if cleanup is True the
class should delete any cached datasets or models.
"""
pass

@abc.abstractmethod
def instances(self) -> Generator[Any, None, None]:
"""
This method should yield all instances in the dataset at least once.
"""
pass

@abc.abstractmethod
def preprocess(self, instance: Any) -> Any:
"""
Any preprocessing that must be performed on an instance is handled with this
method. This method is measured for latency and memory usage as well.
"""
pass

@abc.abstractmethod
def inference(self, instance: Any) -> Any:
"""
This represents the primary inference of the benchmark and is measured for
latency and memory usage to add to the metrics.
"""
pass
Loading

0 comments on commit 1bf033c

Please sign in to comment.