Skip to content

Commit

Permalink
Add segmentation tutorial and files
Browse files Browse the repository at this point in the history
  • Loading branch information
15bonte committed Sep 4, 2023
1 parent 8fb09e4 commit 5dba592
Show file tree
Hide file tree
Showing 19 changed files with 1,943 additions and 128 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# CNN framework

Run CNN models for classification, regression, VAE, contrastive learning with any data set.
Run CNN models for classification, regression, segmentation, VAE, contrastive learning with any data set.

## Installation

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="cnn_framework",
version="0.0.2",
version="0.0.3",
author="Thomas Bonte",
author_email="[email protected]",
description="CNN framework",
Expand Down
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
54 changes: 54 additions & 0 deletions src/cnn_framework/dummy_segmentation/data_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import albumentations as A

from ..utils.data_sets.DatasetOutput import DatasetOutput
from ..utils.enum import ProjectMethods
from ..utils.data_sets.AbstractDataSet import AbstractDataSet, DataSource


class DummyDataSet(AbstractDataSet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Data sources
self.input_data_source = DataSource(
[self.data_manager.get_microscopy_image_path],
[(ProjectMethods.Channel, ([0, 1, 2], 2))],
)

# First channel is always 255
self.output_data_source = DataSource(
[self.data_manager.get_microscopy_image_path],
[(ProjectMethods.Channel, ([0], 2))],
)

def set_transforms(self):
height, width = self.params.input_dimensions.to_tuple(False)
if self.is_train:
self.transforms = A.Compose(
[
A.Normalize(
self.mean_std["mean"], std=self.mean_std["std"], max_pixel_value=1
),
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=0, p=1),
A.CenterCrop(height=height, width=width, p=1),
A.Rotate(border_mode=0),
A.HorizontalFlip(),
A.VerticalFlip(),
]
)
else:
self.transforms = A.Compose(
[
A.Normalize(
self.mean_std["mean"], std=self.mean_std["std"], max_pixel_value=1
),
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=0, p=1),
A.CenterCrop(height=height, width=width, p=1),
]
)

def generate_raw_images(self, filename):
return DatasetOutput(
input=self.input_data_source.get_image(filename, axis_to_merge=2),
target_image=self.output_data_source.get_image(filename, axis_to_merge=2),
)
11 changes: 11 additions & 0 deletions src/cnn_framework/dummy_segmentation/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import segmentation_models_pytorch as smp


class UNet(smp.Unet):
def __init__(self, nb_classes, nb_input_channels):
super().__init__(
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
in_channels=nb_input_channels, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=nb_classes, # model output channels (number of classes in your dataset),
)
17 changes: 17 additions & 0 deletions src/cnn_framework/dummy_segmentation/model_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ..utils.ModelParams import ModelParams
from ..utils.dimensions import Dimensions


class DummyModelParams(ModelParams):
"""
Segmentation model params.
"""

def __init__(self):
super().__init__("dummy_segmentation")

self.input_dimensions = Dimensions(height=128, width=128)
self.learning_rate = 1e-4

self.out_channels = 1
self.nb_modalities = 3
41 changes: 41 additions & 0 deletions src/cnn_framework/dummy_segmentation/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch

from ..utils.DataManagers import DefaultDataManager
from ..utils.data_loader_generators.DataLoaderGenerator import DataLoaderGenerator
from ..utils.metrics import PCC
from ..utils.model_managers.ModelManager import ModelManager
from ..utils.parsers.CnnParser import CnnParser

from .data_set import DummyDataSet
from .model_params import DummyModelParams
from .model import UNet


def main(params):
# Data loading
loader_generator = DataLoaderGenerator(params, DummyDataSet, DefaultDataManager)
_, _, test_dl = loader_generator.generate_data_loader()

# Model definition
# Load pretrained model
model = UNet(
nb_classes=params.out_channels,
nb_input_channels=params.nb_modalities * params.nb_stacks_per_modality,
)
model.load_state_dict(torch.load(params.model_load_path))

manager = ModelManager(model, params, PCC)

manager.predict(test_dl)

manager.write_useful_information()


if __name__ == "__main__":
parser = CnnParser()
args = parser.arguments_parser.parse_args()

parameters = DummyModelParams()
parameters.update(args)

main(parameters)
46 changes: 46 additions & 0 deletions src/cnn_framework/dummy_segmentation/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from torch import optim
from torch import nn

from .data_set import DummyDataSet
from .model_params import DummyModelParams
from .model import UNet

from ..utils.parsers.CnnParser import CnnParser
from ..utils.data_loader_generators.DataLoaderGenerator import DataLoaderGenerator
from ..utils.model_managers.ModelManager import ModelManager
from ..utils.DataManagers import DefaultDataManager
from ..utils.metrics import PCC


def main(params):
loader_generator = DataLoaderGenerator(params, DummyDataSet, DefaultDataManager)
train_dl, val_dl, test_dl = loader_generator.generate_data_loader()

# Load pretrained model
model = UNet(
nb_classes=params.out_channels,
nb_input_channels=params.nb_modalities * params.nb_stacks_per_modality,
)
manager = ModelManager(model, params, PCC)

optimizer = optim.Adam(
model.parameters(),
lr=float(params.learning_rate),
betas=(params.beta1, params.beta2),
) # define the optimization
loss_function = nn.L1Loss()
manager.fit(train_dl, val_dl, optimizer, loss_function)

manager.predict(test_dl)

manager.write_useful_information()


if __name__ == "__main__":
parser = CnnParser()
args = parser.arguments_parser.parse_args()

parameters = DummyModelParams()
parameters.update(args)

main(parameters)
Binary file not shown.
2 changes: 1 addition & 1 deletion src/cnn_framework/dummy_vae_model/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self):

self.input_dimensions = Dimensions(height=128, width=128)

self.num_epochs = 10
self.num_epochs = 30

self.nb_modalities = 3 # RGB or grayscale
self.nb_stacks_per_modality = 1
Expand Down
Binary file modified src/cnn_framework/utils/__pycache__/metrics.cpython-39.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions src/cnn_framework/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_score(self):
return self.metric.compute().item(), None

def reset(self):
self.metric = MeanAveragePrecision()
self.metric = MeanAveragePrecision().to(self.device)


class PCC(AbstractMetric):
Expand All @@ -57,7 +57,7 @@ def get_score(self):
return self.metric.compute().item(), None

def reset(self):
self.metric = PearsonCorrCoef()
self.metric = PearsonCorrCoef().to(self.device)


class IoU(AbstractMetric):
Expand Down
Binary file not shown.
Loading

0 comments on commit 5dba592

Please sign in to comment.