Skip to content

Commit

Permalink
Merge pull request #367 from oddkiva/enh-learn-pytorch
Browse files Browse the repository at this point in the history
ENH: resnet50, homography as a differentiable block in PyTorch
  • Loading branch information
oddkiva authored Jan 5, 2024
2 parents a01eb47 + 38c7345 commit 0b22919
Show file tree
Hide file tree
Showing 12 changed files with 432 additions and 108 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ doc/source/xml/
html/
latex/

config.toml

# Model weights.
*.weights
*.onnx
Expand All @@ -21,13 +19,15 @@ CMakeLists.txt.user
*.pyc
*.pyo
*.coverage
config.toml

# Rmarkdown
.RData
.Rhistory
sara_book.log
_book/
_bookdown_files/
sara_book.log
sara.svg

# cache folder
.cache
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set(CMAKE_CXX_EXTENSIONS OFF)

# Set the version major and minor manually.
set(DO_Sara_VERSION_MAJOR 1)
set(DO_Sara_VERSION_MINOR 7)
set(DO_Sara_VERSION_MINOR 11)

if(APPLE)
set(CMAKE_MACOSX_RPATH ON)
Expand Down
2 changes: 1 addition & 1 deletion python/oddkiva/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
CONFIG_FILE_PATH= pathlib.Path(__file__).parent / 'config.toml'
with open(CONFIG_FILE_PATH, 'rb') as f:
CONFIG = tomllib.load(f)
DATA_DIR_PATH = CONFIG['data']['path']
DATA_DIR_PATH = pathlib.Path(CONFIG['data']['path'])
13 changes: 13 additions & 0 deletions python/oddkiva/brahma/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import logging

import torch

logging.basicConfig(level=logging.DEBUG)


DEFAULT_DEVICE = (
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
logging.info(f"Default device selected as: {DEFAULT_DEVICE}")
116 changes: 116 additions & 0 deletions python/oddkiva/brahma/torch/classification/resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch.nn as nn


class ConvBNA(nn.Module):

def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
stride: int, batch_normalize: bool, activation: str, id: int):
super(ConvBNA, self).__init__()
self.layers = nn.Sequential()

pad_size = (kernel_size - 1) // 2

# Add the convolutional layer
conv = nn.Conv2d(
in_channels, out_channels, kernel_size, stride,
padding=pad_size,
bias=True,
padding_mode='zeros' # Let's be explicit about the padding value
)
self.layers.add_module(f'conv_{id}', conv)
if batch_normalize:
self.layers.add_module(f'batch_norm_{id}', nn.BatchNorm2d(out_channels))

# Add the activation layer
if activation == 'leaky':
activation_fn = nn.LeakyReLU(0.1, inplace=True)
elif activation == 'relu':
activation_fn = nn.ReLU(inplace=True)
elif activation == 'mish':
activation_fn = nn.Mish()
elif activation == 'linear':
activation_fn = nn.Identity(inplace=True)
elif activation == 'logistic':
activation_fn = nn.Sigmoid()
else:
raise ValueError(f'No convolutional activation named {activation}')
self.layers.add_module(f'{activation}_{id}', activation_fn);

def forward(self, x):
return self.layers.forward(x)


class ResidualBottleneckBlock(nn.Module):

def __init__(self,
in_channels: int,
out_channels: int,
stride: int = 1,
activation: str = 'relu'):
super().__init__()
self.convs = nn.Sequential(
ConvBNA(in_channels, out_channels, 1, stride, True, activation, 0),
ConvBNA(out_channels, out_channels, 3, 1, True, activation, 1),
ConvBNA(out_channels, out_channels * (2 ** 2), 1, 1, True, activation, 2),
)

self.shortcut = ConvBNA(in_channels, out_channels * (2 ** 2), 1, stride,
False, 'linear', 0)

# Add the activation layer
if activation == 'leaky':
self.activation = nn.LeakyReLU(0.1, inplace=True)
elif activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'mish':
self.activation = nn.Mish()
elif activation == 'linear':
self.activation = nn.Identity(inplace=True)
elif activation == 'logistic':
self.activation = nn.Sigmoid()
else:
raise ValueError(f'No convolutional activation named {activation}')

def forward(self, x):
return self.activation(self.convs.forward(x) + self.shortcut(x))


class ResNet50(nn.Module):

def __init__(self):
super().__init__()
self.blocks = nn.Sequential(
ConvBNA(3, 64, 7, 2, True, 'relu', 0),
nn.AvgPool2d(3, 2),
# P0
nn.Sequential(
ResidualBottleneckBlock(64, 64, 1, 'relu'),
ResidualBottleneckBlock(256, 64, 1, 'relu'),
ResidualBottleneckBlock(256, 64, 1, 'relu'),
),
# P1
nn.Sequential(
ResidualBottleneckBlock(256, 128, 2, 'relu'),
ResidualBottleneckBlock(512, 128, 1, 'relu'),
ResidualBottleneckBlock(512, 128, 1, 'relu'),
ResidualBottleneckBlock(512, 128, 1, 'relu'),
),
# P2
nn.Sequential(
ResidualBottleneckBlock(512, 256, 2, 'relu'),
ResidualBottleneckBlock(1024, 256, 1, 'relu'),
ResidualBottleneckBlock(1024, 256, 1, 'relu'),
ResidualBottleneckBlock(1024, 256, 1, 'relu'),
ResidualBottleneckBlock(1024, 256, 1, 'relu'),
ResidualBottleneckBlock(1024, 256, 1, 'relu'),
),
# P3
nn.Sequential(
ResidualBottleneckBlock(1024, 512, 2, 'relu'),
ResidualBottleneckBlock(2048, 512, 1, 'relu'),
ResidualBottleneckBlock(2048, 512, 1, 'relu'),
),
)

