From 433457173c9e040dc6772f06ee99110adc77ff71 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Thu, 6 Jun 2024 16:03:46 -0400 Subject: [PATCH] Reduce reliance on env var, use args placeholder instead --- benchmarks/huggingface/prepare.py | 5 ++ benchmarks/super-slomo/prepare.py | 11 +--- benchmarks/super-slomo/slomo/train.py | 75 ++++++++++++++++++--------- benchmarks/timm/voirfile.py | 3 +- benchmate/benchmate/datagen.py | 2 +- benchmate/benchmate/dataset.py | 20 ++++++- config/base.yaml | 3 +- 7 files changed, 80 insertions(+), 39 deletions(-) diff --git a/benchmarks/huggingface/prepare.py b/benchmarks/huggingface/prepare.py index fdcc73fcd..d1bdaf280 100755 --- a/benchmarks/huggingface/prepare.py +++ b/benchmarks/huggingface/prepare.py @@ -8,3 +8,8 @@ print(f"Preparing {args.model}") make_config = models[args.model] make_config() + + # bert dataset + # t5 dataset + # reformer dataset + # whisper dataset \ No newline at end of file diff --git a/benchmarks/super-slomo/prepare.py b/benchmarks/super-slomo/prepare.py index 781e71d8e..a7e5a2f4c 100755 --- a/benchmarks/super-slomo/prepare.py +++ b/benchmarks/super-slomo/prepare.py @@ -1,16 +1,9 @@ #!/usr/bin/env python import torchvision - - - -def download_celebA(): - # celebA use Google drive, and google drive wants to tell us that - # they cant scan for virus so the download fails - # torchvision 0.17.1 might solve this issue though but we dont have it - pass - +from benchmate.datagen import generate_fakeimagenet if __name__ == "__main__": # This will download the weights for vgg16 + generate_fakeimagenet() torchvision.models.vgg16(pretrained=True) diff --git a/benchmarks/super-slomo/slomo/train.py b/benchmarks/super-slomo/slomo/train.py index c4a4493fa..431102b42 100644 --- a/benchmarks/super-slomo/slomo/train.py +++ b/benchmarks/super-slomo/slomo/train.py @@ -10,9 +10,11 @@ import torchcompat.core as accelerator import model -from giving import give import voir.wrapper +import torchvision.transforms as transforms + from synth import SyntheticData +import dataloader def main(): @@ -75,6 +77,12 @@ def main(): action="store_false", help="do not allow tf32", ) + parser.add_argument( + "--loader", + type=str, + default="synthetic", + help="Dataloader to use", + ) args = parser.parse_args() @@ -96,35 +104,52 @@ def main(): validationFlowBackWarp = validationFlowBackWarp.to(device) ###Load Datasets + def load_dataset(): + if args.loader == "synthetic": + def igen(): + sz = 352 + f0 = torch.rand((3, sz, sz)) * 2 - 1 + ft = torch.rand((3, sz, sz)) * 2 - 1 + f1 = torch.rand((3, sz, sz)) * 2 - 1 + return [f0, ft, f1] + + def ogen(): + return torch.randint(0, 7, ()) + + trainset = SyntheticData( + n=args.train_batch_size, + repeat=10000, + generators=[igen, ogen] + ) - # # Channel wise mean calculated on adobe240-fps training dataset - # mean = [0.429, 0.431, 0.397] - # std = [1, 1, 1] - # normalize = transforms.Normalize(mean=mean, - # std=std) - # transform = transforms.Compose([transforms.ToTensor(), normalize]) + return torch.utils.data.DataLoader( + trainset, + batch_size=args.train_batch_size, + num_workers=8 + ) - # trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train', transform=transform, train=True) - # trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=False) + # Channel wise mean calculated on adobe240-fps training dataset + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.429, 0.431, 0.397], + std=[1, 1, 1] + ) + ]) - def igen(): - sz = 352 - f0 = torch.rand((3, sz, sz)) * 2 - 1 - ft = torch.rand((3, sz, sz)) * 2 - 1 - f1 = torch.rand((3, sz, sz)) * 2 - 1 - return [f0, ft, f1] + trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train', transform=transform, train=True) - def ogen(): - return torch.randint(0, 7, ()) + too_small = [] + for i, p in enumerate(trainset.framesPath): + if len(p) < 12: + too_small.append(i) - trainset = SyntheticData( - n=args.train_batch_size, repeat=10000, generators=[igen, ogen] - ) - trainloader = torch.utils.data.DataLoader( - trainset, - batch_size=args.train_batch_size, - num_workers=8 - ) + for i in reversed(too_small): + del trainset.framesPath[i] + + return torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=False) + + trainloader = load_dataset() ###Utils diff --git a/benchmarks/timm/voirfile.py b/benchmarks/timm/voirfile.py index 975446ba7..6058adef8 100644 --- a/benchmarks/timm/voirfile.py +++ b/benchmarks/timm/voirfile.py @@ -32,10 +32,9 @@ def instrument_main(ov, options: Config): import os import torchcompat.core as accelerator - from voir.wrapper import DataloaderWrapper, Wrapper + from voir.wrapper import Wrapper from timm.utils.distributed import is_global_primary - from timm.data import create_loader wrapper = Wrapper( accelerator.Event, diff --git a/benchmate/benchmate/datagen.py b/benchmate/benchmate/datagen.py index 7299a07f8..d3eacecf3 100644 --- a/benchmate/benchmate/datagen.py +++ b/benchmate/benchmate/datagen.py @@ -22,7 +22,7 @@ def write(args): offset, outdir, size = args img = torch.randn(*size) - target = torch.randint(0, 1000, size=(1,), dtype=torch.long)[0] + target = offset % 1000 # torch.randint(0, 1000, size=(1,), dtype=torch.long)[0] img = transforms.ToPILImage()(img) class_val = int(target) diff --git a/benchmate/benchmate/dataset.py b/benchmate/benchmate/dataset.py index 3c4b9d053..89ac90e4c 100644 --- a/benchmate/benchmate/dataset.py +++ b/benchmate/benchmate/dataset.py @@ -1,5 +1,6 @@ - +import os +from collections import defaultdict def no_transform(args): @@ -30,3 +31,20 @@ def __len__(self): def __getitem__(self, item): return self.transforms(self.dataset[item]) + + + +class ImageNetAsFrames: + def __init__(self, folder) -> None: + self.clip = defaultdict(list) + for root, _, files in os.walk(folder): + clip_id = root.split("/")[-1] + video = self.clip[clip_id] + for frame in files: + video.append(frame) + + def __getitem__(self, item): + return self.clip[item] + + def __len__(self): + return len(self.clip) diff --git a/config/base.yaml b/config/base.yaml index ece7820ff..e41960dc7 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -506,7 +506,6 @@ super-slomo: - video-interpolation - unet - convnet - - noio definition: ../benchmarks/super-slomo group: super-slomo install_group: torch @@ -514,6 +513,8 @@ super-slomo: method: per_gpu argv: --train_batch_size: 32 + --dataset_root: /tmp/milabench/cuda/results/data/FakeImageNet + --loader: pytorch ppo: inherits: _sb3