Skip to content

Commit

Permalink
Add resnet50 benchmark (#443) (#451)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Sep 3, 2024
1 parent 3c6e9d1 commit 29ffb43
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
33 changes: 33 additions & 0 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,6 +2825,39 @@ def foo(a, m, v, w, b, training):
return foo


class ResNet50Benchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
def __init__(
self,
batch_size: int,
input_shape: tuple[int, int, int],
device: str = "cuda",
dtype: dtypes.dtype = thunder.float32,
requires_grad: bool = False,
) -> None:
super().__init__()

# the typical input image size of ResNet50 is (3, 224, 224)
self.shape: tuple[int, int, int, int] = (batch_size,) + input_shape
self.device: str = device
self.dtype: dtypes.dtype = dtype
self.tdtype: torch.dtype = ltorch.to_torch_dtype(dtype)
self.requires_grad: bool = requires_grad

self.devices: list[str] = [device]

def make_batch(self) -> tuple[list, dict]:
make = partial(make_tensor, device=self.device, dtype=self.tdtype, requires_grad=self.requires_grad)
a = make(self.shape)
return (a,), {}

def fn(self) -> Callable:
from torchvision.models import resnet50

model = resnet50()
model = model.to(device=self.device, dtype=self.tdtype).requires_grad_(self.requires_grad)
return model


# TODO Add descriptions to the executors when listed, and list them alphabetically
# TODO Allow querying benchmark for details
# TODO Allow specifying benchmark arguments
Expand Down
25 changes: 25 additions & 0 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
NanoGPTCrossEntropyBenchmark,
LitGPTGeluBenchmark,
NanoGPTLayerNormBenchmark,
ResNet50Benchmark,
thunder_apex_executor,
thunder_apex_nvfuser_executor,
thunder_cudnn_executor,
Expand Down Expand Up @@ -721,3 +722,27 @@ def test_interpreter_nanogpt_gpt2_fwd(benchmark, executor: Callable):
fn = executor(bench.fn())

benchmark(fn, *args, **kwargs)


#
# vision benchmarks
#


# Sample command to run this benchmark:
# pytest thunder/benchmarks/targets.py -k "test_resnet50" --benchmark-group-by='param:compute_type'
@pytest.mark.parametrize(
"executor,",
executors,
ids=executors_ids,
)
@parametrize_compute_type
def test_resnet50(benchmark, executor: Callable, compute_type: ComputeType):
b = ResNet50Benchmark(
64, (3, 224, 224), device="cuda:0", dtype=torch.bfloat16, requires_grad=is_requires_grad(compute_type)
)

args, kwargs = b.make_batch()
fn = executor(b.fn())

benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)

0 comments on commit 29ffb43

Please sign in to comment.