Skip to content

Commit

Permalink
fix mc sampling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Sep 1, 2024
1 parent 1115c79 commit ca3ade7
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/astroNN/models/apogee_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ def model(self):
input_tensor.name: input_tensor,
labels_err_tensor.name: labels_err_tensor,
},
outputs={output.name: output, variance_output.name: variance_output},
outputs={"output": output, "variance_output": variance_output},
)
# new astroNN high performance dropout variational inference on GPU expects single output
model_prediction = Model(
inputs={input_tensor.name: input_tensor},
outputs={output.name: output, variance_output.name: variance_output}
outputs={"output": output, "variance_output": variance_output}
)

if self.task == "regression":
Expand Down
18 changes: 12 additions & 6 deletions src/astroNN/models/base_bayesian_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,14 @@ def compile(
self.metrics = (
[mean_absolute_error, mean_error] if not self.metrics else self.metrics
)
if isinstance(self.metrics, list):
new_metrics = {}
# assuming for each output [output1, output2], apply each metric [metric1, metric2] to it
# such that the output1 will be evaluated by both metric and output2 will be evaluated by both metric
for i in self.keras_model.output_names:
new_metrics.update({i: self.metrics})
self.metrics = new_metrics

self.keras_model.compile(
optimizer=self.optimizer,
loss=zeros_loss,
Expand Down Expand Up @@ -462,8 +470,8 @@ def custom_train_step(self, data):
with backend_framework.GradientTape() as tape:
y_pred = self.keras_model(x, training=True)
# TODO: deal with sample weights
loss = self._output_loss(y_pred[1], x["labels_err"])(
y["output"], y_pred[0]
loss = self._output_loss(y_pred["variance_output"], x["labels_err"])(
y["output"], y_pred["output"]
)
self.keras_model._loss_tracker.update_state(loss)
if self.keras_model.optimizer is not None:
Expand All @@ -476,7 +484,7 @@ def custom_train_step(self, data):
elif _KERAS_BACKEND == "torch":
self.keras_model.zero_grad()
y_pred = self.keras_model(x, training=True)
loss = self._output_loss(y_pred[1], x["labels_err"])(y["output"], y_pred[0])
loss = self._output_loss(y_pred["variance_output"], x["labels_err"])(y["output"], y_pred["output"])
loss.sum().backward()
trainable_weights = [v for v in self.keras_model.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
Expand All @@ -503,7 +511,7 @@ def custom_test_step(self, data):

y_pred = self.keras_model(x, training=False)
# Updates stateful loss metrics.
temploss = self._output_loss(y_pred[1], x["labels_err"])
temploss = self._output_loss(y_pred["variance_output"], x["labels_err"])
# self.keras_model.compiled_loss._losses = temploss
# self.keras_model.compiled_loss._losses = nest.map_structure(
# self.keras_model.compiled_loss._get_loss_object,
Expand Down Expand Up @@ -1000,9 +1008,7 @@ def on_epoch_end(self):
pbar=pbar,
nn_model=self,
)

new = FastMCInference(self.mc_num, self.keras_model_predict).new_mc_model

result = np.asarray(new.predict(prediction_generator))

if remainder_shape != 0: # deal with remainder
Expand Down
64 changes: 46 additions & 18 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ class FastMCInference:
:History:
| 2018-Apr-13 - Written - Henry Leung (University of Toronto)
| 2021-Apr-14 - Updated - Henry Leung (University of Toronto)
"""

def __init__(self, n, model, **kwargs):
Expand All @@ -293,10 +292,10 @@ def __init__(self, n, model, **kwargs):
self.meanvar_layer = FastMCInferenceMeanVar()

new_input = keras.layers.Input(shape=(self.model.input_shape[1:]), name="input")
self.mc_model = keras.models.Model(
inputs=self.model.inputs, outputs=self.model.outputs
)
self.fast_mc_layer = FastMCInferenceV2_internal(self.mc_model, self.n)
# self.mc_model = keras.models.Model(
# inputs=self.model.inputs, outputs=self.model.outputs
# )
self.fast_mc_layer = FastMCInferenceV2_internal(self.model, self.n)

mc = self.meanvar_layer(self.fast_mc_layer(new_input))
self.new_mc_model = keras.models.Model(inputs=new_input, outputs=mc)
Expand Down Expand Up @@ -326,7 +325,15 @@ def build(self, input_shape):
self.built = True

def compute_output_shape(self, input_shape):
return self.layer.output_shape
layer_output_shape = self.layer.compute_output_shape(input_shape)
if isinstance(layer_output_shape, list):
# if it is a list of shape, then add self.n in front of each shape
return [tuple([self.n] + list(shape)) for shape in layer_output_shape]
elif isinstance(layer_output_shape, dict):
# if it is a dict of shape, then add self.n in front of each shape
return {key: tuple([self.n] + list(shape)) for key, shape in layer_output_shape.items()}
else:
return (self.n,) + layer_output_shape

def call(self, inputs, training=None, mask=None):
def loop_fn(i):
Expand All @@ -340,10 +347,8 @@ def loop_fn(i):
loop_fn, randomness="different", in_dims=0
)(self.arange_n)
else: # fallback to simple for loop
outputs = keras.ops.stack(
[self.layer(inputs) for _ in self.arange_n], axis=0
)
return outputs
outputs = [self.layer(inputs) for _ in self.arange_n]
return outputs # outputs can be tensor or dict of tensors


class FastMCInferenceMeanVar(Layer):
Expand All @@ -360,9 +365,14 @@ class FastMCInferenceMeanVar(Layer):
def __init__(self, name=None, **kwargs):
super().__init__(name=name, **kwargs)

# def compute_output_shape(self, input_shape):
# print(input_shape)
# return 2, input_shape[0], input_shape[2:]
def compute_output_shape(self, input_shape):
# the first dimension is the number of MC integration, so we remove it but add 2 for mean and var
if isinstance(input_shape, list):
return [shape[1:] + (2,) for shape in input_shape]
elif isinstance(input_shape, dict):
return {key: shape[1:] + (2,) for key, shape in input_shape.items()}
else:
return input_shape[1:] + (2,)

def get_config(self):
"""
Expand All @@ -381,11 +391,29 @@ def call(self, inputs, training=None):
:return: Tensor after applying the layer
:rtype: tf.Tensor
"""
# need to stack because keras can only handle one output
mean, var = keras.ops.moments(inputs, axes=0)
return keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
if isinstance(inputs, dict):
outputs = {}
for key, value in inputs.items():
mean, var = keras.ops.moments(value, axes=0)
outputs[key] = keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
return outputs
elif isinstance(inputs, list):
outputs = []
for value in inputs:
mean, var = keras.ops.moments(value, axes=0)
outputs.append(
keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)
)
return outputs
else: # just a tensor
mean, var = keras.ops.moments(inputs, axes=0)
return keras.ops.stack(
(keras.ops.squeeze([mean]), keras.ops.squeeze([var])), axis=-1
)


class FastMCRepeat(Layer):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_FastMCInference():
model = Model(inputs=input, outputs=output)
model.compile(optimizer="sgd", loss="mse")

model.fit(random_xdata, random_ydata, batch_size=128)
model.fit(random_xdata, random_ydata, batch_size=64)

acc_model = FastMCInference(10, model).new_mc_model

Expand Down

0 comments on commit ca3ade7

Please sign in to comment.