-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Zygote work for Wiener kernel #168
Conversation
The |
Can you try the Zygote master branch locally? Maybe it was fixed by FluxML/Zygote.jl#787. |
Thanks! The Zygote and ForwardDiff tests pass on Zygote#master. However, ReverseDiff fails, probably requires nan checking. |
There's a new Zygote version which contains the fix mentioned above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, just had some minor comments/questions left. The only major thing that is missing is the version bump it seems 🙂
src/basekernels/wiener.jl
Outdated
return minXY.^7 ./ 252 .+ minXY.^4 .* pairwise(Euclidean(), x, y) .* | ||
( 5 .* max.(permutedims(X), Y).^2 .+ 2 .* X .* Y .+ 3 .* minXY.^2 ) ./ 720 | ||
end | ||
return error("Invalid I=$I") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constructor guarantees that this can't be reached.
Alternatively, (and maybe a bit nicer) might be to dispatch depending on the kernel (similar to (::WienerKernel)(x, y)
), maybe just for computing the result from minXY
since this part is always the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you suggest dispatching just the if statement part? Problem with that is such a function would need five inputs - the kernel, x, y, X, Y, minXY. I think this will be a very untidy way to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if one should keep the common part (i.e., the first two or three lines) together but actually as you say that would require inner functions with many arguments. So maybe it would be cleaner to just dispatch the kernelmatrix
call on I
(similar to the methods for individual samples).
@@ -1,4 +1,5 @@ | |||
@testset "wiener" begin | |||
rng = MersenneTwister(123) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, we could just set a seed in runtests.jl and use the global RNG - the seed is reset in every @testset
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the tests are not uniform in this regard. This probably warrants a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I agree, if we want to change this in all tests, this should be part of a separate PR. I mainly commented on this since I saw that you changed large parts of the test file and added a specific RNG - I'm not sure if this is actually useful since the default RNG will be reset in every test file.
It is weird that Zygote on Gabor kernel tests are failing on Mac without any changes related to it. https://travis-ci.com/github/JuliaGaussianProcesses/KernelFunctions.jl/jobs/386076575#L369 |
This looks to be stale, so I'm closing. |
#116
I have defined
kernelmatrix
for wiener kernel. This avoid the "slow method"but gives.NaN
s in the Zygote gradient