diff --git a/jwst/outlier_detection/tests/test_outlier_detection.py b/jwst/outlier_detection/tests/test_outlier_detection.py index 4c4d3cdfb3..bca8e6361f 100644 --- a/jwst/outlier_detection/tests/test_outlier_detection.py +++ b/jwst/outlier_detection/tests/test_outlier_detection.py @@ -553,13 +553,13 @@ def test_outlier_step_image_weak_cr_coron(exptype, tsovisit, tmp_cwd): @pytest.mark.parametrize("exptype, tsovisit", exptypes_tso) -def test_outlier_step_weak_cr_tso(exptype, tsovisit): +@pytest.mark.parametrize("rolling_window_width", [7, 0]) +def test_outlier_step_weak_cr_tso(exptype, tsovisit, rolling_window_width): '''Test outlier detection with rolling median on time-varying source - This test fails if rolling_window_width is set to 100, i.e., take simple median + This test fails if rolling_window_width is set to 0, i.e., take simple median ''' bkg = 1.5 sig = 0.02 - rolling_window_width = 7 numsci = 50 signal = 7.0 im = we_many_sci( @@ -588,8 +588,13 @@ def test_outlier_step_weak_cr_tso(exptype, tsovisit): assert np.all(np.isnan(result.data[i][dnu])) assert np.allclose(model.data[~dnu], result.data[i][~dnu]) - # Verify source is not flagged - assert np.all(result.dq[:, 7, 7] == datamodels.dqflags.pixel["GOOD"]) + # Verify source is not flagged for rolling median + if rolling_window_width == 7: + assert np.all(result.dq[:, 7, 7] == datamodels.dqflags.pixel["GOOD"]) + # But this fails for simple median + elif rolling_window_width == 0: + with pytest.raises(AssertionError): + assert np.all(result.dq[:, 7, 7] == datamodels.dqflags.pixel["GOOD"]) # Verify CR is flagged assert result.dq[cr_timestep, 12, 12] == OUTLIER_DO_NOT_USE diff --git a/jwst/outlier_detection/tso.py b/jwst/outlier_detection/tso.py index 2c0020986f..e5524bc920 100644 --- a/jwst/outlier_detection/tso.py +++ b/jwst/outlier_detection/tso.py @@ -4,7 +4,7 @@ from jwst import datamodels as dm from stcal.outlier_detection.utils import compute_weight_threshold -from .utils import flag_model_crs +from .utils import flag_model_crs, nanmedian3D from ._fileio import save_median import logging @@ -41,7 +41,7 @@ def detect_outliers( medians = compute_rolling_median(weighted_cube, weight_threshold, w=rolling_window_width) else: - medians = np.nanmedian(weighted_cube.data, axis=0) + medians = nanmedian3D(weighted_cube.data, overwrite_input=False) # this is a 2-D array, need to repeat it into the time axis # for consistent shape with rolling median case medians = np.broadcast_to(medians, weighted_cube.shape)