Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor interface for projections/proximal operators #147

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Nov 17, 2024

This PR refactors how post-hoc modifications are applied to the iterates after performing a gradient descent step. For instance, before, updating the parameters of LocationScale always silently applied a projection step. Now, everything needs to be made into its own OptimisationRule to make it more modular and explicit.

More concretely, this PR changes the following:

  • The scale matrix of a LocationScale distribution is no longer projected by default.
  • A new rule object, ProjectScale, which wraps around an actual optimizer like Adam, will apply it instead.

The only major change in the interface is that, for LocationScale, the optimizer

optimizer = Optimisers.Adam(1e-3)

must be wrapped around with a ProjectScale object as

optimizer = ProjectScale(Optimisers.Adam(1e-3)),

Red-Portal and others added 7 commits November 16, 2024 21:10
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: ee36164 Previous: c306736 Ratio
normal/RepGradELBO + STL/meanfield/Zygote 15941287712 ns 15314090063 ns 1.04
normal/RepGradELBO + STL/meanfield/ForwardDiff 3306357048 ns 3302272665 ns 1.00
normal/RepGradELBO + STL/meanfield/ReverseDiff 3210354047 ns 3259154407 ns 0.99
normal/RepGradELBO + STL/fullrank/Zygote 15813010444 ns 15054091209 ns 1.05
normal/RepGradELBO + STL/fullrank/ForwardDiff 3650345366 ns 3645336604 ns 1.00
normal/RepGradELBO + STL/fullrank/ReverseDiff 5811420124 ns 5930367236 ns 0.98
normal/RepGradELBO/meanfield/Zygote 7374634199 ns 7259685436 ns 1.02
normal/RepGradELBO/meanfield/ForwardDiff 2358200835 ns 2478316147 ns 0.95
normal/RepGradELBO/meanfield/ReverseDiff 1464290177 ns 1512242787 ns 0.97
normal/RepGradELBO/fullrank/Zygote 7411929844 ns 7222198006 ns 1.03
normal/RepGradELBO/fullrank/ForwardDiff 2609800143 ns 2700811231 ns 0.97
normal/RepGradELBO/fullrank/ReverseDiff 2566947884 ns 2647199830 ns 0.97
normal + bijector/RepGradELBO + STL/meanfield/Zygote 24474675946 ns 23293427693 ns 1.05
normal + bijector/RepGradELBO + STL/meanfield/ForwardDiff 10690007191 ns 10479788705 ns 1.02
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 5040512801 ns 5426326875 ns 0.93
normal + bijector/RepGradELBO + STL/fullrank/Zygote 24276644334 ns 23052236045 ns 1.05
normal + bijector/RepGradELBO + STL/fullrank/ForwardDiff 10542771331 ns 11253883102 ns 0.94
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 8168988329 ns 8661186736 ns 0.94
normal + bijector/RepGradELBO/meanfield/Zygote 15298860158 ns 14161735088 ns 1.08
normal + bijector/RepGradELBO/meanfield/ForwardDiff 9196864834 ns 9257845085 ns 0.99
normal + bijector/RepGradELBO/meanfield/ReverseDiff 3059438501 ns 3080136920 ns 0.99
normal + bijector/RepGradELBO/fullrank/Zygote 15462865510 ns 14326320128 ns 1.08
normal + bijector/RepGradELBO/fullrank/ForwardDiff 9964287185 ns 9573388177 ns 1.04
normal + bijector/RepGradELBO/fullrank/ReverseDiff 4456238223 ns 4476721408 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@Red-Portal Red-Portal requested review from yebai and mhauru December 10, 2024 08:20
@Red-Portal
Copy link
Member Author

@yebai I'll mark the v0.3 release (at last!) after this PR

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a test where the eigenvalues drift too low and we can test both that it fails when nothing using ProjectScale and that it then succeeds when using ProjectScale? Just to very concretely see the effect, and see that the first case fails in the expected (rather than some other, unexpected) way.

Except for the above request, I'm happy with the software engineering. I would prefer it though if someone else who has views on the design choices here gave a second, approving opinion. I have little idea of what users need and want from their interfaces here, e.g. if the name ProjectScale is intuitive for users, or if there should be a warning if someone tries to optimise a LocationScale without using ProjectScale.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping an optimiser inside ProjectScale(...) feels slightly strange to me. While using ProjectScale might be appropriate for a specific paper, but the terminology is not (yet) widely accepted. I think we could introduce an additional keyword argument to pass this information instead of overloading the optimiser argument for too many purposes.

@Red-Portal
Copy link
Member Author

Thank you both for chiming in!

@yebai I was thinking this to be similar in functionality to operations like gradient clipping. How about I change the name to ClipScale and use the standard composition feature in Optimisers.jl, OptimiserChain?

@yebai
Copy link
Member

yebai commented Dec 11, 2024

@Red-Portal Your proposal looks good!

@Red-Portal Red-Portal added this to the v0.3.0 milestone Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants