Skip to content

Commit

Permalink
Fix sunz correction converting 32-bit floats to 64-bit floats
Browse files Browse the repository at this point in the history
Also fixes that input data types were inconsistent between dask arrays and computed numpy results.
  • Loading branch information
djhoese committed Oct 2, 2023
1 parent 37170e1 commit 8b8bbc2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
9 changes: 7 additions & 2 deletions satpy/modifiers/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ def get_cos_sza(data_arr: xr.DataArray) -> xr.DataArray:
@cache_to_zarr_if("cache_lonlats", sanitize_args_func=_sanitize_args_with_chunks)
def _get_valid_lonlats(area: PRGeometry, chunks: Union[int, str, tuple] = "auto") -> tuple[da.Array, da.Array]:
with ignore_invalid_float_warnings():
# NOTE: This defaults to 64-bit floats due to needed precision for X/Y coordinates
lons, lats = area.get_lonlats(chunks=chunks)
lons = da.where(lons >= 1e30, np.nan, lons)
lats = da.where(lats >= 1e30, np.nan, lats)
Expand Down Expand Up @@ -526,7 +527,7 @@ def _sunzen_corr_cos_ndarray(data: np.ndarray,
max_sza_rad = np.deg2rad(max_sza) if max_sza is not None else max_sza

# Cosine correction
corr = 1. / cos_zen
corr = (1. / cos_zen).astype(data.dtype, copy=False)
if max_sza is not None:
# gradually fall off for larger zenith angle
grad_factor = (np.arccos(cos_zen) - limit_rad) / (max_sza_rad - limit_rad)
Expand All @@ -538,7 +539,11 @@ def _sunzen_corr_cos_ndarray(data: np.ndarray,
else:
# Use constant value (the limit) for larger zenith angles
grad_factor = 1.
corr = np.where(cos_zen > limit_cos, corr, grad_factor / limit_cos)
corr = np.where(
cos_zen > limit_cos,
corr,
(grad_factor / limit_cos).astype(data.dtype, copy=False)
)
# Force "night" pixels to 0 (where SZA is invalid)
corr[np.isnan(cos_zen)] = 0
return data * corr
10 changes: 8 additions & 2 deletions satpy/tests/test_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,23 @@ def sunz_sza():
class TestSunZenithCorrector:
"""Test case for the zenith corrector."""

def test_basic_default_not_provided(self, sunz_ds1):
@pytest.mark.parametrize("as_32bit", [False, True])
def test_basic_default_not_provided(self, sunz_ds1, as_32bit):
"""Test default limits when SZA isn't provided."""
from satpy.modifiers.geometry import SunZenithCorrector

if as_32bit:
sunz_ds1 = sunz_ds1.astype(np.float32)
comp = SunZenithCorrector(name='sza_test', modifiers=tuple())
res = comp((sunz_ds1,), test_attr='test')
np.testing.assert_allclose(res.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]))
assert 'y' in res.coords
assert 'x' in res.coords
ds1 = sunz_ds1.copy().drop_vars(('y', 'x'))
res = comp((ds1,), test_attr='test')
np.testing.assert_allclose(res.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]))
res_np = res.compute()
np.testing.assert_allclose(res_np.values, np.array([[22.401667, 22.31777], [22.437503, 22.353533]]))
assert res.dtype == res_np.dtype
assert 'y' not in res.coords
assert 'x' not in res.coords

Expand Down

0 comments on commit 8b8bbc2

Please sign in to comment.