Skip to content

Commit

Permalink
fix the issue that keras_model.input_names is gone
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Jan 6, 2024
1 parent 1e51ca1 commit 657044d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions astroNN/models/base_bayesian_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def predict(self, input_data, inputs_err=None, batch_size=None):
inputs_err /= self.input_std["input"]

# TODO: better way to handle named input
if "input_err" in self.keras_model.input_names:
if "input_err" in [i.name for i in self.keras_model.inputs]:
input_data = {"input": input_data, "input_err": inputs_err}
else:
input_data = {"input": input_data}
Expand Down Expand Up @@ -1158,7 +1158,7 @@ def evaluate(
norm_input_err = inputs_err / self.input_std["input"]
norm_labels_err = labels_err / self.labels_std["output"]

if "input_err" in self.keras_model.input_names:
if "input_err" in [i.name for i in self.keras_model.inputs]:
norm_data.update(
{"input_err": norm_input_err, "labels_err": norm_labels_err}
)
Expand Down
10 changes: 5 additions & 5 deletions astroNN/models/base_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def pre_training_checklist_child(self, input_data, labels, sample_weight):
): # only compile if there is no keras_model, e.g. fine-tuning does not required
self.compile()

norm_data = self._tensor_dict_sanitize(norm_data, self.keras_model.input_names)
norm_data = self._tensor_dict_sanitize(norm_data, [i.name for i in self.keras_model.inputs])
norm_labels = self._tensor_dict_sanitize(
norm_labels, self.keras_model.output_names
)
Expand Down Expand Up @@ -580,10 +580,10 @@ def predict(self, input_data):
norm_data_remainder.update({name: input_array[name][data_gen_shape:]})

norm_data_main = self._tensor_dict_sanitize(
norm_data_main, self.keras_model.input_names
norm_data_main, [i.name for i in self.keras_model.inputs]
)
norm_data_remainder = self._tensor_dict_sanitize(
norm_data_remainder, self.keras_model.input_names
norm_data_remainder, [i.name for i in self.keras_model.inputs]
)

# Data Generator for prediction
Expand Down Expand Up @@ -635,7 +635,7 @@ def evaluate(self, input_data, labels):
:History: 2018-May-20 - Written - Henry Leung (University of Toronto)
"""
self.has_model_check()
input_data = list_to_dict(self.keras_model.input_names, input_data)
input_data = list_to_dict([i.name for i in self.keras_model.inputs], input_data)
labels = list_to_dict(self.keras_model.output_names, labels)

# check if exists (existing means the model has already been trained (e.g. fine-tuning), so we do not need calculate mean/std again)
Expand All @@ -661,7 +661,7 @@ def evaluate(self, input_data, labels):
norm_data = self.input_normalizer.normalize(input_data, calc=False)
norm_labels = self.labels_normalizer.normalize(labels, calc=False)

norm_data = self._tensor_dict_sanitize(norm_data, self.keras_model.input_names)
norm_data = self._tensor_dict_sanitize(norm_data, [i.name for i in self.keras_model.inputs])
norm_labels = self._tensor_dict_sanitize(
norm_labels, self.keras_model.output_names
)
Expand Down
12 changes: 6 additions & 6 deletions astroNN/models/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def pre_training_checklist_child(
): # only compile if there is no keras_model, e.g. fine-tuning does not required
self.compile()

norm_data = self._tensor_dict_sanitize(norm_data, self.keras_model.input_names)
norm_data = self._tensor_dict_sanitize(norm_data, [i.name for i in self.keras_model.inputs])
norm_labels = self._tensor_dict_sanitize(
norm_labels, self.keras_model.output_names
)
Expand Down Expand Up @@ -564,7 +564,7 @@ def fit_on_batch(self, input_data, input_recon_target, sample_weight=None):
input_recon_target, calc=False
)

norm_data = self._tensor_dict_sanitize(norm_data, self.keras_model.input_names)
norm_data = self._tensor_dict_sanitize(norm_data, [i.name for i in self.keras_model.inputs])
norm_labels = self._tensor_dict_sanitize(
norm_labels, self.keras_model.output_names
)
Expand Down Expand Up @@ -670,10 +670,10 @@ def predict(self, input_data):
norm_data_remainder.update({name: input_array[name][data_gen_shape:]})

norm_data_main = self._tensor_dict_sanitize(
norm_data_main, self.keras_model.input_names
norm_data_main, [i.name for i in self.keras_model.inputs]
)
norm_data_remainder = self._tensor_dict_sanitize(
norm_data_remainder, self.keras_model.input_names
norm_data_remainder, [i.name for i in self.keras_model.inputs]
)

# Data Generator for prediction
Expand Down Expand Up @@ -951,7 +951,7 @@ def evaluate(self, input_data, labels):
self.has_model_check()
input_data = {"input": input_data}
labels = {"output": labels}
input_data = list_to_dict(self.keras_model.input_names, input_data)
input_data = list_to_dict([i.name for i in self.keras_model.inputs], input_data)
labels = list_to_dict(self.keras_model.output_names, labels)

# check if exists (existing means the model has already been trained (e.g. fine-tuning), so we do not need calculate mean/std again)
Expand All @@ -977,7 +977,7 @@ def evaluate(self, input_data, labels):
norm_data = self.input_normalizer.normalize(input_data, calc=False)
norm_labels = self.labels_normalizer.normalize(labels, calc=False)

norm_data = self._tensor_dict_sanitize(norm_data, self.keras_model.input_names)
norm_data = self._tensor_dict_sanitize(norm_data, [i.name for i in self.keras_model.inputs])
norm_labels = self._tensor_dict_sanitize(
norm_labels, self.keras_model.output_names
)
Expand Down

0 comments on commit 657044d

Please sign in to comment.