Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-3521: Fix jump multiprocessing error #239

Merged
merged 19 commits into from
Feb 1, 2024
Merged
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ jump
- Added more allowable selections for the number of cores to use for
multiprocessing [#183].

- Fixed the computation of the number of rows per slice for multiprocessing,
which caused different results when running the step with multiprocess [#239]

ramp_fitting
~~~~~~~~~~~~

Expand Down
28 changes: 21 additions & 7 deletions src/stcal/jump/jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def detect_jumps(
err *= gain_2d
readnoise_2d *= gain_2d
# also apply to the after_jump thresholds
after_jump_flag_e1 = after_jump_flag_dn1 * gain_2d
after_jump_flag_e2 = after_jump_flag_dn2 * gain_2d
after_jump_flag_e1 = after_jump_flag_dn1 * np.nanmedian(gain_2d)
after_jump_flag_e2 = after_jump_flag_dn2 * np.nanmedian(gain_2d)

# Apply the 2-point difference method as a first pass
log.info("Executing two-point difference method")
Expand Down Expand Up @@ -277,6 +277,12 @@ def detect_jumps(
minimum_sigclip_groups=minimum_sigclip_groups,
only_use_ints=only_use_ints,
)
# remove redundant bits in pixels that have jump flagged but were
# already flagged as do_not_use or saturated.
gdq[gdq == np.bitwise_or(dqflags['DO_NOT_USE'], dqflags['JUMP_DET'])] = \
dqflags['DO_NOT_USE']
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']
# This is the flag that controls the flagging of snowballs.
if expand_large_events:
total_snowballs = flag_large_events(
Expand Down Expand Up @@ -314,7 +320,7 @@ def detect_jumps(
log.info("Total showers= %i", num_showers)
number_extended_events = num_showers
else:
yinc = int(n_rows / n_slices)
yinc = int(n_rows // n_slices)
slices = []
# Slice up data, gdq, readnoise_2d into slices
# Each element of slices is a tuple of
Expand All @@ -323,17 +329,16 @@ def detect_jumps(

# must copy arrays here, find_crs will make copies but if slices
# are being passed in for multiprocessing then the original gdq will be
# modified unless copied beforehand
# modified unless copied beforehand.
gdq = gdq.copy()
data = data.copy()
copy_arrs = False # we don't need to copy arrays again in find_crs

for i in range(n_slices - 1):
slices.insert(
i,
(
data[:, :, i * yinc : (i + 1) * yinc, :],
gdq[:, :, i * yinc : (i + 1) * yinc, :],
gdq[:, :, i * yinc : (i + 1) * yinc, :].copy(),
readnoise_2d[i * yinc : (i + 1) * yinc, :],
rejection_thresh,
three_grp_thresh,
Expand All @@ -359,7 +364,7 @@ def detect_jumps(
n_slices - 1,
(
data[:, :, (n_slices - 1) * yinc : n_rows, :],
gdq[:, :, (n_slices - 1) * yinc : n_rows, :],
gdq[:, :, (n_slices - 1) * yinc : n_rows, :].copy() ,
readnoise_2d[(n_slices - 1) * yinc : n_rows, :],
rejection_thresh,
three_grp_thresh,
Expand All @@ -381,6 +386,8 @@ def detect_jumps(
)
log.info("Creating %d processes for jump detection ", n_slices)
pool = multiprocessing.Pool(processes=n_slices)
######### JUST FOR DEBUGGING #########################
# pool = multiprocessing.Pool(processes=1)
# Starts each slice in its own process. Starmap allows more than one
# parameter to be passed.
real_result = pool.starmap(twopt.find_crs, slices)
Expand Down Expand Up @@ -427,6 +434,13 @@ def detect_jumps(
# save the neighbors to be flagged that will be in the next slice
previous_row_above_gdq = row_above_gdq.copy()
k += 1
# remove redundant bits in pixels that have jump flagged but were
# already flagged as do_not_use or saturated.
gdq[gdq == np.bitwise_or(dqflags['DO_NOT_USE'], dqflags['JUMP_DET'])] = \
dqflags['DO_NOT_USE']
gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \
dqflags['SATURATED']

# This is the flag that controls the flagging of snowballs.
if expand_large_events:
total_snowballs = flag_large_events(
Expand Down
12 changes: 4 additions & 8 deletions src/stcal/jump/twopoint_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,21 +401,17 @@ def find_crs(
# the transient seen after ramp jumps
flag_e_threshold = [after_jump_flag_e1, after_jump_flag_e2]
flag_groups = [after_jump_flag_n1, after_jump_flag_n2]

for cthres, cgroup in zip(flag_e_threshold, flag_groups):
if cgroup > 0:
if cgroup > 0 and cthres > 0:
cr_intg, cr_group, cr_row, cr_col = np.where(np.bitwise_and(gdq, jump_flag))
for j in range(len(cr_group)):
intg = cr_intg[j]
group = cr_group[j]
row = cr_row[j]
col = cr_col[j]
if e_jump_4d[intg, group - 1, row, col] >= cthres[row, col]:
for kk in range(group, min(group + cgroup + 1, ngroups)):
if (gdq[intg, kk, row, col] & sat_flag) == 0 and (
gdq[intg, kk, row, col] & dnu_flag
) == 0:
gdq[intg, kk, row, col] = np.bitwise_or(gdq[integ, kk, row, col], jump_flag)
if e_jump_4d[intg, group - 1, row, col] >= cthres:
for kk in range(group + 1, min(group + cgroup + 1, ngroups)):
gdq[intg, kk, row, col] = np.bitwise_or(gdq[intg, kk, row, col], jump_flag)
if "stddev" in locals():
return gdq, row_below_gdq, row_above_gdq, num_primary_crs, stddev

Expand Down
94 changes: 94 additions & 0 deletions tests/test_jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
find_faint_extended,
flag_large_events,
point_inside_ellipse,
detect_jumps,
)

DQFLAGS = {"JUMP_DET": 4, "SATURATED": 2, "DO_NOT_USE": 1, "GOOD": 0, "NO_GAIN_VALUE": 8}
Expand All @@ -30,6 +31,99 @@ def _cube(ngroups, readnoise=10):
return _cube


def test_multiprocessing():
nints = 1
nrows = 13
ncols = 2
ngroups = 13
readnoise = 10
frames_per_group = 1

data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 4
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
num_cores = "1"
data[0, 4:, 5, 1] = 2000
gdq[0, 4:, 6, 1] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
print(data[0, 4, :, :])
print(gdq[0, 4, :, :])
assert gdq[0, 4, 5, 1] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 6, 1] == DQFLAGS['DO_NOT_USE']

# This section of code will fail without the fixes for PR #239 that prevent
# the double flagging pixels with jump which already have do_not_use or saturation set.
num_cores = "5"
data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 3
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
data[0, 4:, 5, 1] = 2000
gdq[0, 4:, 6, 1] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
assert gdq[0, 4, 5, 1] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 6, 1] == DQFLAGS['DO_NOT_USE'] #This value would have been 5 without the fix.


def test_multiprocessing_big():
nints = 1
nrows = 2048
ncols = 7
ngroups = 13
readnoise = 10
frames_per_group = 1

data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 4
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
num_cores = "1"
data[0, 4:, 204, 5] = 2000
gdq[0, 4:, 204, 6] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
print(data[0, 4, :, :])
print(gdq[0, 4, :, :])
assert gdq[0, 4, 204, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 205, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 204, 6] == DQFLAGS['DO_NOT_USE']

# This section of code will fail without the fixes for PR #239 that prevent
# the double flagging pixels with jump which already have do_not_use or saturation set.
num_cores = "10"
data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32)
gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32)
pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32)
readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise
gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 3
err = np.zeros(shape=(nrows, ncols), dtype=np.float32)
data[0, 4:, 204, 5] = 2000
gdq[0, 4:, 204, 6] = DQFLAGS['DO_NOT_USE']
gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps(
frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6,
four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100,
flag_4_neighbors=True, dqflags=DQFLAGS)
assert gdq[0, 4, 204, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 205, 5] == DQFLAGS['JUMP_DET']
assert gdq[0, 4, 204, 6] == DQFLAGS['DO_NOT_USE'] #This value would have been 5 without the fix.



def test_find_simple_ellipse():
plane = np.zeros(shape=(5, 5), dtype=np.uint8)
plane[2, 2] = DQFLAGS["JUMP_DET"]
Expand Down
10 changes: 5 additions & 5 deletions tests/test_twopoint_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ def test_10grps_1cr_afterjump(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 0.0
after_jump_flag_e1 = 1.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down Expand Up @@ -891,7 +891,7 @@ def test_10grps_1cr_afterjump_2group(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 0.0
after_jump_flag_e1 = 1.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down Expand Up @@ -932,7 +932,7 @@ def test_10grps_1cr_afterjump_toosmall(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 10000.0
after_jump_flag_e1 = 10000.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down Expand Up @@ -968,8 +968,8 @@ def test_10grps_1cr_afterjump_twothresholds(setup_cube):
data[0, 8, 100, 100] = 1190
data[0, 9, 100, 100] = 1209

after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 500.0
after_jump_flag_e2 = np.full(data.shape[2:4], 1.0) * 10.0
after_jump_flag_e1 = 500.0
after_jump_flag_e2 = 10.0
out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs(
data,
gdq,
Expand Down
Loading