Skip to content

Commit

Permalink
remove neuralODE testing
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Jan 4, 2024
1 parent 8cec78f commit 3c6a14b
Show file tree
Hide file tree
Showing 10 changed files with 913 additions and 931 deletions.
6 changes: 0 additions & 6 deletions astroNN/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,6 @@ def load_folder(folder=None):
except:
metrics = metrics_raw

sample_weight_mode = (
training_config["sample_weight_mode"]
if hasattr(training_config, "sample_weight_mode")
else None
)
loss_weights = training_config["loss_weights"]
weighted_metrics = None

Expand All @@ -336,7 +331,6 @@ def load_folder(folder=None):
metrics=metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
)

# set weights
Expand Down
30 changes: 13 additions & 17 deletions astroNN/models/base_bayesian_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
from tqdm import tqdm
import keras as tfk
import keras
from astroNN.config import MAGIC_NUMBER, MULTIPROCESS_FLAG
from astroNN.config import _astroNN_MODEL_NAME
from astroNN.datasets import H5Loader
Expand Down Expand Up @@ -39,12 +39,11 @@

from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.util import nest

regularizers = tfk.regularizers
ReduceLROnPlateau = tfk.callbacks.ReduceLROnPlateau
Adam = tfk.optimizers.Adam
regularizers = keras.regularizers
ReduceLROnPlateau = keras.callbacks.ReduceLROnPlateau
Adam = keras.optimizers.Adam


class BayesianCNNDataGenerator(GeneratorMaster):
Expand Down Expand Up @@ -322,7 +321,6 @@ def compile(
metrics=None,
weighted_metrics=None,
loss_weights=None,
sample_weight_mode=None,
):
if optimizer is not None:
self.optimizer = optimizer
Expand Down Expand Up @@ -387,7 +385,6 @@ def compile(
loss=zeros_loss,
metrics=self.metrics,
weighted_metrics=weighted_metrics,
sample_weight_mode=sample_weight_mode,
)
elif self.task == "classification":
self.metrics = [categorical_accuracy] if not self.metrics else self.metrics
Expand All @@ -396,7 +393,6 @@ def compile(
loss=zeros_loss,
metrics={"output": self.metrics},
weighted_metrics=weighted_metrics,
sample_weight_mode=sample_weight_mode,
)
elif self.task == "binary_classification":
self.metrics = [binary_accuracy] if not self.metrics else self.metrics
Expand All @@ -405,7 +401,6 @@ def compile(
loss=zeros_loss,
metrics={"output": self.metrics},
weighted_metrics=weighted_metrics,
sample_weight_mode=sample_weight_mode,
)

