Skip to content

Commit

Permalink
fix testing assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
henrysky committed Jul 20, 2024
1 parent f5cb1ab commit 615373c
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 62 deletions.
3 changes: 1 addition & 2 deletions src/astroNN/gaia/gaia_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,7 @@ def logsol_to_fakemag(logsol, band="K"):

with warnings.catch_warnings(): # suppress numpy Runtime warning caused by MAGIC_NUMBER
warnings.simplefilter("ignore")
fakemag = absmag_to_fakemag(solar_absmag_bands[band] - logsol / 0.4)

fakemag = np.array(absmag_to_fakemag(solar_absmag_bands[band] - logsol / 0.4))
if fakemag.shape != (): # check if its only 1 element
fakemag[magic_idx] = MAGIC_NUMBER
else: # for float
Expand Down
14 changes: 6 additions & 8 deletions src/astroNN/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,12 @@ def call(self, inputs, training=None):
:rtype: tf.Tensor
"""
if self.always_on:
return tf.stop_gradient(inputs)
return keras.ops.stop_gradient(inputs)
else:
if training is None:
training = keras.backend.learning_phase()
output_tensor = keras.ops.where(
keras.ops.equal(training, True), tf.stop_gradient(inputs), inputs
keras.ops.equal(training, True), keras.ops.stop_gradient(inputs), inputs
)
output_tensor._uses_learning_phase = True
return output_tensor
Expand Down Expand Up @@ -495,10 +495,9 @@ def __init__(self, mask, name=None, **kwargs):
super().__init__(name=name, **kwargs)

def compute_output_shape(self, input_shape):
input_shape = len(input_shape)
# TODO: convert to keras
input_shape = input_shape.with_rank_at_least(2)
return input_shape[:-1].concatenate(self.mask_shape)
if len(input_shape) < 2:
raise ValueError(f"Shape {input_shape} must have rank at least 2")
return input_shape[:-1] + (self.mask_shape,)

def call(self, inputs, training=None):
"""
Expand All @@ -513,8 +512,7 @@ def call(self, inputs, training=None):

boolean_mask = keras.ops.any(keras.ops.not_equal(inputs, self.boolmask), axis=1, keepdims=True)

