Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jump step refactoring #8679

Open
stscijgbot-jp opened this issue Jul 29, 2024 · 28 comments
Open

Jump step refactoring #8679

stscijgbot-jp opened this issue Jul 29, 2024 · 28 comments

Comments

@stscijgbot-jp
Copy link
Collaborator

Issue JP-3697 was created on JIRA by Maria Pena-Guerrero:

The jump step is using an excessive amount of RAM, and is taking a long time to complete. This is worse for large files, i.e. for a 2.64 GB MIRI uncal file detector1 gets to the jump step with about 22 GB of memory used and after a few seconds it jumps to 30 GB, after a few minutes, when it gets to "Executing two-point difference method", it goes up to 89 GB, and I killed it after 24 hrs running on my MacBookPro with 32GB of RAM. I am using the main branch for jwst, stipe, and stcal. The behavior is likely due to the number of copies and array operations. Some of these might be unavoidable but some others can perhaps be vectorized.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

Maria Pena-Guerrero Which MIRI file are you using for the above?

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Maria Pena-Guerrero on JIRA:

David Law the file is 
jw01283001001_03101_00001_mirimage_uncal.fits

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

Hm, just to add to this puzzle I tried processing this file myself with all of the latest versions, through a default strun of calwebb_detector1

The entire process completed in about 12 minutes, with jump the longest running step at about 7 minutes.

I've got 64 GB of physical memory in my MacBook Pro; 61 GB of that got used at peak, though I saw the python usage spike up to 153 GB during the jump step (presumably using virtual memory to handle it).

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Maria Pena-Guerrero on JIRA:

David Law more to the mystery.... did you run the Detector1 pipeline with a file or did you open the datamodel first? i.e.

option A  ->  det1.call(uncal_file.fits)

option B  ->  det1.call(uncal_model)

 

For me the run takes significantly more memory if I do option B

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

My initial run was just using strun from a terminal command line.  Took 12 minutes, peaked at 153 GB of python memory (61 GB physical), and when it finished RAM usage returned to the normal baseline.

Using your option A (i.e., running on a file) within a jupyter notebook it took 12 minutes, peaked at 153 GB (with a brief excursion to 172 GB that I may have missed before) and 61 GB physical.  When it finished though physical RAM usage stayed high at 34 GB until I explicitly halted the notebook.

Using your option B (i.e, running on a datamodel) within a jupyter notebook it peaked at 182 GB used, and stalled out in the jump step.  After 20 minutes in the 'two point difference' alone I killed it.

So, I'm also seeing a big difference in whether I pass in a file or a datamodel.  Datamodel presumably required enough additional resources that it stalled out for me as well.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Tyler Pauly on JIRA:

It may be worth testing in an environment with the dev version of stpipe - some of the memory issues using file vs. datamodel may be linked to the log referencing bug found and fixed here: spacetelescope/stpipe#171

I believe a patch release of stpipe is planned, but I don't know its status.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

Just repeated the test as suggested by Tyler Pauly with the dev version of stpipe.

Major memory difference.  Option A ran in 11 minutes, maxing at 60 GB used by python vs the 172 GB it used to take.  Option B maxed at 94 GB vs previous 182 GB, but still hung in the jump step, where I killed it after taking over 20 min in the 2-pt-difference section that took option A just 2.5 minutes to complete.

So the memory has improved, but running with a datamodel input is still hanging despite the memory usage being far less than what ran successfully in the original option A.  Also still about 12 GB hanging around in memory until the kernel is restarted, even with option A.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Tyler Pauly on JIRA:

Opened a PR on the first steps toward reducing memory usage in jump - flamegraph shows a run of standalone jump step on a refpix output. More work to be done, but for this input file it's a quick 10GB reduction.

@stscijgbot-jp
Copy link
Collaborator Author

stscijgbot-jp commented Sep 27, 2024

Comment by Maria Pena-Guerrero on JIRA:

While working on a Help desk ticket, I took a deeper look at the jump step. The multiprocessing option is not fully implemented, it only does a small part of the work but the part that takes the longest is not implemented in multiprocessing so the user does not really get a big advantage from using that option when the data has many integrations. Their file is NIRSpec TSO data and it is about 1.5 GB, by the time it gets to jump, the pipeline is using about 25 GB, according to tracemalloc, virtual memory is way higher. Detector1 takes about 6.5 hrs to finish in my computer. The user was not able to run it as they ran out of memory. The shape of this data is (3814, 12, 32, 512). The step stays a long time at function flag_large_events in jump.py in stcal, where there is a nested loop over the integrations and groups, I’m not sure that this can be avoided but it could be sped up by writing temporary files instead of updating this huge array. What I am thinking is perhaps doing the nested for loop just to gather the inputs of what happens inside the loop, and then multiprocessing can be invoked again to run all those inputs and save temporary files. In this case if the  multiprocessing option is off, the step could write temporary files instead of updating the array.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

