Skip to content

Commit

Permalink
refactor imagenet related code
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jun 6, 2024
1 parent 17d2db5 commit 6774538
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 521 deletions.
4 changes: 0 additions & 4 deletions benchmarks/torchvision/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,6 @@ def _main():
trainbench(args)

def iobench(args):
data_directory = os.environ.get("MILABENCH_DIR_DATA", None)
if args.data is None and data_directory:
args.data = os.path.join(data_directory, "FakeImageNet")

device = accelerator.fetch_device(0)
model = getattr(tvmodels, args.model)()
model.to(device)
Expand Down
61 changes: 8 additions & 53 deletions benchmarks/torchvision_ddp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from voir.smuggle import SmuggleWriter
from giving import give, given
import torchcompat.core as accelerator
from benchmate.dataloader import imagenet_dataloader, dataloader_arguments


def ddp_setup(rank, world_size):
Expand Down Expand Up @@ -114,50 +115,16 @@ def image_transforms():
)
return data_transforms

def prepare_dataloader(dataset: Dataset, args):
dsampler = DistributedSampler(dataset)

return DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers if not args.noio else 0,
pin_memory=not args.noio,
shuffle=False,
sampler=dsampler
)

class FakeDataset:
def __init__(self, args):
self.data = [
(torch.randn((3, 224, 224)), i % 1000) for i in range(60 * args.batch_size)
]

def __len__(self):
return len(self.data)

def __getitem__(self, item):
return self.data[item]


def dataset(args):
if args.noio:
return FakeDataset(args)
else:
data_directory = os.environ.get("MILABENCH_DIR_DATA", None)
if args.data is None and data_directory:
args.data = os.path.join(data_directory, "FakeImageNet")

return datasets.ImageFolder(os.path.join(args.data, "train"), image_transforms())
def prepare_dataloader(args, model, rank, world_size):
return imagenet_dataloader(args, model, rank, world_size)


def load_train_objs(args):
train = dataset(args)

model = getattr(torchvision_models, args.model)()

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

return train, model, optimizer
return model, optimizer


def worker_main(rank: int, world_size: int, args):
Expand All @@ -166,8 +133,9 @@ def worker_main(rank: int, world_size: int, args):

ddp_setup(rank, world_size)

dataset, model, optimizer = load_train_objs(args)
train_data = prepare_dataloader(dataset, args)
model, optimizer = load_train_objs(args)

train_data = prepare_dataloader(args, model, rank, world_size)

trainer = Trainer(model, train_data, optimizer, rank, world_size)

Expand All @@ -184,7 +152,7 @@ def worker_main(rank: int, world_size: int, args):

def main():
parser = argparse.ArgumentParser(description='simple distributed training job')
parser.add_argument('--batch-size', default=512, type=int, help='Input batch size on each device (default: 32)')
dataloader_arguments(parser)
parser.add_argument(
"--model", type=str, help="torchvision model name", default="resnet50"
)
Expand All @@ -195,26 +163,13 @@ def main():
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--num-workers",
type=int,
default=8,
help="number of workers for data loading",
)
parser.add_argument(
"--noio",
action='store_true',
default=False,
help="Disable IO by providing an in memory dataset",
)
parser.add_argument(
"--precision",
type=str,
choices=["fp16", "fp32", "tf32", "tf32-fp16"],
default="fp32",
help="Precision configuration",
)
parser.add_argument("--data", type=str, help="data directory")
args = parser.parse_args()

world_size = accelerator.device_count()
Expand Down
57 changes: 2 additions & 55 deletions benchmarks/torchvision_ddp/prepare.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,6 @@
#!/usr/bin/env python

import multiprocessing
import os
from pathlib import Path

from tqdm import tqdm


def write(args):
from torchvision.datasets import FakeData

image_size, offset, count, outdir = args
dataset = FakeData(
size=count, image_size=image_size, num_classes=1000, random_offset=offset
)

image, y = next(iter(dataset))
class_val = int(y)
image_name = f"{offset}.jpeg"

path = os.path.join(outdir, str(class_val))
os.makedirs(path, exist_ok=True)

image_path = os.path.join(path, image_name)
image.save(image_path)


def generate(image_size, n, outdir):
p_count = min(multiprocessing.cpu_count(), 8)
pool = multiprocessing.Pool(p_count)
for _ in tqdm(
pool.imap_unordered(write, ((image_size, i, n, outdir) for i in range(n))),
total=n,
):
pass


def generate_sets(root, sets, shape):
root = Path(root)
sentinel = root / "done"
if sentinel.exists():
print(f"{root} was already generated")
return
if root.exists():
print(f"{root} exists but is not marked complete; deleting")
root.rm()
for name, n in sets.items():
print(f"Generating {name}")
generate(shape, n, os.path.join(root, name))
sentinel.touch()

from benchmate.datagen import generate_fakeimagenet

if __name__ == "__main__":
data_directory = os.environ["MILABENCH_DIR_DATA"]
dest = os.path.join(data_directory, "FakeImageNet")
print(f"Generating fake data into {dest}...")
generate_sets(dest, {"train": 4096, "val": 16, "test": 16}, (3, 384, 384))
print("Done!")
generate_fakeimagenet()
21 changes: 0 additions & 21 deletions benchmate/benchmate/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,6 @@
from tqdm import tqdm



class FakeInMemoryDataset:
def __init__(self, producer, batch_size, batch_count):
self.data = [producer(i) for i in range(batch_size * batch_count)]

def __len__(self):
return len(self.data)

def __getitem__(self, item):
return self.data[item]


class FakeImageClassification(FakeInMemoryDataset):
def __init__(self, shape, batch_size, batch_count):
def producer(i):
return (torch.randn(shape), i % 1000)

super().__init__(producer, batch_size, batch_count)



def write(args):
import torch
import torchvision.transforms as transforms
Expand Down
Loading

0 comments on commit 6774538

Please sign in to comment.