Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic value_and_gradient examples not working #415

Open
Jollywatt opened this issue Dec 11, 2024 · 8 comments
Open

Basic value_and_gradient examples not working #415

Jollywatt opened this issue Dec 11, 2024 · 8 comments
Labels
enhancement (error messages) The error was produced that should be improved upon enhancement New feature or request

Comments

@Jollywatt
Copy link

Forgive me if this is dumb, but I can't get basic usage of Mooncake to work. For example:

using DifferentiationInterface
import Mooncake

f(x) = sum(x)

x = [1, 2]

value_and_gradient(f, AutoMooncake(config=nothing), x)

produces

ERROR: MethodError: no method matching copyto!(::Mooncake.NoTangent, ::Bool)

How do I differentiate simple functions?

julia> versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 8 × Apple M2
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m2)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

(@v1.11) pkg> st
  ...
  [a0c0ee7d] DifferentiationInterface v0.6.27
  [da2b9cff] Mooncake v0.4.60
  ...
@Jollywatt
Copy link
Author

I realised my mistake — Ints are not differentiable.

Perhaps consider this issue a request to produce kinder error messages when trying to take the tangent of a non-differentiable type?

@yebai yebai reopened this Dec 11, 2024
@yebai
Copy link
Contributor

yebai commented Dec 11, 2024

I agree that a better error message would be helpful.

Cc @gdalle since this is related to DI.

@yebai yebai added the enhancement New feature or request label Dec 11, 2024
@willtebbutt
Copy link
Member

willtebbutt commented Dec 12, 2024

Ah, right, yes. The basic problem here is that we don't have enough constraints on the input types. @gdalle is not keen to impose input type constraints at an interface-wide level in DifferentiationInterface.

@gdalle would it be permissible to impose constraints on the functions passed for Mooncake to differentiate? Specifically that

  1. if they return a scalar, it must be an IEEEFloat, and
  2. that the input that the element type of the arg w.r.t. which the adjoint is computed be a subtype of IEEEFloat?

I think this would probably prevent this kind of problem in the future, and it would be quite straightforward to document.

In fact, in the interest of having a very predictable interface, we should probably also constrain the number of input types that we permit people to pass in to Mooncake to be finite, so that we can be sure to support them all. i.e. literally saying that your input must be of type:

Union{IEEEFloat, Array{<:IEEEFloat}, SArray{<:IEEEFloat}, <:AnotherSpecificArrayType{<:IEEEFloat}}

or whatever. Possibly also Tuples or NamedTuples of any of these, because I know these give user-friendly results. This way it'll make it very clear to users what is currently supported, and what is not, and give clear instructions for how to extend this list. Thoughts @gdalle ?

@willtebbutt
Copy link
Member

As to whether this kind of restriction would constitute a breaking change for DifferentiationInterface, it turns out that "is Base.copy defined for this gradient type" is an implicit part of the interface currently -- see here and here.

Mooncake doesn't currently define Base.copy for most things, including Mooncake.Tangent, Mooncake.MutableTangent, and Mooncake.NoTangent, and Base.copy is not defined for Tuples or NamedTuples. This means that the vast majority of types currently won't work if you try to return them as the result of differentiating stuff using DI + Mooncake -- like the user will get some unintelligible error at present. In fact, one of the very few tangent types that copy will work for is Array.

This is actually a win for this proposal, because it gives us wide latitude to make the interface for Mooncake quite explicit (if we want to) in a non-breaking fashion.

@willtebbutt
Copy link
Member

If we do this, we should also enforce a no-aliasing policy on the arguments (this is functionality I'm going to add to Mooncake soon anyway, because if you can exclude the possibility of aliasing and circular referencing in the arguments to a function, a bunch of things become quite a bit cheaper to do).

@Jollywatt
Copy link
Author

The basic problem here is that we don't have enough constraints on the input types

Rather than restricting the input tangent types (what if I have a custom array type?), wouldn't it be more prudent to explicitly assert that dy_righttype is not NoTangent() where appropriate and raise an informative error message otherwise?

@willtebbutt
Copy link
Member

willtebbutt commented Dec 12, 2024

I agree that it makes sense to e.g. check that the thing returned from a function that you're taking the gradient w.r.t. is not an Int / something whose tangent type is NoTangent. I'm using this issue as an exemplar case of a general form of jankiness in what users get when they interact with Mooncake.jl via DI.jl as a result of us not carefully mapping out what kinds of things we want to offer really good support for.

In terms of things like custom array types, the basic problem is that you'll not typically get the kind of answer that you might expect (e.g. the tangent type for custom array types is typically Tangent, not the custom array type itself) -- this has already caused some problems with StaticArrays.jl.

I wouldn't object to being able to opt-out of restrictions. By default, the input / output types are restricted to be things that we

  1. know for sure we can work with, and
  2. have conversions defined for tangent types to ensure that people get the thing that they expect (e.g. if you take the gradient w.r.t. an SArray{Float64}, you get an SArray{Float64} as the answer, not a Tangent).

As part of the error message you get if you provide an input / output which isn't supported by the restricted version, we say "if you really don't want to modify your code to fit within the restrictions we've imposed, in which we promise everything works as expected, feel free to use this other mode. We offer no promises around whether or not this other mode will work, or give the kind of answer you expect, but it might work."

I guess my overall point is that, at the day-to-day interface between Mooncake and the external world (DI.jl), it's a really good idea to constrain the set of problems that you consider in exchange for robustness and reliability. The reason being that it's generally much easier as a user to debug code you've written yourself to tweak the function you've written to satisfy some basic requirements, than it is to rummage around in the internals of DI to try and figure out what went wrong there.

Put differently, I don't ever want users to see errors which emerge from problems originating inside DI -- I want to know for sure that either they've done something wrong, or that something has gone wrong inside Mooncake.jl itself.

@gdalle
Copy link
Collaborator

gdalle commented Dec 12, 2024

Just popping by to say that I'm currently busy with moving from Switzerland to France, but I'll catch up on the discussion in a few days. Feel free to also open issues in DI if necessary

@yebai yebai added the enhancement (error messages) The error was produced that should be improved upon label Dec 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (error messages) The error was produced that should be improved upon enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants