Skip to content

Commit

Permalink
Merge pull request #136 from cinemascience/mlfix
Browse files Browse the repository at this point in the history
fixes for tf2.17
  • Loading branch information
dhrogers authored Oct 9, 2024
2 parents d9181ed + 9646067 commit 57725c3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
Binary file modified data/MNIST_models/TF/mnist_tf.h5
Binary file not shown.
6 changes: 3 additions & 3 deletions pycinema/filters/MLTFPredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def _update(self):
modelList = self.inputs.trainedModel.get()
# get required input properties from first model
model = modelList[0]
width = model.layers[0].input_shape[1]
height = model.layers[0].input_shape[2]
channels = model.layers[0].input_shape[3]
width = model.input_shape[1]
height = model.input_shape[2]
channels = model.input_shape[3]
if channels == 1:
gray_req = True
else: #channels == 3 or 4
Expand Down
4 changes: 2 additions & 2 deletions pycinema/filters/MLTFReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _update(self):
models = []
# if the path directly points to a TF model
if modelPath.endswith(".h5") or modelPath.endswith(".keras"):
model = tf.keras.models.load_model(modelPath)
model = tf.keras.models.load_model(modelPath, compile=False)
models.append(model)

else:
Expand All @@ -59,7 +59,7 @@ def _update(self):
for row in table:
parent = os.path.dirname(modelPath) + "/"
filePath = os.path.join(parent, row[1])
model = tf.keras.models.load_model(filePath)
model = tf.keras.models.load_model(filePath, compile=False)
models[int(row[0])] = model

#check if training configuration exists, if not give error
Expand Down

0 comments on commit 57725c3

Please sign in to comment.