# inject custom training step if needed
Expand All @@ -426,7 +421,7 @@ def compile(
return None

def recompile(
self, weighted_metrics=None, loss_weights=None, sample_weight_mode=None
self, weighted_metrics=None, loss_weights=None
):
"""
To be used when you need to recompile a already existing model
Expand All @@ -441,7 +436,6 @@ def recompile(
loss=zeros_loss,
metrics=self.metrics,
weighted_metrics=weighted_metrics,
sample_weight_mode=sample_weight_mode,
)
elif self.task == "classification":
self.metrics = [categorical_accuracy] if not self.metrics else self.metrics
Expand All @@ -450,7 +444,6 @@ def recompile(
loss=zeros_loss,
metrics={"output": self.metrics},
weighted_metrics=weighted_metrics,
sample_weight_mode=sample_weight_mode,
)
elif self.task == "binary_classification":
self.metrics = [binary_accuracy] if not self.metrics else self.metrics
Expand All @@ -459,7 +452,6 @@ def recompile(
loss=zeros_loss,
metrics={"output": self.metrics},
weighted_metrics=weighted_metrics,
sample_weight_mode=sample_weight_mode,
)

def custom_train_step(self, data):
Expand All @@ -469,8 +461,10 @@ def custom_train_step(self, data):
:param data:
:return:
"""
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
if len(data) == 3:
x, y, sample_weight = data
else:
x, y = data

# Run forward pass.
with tf.GradientTape() as tape:
Expand Down Expand Up @@ -505,8 +499,10 @@ def custom_train_step(self, data):
return return_metrics

def custom_test_step(self, data):
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
if len(data) == 3:
x, y, sample_weight = data
else:
x, y = data

y_pred = self.keras_model(x, training=False)
# Updates stateful loss metrics.
Expand Down
17 changes: 7 additions & 10 deletions astroNN/models/base_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
from tqdm import tqdm
import keras as tfk
import keras
from astroNN.config import MULTIPROCESS_FLAG
from astroNN.config import _astroNN_MODEL_NAME
from astroNN.models.base_master_nn import NeuralNetMaster
Expand All @@ -19,12 +19,12 @@
from astroNN.shared.warnings import deprecated, deprecated_copy_signature
from sklearn.model_selection import train_test_split

regularizers = tfk.regularizers
regularizers = keras.regularizers
ReduceLROnPlateau, EarlyStopping = (
tfk.callbacks.ReduceLROnPlateau,
tfk.callbacks.EarlyStopping,
keras.callbacks.ReduceLROnPlateau,
keras.callbacks.EarlyStopping,
)
Adam = tfk.optimizers.Adam
Adam = keras.optimizers.Adam


class CNNDataGenerator(GeneratorMaster):
Expand All @@ -37,7 +37,7 @@ class CNNDataGenerator(GeneratorMaster):
:type shuffle: bool
:param data: List of data to NN
:type data: list
:param manual_reset: Whether need to reset the generator manually, usually it is handled by tensorflow
:param manual_reset: Whether need to reset the generator manually, usually it is handled by Keras
:type manual_reset: bool
:param sample_weight: Sample weights (if any)
:type sample_weight: Union([NoneType, ndarray])
Expand Down Expand Up @@ -103,7 +103,7 @@ class CNNPredDataGenerator(GeneratorMaster):
:type shuffle: bool
:param data: List of data to NN
:type data: list
:param manual_reset: Whether need to reset the generator manually, usually it is handled by tensorflow
:param manual_reset: Whether need to reset the generator manually, usually it is handled by Keras
:type manual_reset: bool
:param pbar: tqdm progress bar
:type pbar: obj
Expand Down Expand Up @@ -197,7 +197,6 @@ def compile(
metrics=None,
weighted_metrics=None,
loss_weights=None,
sample_weight_mode=None,
):
if optimizer is not None:
self.optimizer = optimizer
Expand Down Expand Up @@ -238,7 +237,6 @@ def compile(
metrics=self.metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
)

# inject custom training step if needed
Expand All @@ -263,7 +261,6 @@ def recompile(
loss=None,
weighted_metrics=None,
loss_weights=None,
sample_weight_mode=None,
):
"""
To be used when you need to recompile a already existing model
Expand Down
2 changes: 1 addition & 1 deletion astroNN/models/base_master_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def uses_learning_phase(self):

def get_layer(self, *args, **kwargs):
"""
get_layer() method of tensorflow
get_layer() method of Keras
"""
return self.keras_model.get_layer(*args, **kwargs)

Expand Down
8 changes: 2 additions & 6 deletions astroNN/models/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CVAEDataGenerator(GeneratorMaster):
:type shuffle: bool
:param data: List of data to NN
:type data: list
:param manual_reset: Whether need to reset the generator manually, usually it is handled by tensorflow
:param manual_reset: Whether need to reset the generator manually, usually it is handled by Keras
:type manual_reset: bool
:param sample_weight: Sample weights (if any)
:type sample_weight: Union([NoneType, ndarray])
Expand Down Expand Up @@ -108,7 +108,7 @@ class CVAEPredDataGenerator(GeneratorMaster):
:type data: list
:param key_name: key_name for the input data, default to "input"
:type key_name: str
:param manual_reset: Whether need to reset the generator manually, usually it is handled by tensorflow
:param manual_reset: Whether need to reset the generator manually, usually it is handled by Keras
:type manual_reset: bool
:param pbar: tqdm progress bar
:type pbar: obj
Expand Down Expand Up @@ -214,7 +214,6 @@ def compile(
metrics=None,
weighted_metrics=None,
loss_weights=None,
sample_weight_mode=None,
):
self.keras_encoder, self.keras_decoder = self.model()
self.keras_model = keras.Model(
Expand Down Expand Up @@ -245,7 +244,6 @@ def compile(
metrics=self.metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
)
self.keras_model.total_loss_tracker = keras.metrics.Mean(name="loss")
self.keras_model.reconstruction_loss_tracker = keras.metrics.Mean(
Expand Down Expand Up @@ -275,7 +273,6 @@ def recompile(
loss=None,
weighted_metrics=None,
loss_weights=None,
sample_weight_mode=None,
):
"""
To be used when you need to recompile a already existing model
Expand All @@ -286,7 +283,6 @@ def recompile(
metrics=self.metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
)

def custom_train_step(self, data):
Expand Down
Loading

0 comments on commit 3c6a14b

Please sign in to comment.