-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
1,943 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
|
||
setup( | ||
name="cnn_framework", | ||
version="0.0.2", | ||
version="0.0.3", | ||
author="Thomas Bonte", | ||
author_email="[email protected]", | ||
description="CNN framework", | ||
|
Empty file.
Binary file added
BIN
+168 Bytes
src/cnn_framework/dummy_segmentation/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added
BIN
+1.84 KB
src/cnn_framework/dummy_segmentation/__pycache__/data_set.cpython-39.pyc
Binary file not shown.
Binary file added
BIN
+659 Bytes
src/cnn_framework/dummy_segmentation/__pycache__/model.cpython-39.pyc
Binary file not shown.
Binary file added
BIN
+839 Bytes
src/cnn_framework/dummy_segmentation/__pycache__/model_params.cpython-39.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import albumentations as A | ||
|
||
from ..utils.data_sets.DatasetOutput import DatasetOutput | ||
from ..utils.enum import ProjectMethods | ||
from ..utils.data_sets.AbstractDataSet import AbstractDataSet, DataSource | ||
|
||
|
||
class DummyDataSet(AbstractDataSet): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
# Data sources | ||
self.input_data_source = DataSource( | ||
[self.data_manager.get_microscopy_image_path], | ||
[(ProjectMethods.Channel, ([0, 1, 2], 2))], | ||
) | ||
|
||
# First channel is always 255 | ||
self.output_data_source = DataSource( | ||
[self.data_manager.get_microscopy_image_path], | ||
[(ProjectMethods.Channel, ([0], 2))], | ||
) | ||
|
||
def set_transforms(self): | ||
height, width = self.params.input_dimensions.to_tuple(False) | ||
if self.is_train: | ||
self.transforms = A.Compose( | ||
[ | ||
A.Normalize( | ||
self.mean_std["mean"], std=self.mean_std["std"], max_pixel_value=1 | ||
), | ||
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=0, p=1), | ||
A.CenterCrop(height=height, width=width, p=1), | ||
A.Rotate(border_mode=0), | ||
A.HorizontalFlip(), | ||
A.VerticalFlip(), | ||
] | ||
) | ||
else: | ||
self.transforms = A.Compose( | ||
[ | ||
A.Normalize( | ||
self.mean_std["mean"], std=self.mean_std["std"], max_pixel_value=1 | ||
), | ||
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=0, p=1), | ||
A.CenterCrop(height=height, width=width, p=1), | ||
] | ||
) | ||
|
||
def generate_raw_images(self, filename): | ||
return DatasetOutput( | ||
input=self.input_data_source.get_image(filename, axis_to_merge=2), | ||
target_image=self.output_data_source.get_image(filename, axis_to_merge=2), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import segmentation_models_pytorch as smp | ||
|
||
|
||
class UNet(smp.Unet): | ||
def __init__(self, nb_classes, nb_input_channels): | ||
super().__init__( | ||
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 | ||
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization | ||
in_channels=nb_input_channels, # model input channels (1 for gray-scale images, 3 for RGB, etc.) | ||
classes=nb_classes, # model output channels (number of classes in your dataset), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from ..utils.ModelParams import ModelParams | ||
from ..utils.dimensions import Dimensions | ||
|
||
|
||
class DummyModelParams(ModelParams): | ||
""" | ||
Segmentation model params. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__("dummy_segmentation") | ||
|
||
self.input_dimensions = Dimensions(height=128, width=128) | ||
self.learning_rate = 1e-4 | ||
|
||
self.out_channels = 1 | ||
self.nb_modalities = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
|
||
from ..utils.DataManagers import DefaultDataManager | ||
from ..utils.data_loader_generators.DataLoaderGenerator import DataLoaderGenerator | ||
from ..utils.metrics import PCC | ||
from ..utils.model_managers.ModelManager import ModelManager | ||
from ..utils.parsers.CnnParser import CnnParser | ||
|
||
from .data_set import DummyDataSet | ||
from .model_params import DummyModelParams | ||
from .model import UNet | ||
|
||
|
||
def main(params): | ||
# Data loading | ||
loader_generator = DataLoaderGenerator(params, DummyDataSet, DefaultDataManager) | ||
_, _, test_dl = loader_generator.generate_data_loader() | ||
|
||
# Model definition | ||
# Load pretrained model | ||
model = UNet( | ||
nb_classes=params.out_channels, | ||
nb_input_channels=params.nb_modalities * params.nb_stacks_per_modality, | ||
) | ||
model.load_state_dict(torch.load(params.model_load_path)) | ||
|
||
manager = ModelManager(model, params, PCC) | ||
|
||
manager.predict(test_dl) | ||
|
||
manager.write_useful_information() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = CnnParser() | ||
args = parser.arguments_parser.parse_args() | ||
|
||
parameters = DummyModelParams() | ||
parameters.update(args) | ||
|
||
main(parameters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from torch import optim | ||
from torch import nn | ||
|
||
from .data_set import DummyDataSet | ||
from .model_params import DummyModelParams | ||
from .model import UNet | ||
|
||
from ..utils.parsers.CnnParser import CnnParser | ||
from ..utils.data_loader_generators.DataLoaderGenerator import DataLoaderGenerator | ||
from ..utils.model_managers.ModelManager import ModelManager | ||
from ..utils.DataManagers import DefaultDataManager | ||
from ..utils.metrics import PCC | ||
|
||
|
||
def main(params): | ||
loader_generator = DataLoaderGenerator(params, DummyDataSet, DefaultDataManager) | ||
train_dl, val_dl, test_dl = loader_generator.generate_data_loader() | ||
|
||
# Load pretrained model | ||
model = UNet( | ||
nb_classes=params.out_channels, | ||
nb_input_channels=params.nb_modalities * params.nb_stacks_per_modality, | ||
) | ||
manager = ModelManager(model, params, PCC) | ||
|
||
optimizer = optim.Adam( | ||
model.parameters(), | ||
lr=float(params.learning_rate), | ||
betas=(params.beta1, params.beta2), | ||
) # define the optimization | ||
loss_function = nn.L1Loss() | ||
manager.fit(train_dl, val_dl, optimizer, loss_function) | ||
|
||
manager.predict(test_dl) | ||
|
||
manager.write_useful_information() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = CnnParser() | ||
args = parser.arguments_parser.parse_args() | ||
|
||
parameters = DummyModelParams() | ||
parameters.update(args) | ||
|
||
main(parameters) |
Binary file modified
BIN
+0 Bytes
(100%)
src/cnn_framework/dummy_vae_model/__pycache__/model_params.cpython-39.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
+36 Bytes
(100%)
src/cnn_framework/utils/__pycache__/metrics.cpython-39.pyc
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+163 Bytes
src/cnn_framework/utils/parsers/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Oops, something went wrong.