Skip to content

Commit

Permalink
Fix saturation flag issue for numpy 2+ in romancal (#305)
Browse files Browse the repository at this point in the history
* Fix for roman numpy 2 issue

* Add tests

* Add changelog
  • Loading branch information
WilliamJamieson authored Oct 10, 2024
1 parent e651753 commit 28b1245
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
1 change: 1 addition & 0 deletions changes/305.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `IntEnum` saturation flag issue with numpy 2+ for romancal.
8 changes: 4 additions & 4 deletions src/stcal/saturation/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def flag_saturated_pixels(
updated pixel dq array
"""
nints, ngroups, nrows, ncols = data.shape
dnu = dqflags["DO_NOT_USE"]
saturated = dqflags["SATURATED"]
ad_floor = dqflags["AD_FLOOR"]
no_sat_check = dqflags["NO_SAT_CHECK"]
dnu = int(dqflags["DO_NOT_USE"])
saturated = int(dqflags["SATURATED"])
ad_floor = int(dqflags["AD_FLOOR"])
no_sat_check = int(dqflags["NO_SAT_CHECK"])

# Identify pixels flagged in reference file as NO_SAT_CHECK,
no_sat_check_mask = np.bitwise_and(sat_dq, no_sat_check) == no_sat_check
Expand Down
25 changes: 25 additions & 0 deletions tests/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Unit tests for saturation flagging
"""
from enum import IntEnum

import numpy as np

Expand Down Expand Up @@ -227,3 +228,27 @@ def test_zero_frame():
# Check ZEROFRAME flagged elements are zeroed out.
assert zframe[0, 0, 0] == 0.0
assert zframe[1, 0, 1] == 0.0


def test_intenum_flags():
"""
In numpy 2.0 IntEnums induce a failure in bitwise_or. Romancal uses IntEnums
for clarity rather than raw dictionaries of integers
"""

class DqFlags(IntEnum):
GOOD = 0
DO_NOT_USE = 1
SATURATED = 2
AD_FLOOR = 64
NO_SAT_CHECK = 2**21

# Create inputs, data, and saturation maps
data = np.zeros((1, 5, 20, 20), dtype=np.float32)
gdq = np.zeros((1, 5, 20, 20), dtype=np.uint32)
pdq = np.zeros((20, 20), dtype=np.uint32)
sat_thresh = np.ones((20, 20)) * 100000.0
sat_dq = np.zeros((20, 20), dtype=np.uint32)

# Simple test to check no errors are raised
_ = flag_saturated_pixels(data, gdq, pdq, sat_thresh, sat_dq, ATOD_LIMIT, DqFlags)

0 comments on commit 28b1245

Please sign in to comment.