-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #367 from oddkiva/enh-learn-pytorch
ENH: resnet50, homography as a differentiable block in PyTorch
- Loading branch information
Showing
12 changed files
with
432 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import logging | ||
|
||
import torch | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
|
||
DEFAULT_DEVICE = ( | ||
"cuda" if torch.cuda.is_available() | ||
else "mps" if torch.backends.mps.is_available() | ||
else "cpu" | ||
) | ||
logging.info(f"Default device selected as: {DEFAULT_DEVICE}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch.nn as nn | ||
|
||
|
||
class ConvBNA(nn.Module): | ||
|
||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, | ||
stride: int, batch_normalize: bool, activation: str, id: int): | ||
super(ConvBNA, self).__init__() | ||
self.layers = nn.Sequential() | ||
|
||
pad_size = (kernel_size - 1) // 2 | ||
|
||
# Add the convolutional layer | ||
conv = nn.Conv2d( | ||
in_channels, out_channels, kernel_size, stride, | ||
padding=pad_size, | ||
bias=True, | ||
padding_mode='zeros' # Let's be explicit about the padding value | ||
) | ||
self.layers.add_module(f'conv_{id}', conv) | ||
if batch_normalize: | ||
self.layers.add_module(f'batch_norm_{id}', nn.BatchNorm2d(out_channels)) | ||
|
||
# Add the activation layer | ||
if activation == 'leaky': | ||
activation_fn = nn.LeakyReLU(0.1, inplace=True) | ||
elif activation == 'relu': | ||
activation_fn = nn.ReLU(inplace=True) | ||
elif activation == 'mish': | ||
activation_fn = nn.Mish() | ||
elif activation == 'linear': | ||
activation_fn = nn.Identity(inplace=True) | ||
elif activation == 'logistic': | ||
activation_fn = nn.Sigmoid() | ||
else: | ||
raise ValueError(f'No convolutional activation named {activation}') | ||
self.layers.add_module(f'{activation}_{id}', activation_fn); | ||
|
||
def forward(self, x): | ||
return self.layers.forward(x) | ||
|
||
|
||
class ResidualBottleneckBlock(nn.Module): | ||
|
||
def __init__(self, | ||
in_channels: int, | ||
out_channels: int, | ||
stride: int = 1, | ||
activation: str = 'relu'): | ||
super().__init__() | ||
self.convs = nn.Sequential( | ||
ConvBNA(in_channels, out_channels, 1, stride, True, activation, 0), | ||
ConvBNA(out_channels, out_channels, 3, 1, True, activation, 1), | ||
ConvBNA(out_channels, out_channels * (2 ** 2), 1, 1, True, activation, 2), | ||
) | ||
|
||
self.shortcut = ConvBNA(in_channels, out_channels * (2 ** 2), 1, stride, | ||
False, 'linear', 0) | ||
|
||
# Add the activation layer | ||
if activation == 'leaky': | ||
self.activation = nn.LeakyReLU(0.1, inplace=True) | ||
elif activation == 'relu': | ||
self.activation = nn.ReLU(inplace=True) | ||
elif activation == 'mish': | ||
self.activation = nn.Mish() | ||
elif activation == 'linear': | ||
self.activation = nn.Identity(inplace=True) | ||
elif activation == 'logistic': | ||
self.activation = nn.Sigmoid() | ||
else: | ||
raise ValueError(f'No convolutional activation named {activation}') | ||
|
||
def forward(self, x): | ||
return self.activation(self.convs.forward(x) + self.shortcut(x)) | ||
|
||
|
||
class ResNet50(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.blocks = nn.Sequential( | ||
ConvBNA(3, 64, 7, 2, True, 'relu', 0), | ||
nn.AvgPool2d(3, 2), | ||
# P0 | ||
nn.Sequential( | ||
ResidualBottleneckBlock(64, 64, 1, 'relu'), | ||
ResidualBottleneckBlock(256, 64, 1, 'relu'), | ||
ResidualBottleneckBlock(256, 64, 1, 'relu'), | ||
), | ||
# P1 | ||
nn.Sequential( | ||
ResidualBottleneckBlock(256, 128, 2, 'relu'), | ||
ResidualBottleneckBlock(512, 128, 1, 'relu'), | ||
ResidualBottleneckBlock(512, 128, 1, 'relu'), | ||
ResidualBottleneckBlock(512, 128, 1, 'relu'), | ||
), | ||
# P2 | ||
nn.Sequential( | ||
ResidualBottleneckBlock(512, 256, 2, 'relu'), | ||
ResidualBottleneckBlock(1024, 256, 1, 'relu'), | ||
ResidualBottleneckBlock(1024, 256, 1, 'relu'), | ||
ResidualBottleneckBlock(1024, 256, 1, 'relu'), | ||
ResidualBottleneckBlock(1024, 256, 1, 'relu'), | ||
ResidualBottleneckBlock(1024, 256, 1, 'relu'), | ||
), | ||
# P3 | ||
nn.Sequential( | ||
ResidualBottleneckBlock(1024, 512, 2, 'relu'), | ||
ResidualBottleneckBlock(2048, 512, 1, 'relu'), | ||
ResidualBottleneckBlock(2048, 512, 1, 'relu'), | ||
), | ||
) | ||
|
||
def forward(self, x): | ||
return self.blocks.forward(x) |
38 changes: 38 additions & 0 deletions
38
python/oddkiva/brahma/torch/classification/test/test_resnet50.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
|
||
import torch | ||
|
||
import oddkiva.brahma.torch.classification.resnet50 as R | ||
from oddkiva.brahma.torch import DEFAULT_DEVICE | ||
|
||
|
||
def test_conv_bn_activation_block(): | ||
x_np = np.arange(9).reshape(1, 1, 3, 3).astype(np.float32) | ||
x = torch.tensor(x_np, device=DEFAULT_DEVICE) | ||
|
||
conv_bn_act = R.ConvBNA(1, 64, 3, 1, True, 'relu', 0).to(DEFAULT_DEVICE) | ||
y = conv_bn_act.forward(x) | ||
|
||
conv = conv_bn_act.layers[0] | ||
bn = conv_bn_act.layers[1] | ||
assert y.shape == (1, 64, 3, 3) | ||
|
||
|
||
def test_residual_bottleneck_block(): | ||
x_np = np.arange(9).reshape(1, 1, 3, 3).astype(np.float32) | ||
x = torch.tensor(x_np, device=DEFAULT_DEVICE) | ||
|
||
block = R.ResidualBottleneckBlock(1, 8, 2).to(DEFAULT_DEVICE) | ||
y = block.forward(x) | ||
|
||
assert y.shape == (1, 32, 2, 2) | ||
|
||
|
||
def test_resnet50(): | ||
x_np = np.zeros((1, 3, 256, 256)).astype(np.float32) | ||
x = torch.tensor(x_np, device=DEFAULT_DEVICE) | ||
|
||
resnet50 = R.ResNet50().to(DEFAULT_DEVICE) | ||
y = resnet50.forward(x) | ||
|
||
assert y.shape == (1, 2048, 8, 8) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
from torchvision import datasets | ||
from torchvision.transforms import ToTensor | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
|
||
training_data = datasets.FashionMNIST( | ||
root='data', | ||
train=True, | ||
download=True, | ||
transform=ToTensor() | ||
) | ||
|
||
test_data = datasets.FashionMNIST( | ||
root='data', | ||
train=False, | ||
download=True, | ||
transform=ToTensor() | ||
) | ||
|
||
labels = [ | ||
"T-shirt", | ||
"Trouser", | ||
"Pullover", | ||
"Dress", | ||
"Coat", | ||
"Sandal", | ||
"Shirt", | ||
"Sneaker", | ||
"Bag", | ||
"Ankle Boot", | ||
] | ||
|
||
figure = plt.figure(figsize=(8, 8)) | ||
cols, rows = 3, 3 | ||
for i in range(cols * rows): | ||
sample_idx = torch.randint(len(training_data), size=(1,)).item() | ||
img, label = training_data[sample_idx] | ||
figure.add_subplot(rows, cols, i + 1) | ||
plt.title(labels[label]) | ||
plt.axis('off') | ||
plt.imshow(img.squeeze(), cmap='gray') | ||
plt.show() |
51 changes: 51 additions & 0 deletions
51
python/oddkiva/brahma/torch/image_processing/examples/image_warp_example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from pathlib import Path | ||
|
||
from PIL import Image | ||
|
||
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
import torch | ||
import torchvision.transforms.v2 as v2 | ||
|
||
import oddkiva.brahma.torch.image_processing.warp as W | ||
from oddkiva.brahma.torch import DEFAULT_DEVICE | ||
|
||
|
||
def rotation(theta): | ||
return np.array([[np.cos(theta), -np.sin(theta), 0], | ||
[np.sin(theta), np.cos(theta), 0], | ||
[ 0, 0, 1]]) | ||
|
||
|
||
THIS_FILE = __file__ | ||
SARA_SOURCE_DIR_PATH = Path(THIS_FILE[:THIS_FILE.find('sara') + len('sara')]) | ||
SARA_DATA_DIR_PATH = SARA_SOURCE_DIR_PATH / 'data' | ||
DOG_IMAGE_PATH = SARA_DATA_DIR_PATH / 'dog.jpg' | ||
assert DOG_IMAGE_PATH.exists() | ||
|
||
# Image format converters. | ||
to_float_chw = v2.Compose([v2.ToImage(), | ||
v2.ToDtype(torch.float32, scale=True)]) | ||
to_uint8_hwc = v2.Compose([v2.ToDtype(torch.uint8, scale=True), | ||
v2.ToPILImage()]) | ||
|
||
|
||
# Image input | ||
image = to_float_chw(Image.open(DOG_IMAGE_PATH)).to(DEFAULT_DEVICE) | ||
image = image[None, :] | ||
|
||
# Geometric transform input. | ||
R = torch.Tensor(rotation(np.pi / 6)) | ||
|
||
# Differential geometry block | ||
H = W.Homography() | ||
H.homography.data = R | ||
H = H.to(DEFAULT_DEVICE) | ||
|
||
image_warped = H.forward(image) | ||
image_warped_hwc = to_uint8_hwc(image_warped[0]) | ||
|
||
plt.imshow(image_warped_hwc) | ||
plt.show() |
50 changes: 50 additions & 0 deletions
50
python/oddkiva/brahma/torch/image_processing/test/test_image_processing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
|
||
import oddkiva.brahma.torch.image_processing.warp as W | ||
|
||
def test_enumerate_coords(): | ||
coords = W.enumerate_coords(3, 4) | ||
assert torch.equal( | ||
coords, | ||
torch.Tensor([ | ||
[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], | ||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], | ||
])) | ||
|
||
def test_filter_coords(): | ||
coords = torch.Tensor([[0, 1], | ||
[1, 2]]) | ||
|
||
w, h = 3, 4 | ||
x = torch.zeros((h, w)) | ||
|
||
ixs = (coords[1,:] * w + coords[0, :]).int() | ||
x.flatten()[ixs] = 1 | ||
|
||
assert torch.equal( | ||
x, | ||
torch.Tensor([ | ||
[0, 0, 0], | ||
[1, 0, 0], | ||
[0, 1, 0], | ||
[0, 0, 0] | ||
])) | ||
|
||
def test_bilinear_interpolation_2d(): | ||
values = torch.Tensor([[0., 1.], | ||
[2., 3.], | ||
[4., 5.]]) | ||
|
||
x = torch.Tensor([0.5, 0.5, 0.5]) | ||
y = torch.Tensor([0.5, 1.5, 1.5]) | ||
coords = torch.stack((x, y)) | ||
|
||
interp_values, valid_coords = W.bilinear_interpolation_2d(values, coords) | ||
assert torch.equal( | ||
interp_values, | ||
torch.Tensor([1.5, 3.5, 3.5]) | ||
) | ||
assert torch.equal( | ||
valid_coords, | ||
torch.stack((x, y)) | ||
) |
Oops, something went wrong.