Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Zygote work for Wiener kernel #168

Closed
wants to merge 8 commits into from
Closed

Conversation

sharanry
Copy link
Contributor

@sharanry sharanry commented Sep 15, 2020

#116

I have defined kernelmatrix for wiener kernel. This avoid the "slow method" but gives NaNs in the Zygote gradient.

@sharanry
Copy link
Contributor Author

The NaN values in the Zygote gradient seem to be caused by the the pairwise(Euclidean(), x, y). I am not sure why this is happening.

@devmotion
Copy link
Member

Can you try the Zygote master branch locally? Maybe it was fixed by FluxML/Zygote.jl#787.

@sharanry
Copy link
Contributor Author

sharanry commented Sep 15, 2020

Can you try the Zygote master branch locally? Maybe it was fixed by FluxML/Zygote.jl#787.

Thanks! The Zygote and ForwardDiff tests pass on Zygote#master. However, ReverseDiff fails, probably requires nan checking.

@devmotion
Copy link
Member

There's a new Zygote version which contains the fix mentioned above.

@sharanry sharanry requested a review from devmotion September 16, 2020 09:53
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just had some minor comments/questions left. The only major thing that is missing is the version bump it seems 🙂

return minXY.^7 ./ 252 .+ minXY.^4 .* pairwise(Euclidean(), x, y) .*
( 5 .* max.(permutedims(X), Y).^2 .+ 2 .* X .* Y .+ 3 .* minXY.^2 ) ./ 720
end
return error("Invalid I=$I")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor guarantees that this can't be reached.

Alternatively, (and maybe a bit nicer) might be to dispatch depending on the kernel (similar to (::WienerKernel)(x, y)), maybe just for computing the result from minXY since this part is always the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you suggest dispatching just the if statement part? Problem with that is such a function would need five inputs - the kernel, x, y, X, Y, minXY. I think this will be a very untidy way to do this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if one should keep the common part (i.e., the first two or three lines) together but actually as you say that would require inner functions with many arguments. So maybe it would be cleaner to just dispatch the kernelmatrix call on I (similar to the methods for individual samples).

src/basekernels/wiener.jl Outdated Show resolved Hide resolved
@@ -1,4 +1,5 @@
@testset "wiener" begin
rng = MersenneTwister(123)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could just set a seed in runtests.jl and use the global RNG - the seed is reset in every @testset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the tests are not uniform in this regard. This probably warrants a separate PR.

Copy link
Member

@devmotion devmotion Sep 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I agree, if we want to change this in all tests, this should be part of a separate PR. I mainly commented on this since I saw that you changed large parts of the test file and added a specific RNG - I'm not sure if this is actually useful since the default RNG will be reset in every test file.

test/test_utils.jl Show resolved Hide resolved
@sharanry
Copy link
Contributor Author

It is weird that Zygote on Gabor kernel tests are failing on Mac without any changes related to it. https://travis-ci.com/github/JuliaGaussianProcesses/KernelFunctions.jl/jobs/386076575#L369

@willtebbutt
Copy link
Member

This looks to be stale, so I'm closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants