-
Notifications
You must be signed in to change notification settings - Fork 0
/
PowerScale Data Loader
88 lines (74 loc) · 3.32 KB
/
PowerScale Data Loader
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import numpy as np
import argparse
import logging
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="DALI data loader for image datasets")
parser.add_argument("--data_dir", type=str, default="/mnt/RDMA/train2017", help="Directory containing images")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for data loading")
parser.add_argument("--image_size", type=int, default=224, help="Size to resize images to")
parser.add_argument("--shuffle", action="store_true", help="Whether to shuffle the dataset")
return parser.parse_args()
class ExternalInputIterator:
def __init__(self, files, batch_size, shuffle=False):
self.files = files
self.batch_size = batch_size
self.shuffle = shuffle
self.index = 0
if self.shuffle:
np.random.shuffle(self.files)
def __iter__(self):
self.index = 0
if self.shuffle:
np.random.shuffle(self.files)
return self
def __next__(self):
if self.index >= len(self.files):
raise StopIteration
batch = []
for _ in range(self.batch_size):
if self.index >= len(self.files):
break
try:
with open(self.files[self.index], 'rb') as f:
batch.append(np.frombuffer(f.read(), dtype=np.uint8))
self.index += 1
except IOError as e:
logger.error(f"Error reading file {self.files[self.index]}: {e}")
self.index += 1
return batch
@pipeline_def
def create_pipeline(image_size):
jpegs, labels = fn.external_source(source=eii, num_outputs=2, device="cpu")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, device="gpu", size=image_size)
images = fn.crop_mirror_normalize(images,
device="gpu",
dtype=types.FLOAT,
output_layout="CHW",
crop=(image_size, image_size),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
return images, labels
def main():
args = parse_args()
logger.info(f"Loading images from {args.data_dir}")
image_files = [os.path.join(args.data_dir, f) for f in os.listdir(args.data_dir) if f.endswith(('.jpg', '.png'))]
logger.info(f"Found {len(image_files)} images")
eii = ExternalInputIterator(image_files, args.batch_size, shuffle=args.shuffle)
pipe = create_pipeline(batch_size=args.batch_size, num_threads=4, device_id=0, image_size=args.image_size)
pipe.build()
dali_iter = DALIGenericIterator(pipe, ['data', 'label'], size=len(image_files))
for i, data in enumerate(dali_iter):
logger.info(f"Batch {i}: Data loaded successfully!")
logger.info(f"Shape of the batch: {data[0]['data'].shape}")
if i == 0: # Just check the first batch
break
if __name__ == "__main__":
main()