Skip to content

Commit

Permalink
fix ErrorProp and StopGrad layer testing
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Jul 20, 2024
1 parent 615373c commit dc20674
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
14 changes: 6 additions & 8 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ def call(self, inputs, training=None):
if training is None:
training = keras.backend.learning_phase()

noised = keras.random.normal([1], mean=inputs[0], stddev=inputs[1])
output_tensor = keras.ops.where(keras.ops.equal(training, True), inputs[0], noised)
noise = keras.random.normal(inputs[0].shape)
noised_inputs = inputs[0] + noise * inputs[1]

output_tensor = keras.ops.where(keras.ops.equal(training, True), inputs[0], noised_inputs)
output_tensor._uses_learning_phase = True
return output_tensor

Expand All @@ -261,7 +263,7 @@ def get_config(self):
return {**dict(base_config.items()), **config}

def compute_output_shape(self, input_shape):
return input_shape
return input_shape[0]


class FastMCInference:
Expand Down Expand Up @@ -509,11 +511,7 @@ def call(self, inputs, training=None):
"""
batchsize = keras.ops.shape(inputs)[0]
# need to reshape because tf.keras cannot get the Tensor shape correctly from tf.boolean_mask op

boolean_mask = keras.ops.any(keras.ops.not_equal(inputs, self.boolmask), axis=1, keepdims=True)

return keras.ops.reshape(inputs[self.boolmask], [batchsize, self.mask_shape]
)
return keras.ops.reshape(inputs[:, self.boolmask], [batchsize, self.mask_shape])

def get_config(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_StopGrad(self):
model_stopped.compile(optimizer="adam", loss="mse")
# assert error because of no gradient via this layer
self.assertRaises(
ValueError,
RuntimeError,
model_stopped.fit,
random_xdata,
random_ydata,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from importlib import import_module

import numpy as np
import numpy.testing as npt
import keras

import astroNN
from astroNN.config import config_path
from astroNN.models import Cifar10CNN, Galaxy10CNN, MNIST_BCNN
from astroNN.models import load_folder
Expand All @@ -34,7 +34,6 @@ def test_mnist(self):
mnist_test.callbacks = ErrorOnNaN()

mnist_test.train(x_train, y_train)
output_shape = mnist_test.output_shape
pred = mnist_test.test(x_test)
test_num = y_test.shape[0]
assert (np.sum(np.argmax(pred, axis=1) == y_test)) / test_num > 0.9 # assert accurancy
Expand Down

0 comments on commit dc20674

Please sign in to comment.