Skip to content

Commit

Permalink
Reduce reliance on env var, use args placeholder instead
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jun 6, 2024
1 parent 2c1aafb commit 4334571
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 39 deletions.
5 changes: 5 additions & 0 deletions benchmarks/huggingface/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@
print(f"Preparing {args.model}")
make_config = models[args.model]
make_config()

# bert dataset
# t5 dataset
# reformer dataset
# whisper dataset
11 changes: 2 additions & 9 deletions benchmarks/super-slomo/prepare.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
#!/usr/bin/env python

import torchvision



def download_celebA():
# celebA use Google drive, and google drive wants to tell us that
# they cant scan for virus so the download fails
# torchvision 0.17.1 might solve this issue though but we dont have it
pass

from benchmate.datagen import generate_fakeimagenet

if __name__ == "__main__":
# This will download the weights for vgg16
generate_fakeimagenet()
torchvision.models.vgg16(pretrained=True)
75 changes: 50 additions & 25 deletions benchmarks/super-slomo/slomo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import torchcompat.core as accelerator

import model
from giving import give
import voir.wrapper
import torchvision.transforms as transforms

from synth import SyntheticData
import dataloader


def main():
Expand Down Expand Up @@ -75,6 +77,12 @@ def main():
action="store_false",
help="do not allow tf32",
)
parser.add_argument(
"--loader",
type=str,
default="synthetic",
help="Dataloader to use",
)

args = parser.parse_args()

Expand All @@ -96,35 +104,52 @@ def main():
validationFlowBackWarp = validationFlowBackWarp.to(device)

###Load Datasets
def load_dataset():
if args.loader == "synthetic":
def igen():
sz = 352
f0 = torch.rand((3, sz, sz)) * 2 - 1
ft = torch.rand((3, sz, sz)) * 2 - 1
f1 = torch.rand((3, sz, sz)) * 2 - 1
return [f0, ft, f1]

def ogen():
return torch.randint(0, 7, ())

trainset = SyntheticData(
n=args.train_batch_size,
repeat=10000,
generators=[igen, ogen]
)

# # Channel wise mean calculated on adobe240-fps training dataset
# mean = [0.429, 0.431, 0.397]
# std = [1, 1, 1]
# normalize = transforms.Normalize(mean=mean,
# std=std)
# transform = transforms.Compose([transforms.ToTensor(), normalize])
return torch.utils.data.DataLoader(
trainset,
batch_size=args.train_batch_size,
num_workers=8
)

# trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train', transform=transform, train=True)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=False)
# Channel wise mean calculated on adobe240-fps training dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.429, 0.431, 0.397],
std=[1, 1, 1]
)
])

def igen():
sz = 352
f0 = torch.rand((3, sz, sz)) * 2 - 1
ft = torch.rand((3, sz, sz)) * 2 - 1
f1 = torch.rand((3, sz, sz)) * 2 - 1
return [f0, ft, f1]
trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train', transform=transform, train=True)

def ogen():
return torch.randint(0, 7, ())
too_small = []
for i, p in enumerate(trainset.framesPath):
if len(p) < 12:
too_small.append(i)

trainset = SyntheticData(
n=args.train_batch_size, repeat=10000, generators=[igen, ogen]
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=args.train_batch_size,
num_workers=8
)
for i in reversed(too_small):
del trainset.framesPath[i]

return torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=False)

trainloader = load_dataset()

###Utils

Expand Down
3 changes: 1 addition & 2 deletions benchmarks/timm/voirfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ def instrument_main(ov, options: Config):

import os
import torchcompat.core as accelerator
from voir.wrapper import DataloaderWrapper, Wrapper
from voir.wrapper import Wrapper

from timm.utils.distributed import is_global_primary
from timm.data import create_loader

wrapper = Wrapper(
accelerator.Event,
Expand Down
2 changes: 1 addition & 1 deletion benchmate/benchmate/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def write(args):
offset, outdir, size = args

img = torch.randn(*size)
target = torch.randint(0, 1000, size=(1,), dtype=torch.long)[0]
target = offset % 1000 # torch.randint(0, 1000, size=(1,), dtype=torch.long)[0]
img = transforms.ToPILImage()(img)

class_val = int(target)
Expand Down
20 changes: 19 additions & 1 deletion benchmate/benchmate/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@


import os
from collections import defaultdict


def no_transform(args):
Expand Down Expand Up @@ -30,3 +31,20 @@ def __len__(self):

def __getitem__(self, item):
return self.transforms(self.dataset[item])



class ImageNetAsFrames:
def __init__(self, folder) -> None:
self.clip = defaultdict(list)
for root, _, files in os.walk(folder):
clip_id = root.split("/")[-1]
video = self.clip[clip_id]
for frame in files:
video.append(frame)

def __getitem__(self, item):
return self.clip[item]

def __len__(self):
return len(self.clip)
3 changes: 2 additions & 1 deletion config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -506,14 +506,15 @@ super-slomo:
- video-interpolation
- unet
- convnet
- noio
definition: ../benchmarks/super-slomo
group: super-slomo
install_group: torch
plan:
method: per_gpu
argv:
--train_batch_size: 32
--dataset_root: /tmp/milabench/cuda/results/data/FakeImageNet
--loader: pytorch

ppo:
inherits: _sb3
Expand Down

0 comments on commit 4334571

Please sign in to comment.