Adding a clarifying note based on discussion with Tyler Pauly that this ticket is an exploratory general effort to improve both memory usage, runtime, and overall coding of the jump step.  This would include both the 'basic' jump detection and the snowball/shower detection and flagging sections as well.  Specific deliverables are unclear as the situation first needs to be investigated before we know how challenging various levels of improvement will be to achieve.

Maria Pena-Guerrero What's the filename of that NIRSpec TSO example?

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Maria Pena-Guerrero on JIRA:

David Law the file name is jw02571001001_04101_00001-seg001_nrs2_uncal.fits, but the data is propietary.

@stscijgbot-jp
Copy link
Collaborator Author

stscijgbot-jp commented Oct 2, 2024

Comment by Timothy Brandt on JIRA:

A couple of quick thoughts, hopefully useful.

There are several places where np.where is used, e.g., array[np.where(x > 0)] = np.nan.  np.where creates two int64 arrays for 2D arrays.  If you instead just write array[x > 0] = np.nan you instead create one uint8 (boolean) array.  It is significantly faster and less memory intensive, and just as clear to my eye.

It looks to me like you could see a significant improvement in memory usage here by writing an explicit for loop over rows.  Based on the comments I see, I think these lines are a likely culprit.  For example,

row4, col4 = np.where(num_usable_diffs >= 4)  # locations of >= 4 usable diffs pixels
if len(row4) > 0:
    four_slice = shaped_diffs[:, row4, col4]
    loc0 = np.nanargmax(four_slice, axis=0)
    shaped_diffs[loc0, row4, col4] = np.nan
    median_diffs[row4, col4] = np.nanmedian(shaped_diffs[:, row4, col4], axis=0)

could become

for row in range(shaped_diffs.shape[1]):
    cols = np.where(num_usable_diffs[row] >= 4)
    if len(cols) == 0:
        continue
    four_slice = shaped_diffs[:, row, cols]
    loc0 = np.nanargmax(four_slice, axis=0)
    shaped_diffs[loc0, row, cols] = np.nan
    median_diffs[row, cols] = np.nanmedian(shaped_diffs[:, row, cols], axis=0)

I think this should reduce memory usage by a lot where shaped_diffs has a very long first axis, and incur a negligible penalty in runtime (maybe a substantial savings if memory axis is a major bottleneck).

The other thought I have is that stats.sigma_clip can be inefficient.  Again, the memory aspect of this can almost certainly be fixed with an explicit for loop over rows.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Timothy Brandt on JIRA:

After a little more poking around, I am pretty sure that just adding an explicit for loop over row in a few places will solve the excessive memory usage with no (or almost no) runtime penalty.  Kenneth MacDonald, Tyler Pauly, David Law, please tell me if you would like me to create and modify a branch of stcal.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Timothy Brandt on JIRA:

I have a PR to address at least some of this.  I removed a number of redundant calculations, changed masked array median to nanmedian, and changed the name of a couple of variables to avoid a naming conflict.  Memory usage is significantly lower on my machine for jw01283001001_03101_00001_mirimage_uncal.fits and runtime for the jump step is 60s instead of 160s.

PR: spacetelescope/stcal#302

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Timothy Brandt on JIRA:

A couple of other notes on specific cases and possibilities to think about:

If an inappropriately large number of jumps is flagged, it will grind the step to a halt because of the use of for loops over jumps.  Much of this can be avoided by using array operations.  The break-even point between the two methods is likely a ~1% jump rate.  We could greatly simplify the code (removing about 50 lines) and speed it up in the case of many jumps at the cost of a slight slowdown with few jumps.  There would be no memory issues.  Let me know if you would like me to show this in a branch.

If memory usage remains high, a couple of judicious for loops over rows can probably fix a lot of it.  Again, tell me if memory usage remains high, and if so, what the shape of the offending uncal file is.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

Just adding a note that I tested out ST cal PR 302 above on jw01864005001_03102_00001_mirimage_uncal (292-group MIRI imaging).  jump step originally took about 2.5 minute,  1 min in regular jump and 1.5 in the showers code.  With the PR it's about 2 minutes; 30 sec in regular jump and 1.5 min in showers.  Results look to have identically zero difference.  Suggests that the showers code will likely be another tall pole.

 

