Skip to content

Commit

Permalink
fix FastMCInference not quite precise because of the use of moments f…
Browse files Browse the repository at this point in the history
…unction
  • Loading branch information
henrysky committed Sep 2, 2024
1 parent 9936171 commit 20ba3ab
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 34 deletions.
10 changes: 5 additions & 5 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,11 @@ def compute_output_shape(self, input_shape):
def call(self, inputs, training=None, mask=None):
def loop_fn(i):
return self.layer(inputs)


# vectorizing operation depends on backend
if keras.backend.backend() == "tensorflow":
outputs = backend_framework.vectorized_map(loop_fn, self.arange_n)
elif keras.backend.backend() == "torch":
# vectorize using torch.vmap
outputs = backend_framework.vmap(
loop_fn, randomness="different", in_dims=0
)(self.arange_n)
Expand Down Expand Up @@ -397,17 +397,17 @@ def call(self, inputs, training=None):
if isinstance(inputs, dict):
outputs = {}
for key, value in inputs.items():
mean, var = keras.ops.moments(value, axes=0)
mean, var = keras.ops.mean(value, axis=0), keras.ops.var(value, axis=0)
outputs[key] = keras.ops.stack((mean, var), axis=-1)
return outputs
elif isinstance(inputs, list):
outputs = []
for value in inputs:
mean, var = keras.ops.moments(value, axes=0)
mean, var = keras.ops.mean(value, axis=0), keras.ops.var(value, axis=0)
outputs.append(keras.ops.stack((mean, var), axis=-1))
return outputs
else: # just a tensor
mean, var = keras.ops.moments(inputs, axes=0)
mean, var = keras.ops.mean(inputs, axis=0), keras.ops.var(inputs, axis=0)
return keras.ops.stack((mean, var), axis=-1)


Expand Down
102 changes: 73 additions & 29 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,20 @@ def test_SpatialDropout1D(random_data):
npt.assert_equal(np.any(np.not_equal(x, y)), True)


def test_SpatialDropout12D(random_data):
random_xdata, random_xdata_err, random_ydata, random_ydata_err = random_data
def test_SpatialDropout12D(mnist_data):
random_xdata, random_ydata, _, _ = mnist_data

input = keras.layers.Input(shape=[28, 28, 1])
conv = keras.layers.Conv2D(
kernel_initializer="he_normal", padding="same", filters=2, kernel_size=16
)(input)
dropout = astroNN.nn.layers.MCSpatialDropout2D(0.2)(conv)
flattened = keras.layers.Flatten()(dropout)
output = keras.layers.Dense(25)(flattened)
output = keras.layers.Dense(10, activation="softmax")(flattened)
model = keras.models.Model(inputs=input, outputs=output)
model.compile(optimizer="sgd", loss="mse")
model.compile(optimizer="sgd", loss="categorical_crossentropy")

model.fit(random_xdata, random_ydata, batch_size=128)
model.fit(random_xdata, random_ydata)

# make sure dropout is on even in testing phase
x = model.predict(random_xdata)
Expand Down Expand Up @@ -127,25 +127,27 @@ def test_StopGrad(random_data):
model = keras.models.Model(inputs=input, outputs=output)
model_stopped = keras.models.Model(inputs=input, outputs=stopped_output)
model.compile(optimizer="adam", loss="mse")

model_stopped.compile(optimizer="adam", loss="mse")
# assert error because of no gradient via this layer
# RuntimeError is raised with PyTorch backend
# ValueError is raised with TensorFlow backend
with pytest.raises((RuntimeError, ValueError)):
model_stopped.fit(random_xdata, random_ydata, batch_size=128, epochs=1)
model_stopped.fit(random_xdata, random_ydata)

x = model.predict(random_xdata)
y = model_stopped.predict(random_xdata)
npt.assert_almost_equal(x, y) # make sure StopGrad does not change any result
# make sure StopGrad does not change any result when predicting
npt.assert_almost_equal(
model.predict(random_xdata), model_stopped.predict(random_xdata), err_msg="StopGrad layer should not change result when predicting"
)

