-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloaders.py
94 lines (73 loc) · 2.98 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
from utils import get_device
from torchvision import transforms, datasets
import os
import cv2
BATCH_SIZE = 512
image_size = 32
train_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
0.2023, 0.1994, 0.2010])
])
val_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
0.2023, 0.1994, 0.2010])
])
cifar_train = datasets.CIFAR10(
root="datasets/cifar_train", train=True, transform=train_transform, download=True)
cifar_test = datasets.CIFAR10(
root="datasets/cifar_test", train=False, transform=val_transform, download=True)
cifar_train_loader = DataLoader(
cifar_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
cifar_test_loader = DataLoader(
cifar_test, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
# 自定义数据集类
class CustomCIFAR10Dataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_files = os.listdir(image_dir)
self.label_map = {
'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9
}
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_name = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_name)
# 从文件名中提取标签
label = None
for label_name in self.label_map:
if label_name in image_name.lower(): # 忽略大小写
label = self.label_map[label_name]
break
# 确保找到了标签
if label is None:
raise ValueError(f"Image name {image_name} does not match any label.")
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转换为RGB格式
if self.transform:
image = self.transform(image)
return image, label
# 定义数据转换
custom_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((image_size, image_size)), # Resize到32x32以符合CIFAR-10标准
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
# 创建自定义数据集
custom_train_dataset = CustomCIFAR10Dataset(
image_dir="", transform=custom_transform)
# 封装数据加载器
custom_train_loader = DataLoader(
custom_train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)