Skip to content

Commit

Permalink
fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Sep 7, 2024
1 parent 1352dba commit b23aeb5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 8 additions & 2 deletions src/astroNN/models/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ def custom_train_step(self, data):
else:
z_mean, z_log_var, z = encoder_output
y_pred = self.keras_decoder(z, training=True)
# TODO: should not need to be squeezed everytime
y, y_pred = keras.ops.squeeze(y), keras.ops.squeeze(y_pred)
reconstruction_loss = self.loss(y, y_pred, sample_weight=sample_weight)
kl_loss = -0.5 * (
1
Expand All @@ -313,6 +315,8 @@ def custom_train_step(self, data):
else:
z_mean, z_log_var, z = encoder_output
y_pred = self.keras_decoder(z, training=True)
# TODO: should not need to be squeezed everytime
y, y_pred = keras.ops.squeeze(y), keras.ops.squeeze(y_pred)
reconstruction_loss = self.loss(y, y_pred, sample_weight=sample_weight)
kl_loss = -0.5 * (
1
Expand All @@ -335,8 +339,8 @@ def custom_train_step(self, data):
# self.keras_model.compiled_metrics.update_state(y, y_pred, sample_weight)

for i in self.keras_model.metrics[1:]:
i.update_state(y, y_pred)

# TODO: properly fix this
i.update_state(keras.ops.zeros_like(y), keras.ops.zeros_like(y_pred))
return self.keras_model.get_metrics_result()

def custom_test_step(self, data):
Expand All @@ -350,6 +354,8 @@ def custom_test_step(self, data):
else:
z_mean, z_log_var, z = encoder_output
y_pred = self.keras_decoder(z, training=False)
# TODO: should not need to be squeezed everytime
y, y_pred = keras.ops.squeeze(y), keras.ops.squeeze(y_pred)
reconstruction_loss = self.loss(y, y_pred, sample_weight=sample_weight)
kl_loss = -0.5 * (
1 + z_log_var - keras.ops.square(z_mean) - keras.ops.exp(z_log_var)
Expand Down
4 changes: 2 additions & 2 deletions src/astroNN/nn/utilities/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, mode=None, verbose=2):
def mode_checker(self, data):
if type(data) is not dict:
dict_flag = False
data = {"Temp": data}
data = {"Temp": data.astype(np.float32)}
self.mean_labels = {"Temp": self.mean_labels}
self.std_labels = {"Temp": self.std_labels}
else:
Expand Down Expand Up @@ -121,7 +121,7 @@ def mode_checker(self, data):
self.std_labels.update({name: np.array([255.0])})
else:
raise ValueError(f"Unknown Mode -> {self.normalization_mode[name]}")
master_data.update({name: data_array})
master_data.update({name: data_array.astype(np.float32)})

return master_data, dict_flag

Expand Down

0 comments on commit b23aeb5

Please sign in to comment.