diff --git a/aucmedi/neural_network/architectures/image/__init__.py b/aucmedi/neural_network/architectures/image/__init__.py index 6de1b5fc..6e1687d6 100644 --- a/aucmedi/neural_network/architectures/image/__init__.py +++ b/aucmedi/neural_network/architectures/image/__init__.py @@ -49,6 +49,9 @@ from aucmedi.neural_network.architectures.image.resnet50v2 import ResNet50V2 from aucmedi.neural_network.architectures.image.resnet101v2 import ResNet101V2 from aucmedi.neural_network.architectures.image.resnet152v2 import ResNet152V2 +# ResNeXt +from aucmedi.neural_network.architectures.image.resnext50 import ResNeXt50 +from aucmedi.neural_network.architectures.image.resnext101 import ResNeXt101 # MobileNet from aucmedi.neural_network.architectures.image.mobilenet import MobileNet from aucmedi.neural_network.architectures.image.mobilenetv2 import MobileNetV2 @@ -83,6 +86,8 @@ "ResNet50V2": ResNet50V2, "ResNet101V2": ResNet101V2, "ResNet152V2": ResNet152V2, + "ResNeXt50": ResNeXt50, + "ResNeXt101": ResNeXt101, "DenseNet121": DenseNet121, "DenseNet169": DenseNet169, "DenseNet201": DenseNet201, @@ -170,6 +175,8 @@ "ResNet50V2": "tf", "ResNet101V2": "tf", "ResNet152V2": "tf", + "ResNeXt50": "torch", + "ResNeXt101": "torch", "DenseNet121": "torch", "DenseNet169": "torch", "DenseNet201": "torch", diff --git a/tests/test_architectures_image.py b/tests/test_architectures_image.py index 64a1a5df..f2336460 100644 --- a/tests/test_architectures_image.py +++ b/tests/test_architectures_image.py @@ -229,6 +229,52 @@ def test_ResNet152V2(self): self.assertTrue(supported_standardize_mode["ResNet152V2"] == "tf") self.assertTrue(sdm_global["2D.ResNet152V2"] == "tf") + # -------------------------------------------------# + # Architecture: ResNeXt50 # + # -------------------------------------------------# + def test_ResNeXt50(self): + arch = ResNeXt50(Classifier(n_labels=4), channels=1, + input_shape=(32, 32)) + model = NeuralNetwork(n_labels=4, channels=1, architecture=arch, + batch_queue_size=1) + model.predict(self.datagen_GRAY) + arch = ResNeXt50(Classifier(n_labels=4), channels=3, + input_shape=(32, 32)) + model = NeuralNetwork(n_labels=4, channels=3, architecture=arch, + batch_queue_size=1) + model.predict(self.datagen_RGB) + model = NeuralNetwork(n_labels=4, channels=3, architecture="2D.ResNeXt50", + batch_queue_size=1, input_shape=(32, 32)) + try: + model.model.summary() + except: + raise Exception() + self.assertTrue(supported_standardize_mode["ResNeXt50"] == "torch") + self.assertTrue(sdm_global["2D.ResNeXt50"] == "torch") + + # -------------------------------------------------# + # Architecture: ResNeXt101 # + # -------------------------------------------------# + def test_ResNeXt101(self): + arch = ResNeXt101(Classifier(n_labels=4), channels=1, + input_shape=(32, 32)) + model = NeuralNetwork(n_labels=4, channels=1, architecture=arch, + batch_queue_size=1) + model.predict(self.datagen_GRAY) + arch = ResNeXt101(Classifier(n_labels=4), channels=3, + input_shape=(32, 32)) + model = NeuralNetwork(n_labels=4, channels=3, architecture=arch, + batch_queue_size=1) + model.predict(self.datagen_RGB) + model = NeuralNetwork(n_labels=4, channels=3, architecture="2D.ResNeXt101", + batch_queue_size=1, input_shape=(32, 32)) + try: + model.model.summary() + except: + raise Exception() + self.assertTrue(supported_standardize_mode["ResNeXt101"] == "torch") + self.assertTrue(sdm_global["2D.ResNeXt101"] == "torch") + #-------------------------------------------------# # Architecture: DenseNet121 # #-------------------------------------------------#