@stscijgbot-jp
Copy link
Collaborator Author

stscijgbot-jp commented Oct 15, 2024

Comment by David Law on JIRA:

Also noting here that I just made a pair of PRs against jwst and stcal relating to the MIRI cosmic ray showers flagging routine at described on https://jira.stsci.edu/browse/JP-3677

May or may not make sense to merge that work with the broader jump step refactor being discussed here.  It may also be possible to streamline the new code for better performance.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

Also linking https://jira.stsci.edu/projects/JP/issues/JP-3793 for another issue found in the jump code currently being investigated by Michael Regan 

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Timothy Brandt on JIRA:

A quick note: using bottleneck.nanmedian (c.f. https://jira.stsci.edu/browse/JP-3819) reduced the cost of the jump step for me by a factor of 2, mostly in the shower detection part: 60s -> 47s in two-point jump detection after the changes above, and 257s -> 115s in shower detection for jw01283001001_03101_00001_mirimage_uncal.fits (this is after my two-point difference improvements mentioned previously).  For showers, 70% of the time is now spent in astropy.convolve.  The adopted kernel is binary valued, which means that, for an n x n array and m x m kernel, the convolution may be computed for a cost of order n x n x m rather than the naive n x n x m x m.  I suspect that this would give another factor of 2 speedup for the shower detection, but it would require additional code.  I will see if I can do something useful in pure Python.  If we are ok with adding bottleneck as a dependency, I can go ahead and make a few very minor changes for significant speedups.

@stscijgbot-jp
Copy link
Collaborator Author

stscijgbot-jp commented Dec 16, 2024

Comment by Timothy Brandt on JIRA:

Ok, I made the changes discussed here:

https://github.com/t-brandt/stcal/tree/jump_shower_speedup

I was able to get a good improvement on the convolution performance by switching to a different scipy.signal routine.  On jw01283001001_03101_00001_mirimage_uncal.fits, shower detection goes from 257s in the default branch to about 70s in this new branch, for a final speedup factor of about 3.5.  The resulting 4D group DQ array is identical.  I am happy to note the changes and initiate a PR if people would like; it will introduce a bottleneck dependency.  I would also like another set of eyes to make sure I haven't broken anything in the various conditional branches.

gdq values differ for jw01701012001_02105_00001_mirifulong_uncal; I will investigate.  Masked arrays are used in places in the default branch where that behavior may not be intended.

 

@stscijgbot-jp
Copy link
Collaborator Author

stscijgbot-jp commented Dec 16, 2024

Comment by Timothy Brandt on JIRA:

Ah, I think I get it.  Lines

https://github.com/spacetelescope/stcal/blob/495338006fa6faea23e8de3ddd8a506534331b98/src/stcal/jump/jump.py#L1003

and the following two reset some of the initially NaN pixels to NaN, but not all of them.  In my branch I am resetting all of the initially NaN pixels to NaN.  This explains most of the difference in the behavior of astropy.convolve(), astropy.convolve(preserve_nan=True), and signal.oaconvolve().  The remaining very small differences are probably due to differences in rounding with the algorithms; I will confirm.  David Law: is this a bug?

Edit to confirm that the small differences between astropy.convolve(preserve_nan=True), and signal.oaconvolve() are due to my treatment of the edges, and (other than expected floating point truncation error) they are confined to the array edges.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Timothy Brandt on JIRA:

Sorry for the barrage of comments.  One more question.  The two lines

https://github.com/spacetelescope/stcal/blob/495338006fa6faea23e8de3ddd8a506534331b98/src/stcal/jump/twopoint_difference.py#L229

jumpy, jumpx = np.where(gdq[integ, grp, :, :] == jump_flag)
gdq[integ, grp, jumpy, jumpx] = 0```
are prepended by a comment

"if grp is all jump set to do not use"

If we do not want to use these groups, should we not instead have something like
```java
onlyjumpset = gdq[integ, grp, :, :] == jump_flag
gdq[integ, grp][onlyjumpset] = dnu_flag ```
or
```java
onlyjumpset = gdq[integ, grp, :, :] == jump_flag
gdq[integ, grp][onlyjumpset] |= dnu_flag ```
Is this a bug?

 

@stscijgbot-jp
Copy link
Collaborator Author

Comment by David Law on JIRA:

Timothy Brandt Looking at the issue with resetting initially-NaN pixels to NaN again after the convolution (L1003) it looks to me as if there's some weird behavior going on where the DO_NOT_USE pixels are picking up every pixel in many groups for (e.g.) integ 1.  I suspect something weird with groups that are all masked propagating forwards erroneously here, though we'll clearly need to look more closely.

For the second question on line 229, that sure looks like a bug to me.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Timothy Brandt on JIRA:

Ah, David Law, I think I see the bug.  The array first_diffs is defined here

https://github.com/spacetelescope/stcal/blob/495338006fa6faea23e8de3ddd8a506534331b98/src/stcal/jump/jump.py#L942

as

first_diffs = np.diff(data, axis=1) ```
 The mask for keeping track of, preserving, and replacing NaNs is defined here

<https://github.com/spacetelescope/stcal/blob/495338006fa6faea23e8de3ddd8a506534331b98/src/stcal/jump/jump.py#L986C13-L986C81>

beginning with
```java
combined_pixel_mask = np.bitwise_or(gdq[intg, grp, :, :], pdq[:, :]) ```
This uses the mask of the group.  However, the difference will be NaN if **either** contributing group is NaN, so I think this would need to be replaced with something like
```java
combined_pixel_mask = np.bitwise_or(gdq[intg, grp] | gdq[intg, grp - 1], pdq)  ```
(or just use preserve_nan=True in astropy.convolve, use oaconvolve because it is twice as fast, or explicitly mask NaN pixels in the difference).

@stscijgbot-jp
Copy link
Collaborator Author

stscijgbot-jp commented Dec 18, 2024

Comment by Timothy Brandt on JIRA:

Ok, I have somewhat more substantial changes to twopoint_difference.py here:

https://github.com/t-brandt/stcal/tree/jumpdetection_speedup

A few important changes/improvements:

Before using bottleneck, runtime on jw01283001001_03101_00001_mirimage_uncal.fits (10 ints, 100 groups) improves from 160s (in the main stcal branch) to 35s in this branch.  Using bottleneck to replace np.nanmedian with bn.nanmedian and the like improves runtime further to 21s, a total improvement factor of 7.5.  Memory usage decreases in this branch from 60 GB to 33 GB in the two-point difference routine.  Eliminating the computation of an unused array ([https://github.com/spacetelescope/stcal/blob/495338006fa6faea23e8de3ddd8a506534331b98/src/stcal/jump/jump.py#L253)] saves another 4 GB, and I could get 4 GB more by moving the gain scaling to be internal to the two point difference routine.

With many jumps flagged, the main branch grinds to a halt due to the flagging of neighbors and subsequent groups.  The new branch takes almost the same time and memory independent of the number of flagged jumps.

DQ flags should be identical to those computed using the main branch.  The standard deviation reported will change very slightly (<<1%) as it is being computed as a byproduct of a different satrapy.stats function.

 

@stscijgbot-jp
Copy link
Collaborator Author

stscijgbot-jp commented Dec 19, 2024

Comment by Timothy Brandt on JIRA:

There are also some pretty egregious time sinks in the snowball code.  I have pushed some changes to the same branch that do not change the resulting DQ arrays:

https://github.com/t-brandt/stcal/blob/jump_shower_speedup/src/stcal/jump/jump.py

By removing excessive and unnecessary copying of arrays, unused calculations, and by operating on the appropriate subset of the detector for each snowball, I get a factor of 30 improvement in snowball step runtime for an 8x4x2048x2048 dataset (45s -> 1.5s) and a factor of 1000 improvement for a 171x30x64x2048 dataset (800s -> 0.7s).  The jump_shower_speedup branch should be ready for review (I can make a PR if people agree).  It does impose a bottleneck dependency, though bottleneck is not used in snowball flagging.  Since it does impose that bottleneck dependency, I can revise the jumpdetection_speedup branch

https://github.com/t-brandt/stcal/tree/jumpdetection_speedup

(which modifies twopointdifference.py) to also require bottleneck; this would save about 30-40% of the runtime in that routine.  The jumpdetection_speedup branch should also be ready for a PR.

@stscijgbot-jp
Copy link
Collaborator Author

Comment by Timothy Brandt on JIRA:

One more note on memory usage since it was mentioned in the original ticket by Maria Pena-Guerrero.  There are a number of unnecessary allocations and copies that dramatically inflate the memory usage of this step.  Several of the related issues that I linked above are about this.  Removing these unneeded allocations reduces the memory demand of the jump step from about 60 GB to 15 GB on jw01283001001_03101_00001_mirimage_uncal.fits (a 2.4 GB file).  I think that memory usage throughout Stage 1 can be reduced to this level relatively straightforwardly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant