From e536a08a02b7f687e8da873235e4ed7fc0d16100 Mon Sep 17 00:00:00 2001 From: asyatrhl <123384000+asyatrhl@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:13:18 +0300 Subject: [PATCH] Add Bayer2RGB network (#250) --- ai8x.py | 15 + datasets/imagenet.py | 187 +++++++++--- models/ai85net-bayer2rgbnet.py | 56 ++++ notebooks/Bayer2RGB_Evaluation.ipynb | 375 +++++++++++++++++++++++++ policies/qat_policy_bayer2rgbnet.yaml | 3 + scripts/evaluate_imagenet_bayer2rgb.sh | 2 + scripts/train_bayer2rgb_imagenet.sh | 2 + train_all_models.sh | 3 + 8 files changed, 598 insertions(+), 45 deletions(-) create mode 100644 models/ai85net-bayer2rgbnet.py create mode 100644 notebooks/Bayer2RGB_Evaluation.ipynb create mode 100644 policies/qat_policy_bayer2rgbnet.yaml create mode 100755 scripts/evaluate_imagenet_bayer2rgb.sh create mode 100755 scripts/train_bayer2rgb_imagenet.sh diff --git a/ai8x.py b/ai8x.py index 98c9435f6..f018b9772 100644 --- a/ai8x.py +++ b/ai8x.py @@ -1869,3 +1869,18 @@ def _onnx_export_prep(m): setattr(m, attr_str, ScalerONNX()) m.apply(_onnx_export_prep) + + +class bayer_filter: + """ + Implement bayer filter to rgb images + """ + def __call__(self, img): + out = torch.zeros(1, img.shape[1], img.shape[2]) + + out[0, 0::2, 1::2] = img[2, 0::2, 1::2] + out[0, 0::2, 0::2] = img[1, 0::2, 0::2] + out[0, 1::2, 1::2] = img[1, 1::2, 1::2] + out[0, 1::2, 0::2] = img[0, 1::2, 0::2] + + return out diff --git a/datasets/imagenet.py b/datasets/imagenet.py index a7d5999bc..36b69bcc4 100644 --- a/datasets/imagenet.py +++ b/datasets/imagenet.py @@ -27,12 +27,14 @@ import os import torchvision +from torch.utils.data import Dataset from torchvision import transforms import ai8x -def imagenet_get_datasets(data, load_train=True, load_test=True, input_size=112, folder=False): +def imagenet_get_datasets(data, load_train=True, load_test=True, + input_size=112, folder=False, augment_data=True): """ Load the ImageNet 2012 Classification dataset. @@ -47,54 +49,101 @@ def imagenet_get_datasets(data, load_train=True, load_test=True, input_size=112, """ (data_dir, args) = data - if load_train: - train_transform = transforms.Compose([ - transforms.RandomResizedCrop(input_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ai8x.normalize(args=args), - ]) - - if not folder: - train_dataset = torchvision.datasets.ImageNet( - os.path.join(data_dir, 'ImageNet'), - split='train', - transform=train_transform, - ) + if augment_data: + if load_train: + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(input_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ai8x.normalize(args=args), + ]) + + if not folder: + train_dataset = torchvision.datasets.ImageNet( + os.path.join(data_dir, 'ImageNet'), + split='train', + transform=train_transform, + ) + else: + train_dataset = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'ImageNet', 'train'), + transform=train_transform, + ) else: - train_dataset = torchvision.datasets.ImageFolder( - os.path.join(data_dir, 'ImageNet', 'train'), - transform=train_transform, - ) - else: - train_dataset = None - - if load_test: - test_transform = transforms.Compose([ - transforms.Resize(int(input_size / 0.875)), - transforms.CenterCrop(input_size), - transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ai8x.normalize(args=args), - ]) - - if not folder: - test_dataset = torchvision.datasets.ImageNet( - os.path.join(data_dir, 'ImageNet'), - split='val', - transform=test_transform, - ) + train_dataset = None + + if load_test: + test_transform = transforms.Compose([ + transforms.Resize(int(input_size / 0.875)), + transforms.CenterCrop(input_size), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ai8x.normalize(args=args), + ]) + + if not folder: + test_dataset = torchvision.datasets.ImageNet( + os.path.join(data_dir, 'ImageNet'), + split='val', + transform=test_transform, + ) + else: + test_dataset = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'ImageNet', 'val'), + transform=test_transform, + ) + + if args.truncate_testset: + test_dataset.data = test_dataset.data[:1] else: - test_dataset = torchvision.datasets.ImageFolder( - os.path.join(data_dir, 'ImageNet', 'val'), - transform=test_transform, - ) + test_dataset = None - if args.truncate_testset: - test_dataset.data = test_dataset.data[:1] # type: ignore # .data exists else: - test_dataset = None + if load_train: + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(input_size), + transforms.ToTensor(), + ai8x.normalize(args=args), + ]) + + if not folder: + train_dataset = torchvision.datasets.ImageNet( + os.path.join(data_dir, 'ImageNet'), + split='train', + transform=train_transform, + ) + else: + train_dataset = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'ImageNet', 'train'), + transform=train_transform, + ) + else: + train_dataset = None + + if load_test: + test_transform = transforms.Compose([ + transforms.RandomResizedCrop(input_size), + transforms.ToTensor(), + ai8x.normalize(args=args), + ]) + + if not folder: + test_dataset = torchvision.datasets.ImageNet( + os.path.join(data_dir, 'ImageNet'), + split='val', + transform=test_transform, + ) + else: + test_dataset = torchvision.datasets.ImageFolder( + os.path.join(data_dir, 'ImageNet', 'val'), + transform=test_transform, + ) + + if args.truncate_testset: + test_dataset.data = test_dataset.data[:1] + else: + test_dataset = None return train_dataset, test_dataset @@ -108,6 +157,48 @@ def imagenetfolder_get_datasets(data, load_train=True, load_test=True, input_siz return imagenet_get_datasets(data, load_train, load_test, input_size, folder=True) +class Bayer_Dataset_Adapter(Dataset): + """ + Implement the transforms to generate bayer filtered images from RGB images, + and fold the input data. + Change the target data as the input images. + """ + def __init__(self, dataset, fold_ratio): + self.dataset = dataset + self.fold_ratio = fold_ratio + self.transform = transforms.Compose([ai8x.bayer_filter(), + ai8x.fold(fold_ratio=fold_ratio), + ]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + image = self.dataset[idx][0] + data = self.transform(image) + return data, image + + +def imagenet_bayer_fold_2_get_dataset(data, load_train=True, load_test=True, fold_ratio=2): + """ + Load the ImageNet 2012 Classification dataset using ImageNet. + This function is used to modify the image dataset for debayerization network. + Obtain raw images from RGB images. + """ + + train_dataset, test_dataset = imagenet_get_datasets( + data, load_train, load_test, input_size=128, augment_data=False + ) + + if load_train: + train_dataset = Bayer_Dataset_Adapter(train_dataset, fold_ratio=fold_ratio) + + if load_test: + test_dataset = Bayer_Dataset_Adapter(test_dataset, fold_ratio=fold_ratio) + + return train_dataset, test_dataset + + datasets = [ { 'name': 'ImageNet', @@ -121,4 +212,10 @@ def imagenetfolder_get_datasets(data, load_train=True, load_test=True, input_siz 'output': list(map(str, range(50))), 'loader': imagenetfolder_get_datasets, }, + { + 'name': 'ImageNet_Bayer', + 'input': (4, 64, 64), + 'output': ('rgb'), + 'loader': imagenet_bayer_fold_2_get_dataset, + } ] diff --git a/models/ai85net-bayer2rgbnet.py b/models/ai85net-bayer2rgbnet.py new file mode 100644 index 000000000..cb87e5f90 --- /dev/null +++ b/models/ai85net-bayer2rgbnet.py @@ -0,0 +1,56 @@ +################################################################################################### +# +# Copyright © 2023 Analog Devices, Inc. All Rights Reserved. +# This software is proprietary and confidential to Analog Devices, Inc. and its licensors. +# +################################################################################################### +""" +Bayer to Rgb network for AI85 +""" +from torch import nn + +import ai8x + + +class bayer2rgbnet(nn.Module): + """ + Bayer to RGB Network Model + """ + def __init__( + self, + num_classes=None, # pylint: disable=unused-argument + num_channels=4, + dimensions=(64, 64), # pylint: disable=unused-argument + bias=False, + **kwargs): # pylint: disable=unused-argument + + super().__init__() + + self.l1 = ai8x.Conv2d(num_channels, 3, kernel_size=1, padding=0, bias=bias) + self.l2 = ai8x.ConvTranspose2d(3, 3, kernel_size=3, padding=1, stride=2, bias=bias) + self.l3 = ai8x.Conv2d(3, 3, kernel_size=3, padding=1, bias=bias) + + def forward(self, x): + """Forward prop""" + x = self.l1(x) + x = self.l2(x) + x = self.l3(x) + + return x + + +def b2rgb(pretrained: bool = False, **kwargs): + """ + Constructs a bayer2rgbnet model + """ + assert not pretrained + return bayer2rgbnet(**kwargs) + + +models = [ + { + 'name': 'bayer2rgbnet', + 'min_input': 1, + 'dim': 2, + }, +] diff --git a/notebooks/Bayer2RGB_Evaluation.ipynb b/notebooks/Bayer2RGB_Evaluation.ipynb new file mode 100644 index 000000000..24059aec0 --- /dev/null +++ b/notebooks/Bayer2RGB_Evaluation.ipynb @@ -0,0 +1,375 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "c97a7c94", + "metadata": {}, + "source": [ + "# Bayer2RGB Usage Demonstration For Other Networks\n", + "\n", + "With this notebook, you can use a Bayer2RGB model and its quantized checkpoint for debayerization of RGB image and see the accuracy difference between bilinear interpolation and Bayer2RGB debayered image on the \"ai87net-imagenet-effnetv2\" model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a65b4955", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "###################################################################################################\n", + "#\n", + "# Copyright © 2023 Analog Devices, Inc. All Rights Reserved.\n", + "# This software is proprietary and confidential to Analog Devices, Inc. and its licensors.\n", + "#\n", + "###################################################################################################import cv2\n", + "import importlib\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import sys\n", + "import torch\n", + "from torch.utils import data\n", + "import cv2\n", + "\n", + "\n", + "%matplotlib inline\n", + "\n", + "sys.path.append(os.path.dirname(os.getcwd()))\n", + "sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'models'))\n", + "\n", + "import ai8x\n", + "import parse_qat_yaml\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d336ad92", + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append(os.path.join(os.getcwd(), './models/'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "582c27d8", + "metadata": {}, + "outputs": [], + "source": [ + "b2rgb = importlib.import_module(\"ai85net-bayer2rgbnet\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "943067bc", + "metadata": {}, + "outputs": [], + "source": [ + "class Args:\n", + " def __init__(self, act_mode_8bit):\n", + " self.act_mode_8bit = act_mode_8bit\n", + " self.truncate_testset = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fc25a42", + "metadata": {}, + "outputs": [], + "source": [ + "act_mode_8bit = True # For evaluation mode, input/output range: -128, 127\n", + "\n", + "test_batch_size = 1\n", + "\n", + "args = Args(act_mode_8bit=act_mode_8bit)\n", + "\n", + "checkpoint_path_b2rgb = \"../../ai8x-synthesis/trained/ai85-b2rgb-qat8-q.pth.tar\"\n", + "\n", + "qat_yaml_file_used_in_training_b2rgb = '../policies/qat_policy_imagenet.yaml'\n", + "\n", + "ai_device = 87\n", + "round_avg = True" + ] + }, + { + "cell_type": "markdown", + "id": "5336f275", + "metadata": {}, + "source": [ + "# imagenet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d3ac0fe", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import imagenet\n", + "test_model = importlib.import_module('models.ai87net-imagenet-effnetv2')\n", + "data_path = '/data_ssd/'\n", + "checkpoint_path = \"../../ai8x-synthesis/trained/ai87-imagenet-effnet2-q.pth.tar\"\n", + "qat_yaml_file_used_in_training = '../policies/qat_policy_imagenet.yaml'\n", + "\n", + "# Dataset used for Biliner Interpolation\n", + "_, test_set_inter = imagenet.imagenet_bayer_fold_2_get_dataset((data_path, args), load_train=False, load_test=True, fold_ratio=1)\n", + "\n", + "# Dataset used for Bayer2RGB Debayerization\n", + "_, test_set = imagenet.imagenet_bayer_fold_2_get_dataset((data_path, args), load_train=False, load_test=True, fold_ratio=2)\n", + "\n", + "# Original dataset\n", + "_, test_set_original = imagenet.imagenet_get_datasets((data_path, args), load_train=False, load_test=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acad3fd6", + "metadata": {}, + "outputs": [], + "source": [ + "test_dataloader_inter = data.DataLoader(test_set_inter, batch_size=test_batch_size, shuffle=False)\n", + "test_dataloader = data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False)\n", + "test_dataloader_original = data.DataLoader(test_set_original, batch_size=test_batch_size, shuffle=False)\n", + "print(len(test_dataloader))\n", + "print(len(test_dataloader.dataset))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cc02416", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "qat_policy_bayer2rgb = parse_qat_yaml.parse(qat_yaml_file_used_in_training_b2rgb)\n", + "qat_policy = parse_qat_yaml.parse(qat_yaml_file_used_in_training)\n", + "\n", + "ai8x.set_device(device=ai_device, simulate=act_mode_8bit, round_avg=round_avg)\n", + "\n", + "model_bayer2rgb = b2rgb.bayer2rgbnet().to(device)\n" + ] + }, + { + "cell_type": "markdown", + "id": "aa8e4c3a", + "metadata": {}, + "source": [ + "Run one of the following models according to the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c308ad5", + "metadata": {}, + "outputs": [], + "source": [ + "model = test_model.AI87ImageNetEfficientNetV2(bias=\"--use-bias\").to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1579d9d7", + "metadata": {}, + "outputs": [], + "source": [ + "model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26533322", + "metadata": {}, + "outputs": [], + "source": [ + "model_bayer2rgb.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37758e6b", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# fuse the BN parameters into conv layers before Quantization Aware Training (QAT)\n", + "ai8x.fuse_bn_layers(model_bayer2rgb)\n", + "ai8x.fuse_bn_layers(model)\n", + "\n", + "# switch model from unquantized to quantized for QAT\n", + "ai8x.initiate_qat(model_bayer2rgb, qat_policy_bayer2rgb)\n", + "ai8x.initiate_qat(model, qat_policy)\n", + "\n", + "checkpoint_b2rgb = torch.load(checkpoint_path_b2rgb,map_location = device)\n", + "checkpoint = torch.load(checkpoint_path,map_location = device)\n", + "\n", + "model_bayer2rgb.load_state_dict(checkpoint_b2rgb['state_dict'], strict=False)\n", + "model.load_state_dict(checkpoint['state_dict'], strict=False)\n", + "\n", + "ai8x.update_model(model_bayer2rgb)\n", + "model_bayer2rgb = model_bayer2rgb.to(device)\n", + "ai8x.update_model(model)\n", + "model = model.to(device)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "79f325e1", + "metadata": {}, + "source": [ + "# Bayer-to-RGB + AI87ImageNetEfficientNetV2 Model\n", + "Bayer2RGB model is used before AI87ImageNetEfficientNetV2 to obtain RGB images from bayered images and then AI87ImageNetEfficientNetV2 model is evaluated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46d513cb", + "metadata": {}, + "outputs": [], + "source": [ + "model_bayer2rgb.eval()\n", + "model.eval()\n", + "correct = 0\n", + "with torch.no_grad():\n", + " for (image1, label1), (image2, label2) in zip(test_dataloader, test_dataloader_original):\n", + " image = image1.to(device)\n", + "\n", + " primer_out = model_bayer2rgb(image)\n", + "\n", + " model_out = model(primer_out)\n", + " result = np.argmax(model_out.cpu())\n", + "\n", + " if(label2 == result):\n", + " correct = correct + 1 \n", + " if correct % 15 == 0:\n", + " print(\"accuracy:\")\n", + " print(correct / len(test_set))\n", + "\n", + "print(\"accuracy:\")\n", + "print(correct / len(test_set))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "abc140e4", + "metadata": {}, + "source": [ + "# Model\n", + "Original Dataset is used to evaluate AI87ImageNetEfficientNetV2 model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "173f9d5d", + "metadata": {}, + "outputs": [], + "source": [ + "model.eval()\n", + "correct = 0\n", + "with torch.no_grad():\n", + " for image, label in test_dataloader_original:\n", + " image = image.to(device)\n", + " model_out = model(image)\n", + " result = np.argmax(model_out.cpu())\n", + "\n", + " if(label == result):\n", + " correct = correct + 1\n", + "\n", + " if correct % 15 == 0:\n", + " print(\"accuracy:\")\n", + " print(correct / len(test_set))\n", + "\n", + "print(\"accuracy:\")\n", + "print(correct / len(test_set))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bea82bf8", + "metadata": {}, + "source": [ + "# Bilinear Interpolation + Model\n", + "Bilinear Interpolation is used before AI87ImageNetEfficientNetV2 to obtain RGB images from bayered images and then model is evaluated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5fe86ec", + "metadata": {}, + "outputs": [], + "source": [ + "model.eval()\n", + "correct = 0\n", + "with torch.no_grad():\n", + " for (image1, label1), (image2, label2) in zip(test_dataloader_inter, test_dataloader_original):\n", + " image = image1.to(device)\n", + "\n", + " img = (128+(image[0].cpu().detach().numpy().transpose(1,2,0))).astype(np.uint8)\n", + " img = cv2.cvtColor(img,cv2.COLOR_BayerGR2RGB)\n", + "\n", + " out_tensor = torch.Tensor(((img.transpose(2,0,1).astype(np.float32))/128-1)).to(device)\n", + " out_tensor = out_tensor.unsqueeze(0)\n", + " model_out = model(out_tensor)\n", + " result = np.argmax(model_out.cpu())\n", + "\n", + " if(label2 == result):\n", + " correct = correct + 1\n", + "\n", + " if correct % 15 == 0:\n", + " print(\"accuracy:\")\n", + " print(correct / len(test_set))\n", + "\n", + "print(\"accuracy:\")\n", + "print(correct / len(test_set))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.11" + }, + "vscode": { + "interpreter": { + "hash": "570feb405e2e27c949193ac68f46852414290d515b0ba6e5d90d076ed2284471" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/policies/qat_policy_bayer2rgbnet.yaml b/policies/qat_policy_bayer2rgbnet.yaml new file mode 100644 index 000000000..d4b48c041 --- /dev/null +++ b/policies/qat_policy_bayer2rgbnet.yaml @@ -0,0 +1,3 @@ +--- +start_epoch: 5 +weight_bits: 8 diff --git a/scripts/evaluate_imagenet_bayer2rgb.sh b/scripts/evaluate_imagenet_bayer2rgb.sh new file mode 100755 index 000000000..3ca23e009 --- /dev/null +++ b/scripts/evaluate_imagenet_bayer2rgb.sh @@ -0,0 +1,2 @@ +#!/bin/sh +python train.py --model bayer2rgbnet --dataset ImageNet_Bayer --evaluate --device MAX78000 --regression --exp-load-weights-from ../ai8x-synthesis/trained/ai85-bayer2rgb-qat8-q.pth.tar -8 "$@" diff --git a/scripts/train_bayer2rgb_imagenet.sh b/scripts/train_bayer2rgb_imagenet.sh new file mode 100755 index 000000000..d8687077a --- /dev/null +++ b/scripts/train_bayer2rgb_imagenet.sh @@ -0,0 +1,2 @@ +#!/bin/sh +python train.py --regression --deterministic --epochs 100 --optimizer Adam --lr 0.01 --model bayer2rgbnet --dataset ImageNet_Bayer --device MAX78000 --batch-size 64 --print-freq 1000 --validation-split 0 --qat-policy policies/qat_policy_bayer2rgbnet.yaml "$@" diff --git a/train_all_models.sh b/train_all_models.sh index 3bd8df73a..5b23dcac5 100755 --- a/train_all_models.sh +++ b/train_all_models.sh @@ -50,3 +50,6 @@ scripts/train_svhn_tinierssd.sh "$@" echo "-----------------------------" echo "Training Tiny SSD face detection model" scripts/train_facedet_tinierssd.sh "$@" +echo "-----------------------------" +echo "Training Bayer2RGB debayerization model" +scripts/train_bayer2rgb_imagenet.sh "$@"