Skip to content

Commit

Permalink
improve FastMCInference tests
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Sep 2, 2024
1 parent f990a98 commit 9936171
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 158 deletions.
4 changes: 2 additions & 2 deletions src/astroNN/models/base_bayesian_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ def predict(self, input_data, inputs_err=None, batch_size=None):
pbar=pbar,
)

new = FastMCInference(self.mc_num, self.keras_model_predict).new_mc_model
new = FastMCInference(self.mc_num, self.keras_model_predict).transformed_model

result = new.predict(prediction_generator, verbose=0)

Expand Down Expand Up @@ -1013,7 +1013,7 @@ def on_epoch_end(self):
pbar=pbar,
nn_model=self,
)
new = FastMCInference(self.mc_num, self.keras_model_predict).new_mc_model
new = FastMCInference(self.mc_num, self.keras_model_predict).transformed_model
result = np.asarray(new.predict(prediction_generator))

if remainder_shape != 0: # deal with remainder
Expand Down
21 changes: 8 additions & 13 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def __init__(self, n, model, **kwargs):
self.fast_mc_layer = FastMCInferenceV2_internal(self.model, self.n)

mc = self.meanvar_layer(self.fast_mc_layer(new_input))
self.new_mc_model = keras.models.Model(inputs=new_input, outputs=mc)
self.transformed_model = keras.models.Model(inputs=new_input, outputs=mc)

def get_config(self):
"""
Expand Down Expand Up @@ -331,7 +331,10 @@ def compute_output_shape(self, input_shape):
return [tuple([self.n] + list(shape)) for shape in layer_output_shape]
elif isinstance(layer_output_shape, dict):
# if it is a dict of shape, then add self.n in front of each shape
return {key: tuple([self.n] + list(shape)) for key, shape in layer_output_shape.items()}
return {
key: tuple([self.n] + list(shape))
for key, shape in layer_output_shape.items()
}
else:
return (self.n,) + layer_output_shape

Expand Down Expand Up @@ -395,25 +398,17 @@ def call(self, inputs, training=None):
outputs = {}
for key, value in inputs.items():
mean, var = keras.ops.moments(value, axes=0)
outputs[key] = keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
outputs[key] = keras.ops.stack((mean, var), axis=-1)
return outputs
elif isinstance(inputs, list):
outputs = []
for value in inputs:
mean, var = keras.ops.moments(value, axes=0)
outputs.append(
keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
)
outputs.append(keras.ops.stack((mean, var), axis=-1))
return outputs
else: # just a tensor
mean, var = keras.ops.moments(inputs, axes=0)
return keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
return keras.ops.stack((mean, var), axis=-1)


class FastMCRepeat(Layer):
Expand Down
Loading

0 comments on commit 9936171

Please sign in to comment.