diff --git a/src/astroNN/models/apogee_models.py b/src/astroNN/models/apogee_models.py index 4a17fea3..ef294c9b 100644 --- a/src/astroNN/models/apogee_models.py +++ b/src/astroNN/models/apogee_models.py @@ -1233,9 +1233,7 @@ class DeNormAdd(keras.layers.Layer): def __init__(self, norm, name=None, **kwargs): self.norm = norm self.supports_masking = True - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) + super().__init__(name=name, **kwargs) def call(self, inputs, training=None): diff --git a/src/astroNN/nn/layers.py b/src/astroNN/nn/layers.py index f5328247..d013752b 100644 --- a/src/astroNN/nn/layers.py +++ b/src/astroNN/nn/layers.py @@ -18,9 +18,6 @@ class KLDivergenceLayer(Layer): def __init__(self, name=None, **kwargs): self.is_placeholder = True - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def call(self, inputs, training=None): @@ -59,9 +56,6 @@ class VAESampling(Layer): def __init__(self, name=None, **kwargs): self.supports_masking = True - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def call(self, inputs): @@ -90,9 +84,6 @@ def __init__(self, rate, disable=False, noise_shape=None, name=None, **kwargs): self.disable_layer = disable self.supports_masking = True self.noise_shape = noise_shape - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def _get_noise_shape(self, inputs): @@ -200,9 +191,6 @@ def __init__(self, rate, disable=False, name=None, **kwargs): self.disable_layer = disable self.supports_masking = True self.rate = rate - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def call(self, inputs, training=None): @@ -244,9 +232,6 @@ class ErrorProp(Layer): def __init__(self, name=None, **kwargs): self.supports_masking = True - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def call(self, inputs, training=None): @@ -368,9 +353,6 @@ class FastMCInferenceMeanVar(Layer): """ def __init__(self, name=None, **kwargs): - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def compute_output_shape(self, input_shape): @@ -413,9 +395,6 @@ class FastMCRepeat(Layer): def __init__(self, n, name=None, **kwargs): self.n = n - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def compute_output_shape(self, input_shape): @@ -460,9 +439,6 @@ class StopGrad(Layer): """ def __init__(self, name=None, always_on=False, **kwargs): - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) self.always_on = always_on @@ -516,9 +492,6 @@ def __init__(self, mask, name=None, **kwargs): self.boolmask = mask self.mask_shape = self.boolmask.sum() self.supports_masking = True - if not name: - prefix = self.__class__.__name__ - name = prefix + "_" + str(keras.utils.naming.auto_name(prefix)) super().__init__(name=name, **kwargs) def compute_output_shape(self, input_shape): diff --git a/src/astroNN/shared/nn_tools.py b/src/astroNN/shared/nn_tools.py index d9a6771f..24db1e44 100644 --- a/src/astroNN/shared/nn_tools.py +++ b/src/astroNN/shared/nn_tools.py @@ -3,11 +3,15 @@ # ---------------------------------------------------------# import datetime import os -import keras import inspect import warnings from astroNN.config import _KERAS_BACKEND +try: + import keras.src as keras +except ModuleNotFoundError: + import keras + # TODO: removed gpu_memory_manage() and gpu_availability() as they are not used in astroNN