From 0f2390047b7c68753ec1ce58a7bbb6d076f3bed5 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Fri, 31 May 2024 16:43:57 -0400 Subject: [PATCH] - --- benchmarks/torchvision/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmarks/torchvision/main.py b/benchmarks/torchvision/main.py index f7c08c570..fac6d45eb 100644 --- a/benchmarks/torchvision/main.py +++ b/benchmarks/torchvision/main.py @@ -313,8 +313,11 @@ def iobench(args): if args.data is None and data_directory: args.data = os.path.join(data_directory, "FakeImageNet") - loader = dataloader(args) device = accelerator.fetch_device(0) + model = getattr(tvmodels, args.model)() + model.to(device) + + loader = dataloader(args, model) dtype = float_dtype(args.precision) for _ in range(args.epochs):