diff --git a/CHANGES.rst b/CHANGES.rst index b22ce3e2..48d9eda8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -86,6 +86,9 @@ jump - Added more allowable selections for the number of cores to use for multiprocessing [#183]. +- Fixed the computation of the number of rows per slice for multiprocessing, + which caused different results when running the step with multiprocess [#239] + ramp_fitting ~~~~~~~~~~~~ diff --git a/src/stcal/jump/jump.py b/src/stcal/jump/jump.py index 6139d103..d6aeb3ad 100644 --- a/src/stcal/jump/jump.py +++ b/src/stcal/jump/jump.py @@ -235,8 +235,8 @@ def detect_jumps( err *= gain_2d readnoise_2d *= gain_2d # also apply to the after_jump thresholds - after_jump_flag_e1 = after_jump_flag_dn1 * gain_2d - after_jump_flag_e2 = after_jump_flag_dn2 * gain_2d + after_jump_flag_e1 = after_jump_flag_dn1 * np.nanmedian(gain_2d) + after_jump_flag_e2 = after_jump_flag_dn2 * np.nanmedian(gain_2d) # Apply the 2-point difference method as a first pass log.info("Executing two-point difference method") @@ -277,6 +277,12 @@ def detect_jumps( minimum_sigclip_groups=minimum_sigclip_groups, only_use_ints=only_use_ints, ) + # remove redundant bits in pixels that have jump flagged but were + # already flagged as do_not_use or saturated. + gdq[gdq == np.bitwise_or(dqflags['DO_NOT_USE'], dqflags['JUMP_DET'])] = \ + dqflags['DO_NOT_USE'] + gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \ + dqflags['SATURATED'] # This is the flag that controls the flagging of snowballs. if expand_large_events: total_snowballs = flag_large_events( @@ -314,7 +320,7 @@ def detect_jumps( log.info("Total showers= %i", num_showers) number_extended_events = num_showers else: - yinc = int(n_rows / n_slices) + yinc = int(n_rows // n_slices) slices = [] # Slice up data, gdq, readnoise_2d into slices # Each element of slices is a tuple of @@ -323,17 +329,16 @@ def detect_jumps( # must copy arrays here, find_crs will make copies but if slices # are being passed in for multiprocessing then the original gdq will be - # modified unless copied beforehand + # modified unless copied beforehand. gdq = gdq.copy() data = data.copy() copy_arrs = False # we don't need to copy arrays again in find_crs - for i in range(n_slices - 1): slices.insert( i, ( data[:, :, i * yinc : (i + 1) * yinc, :], - gdq[:, :, i * yinc : (i + 1) * yinc, :], + gdq[:, :, i * yinc : (i + 1) * yinc, :].copy(), readnoise_2d[i * yinc : (i + 1) * yinc, :], rejection_thresh, three_grp_thresh, @@ -359,7 +364,7 @@ def detect_jumps( n_slices - 1, ( data[:, :, (n_slices - 1) * yinc : n_rows, :], - gdq[:, :, (n_slices - 1) * yinc : n_rows, :], + gdq[:, :, (n_slices - 1) * yinc : n_rows, :].copy() , readnoise_2d[(n_slices - 1) * yinc : n_rows, :], rejection_thresh, three_grp_thresh, @@ -381,6 +386,8 @@ def detect_jumps( ) log.info("Creating %d processes for jump detection ", n_slices) pool = multiprocessing.Pool(processes=n_slices) +######### JUST FOR DEBUGGING ######################### +# pool = multiprocessing.Pool(processes=1) # Starts each slice in its own process. Starmap allows more than one # parameter to be passed. real_result = pool.starmap(twopt.find_crs, slices) @@ -427,6 +434,13 @@ def detect_jumps( # save the neighbors to be flagged that will be in the next slice previous_row_above_gdq = row_above_gdq.copy() k += 1 + # remove redundant bits in pixels that have jump flagged but were + # already flagged as do_not_use or saturated. + gdq[gdq == np.bitwise_or(dqflags['DO_NOT_USE'], dqflags['JUMP_DET'])] = \ + dqflags['DO_NOT_USE'] + gdq[gdq == np.bitwise_or(dqflags['SATURATED'], dqflags['JUMP_DET'])] = \ + dqflags['SATURATED'] + # This is the flag that controls the flagging of snowballs. if expand_large_events: total_snowballs = flag_large_events( diff --git a/src/stcal/jump/twopoint_difference.py b/src/stcal/jump/twopoint_difference.py index 62d44b1a..b9111870 100644 --- a/src/stcal/jump/twopoint_difference.py +++ b/src/stcal/jump/twopoint_difference.py @@ -401,21 +401,17 @@ def find_crs( # the transient seen after ramp jumps flag_e_threshold = [after_jump_flag_e1, after_jump_flag_e2] flag_groups = [after_jump_flag_n1, after_jump_flag_n2] - for cthres, cgroup in zip(flag_e_threshold, flag_groups): - if cgroup > 0: + if cgroup > 0 and cthres > 0: cr_intg, cr_group, cr_row, cr_col = np.where(np.bitwise_and(gdq, jump_flag)) for j in range(len(cr_group)): intg = cr_intg[j] group = cr_group[j] row = cr_row[j] col = cr_col[j] - if e_jump_4d[intg, group - 1, row, col] >= cthres[row, col]: - for kk in range(group, min(group + cgroup + 1, ngroups)): - if (gdq[intg, kk, row, col] & sat_flag) == 0 and ( - gdq[intg, kk, row, col] & dnu_flag - ) == 0: - gdq[intg, kk, row, col] = np.bitwise_or(gdq[integ, kk, row, col], jump_flag) + if e_jump_4d[intg, group - 1, row, col] >= cthres: + for kk in range(group + 1, min(group + cgroup + 1, ngroups)): + gdq[intg, kk, row, col] = np.bitwise_or(gdq[intg, kk, row, col], jump_flag) if "stddev" in locals(): return gdq, row_below_gdq, row_above_gdq, num_primary_crs, stddev diff --git a/tests/test_jump.py b/tests/test_jump.py index 0ddbefb1..a85eef51 100644 --- a/tests/test_jump.py +++ b/tests/test_jump.py @@ -8,6 +8,7 @@ find_faint_extended, flag_large_events, point_inside_ellipse, + detect_jumps, ) DQFLAGS = {"JUMP_DET": 4, "SATURATED": 2, "DO_NOT_USE": 1, "GOOD": 0, "NO_GAIN_VALUE": 8} @@ -30,6 +31,99 @@ def _cube(ngroups, readnoise=10): return _cube +def test_multiprocessing(): + nints = 1 + nrows = 13 + ncols = 2 + ngroups = 13 + readnoise = 10 + frames_per_group = 1 + + data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32) + readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise + gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 4 + gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32) + pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32) + err = np.zeros(shape=(nrows, ncols), dtype=np.float32) + num_cores = "1" + data[0, 4:, 5, 1] = 2000 + gdq[0, 4:, 6, 1] = DQFLAGS['DO_NOT_USE'] + gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps( + frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6, + four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100, + flag_4_neighbors=True, dqflags=DQFLAGS) + print(data[0, 4, :, :]) + print(gdq[0, 4, :, :]) + assert gdq[0, 4, 5, 1] == DQFLAGS['JUMP_DET'] + assert gdq[0, 4, 6, 1] == DQFLAGS['DO_NOT_USE'] + + # This section of code will fail without the fixes for PR #239 that prevent + # the double flagging pixels with jump which already have do_not_use or saturation set. + num_cores = "5" + data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32) + gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32) + pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32) + readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise + gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 3 + err = np.zeros(shape=(nrows, ncols), dtype=np.float32) + data[0, 4:, 5, 1] = 2000 + gdq[0, 4:, 6, 1] = DQFLAGS['DO_NOT_USE'] + gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps( + frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6, + four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100, + flag_4_neighbors=True, dqflags=DQFLAGS) + assert gdq[0, 4, 5, 1] == DQFLAGS['JUMP_DET'] + assert gdq[0, 4, 6, 1] == DQFLAGS['DO_NOT_USE'] #This value would have been 5 without the fix. + + +def test_multiprocessing_big(): + nints = 1 + nrows = 2048 + ncols = 7 + ngroups = 13 + readnoise = 10 + frames_per_group = 1 + + data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32) + readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise + gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 4 + gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32) + pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32) + err = np.zeros(shape=(nrows, ncols), dtype=np.float32) + num_cores = "1" + data[0, 4:, 204, 5] = 2000 + gdq[0, 4:, 204, 6] = DQFLAGS['DO_NOT_USE'] + gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps( + frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6, + four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100, + flag_4_neighbors=True, dqflags=DQFLAGS) + print(data[0, 4, :, :]) + print(gdq[0, 4, :, :]) + assert gdq[0, 4, 204, 5] == DQFLAGS['JUMP_DET'] + assert gdq[0, 4, 205, 5] == DQFLAGS['JUMP_DET'] + assert gdq[0, 4, 204, 6] == DQFLAGS['DO_NOT_USE'] + + # This section of code will fail without the fixes for PR #239 that prevent + # the double flagging pixels with jump which already have do_not_use or saturation set. + num_cores = "10" + data = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.float32) + gdq = np.zeros(shape=(nints, ngroups, nrows, ncols), dtype=np.uint32) + pdq = np.zeros(shape=(nrows, ncols), dtype=np.uint32) + readnoise_2d = np.ones((nrows, ncols), dtype=np.float32) * readnoise + gain_2d = np.ones((nrows, ncols), dtype=np.float32) * 3 + err = np.zeros(shape=(nrows, ncols), dtype=np.float32) + data[0, 4:, 204, 5] = 2000 + gdq[0, 4:, 204, 6] = DQFLAGS['DO_NOT_USE'] + gdq, pdq, total_primary_crs, number_extended_events, stddev = detect_jumps( + frames_per_group, data, gdq, pdq, err, gain_2d, readnoise_2d, rejection_thresh=5, three_grp_thresh=6, + four_grp_thresh=7, max_cores=num_cores, max_jump_to_flag_neighbors=10000, min_jump_to_flag_neighbors=100, + flag_4_neighbors=True, dqflags=DQFLAGS) + assert gdq[0, 4, 204, 5] == DQFLAGS['JUMP_DET'] + assert gdq[0, 4, 205, 5] == DQFLAGS['JUMP_DET'] + assert gdq[0, 4, 204, 6] == DQFLAGS['DO_NOT_USE'] #This value would have been 5 without the fix. + + + def test_find_simple_ellipse(): plane = np.zeros(shape=(5, 5), dtype=np.uint8) plane[2, 2] = DQFLAGS["JUMP_DET"] diff --git a/tests/test_twopoint_difference.py b/tests/test_twopoint_difference.py index c6443bc7..f066b4b1 100644 --- a/tests/test_twopoint_difference.py +++ b/tests/test_twopoint_difference.py @@ -855,7 +855,7 @@ def test_10grps_1cr_afterjump(setup_cube): data[0, 8, 100, 100] = 1190 data[0, 9, 100, 100] = 1209 - after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 0.0 + after_jump_flag_e1 = 1.0 out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs( data, gdq, @@ -891,7 +891,7 @@ def test_10grps_1cr_afterjump_2group(setup_cube): data[0, 8, 100, 100] = 1190 data[0, 9, 100, 100] = 1209 - after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 0.0 + after_jump_flag_e1 = 1.0 out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs( data, gdq, @@ -932,7 +932,7 @@ def test_10grps_1cr_afterjump_toosmall(setup_cube): data[0, 8, 100, 100] = 1190 data[0, 9, 100, 100] = 1209 - after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 10000.0 + after_jump_flag_e1 = 10000.0 out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs( data, gdq, @@ -968,8 +968,8 @@ def test_10grps_1cr_afterjump_twothresholds(setup_cube): data[0, 8, 100, 100] = 1190 data[0, 9, 100, 100] = 1209 - after_jump_flag_e1 = np.full(data.shape[2:4], 1.0) * 500.0 - after_jump_flag_e2 = np.full(data.shape[2:4], 1.0) * 10.0 + after_jump_flag_e1 = 500.0 + after_jump_flag_e2 = 10.0 out_gdq, row_below_gdq, rows_above_gdq, total_crs, stddev = find_crs( data, gdq,