Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Friendly Types at the Interface Level #393

Open
willtebbutt opened this issue Nov 26, 2024 · 1 comment
Open

More Friendly Types at the Interface Level #393

willtebbutt opened this issue Nov 26, 2024 · 1 comment
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Nov 26, 2024

This issue JuliaDiff/DifferentiationInterface.jl#642 makes me wonder whether we need a systematic approach to translating between primal types and tangent types at the interface level.

For example, while users probably want to represent the tangent of a, SArray with another SArray, rather than a Tangent, Mooncake requires that users provide a Tangent.

I think we can probably define sensible translation functionality between primals and tangents which makes some choices around how to handle non-differentiable fields, but which works quite generically. This function would be something like

translate_to_tangent(t::IEEEFloat) = t
translate_to_tangent(t::VariousIntegerTypes) = NoTangent()
translate_to_tangent(t::Array{<:IEEEFloat}) = t
translate_to_tangent(t::Array) = map(translate_to_tangent, t)
function translate_to_tangent(t::P) where {P}
    isprimitivetype(P) && throw(error("need a translation rule"))
    return # recursively transform into tangent_type(P)
end

This would have the effect of, for example, dropping any non-differentiable fields.

On the way back, we could do a similar thing, but would need to pick a placeholder value for any non-differentiable fields. Not all types have a well-defined zero value (e.g. Strings and Symbols), so it might just make sense to make the conversion the other way require that you pass in the primal value, and we just copy its fields. For example

translate_to_primal(::P, t::P) where {P<:IEEEFloat} = t
translate_to_primal(p::P, t) where {P<:VariousIntegerTypes} = p
translate_to_primal(p::Array{P}, t::Array{P}) where {P<:IEEEFloat} = p
translate_to_primal(p::Array, t::Array) = map(translate_to_primal, p, t)
function translate_to_primal(p::P, t) where {P}
    # same idea as translate_to_tangent, but in the other direction
end

We could then e.g. shove this in value_and_pullback or whatever, so that users get "nice" types. We would want to ensure that there is a sensible way to opt-out of this translation, of course.

@willtebbutt willtebbutt added enhancement New feature or request good first issue Good for newcomers labels Nov 26, 2024
@gdalle
Copy link
Collaborator

gdalle commented Nov 26, 2024

I think this would be very helpful for users, because what most of them want is something that "looks like the primal", not the Mooncake Tangent which could be considered an implementation detail

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants