Skip to content

Commit

Permalink
fix mc sampling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Aug 30, 2024
1 parent 1115c79 commit 2749512
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/astroNN/models/apogee_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ def model(self):
input_tensor.name: input_tensor,
labels_err_tensor.name: labels_err_tensor,
},
outputs={output.name: output, variance_output.name: variance_output},
outputs={"output": output, "variance_output": variance_output},
)
# new astroNN high performance dropout variational inference on GPU expects single output
model_prediction = Model(
inputs={input_tensor.name: input_tensor},
outputs={output.name: output, variance_output.name: variance_output}
outputs={"output": output, "variance_output": variance_output}
)

if self.task == "regression":
Expand Down
19 changes: 13 additions & 6 deletions src/astroNN/models/base_bayesian_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ def compile(
self.metrics = (
[mean_absolute_error, mean_error] if not self.metrics else self.metrics
)
if isinstance(self.metrics, list):
new_metrics = {}
# assuming for each output [output1, output2], apply each metric [metric1, metric2] to it
# such that the output1 will be evaluated by both metric and output2 will be evaluated by both metric
for i in self.keras_model.output_names:
new_metrics.update({i: self.metrics})
self.metrics = new_metrics

self.keras_model.compile(
optimizer=self.optimizer,
loss=zeros_loss,
Expand Down Expand Up @@ -462,8 +470,8 @@ def custom_train_step(self, data):
with backend_framework.GradientTape() as tape:
y_pred = self.keras_model(x, training=True)
# TODO: deal with sample weights
loss = self._output_loss(y_pred[1], x["labels_err"])(
y["output"], y_pred[0]
loss = self._output_loss(y_pred["variance_output"], x["labels_err"])(
y["output"], y_pred["output"]
)
self.keras_model._loss_tracker.update_state(loss)
if self.keras_model.optimizer is not None:
Expand All @@ -476,7 +484,7 @@ def custom_train_step(self, data):
elif _KERAS_BACKEND == "torch":
self.keras_model.zero_grad()
y_pred = self.keras_model(x, training=True)
loss = self._output_loss(y_pred[1], x["labels_err"])(y["output"], y_pred[0])
loss = self._output_loss(y_pred["variance_output"], x["labels_err"])(y["output"], y_pred["output"])
loss.sum().backward()
trainable_weights = [v for v in self.keras_model.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
Expand All @@ -503,7 +511,7 @@ def custom_test_step(self, data):

y_pred = self.keras_model(x, training=False)
# Updates stateful loss metrics.
temploss = self._output_loss(y_pred[1], x["labels_err"])
temploss = self._output_loss(y_pred["variance_output"], x["labels_err"])
# self.keras_model.compiled_loss._losses = temploss
# self.keras_model.compiled_loss._losses = nest.map_structure(
# self.keras_model.compiled_loss._get_loss_object,
Expand Down Expand Up @@ -1000,9 +1008,8 @@ def on_epoch_end(self):
pbar=pbar,
nn_model=self,
)

raise ValueError("This is not implemented yet")
new = FastMCInference(self.mc_num, self.keras_model_predict).new_mc_model

result = np.asarray(new.predict(prediction_generator))

if remainder_shape != 0: # deal with remainder
Expand Down
21 changes: 12 additions & 9 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ def loop_fn(i):
outputs = backend_framework.vectorized_map(loop_fn, self.arange_n)
elif keras.backend.backend() == "torch":
# vectorize using torch.vmap
outputs = backend_framework.vmap(
outputs = keras.ops.stack(backend_framework.vmap(
loop_fn, randomness="different", in_dims=0
)(self.arange_n)
)(self.arange_n), axis=0)
else: # fallback to simple for loop
outputs = keras.ops.stack(
[self.layer(inputs) for _ in self.arange_n], axis=0
Expand All @@ -361,8 +361,7 @@ def __init__(self, name=None, **kwargs):
super().__init__(name=name, **kwargs)

# def compute_output_shape(self, input_shape):
# print(input_shape)
# return 2, input_shape[0], input_shape[2:]
# return (2, input_shape[0], input_shape[2:])

def get_config(self):
"""
Expand All @@ -381,11 +380,15 @@ def call(self, inputs, training=None):
:return: Tensor after applying the layer
:rtype: tf.Tensor
"""
# need to stack because keras can only handle one output
mean, var = keras.ops.moments(inputs, axes=0)
return keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
outputs = keras.ops.zeros_like(inputs[0, 0])
# need to do each input separately
for idx, i in enumerate(inputs):
# need to stack because keras can only handle one output
mean, var = keras.ops.moments(i, axes=0)
outputs[idx] = keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
return outputs


class FastMCRepeat(Layer):
Expand Down

0 comments on commit 2749512

Please sign in to comment.