def forward(self, x):
return self.blocks.forward(x)
38 changes: 38 additions & 0 deletions python/oddkiva/brahma/torch/classification/test/test_resnet50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np

import torch

import oddkiva.brahma.torch.classification.resnet50 as R
from oddkiva.brahma.torch import DEFAULT_DEVICE


def test_conv_bn_activation_block():
x_np = np.arange(9).reshape(1, 1, 3, 3).astype(np.float32)
x = torch.tensor(x_np, device=DEFAULT_DEVICE)

conv_bn_act = R.ConvBNA(1, 64, 3, 1, True, 'relu', 0).to(DEFAULT_DEVICE)
y = conv_bn_act.forward(x)

conv = conv_bn_act.layers[0]
bn = conv_bn_act.layers[1]
assert y.shape == (1, 64, 3, 3)


def test_residual_bottleneck_block():
x_np = np.arange(9).reshape(1, 1, 3, 3).astype(np.float32)
x = torch.tensor(x_np, device=DEFAULT_DEVICE)

block = R.ResidualBottleneckBlock(1, 8, 2).to(DEFAULT_DEVICE)
y = block.forward(x)

assert y.shape == (1, 32, 2, 2)


def test_resnet50():
x_np = np.zeros((1, 3, 256, 256)).astype(np.float32)
x = torch.tensor(x_np, device=DEFAULT_DEVICE)

resnet50 = R.ResNet50().to(DEFAULT_DEVICE)
y = resnet50.forward(x)

assert y.shape == (1, 2048, 8, 8)
45 changes: 45 additions & 0 deletions python/oddkiva/brahma/torch/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)

test_data = datasets.FashionMNIST(
root='data',
train=False,
download=True,
transform=ToTensor()
)

labels = [
"T-shirt",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle Boot",
]

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(cols * rows):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i + 1)
plt.title(labels[label])
plt.axis('off')
plt.imshow(img.squeeze(), cmap='gray')
plt.show()
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from pathlib import Path

from PIL import Image

import numpy as np

import matplotlib.pyplot as plt

import torch
import torchvision.transforms.v2 as v2

import oddkiva.brahma.torch.image_processing.warp as W
from oddkiva.brahma.torch import DEFAULT_DEVICE


def rotation(theta):
return np.array([[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0],
[ 0, 0, 1]])


THIS_FILE = __file__
SARA_SOURCE_DIR_PATH = Path(THIS_FILE[:THIS_FILE.find('sara') + len('sara')])
SARA_DATA_DIR_PATH = SARA_SOURCE_DIR_PATH / 'data'
DOG_IMAGE_PATH = SARA_DATA_DIR_PATH / 'dog.jpg'
assert DOG_IMAGE_PATH.exists()

# Image format converters.
to_float_chw = v2.Compose([v2.ToImage(),
v2.ToDtype(torch.float32, scale=True)])
to_uint8_hwc = v2.Compose([v2.ToDtype(torch.uint8, scale=True),
v2.ToPILImage()])


# Image input
image = to_float_chw(Image.open(DOG_IMAGE_PATH)).to(DEFAULT_DEVICE)
image = image[None, :]

# Geometric transform input.
R = torch.Tensor(rotation(np.pi / 6))

# Differential geometry block
H = W.Homography()
H.homography.data = R
H = H.to(DEFAULT_DEVICE)

image_warped = H.forward(image)
image_warped_hwc = to_uint8_hwc(image_warped[0])

plt.imshow(image_warped_hwc)
plt.show()
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

import oddkiva.brahma.torch.image_processing.warp as W

def test_enumerate_coords():
coords = W.enumerate_coords(3, 4)
assert torch.equal(
coords,
torch.Tensor([
[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
]))

def test_filter_coords():
coords = torch.Tensor([[0, 1],
[1, 2]])

w, h = 3, 4
x = torch.zeros((h, w))

ixs = (coords[1,:] * w + coords[0, :]).int()
x.flatten()[ixs] = 1

assert torch.equal(
x,
torch.Tensor([
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 0]
]))

def test_bilinear_interpolation_2d():
values = torch.Tensor([[0., 1.],
[2., 3.],
[4., 5.]])

x = torch.Tensor([0.5, 0.5, 0.5])
y = torch.Tensor([0.5, 1.5, 1.5])
coords = torch.stack((x, y))

interp_values, valid_coords = W.bilinear_interpolation_2d(values, coords)
assert torch.equal(
interp_values,
torch.Tensor([1.5, 3.5, 3.5])
)
assert torch.equal(
valid_coords,
torch.stack((x, y))
)
Loading

0 comments on commit 0b22919

Please sign in to comment.