Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UNet PyTorch #11

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Dataset from Kaggle
continunet/network/data
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
KAGGLE_DATASET = "harrietstewart/continunet"
TRAIN_DATASET_PATH = "continunet/network/data"
Empty file added continunet/network/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions continunet/network/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Train the ContinUNet model on SDC1 dataset publically available by calling Kaggle API."""

import os
from kaggle.api.kaggle_api_extended import KaggleApi

from config import KAGGLE_DATASET, TRAIN_DATASET_PATH

class ApiCaller:
def __init__(self):
self.api = KaggleApi()

def download_dataset(self, dataset_name: str, save_path: str):
"""Download the dataset from Kaggle and save it to the specified path."""
os.makedirs(save_path, exist_ok=True)
self.api.dataset_download_files(dataset_name, path=save_path, unzip=True)
return self


class UnetTrainer:
def __init__(self):
self.data_api = ApiCaller()

def get_data(self):
# check if data is already downloaded
if not os.path.exists(KAGGLE_DATASET):
self.data_api.download_dataset(TRAIN_DATASET_PATH, KAGGLE_DATASET)
return self
89 changes: 89 additions & 0 deletions continunet/network/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""UNet model for semantic segmentation."""

import torch
import torch.nn as nn


class DoubleConvolutionalBlock(nn.Module):
"""Double convolutional block with batch normalization and ReLU activation."""

def __init__(self, in_channels, out_channels):
super(DoubleConvolutionalBlock, self).__init__()
self.double_convolutional_block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)

def forward(self, x):
return self.double_convolutional_block(x)


class DownsamplingBlock(nn.Module):
"""Downsampling block with max pooling and double convolutional block."""

def __init__(self, in_channels, out_channels, dropout_rate=0.1):
super(DownsamplingBlock, self).__init__()
self.downsampling_block = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
DoubleConvolutionalBlock(in_channels, out_channels),
nn.Dropout2d(dropout_rate),
)

def forward(self, x):
return self.downsampling_block(x)


class UpsamplingBlock(nn.Module):
"""Upsampling block with transposed convolution and double convolutional block."""

def __init__(self, in_channels, out_channels, dropout_rate=0.1):
super(UpsamplingBlock, self).__init__()
self.upsampling_block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2),
DoubleConvolutionalBlock(in_channels, out_channels),
nn.Dropout2d(dropout_rate),
)

def forward(self, decoder_input, encoder_output):
x = self.upsampling_block(decoder_input)
return torch.cat([x, encoder_output], dim=1)


class UNet(nn.Module):
"""UNet model for semantic segmentation."""

def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder = nn.ModuleList(
[
DoubleConvolutionalBlock(in_channels, 64),
DownsamplingBlock(64, 128),
DownsamplingBlock(128, 256),
DownsamplingBlock(256, 512),
DownsamplingBlock(512, 1024),
]
)
self.decoder = nn.ModuleList(
[
UpsamplingBlock(1024, 512),
UpsamplingBlock(512, 256),
UpsamplingBlock(256, 128),
UpsamplingBlock(128, 64),
]
)
self.output = nn.Conv2d(64, out_channels, kernel_size=1)

def forward(self, x):
encoder_outputs = []
for encoder_block in self.encoder:
x = encoder_block(x)
encoder_outputs.append(x)
encoder_outputs.pop()
for decoder_block in self.decoder:
x = decoder_block(x, encoder_outputs.pop())
x = self.output(x)
return torch.sigmoid(x)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ requires-python = ">=3.8"
dependencies = [
"astropy>=6.0",
"numpy>=1.26",
"torch>=2.2",
]

[project.optional-dependencies]
Expand All @@ -27,6 +28,7 @@ dev = [
"pre-commit",
"pytest",
"ipython",
"kaggle",
]
ci = [
"twine",
Expand Down