Skip to content

Commit

Permalink
added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tgoelles committed Sep 21, 2023
1 parent 328d0dd commit 3cb24cf
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/specarray/specarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def spectral_albedo(self):
)
spectral_albedo = xr.where(spectral_albedo < 0.0, 0.0, spectral_albedo)
spectral_albedo = xr.where(spectral_albedo > 1.0, 1.0, spectral_albedo)
spectral_albedo.name = "spectral albedo"
return spectral_albedo

@property
Expand Down
44 changes: 44 additions & 0 deletions src/specarray/specarray.test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path
from pytest_check import check

from specarray.specarray import SpecArray


def test_specarray():
# Create a SpecArray object from a folder
folder = Path("/path/to/folder")
specarray = SpecArray.from_folder(folder)

# Test the __len__ method
check(len(specarray) == specarray.capture.shape[0])

# Test the __getitem__ method
check(np.array_equal(specarray[0], specarray.capture[0]))

# Test the shape property
check(specarray.shape == specarray.capture.shape)

# Test the broadband_albedo property
broadband_albedo = specarray.broadband_albedo
check(isinstance(broadband_albedo, xr.DataArray))
check(broadband_albedo.shape == (specarray.capture.shape[0], specarray.capture.shape[2]))
check(broadband_albedo.name == "broadband_albedo")

# Test the spectral_albedo property
spectral_albedo = specarray.spectral_albedo
check(isinstance(spectral_albedo, xr.DataArray))
check(spectral_albedo.shape == specarray.capture.shape)
check(spectral_albedo.name == "capture")
check(spectral_albedo.coords["wavelength"].equals(specarray.capture.coords["wavelength"]))
check(spectral_albedo.min() >= 0.0)
check(spectral_albedo.max() <= 1.0)

# Test the _gen_wavelength_point_df method
raw_array = np.random.rand(specarray.capture.shape[0], specarray.capture.shape[2])
df = specarray._gen_wavelength_point_df(raw_array)
check(isinstance(df, pd.DataFrame))
check(df.shape == (len(specarray.wavelengths), specarray.capture.shape[2]))
check(df.index.equals(specarray.wavelengths))
10 changes: 10 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,13 @@ def test_getitem(testdata_specim: SpecArray):

def test_shape(testdata_specim: SpecArray):
check.equal(testdata_specim.shape, (2, 1024, 448))


def test_spectral_albedo(testdata_specim: SpecArray):
spectral_albedo = testdata_specim.spectral_albedo
check.is_instance(spectral_albedo, DataArray)
check.equal(spectral_albedo.shape, testdata_specim.capture.shape)
check.equal(spectral_albedo.name, "spectral albedo")
check.is_true(spectral_albedo.coords["wavelength"].equals(testdata_specim.capture.coords["wavelength"]))
check.greater_equal(spectral_albedo.min(), 0.0)
check.less_equal(spectral_albedo.min(), 1.0)

0 comments on commit 3cb24cf

Please sign in to comment.