Skip to content

Commit

Permalink
Add Bayer2RGB network (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
asyatrhl authored Sep 18, 2023
1 parent ae0a475 commit e536a08
Show file tree
Hide file tree
Showing 8 changed files with 598 additions and 45 deletions.
15 changes: 15 additions & 0 deletions ai8x.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,3 +1869,18 @@ def _onnx_export_prep(m):
setattr(m, attr_str, ScalerONNX())

m.apply(_onnx_export_prep)


class bayer_filter:
"""
Implement bayer filter to rgb images
"""
def __call__(self, img):
out = torch.zeros(1, img.shape[1], img.shape[2])

out[0, 0::2, 1::2] = img[2, 0::2, 1::2]
out[0, 0::2, 0::2] = img[1, 0::2, 0::2]
out[0, 1::2, 1::2] = img[1, 1::2, 1::2]
out[0, 1::2, 0::2] = img[0, 1::2, 0::2]

return out
187 changes: 142 additions & 45 deletions datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
import os

import torchvision
from torch.utils.data import Dataset
from torchvision import transforms

import ai8x


def imagenet_get_datasets(data, load_train=True, load_test=True, input_size=112, folder=False):
def imagenet_get_datasets(data, load_train=True, load_test=True,
input_size=112, folder=False, augment_data=True):
"""
Load the ImageNet 2012 Classification dataset.
Expand All @@ -47,54 +49,101 @@ def imagenet_get_datasets(data, load_train=True, load_test=True, input_size=112,
"""
(data_dir, args) = data

if load_train:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
ai8x.normalize(args=args),
])

if not folder:
train_dataset = torchvision.datasets.ImageNet(
os.path.join(data_dir, 'ImageNet'),
split='train',
transform=train_transform,
)
if augment_data:
if load_train:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
ai8x.normalize(args=args),
])

if not folder:
train_dataset = torchvision.datasets.ImageNet(
os.path.join(data_dir, 'ImageNet'),
split='train',
transform=train_transform,
)
else:
train_dataset = torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'ImageNet', 'train'),
transform=train_transform,
)
else:
train_dataset = torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'ImageNet', 'train'),
transform=train_transform,
)
else:
train_dataset = None

if load_test:
test_transform = transforms.Compose([
transforms.Resize(int(input_size / 0.875)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
ai8x.normalize(args=args),
])

if not folder:
test_dataset = torchvision.datasets.ImageNet(
os.path.join(data_dir, 'ImageNet'),
split='val',
transform=test_transform,
)
train_dataset = None

if load_test:
test_transform = transforms.Compose([
transforms.Resize(int(input_size / 0.875)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
ai8x.normalize(args=args),
])

if not folder:
test_dataset = torchvision.datasets.ImageNet(
os.path.join(data_dir, 'ImageNet'),
split='val',
transform=test_transform,
)
else:
test_dataset = torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'ImageNet', 'val'),
transform=test_transform,
)

if args.truncate_testset:
test_dataset.data = test_dataset.data[:1]
else:
test_dataset = torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'ImageNet', 'val'),
transform=test_transform,
)
test_dataset = None

if args.truncate_testset:
test_dataset.data = test_dataset.data[:1] # type: ignore # .data exists
else:
test_dataset = None
if load_train:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.ToTensor(),
ai8x.normalize(args=args),
])

if not folder:
train_dataset = torchvision.datasets.ImageNet(
os.path.join(data_dir, 'ImageNet'),
split='train',
transform=train_transform,
)
else:
train_dataset = torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'ImageNet', 'train'),
transform=train_transform,
)
else:
train_dataset = None

if load_test:
test_transform = transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.ToTensor(),
ai8x.normalize(args=args),
])

if not folder:
test_dataset = torchvision.datasets.ImageNet(
os.path.join(data_dir, 'ImageNet'),
split='val',
transform=test_transform,
)
else:
test_dataset = torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'ImageNet', 'val'),
transform=test_transform,
)

if args.truncate_testset:
test_dataset.data = test_dataset.data[:1]
else:
test_dataset = None

return train_dataset, test_dataset

Expand All @@ -108,6 +157,48 @@ def imagenetfolder_get_datasets(data, load_train=True, load_test=True, input_siz
return imagenet_get_datasets(data, load_train, load_test, input_size, folder=True)


class Bayer_Dataset_Adapter(Dataset):
"""
Implement the transforms to generate bayer filtered images from RGB images,
and fold the input data.
Change the target data as the input images.
"""
def __init__(self, dataset, fold_ratio):
self.dataset = dataset
self.fold_ratio = fold_ratio
self.transform = transforms.Compose([ai8x.bayer_filter(),
ai8x.fold(fold_ratio=fold_ratio),
])

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
image = self.dataset[idx][0]
data = self.transform(image)
return data, image


def imagenet_bayer_fold_2_get_dataset(data, load_train=True, load_test=True, fold_ratio=2):
"""
Load the ImageNet 2012 Classification dataset using ImageNet.
This function is used to modify the image dataset for debayerization network.
Obtain raw images from RGB images.
"""

train_dataset, test_dataset = imagenet_get_datasets(
data, load_train, load_test, input_size=128, augment_data=False
)

if load_train:
train_dataset = Bayer_Dataset_Adapter(train_dataset, fold_ratio=fold_ratio)

if load_test:
test_dataset = Bayer_Dataset_Adapter(test_dataset, fold_ratio=fold_ratio)

return train_dataset, test_dataset


datasets = [
{
'name': 'ImageNet',
Expand All @@ -121,4 +212,10 @@ def imagenetfolder_get_datasets(data, load_train=True, load_test=True, input_siz
'output': list(map(str, range(50))),
'loader': imagenetfolder_get_datasets,
},
{
'name': 'ImageNet_Bayer',
'input': (4, 64, 64),
'output': ('rgb'),
'loader': imagenet_bayer_fold_2_get_dataset,
}
]
56 changes: 56 additions & 0 deletions models/ai85net-bayer2rgbnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
###################################################################################################
#
# Copyright © 2023 Analog Devices, Inc. All Rights Reserved.
# This software is proprietary and confidential to Analog Devices, Inc. and its licensors.
#
###################################################################################################
"""
Bayer to Rgb network for AI85
"""
from torch import nn

import ai8x


class bayer2rgbnet(nn.Module):
"""
Bayer to RGB Network Model
"""
def __init__(
self,
num_classes=None, # pylint: disable=unused-argument
num_channels=4,
dimensions=(64, 64), # pylint: disable=unused-argument
bias=False,
**kwargs): # pylint: disable=unused-argument

super().__init__()

self.l1 = ai8x.Conv2d(num_channels, 3, kernel_size=1, padding=0, bias=bias)
self.l2 = ai8x.ConvTranspose2d(3, 3, kernel_size=3, padding=1, stride=2, bias=bias)
self.l3 = ai8x.Conv2d(3, 3, kernel_size=3, padding=1, bias=bias)

def forward(self, x):
"""Forward prop"""
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)

return x


def b2rgb(pretrained: bool = False, **kwargs):
"""
Constructs a bayer2rgbnet model
"""
assert not pretrained
return bayer2rgbnet(**kwargs)


models = [
{
'name': 'bayer2rgbnet',
'min_input': 1,
'dim': 2,
},
]
Loading

0 comments on commit e536a08

Please sign in to comment.