return keras.ops.reshape(
tf.boolean_mask(inputs, self.boolmask, axis=1), [batchsize, self.mask_shape]
return keras.ops.reshape(inputs[self.boolmask], [batchsize, self.mask_shape]
)

def get_config(self):
Expand Down
6 changes: 3 additions & 3 deletions src/astroNN/nn/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def median_internal(_x):
median = median_internal(x_flattened)
return median
else:
x_unstacked = keras.backend.core.unstack(
x_unstacked = keras.ops.unstack(
keras.ops.transpose(x), axis=axis
)
median = keras.ops.stack([median_internal(_x) for _x in x_unstacked])
Expand Down Expand Up @@ -529,7 +529,7 @@ def categorical_crossentropy(y_true, y_pred, sample_weight=None, from_logits=Fal
return weighted_loss(losses, sample_weight)
else:
losses = (
keras.ops.categorical_crossentropy(y_true, y_pred, from_logits=True)
keras.ops.categorical_crossentropy(target=y_true, output=y_pred, from_logits=True)
* correction
)
return weighted_loss(losses, sample_weight)
Expand Down Expand Up @@ -562,7 +562,7 @@ def binary_crossentropy(y_true, y_pred, sample_weight=None, from_logits=False):
y_pred = keras.ops.log(y_pred / (1.0 - y_pred))

cross_entropy = keras.ops.binary_crossentropy(
labels=y_true, logits=y_pred, from_logits=True
target=y_true, output=y_pred, from_logits=True
)
corrected_cross_entropy = keras.ops.where(
magic_num_check(y_true),
Expand Down
21 changes: 15 additions & 6 deletions src/astroNN/nn/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from astroNN.config import MAGIC_NUMBER


def mask_magicnum(x):
"""
Mask generation logic
"""
return (x == MAGIC_NUMBER) | np.isnan(x)


def sigmoid(x):
"""
NumPy implementation of tf.sigmoid, mask ``magicnumber``
Expand All @@ -17,7 +24,7 @@ def sigmoid(x):
:rtype: Union[ndarray, float]
:History: 2018-Apr-11 - Written - Henry Leung (University of Toronto)
"""
x = np.ma.array(x, mask=(x == MAGIC_NUMBER))
x = np.ma.array(x, mask=mask_magicnum(x))
return np.ma.divide(1, np.ma.add(1, np.divide(1, np.ma.exp(x))))


Expand All @@ -31,7 +38,7 @@ def sigmoid_inv(x):
:rtype: Union[numpy.ndarray, float]
:History: 2018-Apr-11 - Written - Henry Leung (University of Toronto)
"""
x = np.ma.array(x, mask=(x == MAGIC_NUMBER))
x = np.ma.array(x, mask=mask_magicnum(x))
return np.ma.log(np.ma.divide(x, np.ma.subtract(1, x)))


Expand Down Expand Up @@ -102,19 +109,20 @@ def mape_core(x, y, axis=None, mode=None):
)
else:
percentage = (x - y) / y
mask = (mask_magicnum(x) | mask_magicnum(y))
if mode == "mean":
return np.ma.mean(
np.ma.array(
np.abs(percentage) * 100.0,
mask=((x == MAGIC_NUMBER) | (y == MAGIC_NUMBER)),
mask=mask,
),
axis=axis,
)
elif mode == "median":
return np.ma.median(
np.ma.array(
np.abs(percentage) * 100.0,
mask=[(x == MAGIC_NUMBER) | (y == MAGIC_NUMBER)],
mask=mask,
),
axis=axis,
)
Expand Down Expand Up @@ -180,14 +188,15 @@ def mae_core(x, y, axis=None, mode=None):
)
else:
diff = x - y
mask = (mask_magicnum(x) | mask_magicnum(y))
if mode == "mean":
return np.ma.mean(
np.ma.array(np.abs(diff), mask=((x == MAGIC_NUMBER) | (y == MAGIC_NUMBER))),
np.ma.array(np.abs(diff), mask=(mask)),
axis=axis,
)
elif mode == "median":
return np.ma.median(
np.ma.array(np.abs(diff), mask=[(x == MAGIC_NUMBER) | (y == MAGIC_NUMBER)]),
np.ma.array(np.abs(diff), mask=mask),
axis=axis,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_apogee_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_apogee_identical_transfer(self):
mad_2 = median_absolute_deviation(pred2[:, 0], ydata[neuralnet.val_idx][:, 0], axis=None).numpy()

# accurancy sould be very similar as they are the same model
self.assertAlmostEqual(mad_1, mad_2)
npt.assert_almost_equal(mad_1, mad_2)


def test_apogee_transferlearning(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_apogee_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_apogee_continuum(self):
raw_spectra_err = np.zeros((10, 8575))
# continuum
cont_spectra, cont_spectra_arr = apogee_continuum(raw_spectra, raw_spectra_err)
self.assertAlmostEqual(float(np.mean(cont_spectra)), 1.0)
npt.assert_almost_equal(float(np.mean(cont_spectra)), 1.0)

def test_apogee_digit_extractor(self):
# Test apogeeid digit extractor
Expand Down
13 changes: 5 additions & 8 deletions tests/test_gaia_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def test_logsol(self):
npt.assert_array_equal(logsol_to_fakemag(fakemag_to_logsol(np.array([100, 100, 100]))), [100., 100., 100.])
npt.assert_equal(fakemag_to_logsol(MAGIC_NUMBER), MAGIC_NUMBER)
npt.assert_equal(logsol_to_fakemag(fakemag_to_logsol(MAGIC_NUMBER)), MAGIC_NUMBER)
npt.assert_equal(fakemag_to_logsol([MAGIC_NUMBER, 1000])[1], MAGIC_NUMBER)
npt.assert_equal(fakemag_to_logsol([MAGIC_NUMBER, 1000])[0], MAGIC_NUMBER)

npt.assert_equal(logsol_to_absmag(absmag_to_logsol(99.)), 99.)
self.assertAlmostEqual(logsol_to_absmag(absmag_to_logsol(-99.)), -99.)
npt.assert_almost_equal(logsol_to_absmag(absmag_to_logsol(-99.)), -99.)
npt.assert_array_equal(logsol_to_absmag(absmag_to_logsol([99., 99.])), [99., 99.])
npt.assert_array_almost_equal(logsol_to_absmag(absmag_to_logsol([-99., -99.])), [-99., -99.])
npt.assert_array_almost_equal(logsol_to_absmag(absmag_to_logsol(np.array([99., 99., 99.]))), [99., 99., 99.])
Expand All @@ -101,18 +101,15 @@ def test_logsol(self):
def test_extinction(self):
from astroNN.gaia import extinction_correction

npt.assert_raises(AssertionError, npt.assert_array_equal, extinction_correction(10., -90.)[1], -9999)
npt.assert_equal(extinction_correction(10., -90.), 10.)
npt.assert_equal(np.any([extinction_correction(-99.99, -90.) == -9999.]))
npt.assert_equal(extinction_correction(-99.99, -90.), MAGIC_NUMBER)

def test_known_regression(self):
# prevent regression of known bug
npt.assert_equal(mag_to_absmag(1., MAGIC_NUMBER), MAGIC_NUMBER)
npt.assert_equal(mag_to_absmag(MAGIC_NUMBER, MAGIC_NUMBER), MAGIC_NUMBER)
npt.assert_equal(
np.all(mag_to_absmag(MAGIC_NUMBER, MAGIC_NUMBER, 1.), (MAGIC_NUMBER, MAGIC_NUMBER)))
npt.assert_equal(
np.all(mag_to_fakemag(MAGIC_NUMBER, MAGIC_NUMBER, 1.), (MAGIC_NUMBER, MAGIC_NUMBER)))
npt.assert_equal(mag_to_absmag(MAGIC_NUMBER, MAGIC_NUMBER, 1.), (MAGIC_NUMBER, MAGIC_NUMBER))
npt.assert_equal(mag_to_fakemag(MAGIC_NUMBER, MAGIC_NUMBER, 1.), (MAGIC_NUMBER, MAGIC_NUMBER))

npt.assert_equal(mag_to_fakemag(1., MAGIC_NUMBER), MAGIC_NUMBER)
npt.assert_equal(mag_to_fakemag(MAGIC_NUMBER, MAGIC_NUMBER), MAGIC_NUMBER)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_FastMCInference(self):
y = acc_model.predict(random_xdata)
self.assertEqual(np.any(np.not_equal(x, y[:, :, 0])), True)
# make sure accelerated model has no variance (uncertainty) on deterministic model prediction
self.assertAlmostEqual(np.sum(y[:, :, 1]), 0.0)
npt.assert_almost_equal(np.sum(y[:, :, 1]), 0.0)

# assert error raised for things other than keras model
self.assertRaises(TypeError, FastMCInference(10), "123")
Expand All @@ -248,7 +248,7 @@ def test_FastMCInference(self):
sy = acc_smodel.predict(random_xdata)
self.assertEqual(np.any(np.not_equal(sx, sy[:, :, 0])), True)
# make sure accelerated model has no variance (uncertainty) on deterministic model prediction
self.assertAlmostEqual(np.sum(sy[:, :, 1]), 0.0)
npt.assert_almost_equal(np.sum(sy[:, :, 1]), 0.0)

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_mnist(self):
# test with astype boolean deliberately
eval_result_again = mnist_reloaded_again.evaluate(x_test, keras.utils.to_categorical(y_test, 10).astype(bool))
# assert saving again wont affect the model
self.assertAlmostEqual(eval_result_again['loss'], eval_result['loss'], places=3)
npt.assert_almost_equal(eval_result_again['loss'], eval_result['loss'], places=3)


class Models_TestCase2(unittest.TestCase):
Expand Down
8 changes: 3 additions & 5 deletions tests/test_numpy_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ def test_regularizator(self):
npt.assert_array_almost_equal(tf_x_2.numpy(), astroNN_x_2)

def test_numpy_metrics(self):
x = np.array([-2., 2.])
x = np.array([-2., 2.])
y = np.array([MAGIC_NUMBER, 4.])

# ------------------- Mean ------------------- #
mean_absolute_error([2., 3., 7.], [2., 0., 7.])
mape = mean_absolute_percentage_error(x * u.kpc, y * u.kpc)
mape_ubnitless = mean_absolute_percentage_error(x, y)
mape_unitless = mean_absolute_percentage_error(x, y)
npt.assert_array_equal(mape, 50.)
npt.assert_array_equal(mape, mape_ubnitless)
npt.assert_array_equal(mape, mape_unitless)
# assert error raise if only x or y carries astropy units
self.assertRaises(TypeError, mean_absolute_percentage_error, x * u.kpc, y)
self.assertRaises(TypeError, mean_absolute_percentage_error, x, y * u.kpc)
Expand All @@ -95,7 +94,6 @@ def test_numpy_metrics(self):
self.assertRaises(TypeError, mean_absolute_error, x, y * u.kpc)

# ------------------- Median ------------------- #

self.assertEqual(median_absolute_percentage_error([2., 3., 7.], [2., 1., 7.]), 0.)
self.assertEqual(median_absolute_error([2., 3., 7.], [2., 1., 7.]), 0.)

Expand Down
69 changes: 44 additions & 25 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,22 @@ class UtilitiesTestCase(unittest.TestCase):
def test_checksum(self):
import astroNN
from astroNN.shared.downloader_tools import filehash
anderson2017_path = os.path.join(os.path.dirname(astroNN.__path__[0]), 'astroNN', 'data',
'anderson_2017_dr14_parallax.npz')
md5_pred = filehash(anderson2017_path, algorithm='md5')
sha1_pred = filehash(anderson2017_path, algorithm='sha1')
sha256_pred = filehash(anderson2017_path, algorithm='sha256')

test_data_path = os.path.join(
os.path.dirname(astroNN.__path__[0]), "astroNN", "data", "dr17_contmask.npy"
)
md5_pred = filehash(test_data_path, algorithm="md5")
sha1_pred = filehash(test_data_path, algorithm="sha1")
sha256_pred = filehash(test_data_path, algorithm="sha256")

# read answer hashed by Windows Get-FileHash
self.assertEqual(md5_pred, '9C714F5FE22BB7C4FF9EA32F3E859D73'.lower())
self.assertEqual(sha1_pred, '733C0227CF93DB0CD6106B5349402F251E7ED735'.lower())
self.assertEqual(sha256_pred, '36C265C907F440114D747DA21D2A014D32B5E442D541F183C0EE862F5865FD26'.lower())
self.assertRaises(ValueError, filehash, anderson2017_path, algorithm='sha123')
self.assertEqual(md5_pred, "a646a9707e7aa2d943417c7e603e3731".lower())
self.assertEqual(sha1_pred, "f701087e845b12b43f87c0d49fd15597bac9f171".lower())
self.assertEqual(
sha256_pred,
"a5705443e33698547ff6f7d6145ff8a4b8b3051a425aef468490a06e233dadb1".lower(),
)
self.assertRaises(ValueError, filehash, test_data_path, algorithm="sha123")

def test_normalizer(self):
from astroNN.nn.utilities.normalizer import Normalizer
Expand All @@ -32,40 +37,54 @@ def test_normalizer(self):
# create a normalizer instance for mode 0
normer = Normalizer(mode=0)
norm_data = normer.normalize(data)
self.assertEqual(norm_data[magic_idx], MAGIC_NUMBER) # make sure normalizer preserve magic_number
# make sure normalizer preserve magic_number
npt.assert_equal(norm_data[magic_idx], MAGIC_NUMBER)
# test demoralize
data_denorm = normer.denormalize(norm_data)
# make sure demoralizer preserve magic_number
self.assertEqual(data_denorm[magic_idx], MAGIC_NUMBER)
npt.assert_equal(data_denorm[magic_idx], MAGIC_NUMBER)
npt.assert_array_almost_equal(data_denorm, data)
npt.assert_array_almost_equal(norm_data, data)

# create a normalizer instance for mode 1
normer = Normalizer(mode=1)
norm_data = normer.normalize(data)
self.assertEqual(norm_data[magic_idx], MAGIC_NUMBER) # make sure normalizer preserve magic_number
# make sure normalizer preserve magic_number
npt.assert_equal(norm_data[magic_idx], MAGIC_NUMBER)
# test demoralize
data_denorm = normer.denormalize(norm_data)
# make sure demoralizer preserve magic_number
self.assertEqual(data_denorm[magic_idx], MAGIC_NUMBER)
npt.assert_equal(data_denorm[magic_idx], MAGIC_NUMBER)
npt.assert_array_almost_equal(data_denorm, data)

# test mode='3s' can do identity transformation
s3_norm = Normalizer(mode='3s')
s3_norm = Normalizer(mode="3s")
data = np.random.normal(0, 1, (100, 10))
npt.assert_array_almost_equal(s3_norm.denormalize(s3_norm.normalize(data)), data, decimal=5)
npt.assert_array_almost_equal(
s3_norm.denormalize(s3_norm.normalize(data)), data, decimal=5
)

data_8bit = np.random.randint(0, 256, (100, 50, 50))
normer = Normalizer(mode=255)
norm_data_8bit = normer.normalize(data_8bit)
self.assertEqual(np.max(norm_data_8bit), 1.) # make sure max of normalized image is 1.
self.assertEqual(np.min(norm_data_8bit), 0.) # make sure max of normalized image is 0.

normer = Normalizer(mode={'input': 255, 'aux': 0})
norm_data_dict = normer.normalize({'input': data_8bit, 'aux': data})
self.assertEqual(np.max(norm_data_dict['input']), 1.) # make sure max of normalized image is 1.
self.assertEqual(np.min(norm_data_dict['input']), 0.) # make sure max of normalized image is 0.
npt.assert_array_almost_equal(norm_data_dict['aux'], data) # make sure aux data is not normalized in this case
self.assertEqual(
np.max(norm_data_8bit), 1.0
) # make sure max of normalized image is 1.
self.assertEqual(
np.min(norm_data_8bit), 0.0
) # make sure max of normalized image is 0.

normer = Normalizer(mode={"input": 255, "aux": 0})
norm_data_dict = normer.normalize({"input": data_8bit, "aux": data})
self.assertEqual(
np.max(norm_data_dict["input"]), 1.0
) # make sure max of normalized image is 1.
self.assertEqual(
np.min(norm_data_dict["input"]), 0.0
) # make sure max of normalized image is 0.
npt.assert_array_almost_equal(
norm_data_dict["aux"], data
) # make sure aux data is not normalized in this case

errorous_norm = Normalizer(mode=-1234)
self.assertRaises(ValueError, errorous_norm.normalize, data)
Expand Down Expand Up @@ -94,9 +113,9 @@ def test_config(self):

def test_pltstyle(self):
from astroNN.shared import pylab_style

pylab_style()


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit 615373c

Please sign in to comment.