# # =================test weight equals================= #
input2 = keras.layers.Input(shape=[7514])
dense1 = keras.layers.Dense(100, name="normaldense")(input2)
dense2 = keras.layers.Dense(25, name="wanted_dense")(input2)
# =================test weight equals================= #
input = keras.layers.Input(shape=[7514])
dense1 = keras.layers.Dense(100, name="normal_dense")(input)
dense2 = keras.layers.Dense(32, name="wanted_dense")(input)
dense2_stopped = astroNN.nn.layers.StopGrad(name="stopgrad", always_on=True)(dense2)
output2 = keras.layers.Dense(25, name="wanted_dense2")(keras.layers.concatenate([dense1, dense2_stopped]))
model2 = keras.models.Model(inputs=[input2], outputs=[output2])
output = keras.layers.Dense(25, name="wanted_dense2")(
keras.layers.concatenate([dense1, dense2_stopped])
)
model2 = keras.models.Model(inputs=[input], outputs=[output])
model2.compile(
optimizer=keras.optimizers.SGD(learning_rate=0.1),
loss="mse",
Expand Down Expand Up @@ -190,12 +192,24 @@ def test_FastMCInference(random_data):
original_weights = model.get_weights()
acc_model = astroNN.nn.layers.FastMCInference(10, model).transformed_model
# make sure accelerated model has no effect on deterministic model weights
npt.assert_equal(original_weights, acc_model.get_weights(), err_msg="FastMCInference layer should not change weights")
npt.assert_equal(
original_weights,
acc_model.get_weights(),
err_msg="FastMCInference layer should not change weights",
)
x = acc_model.predict(random_xdata)
# make sure the shape is correct, 100 samples, 25 outputs, 2 columns (mean and variance)
npt.assert_equal(x.shape, (100, 25, 2), err_msg="FastMCInference layer should return 2 columns in the last axis (mean and variance)")
npt.assert_equal(
x.shape,
(100, 25, 2),
err_msg="FastMCInference layer should return 2 columns in the last axis (mean and variance)",
)
# make sure accelerated model has no variance (within numerical precision) on deterministic model prediction
npt.assert_almost_equal(np.max(x[:, :, 1]), 0., err_msg="FastMCInference layer should return 0 variance for deterministic model")
npt.assert_almost_equal(
np.max(x[:, :, 1]),
0.0,
err_msg="FastMCInference layer should return 0 variance for deterministic model",
)

# ======== Simple Keras sequential Model, one input one output ======== #
smodel = keras.models.Sequential()
Expand All @@ -208,25 +222,45 @@ def test_FastMCInference(random_data):
acc_smodel = astroNN.nn.layers.FastMCInference(10, smodel).transformed_model
x = acc_smodel.predict(random_xdata)
# make sure the shape is correct, 100 samples, 10 outputs, 2 columns (mean and variance)
npt.assert_equal(x.shape, (100, 10, 2), err_msg="FastMCInference layer should return 2 columns in the last axis (mean and variance)")
npt.assert_equal(
x.shape,
(100, 10, 2),
err_msg="FastMCInference layer should return 2 columns in the last axis (mean and variance)",
)
# make sure accelerated model has no variance (within numerical precision) on deterministic model prediction
npt.assert_almost_equal(np.max(x[:, :, 1]), 0., err_msg="FastMCInference layer should return 0 variance for deterministic model")
npt.assert_almost_equal(
np.max(x[:, :, 1]),
0.0,
err_msg="FastMCInference layer should return 0 variance for deterministic model",
)

# ======== Complex Keras functional Model, one input multiple output ======== #
input = keras.layers.Input(shape=[7514])
dense = keras.layers.Dense(100)(input)
output1 = keras.layers.Dense(4, name="output1")(dense)
output2 = keras.layers.Dense(8, name="output2")(dense)
model = keras.models.Model(inputs=input, outputs={"output1": output1, "output2": output2})
model = keras.models.Model(
inputs=input, outputs={"output1": output1, "output2": output2}
)
model.compile(optimizer="sgd", loss="mse", metrics=["mse"])
acc_model = astroNN.nn.layers.FastMCInference(10, model).transformed_model
x = acc_model.predict(random_xdata)
# make sure the shape is correct
assert isinstance(x, dict), "Output from FastMCInference layer should be a dictionary if model has multiple outputs"
npt.assert_equal(x["output1"].shape, (100, 4, 2), err_msg="FastMCInference layer return errorous shape for model with multiple outputs")
npt.assert_equal(x["output2"].shape, (100, 8, 2), err_msg="FastMCInference layer return errorous shape for model with multiple outputs")
assert isinstance(
x, dict
), "Output from FastMCInference layer should be a dictionary if model has multiple outputs"
npt.assert_equal(
x["output1"].shape,
(100, 4, 2),
err_msg="FastMCInference layer return errorous shape for model with multiple outputs",
)
npt.assert_equal(
x["output2"].shape,
(100, 8, 2),
err_msg="FastMCInference layer return errorous shape for model with multiple outputs",
)

# ======== Simple Keras functiona; Model with randomness ======== #
# ======== Simple Keras functional Model with randomness ======== #
input = keras.layers.Input(shape=[7514])
dense = keras.layers.Dense(100)(input)
dropout = astroNN.nn.layers.MCDropout(0.5)(dense)
Expand All @@ -236,12 +270,22 @@ def test_FastMCInference(random_data):
original_weights = model.get_weights()
acc_model = astroNN.nn.layers.FastMCInference(10, model).transformed_model
# make sure accelerated model has no effect on sochastic model weights
npt.assert_equal(original_weights, acc_model.get_weights(), err_msg="FastMCInference layer should not change weights")
npt.assert_equal(
original_weights,
acc_model.get_weights(),
err_msg="FastMCInference layer should not change weights",
)
x = acc_model.predict(random_xdata)
# make sure the shape is correct, 100 samples, 25 outputs, 2 columns (mean and variance)
npt.assert_equal(x.shape, (100, 25, 2), err_msg="FastMCInference layer should return 2 columns in the last axis (mean and variance)")
npt.assert_equal(
x.shape,
(100, 25, 2),
err_msg="FastMCInference layer should return 2 columns in the last axis (mean and variance)",
)
# make sure accelerated model has variance because of dropout
assert np.median(x[:, :, 1]) > 1.0, "FastMCInference layer should return some degree of variances for stochastic model"
assert (
np.median(x[:, :, 1]) > 1.0
), "FastMCInference layer should return some degree of variances for stochastic model"

# assert error raised for things other than keras model
with pytest.raises(TypeError):
Expand Down

0 comments on commit 20ba3ab

Please sign in to comment.