Skip to content

Commit

Permalink
ResNext50 and 101 added to __init__.py of neural_network/architecture…
Browse files Browse the repository at this point in the history
…s/image, as well as test_architectures_image.py
  • Loading branch information
Dennis182 committed Apr 30, 2024
1 parent a155af8 commit d6c339d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
7 changes: 7 additions & 0 deletions aucmedi/neural_network/architectures/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
from aucmedi.neural_network.architectures.image.resnet50v2 import ResNet50V2
from aucmedi.neural_network.architectures.image.resnet101v2 import ResNet101V2
from aucmedi.neural_network.architectures.image.resnet152v2 import ResNet152V2
# ResNeXt
from aucmedi.neural_network.architectures.image.resnext50 import ResNeXt50
from aucmedi.neural_network.architectures.image.resnext101 import ResNeXt101
# MobileNet
from aucmedi.neural_network.architectures.image.mobilenet import MobileNet
from aucmedi.neural_network.architectures.image.mobilenetv2 import MobileNetV2
Expand Down Expand Up @@ -83,6 +86,8 @@
"ResNet50V2": ResNet50V2,
"ResNet101V2": ResNet101V2,
"ResNet152V2": ResNet152V2,
"ResNeXt50": ResNeXt50,
"ResNeXt101": ResNeXt101,
"DenseNet121": DenseNet121,
"DenseNet169": DenseNet169,
"DenseNet201": DenseNet201,
Expand Down Expand Up @@ -170,6 +175,8 @@
"ResNet50V2": "tf",
"ResNet101V2": "tf",
"ResNet152V2": "tf",
"ResNeXt50": "torch",
"ResNeXt101": "torch",
"DenseNet121": "torch",
"DenseNet169": "torch",
"DenseNet201": "torch",
Expand Down
46 changes: 46 additions & 0 deletions tests/test_architectures_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,52 @@ def test_ResNet152V2(self):
self.assertTrue(supported_standardize_mode["ResNet152V2"] == "tf")
self.assertTrue(sdm_global["2D.ResNet152V2"] == "tf")

# -------------------------------------------------#
# Architecture: ResNeXt50 #
# -------------------------------------------------#
def test_ResNeXt50(self):
arch = ResNeXt50(Classifier(n_labels=4), channels=1,
input_shape=(32, 32))
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch,
batch_queue_size=1)
model.predict(self.datagen_GRAY)
arch = ResNeXt50(Classifier(n_labels=4), channels=3,
input_shape=(32, 32))
model = NeuralNetwork(n_labels=4, channels=3, architecture=arch,
batch_queue_size=1)
model.predict(self.datagen_RGB)
model = NeuralNetwork(n_labels=4, channels=3, architecture="2D.ResNeXt50",
batch_queue_size=1, input_shape=(32, 32))
try:
model.model.summary()
except:
raise Exception()
self.assertTrue(supported_standardize_mode["ResNeXt50"] == "torch")
self.assertTrue(sdm_global["2D.ResNeXt50"] == "torch")

# -------------------------------------------------#
# Architecture: ResNeXt101 #
# -------------------------------------------------#
def test_ResNeXt101(self):
arch = ResNeXt101(Classifier(n_labels=4), channels=1,
input_shape=(32, 32))
model = NeuralNetwork(n_labels=4, channels=1, architecture=arch,
batch_queue_size=1)
model.predict(self.datagen_GRAY)
arch = ResNeXt101(Classifier(n_labels=4), channels=3,
input_shape=(32, 32))
model = NeuralNetwork(n_labels=4, channels=3, architecture=arch,
batch_queue_size=1)
model.predict(self.datagen_RGB)
model = NeuralNetwork(n_labels=4, channels=3, architecture="2D.ResNeXt101",
batch_queue_size=1, input_shape=(32, 32))
try:
model.model.summary()
except:
raise Exception()
self.assertTrue(supported_standardize_mode["ResNeXt101"] == "torch")
self.assertTrue(sdm_global["2D.ResNeXt101"] == "torch")

#-------------------------------------------------#
# Architecture: DenseNet121 #
#-------------------------------------------------#
Expand Down

0 comments on commit d6c339d

Please sign in to comment.