Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp batch mode for pushforward, pullback and hvp #412

Merged
merged 72 commits into from
Aug 29, 2024
Merged

Conversation

gdalle
Copy link
Member

@gdalle gdalle commented Aug 17, 2024

Overview

Before this PR:

  • There is a Batch type which stores several seeds at once, allowing batch mode pushforward, pullback and hvp.
  • Secret operators pushforward_batched, pullback_batched and hvp_batched (not part of the public API) resolve what would otherwise be a method ambiguity between
# single mode in an extension
pushforward(f, backend::AutoSomething, x, dx)

# batch mode in the main package, falls back on B times single mode
pushforward_batched(f, backend::AbstractADType, x, dxs::Batch{B})  
  • Causes a lot of code duplication. Some tests have special cases for batched scenarios, others just error.

After this PR

  • Batch{B} is now Tangents{B} (that's a detail).
  • Batch mode becomes the default that extensions need to implement, and the main package fallback is in the other direction:
# batch mode in an extension
pushforward(f, backend::AutoSomething, x, dx::Tangents{B})

# single mode in the main package, falls back on Tangents{1}, less specific on every argument
pushforward(f, backend::AbstractADType, x, dx)
  • Brings a bit more code in each extension, but more clarity in general. We even save 200 LOCs overall, despite the heavy new machinery.

Long-term vision

pushforward(f, backend::AutoSomething, i::Input, t::Tangents, c::Constant, s::Scratch)
  • Meanwhile the main package defines a fallback from untyped input
function pushforward(f, backend::AbstractADType, x, dx)
    i = Input(x)
    t = Tangents(dx)
    c = NoConstant()
    s = NoScratch()
    return pushforward(f, backend, i, t, c, s)
end

Changes

This PR is non-breaking because batch mode was never part of the public API.

Main files to look at:

  • src/utils/tangents.jl for the type
  • src/fallbacks/no_tangents.jl for the fallbacks
  • src/first_order/pushforward.jl for the new seeded operators
  • src/first_order/derivative.jl for the non-seeded operators which are also modified
  • ext/DIFastDifferentiationExt for a clean example of extension adaptation

DI source

  • Introduce Tangents as an NTuple wrapper. The second type parameter is the tuple type T: if it I had used the tuple's eltype, it wouldnt be well-defined for empty tuples (Aqua complained about that).
  • Modify pushforward, pullback and hvp to handle Tangents by default
  • Modify other operators to short-circuit the fallbacks and construct Tangents themselves
  • Adjust preparation for Jacobians and Hessians
  • Fix faulty HVP in reverse-over-forward (preparation is impossible)
  • Update FromPrimitive backends.

DI extensions

  • Modify every extension to handle Tangents directly.
  • The Enzyme and Tapir fixes are rather ugly, they fall back on Tangents{1}.

Warning

For several extensions, this part is suboptimal and may lead to type instabilities that weren't present before. To create a Tangents output, I often need map or ntuple which requires a closure. An easy fix is to specialize on Tangents{1} every time, I haven't done it yet.

DI docs

  • No changes yet, Tangents is still not part of the public API while it matures.

DIT source

  • Simplify correctness tests by removing the _batched special case.
  • Update AutoZero backends.

DIT tests

  • Deactivate type stability on AutoZero tests.

Warning

This needs to be reactivated before the PR is merged.

@gdalle gdalle marked this pull request as draft August 17, 2024 09:26
@codecov-commenter
Copy link

codecov-commenter commented Aug 17, 2024

Codecov Report

Attention: Patch coverage is 99.62406% with 3 lines in your changes missing coverage. Please review.

Project coverage is 93.63%. Comparing base (56fc186) to head (281e852).

Files with missing lines Patch % Lines
...rentiationInterface/src/first_order/pushforward.jl 92.68% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #412      +/-   ##
==========================================
- Coverage   97.33%   93.63%   -3.71%     
==========================================
  Files         104      102       -2     
  Lines        4922     4885      -37     
==========================================
- Hits         4791     4574     -217     
- Misses        131      311     +180     
Flag Coverage Δ
DI 99.27% <99.58%> (-0.08%) ⬇️
DIT 86.33% <100.00%> (-8.47%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gdalle gdalle marked this pull request as draft August 26, 2024 15:50
@gdalle gdalle marked this pull request as ready for review August 28, 2024 08:13
@gdalle gdalle changed the title Simplify batched mode pushforward, pullback and hvp Revamp batch mode for pushforward, pullback and hvp Aug 28, 2024
@gdalle gdalle merged commit 7c60378 into main Aug 29, 2024
54 checks passed
@gdalle gdalle deleted the gd/tangents branch August 29, 2024 09:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants