diff --git a/benchmarks/torchvision/prepare.py b/benchmarks/torchvision/prepare.py index c39429311..d16dab2c5 100755 --- a/benchmarks/torchvision/prepare.py +++ b/benchmarks/torchvision/prepare.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import argparse -from collection import defaultdict +from collections import defaultdict import multiprocessing import os from pathlib import Path @@ -48,13 +48,10 @@ def generate(image_size, n, outdir, start = 0): def count_images(path): - count = defaultdict(0) + count = defaultdict(int) for root, _, files in tqdm(os.walk(path)): - try: - _, split, _ = root.split('/') - count[split] += len(files) - except: - pass + split = root.split('/')[-2] + count[split] += len(files) return count @@ -71,7 +68,7 @@ def generate_sets(root, sets, shape): if current_count < count: print(f"Generating {split} (current {current_count}) (target: {count})") - generate(shape, count, os.path.join(root, split), start=current_count) + generate(shape, count - current_count, os.path.join(root, split), start=current_count) with open(sentinel, "w") as fp: json.dump(sets, fp)