diff --git a/src/specarray/specarray.py b/src/specarray/specarray.py index 3c28f64..7cdd213 100644 --- a/src/specarray/specarray.py +++ b/src/specarray/specarray.py @@ -49,6 +49,7 @@ def spectral_albedo(self): ) spectral_albedo = xr.where(spectral_albedo < 0.0, 0.0, spectral_albedo) spectral_albedo = xr.where(spectral_albedo > 1.0, 1.0, spectral_albedo) + spectral_albedo.name = "spectral albedo" return spectral_albedo @property diff --git a/src/specarray/specarray.test.py b/src/specarray/specarray.test.py new file mode 100644 index 0000000..118234d --- /dev/null +++ b/src/specarray/specarray.test.py @@ -0,0 +1,44 @@ +import numpy as np +import pandas as pd +import xarray as xr +from pathlib import Path +from pytest_check import check + +from specarray.specarray import SpecArray + + +def test_specarray(): + # Create a SpecArray object from a folder + folder = Path("/path/to/folder") + specarray = SpecArray.from_folder(folder) + + # Test the __len__ method + check(len(specarray) == specarray.capture.shape[0]) + + # Test the __getitem__ method + check(np.array_equal(specarray[0], specarray.capture[0])) + + # Test the shape property + check(specarray.shape == specarray.capture.shape) + + # Test the broadband_albedo property + broadband_albedo = specarray.broadband_albedo + check(isinstance(broadband_albedo, xr.DataArray)) + check(broadband_albedo.shape == (specarray.capture.shape[0], specarray.capture.shape[2])) + check(broadband_albedo.name == "broadband_albedo") + + # Test the spectral_albedo property + spectral_albedo = specarray.spectral_albedo + check(isinstance(spectral_albedo, xr.DataArray)) + check(spectral_albedo.shape == specarray.capture.shape) + check(spectral_albedo.name == "capture") + check(spectral_albedo.coords["wavelength"].equals(specarray.capture.coords["wavelength"])) + check(spectral_albedo.min() >= 0.0) + check(spectral_albedo.max() <= 1.0) + + # Test the _gen_wavelength_point_df method + raw_array = np.random.rand(specarray.capture.shape[0], specarray.capture.shape[2]) + df = specarray._gen_wavelength_point_df(raw_array) + check(isinstance(df, pd.DataFrame)) + check(df.shape == (len(specarray.wavelengths), specarray.capture.shape[2])) + check(df.index.equals(specarray.wavelengths)) diff --git a/tests/test_methods.py b/tests/test_methods.py index c572761..15a4a8a 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -25,3 +25,13 @@ def test_getitem(testdata_specim: SpecArray): def test_shape(testdata_specim: SpecArray): check.equal(testdata_specim.shape, (2, 1024, 448)) + + +def test_spectral_albedo(testdata_specim: SpecArray): + spectral_albedo = testdata_specim.spectral_albedo + check.is_instance(spectral_albedo, DataArray) + check.equal(spectral_albedo.shape, testdata_specim.capture.shape) + check.equal(spectral_albedo.name, "spectral albedo") + check.is_true(spectral_albedo.coords["wavelength"].equals(testdata_specim.capture.coords["wavelength"])) + check.greater_equal(spectral_albedo.min(), 0.0) + check.less_equal(spectral_albedo.min(), 1.0)