-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rsg devel #54
base: main
Are you sure you want to change the base?
Rsg devel #54
Conversation
rsanchezgarc
commented
Jan 12, 2024
- Set of changes to allow extract_central_slices_rfft compilation
- New def compute_vol_dtf function extracted from project_fourier.project_fourier to be used outside. Avoids computing the dft many times if several projections are going to be computed at different times.
- Minor api changes to select devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Ruben, nice to see you here and happy new year!!
Thank you, there's loads of great stuff in this PR and i general most changes are small and easy to get in - a couple require a little bit of discussion. For the future, smaller PRs with isolated changes should be easier to get merged quickly
Let's try to get this in ASAP!!
[samples] = einops.unpack(samples, pattern='*', packed_shapes=ps) | ||
# [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps) | ||
samples = samples.reshape(*ps) #Ask Alister if this will work in any situation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting, is this a significant performance gain? I don't think this works in the general case but it should always work here as far as I can tell
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohhh I see - torch compile couldn't go through unpack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly. This is a requirement for the compiler
conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') | ||
# conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') | ||
conjugate_mask.unsqueeze(-1).repeat(1, 1, 1, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
repeat should be creating a view here and shouldn't be memory intensive even though the tensor is huge - is this not the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I am correct, is expand and not repeat the one that is a view. This change was again a requirement for the compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right for the torch API but einops it creates a view where possible - regardless, compilation is super important.
I'm a little hesitant to lose the rank polymorphism here and it looks like this unsqueeze/repeat is specific to b h w 3
rather than ... h w 3
-> could you try adding some code to intepret the current shape and unsqueeze/repeat according to that? This should allow us to maintain the current flexibility and have compatibility with the compiler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3)
and being more memory efficient, since it is a view.
pad: bool = True, | ||
pad_length: int | None = None | ||
) -> Tuple[torch.Tensor, Tuple[int,int,int], int]: | ||
"""Project a cubic volume by sampling a central slice through its DFT. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this docstring needs fixing
Happy New Year too! And sorry for not answering before, I had a few unexpected issues to solve before going into this. |
sorry for the delay here @rsanchezgarc - travelled back to the US and been a bit hectic I just found some interesting notes on how to make einops work with torch.compile so will probably revert to einops in a few places - I'm aiming to take a look at this final version over the next few days :) |
No problem at all! Good to know about the hack for the einops! |
@alisterburt what is the status of this PR? |
@rsanchezgarc still need to finalise it, free time has been taken up with preparing taxes and my partner visiting - will find some time ASAP! |
@alisterburt . No need to rush! Good luck with your taxes |