-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote errors with parameterized mean functions and multidimensional input #344
Comments
Hmmm this issue has come up repeatedly recently. All a problem in this Stheno.jl issue, I suspect for the same reasons. To be honest, the simplest solution is going to be to implement function AbstractGPs._map_meanfunction(f::CustomMean{typeof(your_mean_function)}, x::RowVecs)
end etc. @simsurace could you let me know if this solves the problem? I can't see this issue surrounding the mean function getting fixed any time soon (because it's AD-related), so I'm wondering whether we should change our approach to documenting it. e.g. making it clear that if you might well need to implement |
Thanks for the suggestion. Maybe I misunderstood, but this did not solve the issue: struct LinearMean{T}
a::T
b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b
using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return [f.f.a * first(xi) + f.f.b for xi in x]
end
function build_model(pars)
a, b = pars
return GP(CustomMean(LinearMean(a, b)), SEKernel())
end |
Ah, sorry, I more mean something like function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return f.f.a * x.X[1, :] .+ f.f.b
end so that you're interacting with the underlying matrix. |
Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping |
P.S. |
That should indeed work! |
I ended up with a general struct |
I also ran into this issue recently, and because you end up hitting the def in KernelFunctions, debugging is somewhat confusing 😕 Probably worth an entry in the docs + maybe changing the error in KernelFunctions? |
I agree that the docs should probably be improved here |
Hi, @simsurace I did use your code to check working with Zygote. However I am getting error. Could you please guide me how did you solve this issue? |
Hi @kjrathore, thanks for reaching out. using AbstractGPs, Zygote
pars = [1., 0.]
rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)
struct LinearMean{T}
a::T
b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b
using AbstractGPs: CustomMean
function AbstractGPs.mean_vector(m::CustomMean{<:LinearMean}, x::RowVecs)
@info "Calling specialized function"
return vec(sum(m.f.(x.X); dims=2))
end
function build_model(pars)
a, b = pars
return GP(CustomMean(LinearMean(a, b)), SEKernel())
end
function test_logpdf(pars)
f = build_model(pars)
x, y = rand_data(10)
return logpdf(f(x, 1e-3), y)
end
test_logpdf(pars)
Zygote.gradient(test_logpdf, pars) |
Thanks @simsurace ! Note : "using" and "import" in Julia are not the same. While using brings all exported names from a module into the current namespace, import allows you to extend a function without prefixing it with its module name (https://stackoverflow.com/questions/42888911/function-base-must-be-explicitly-imported-to-be-extended) |
When trying to differentiate
logpdf
or other scalar functions with a parameterized mean function and multidimensional input, there are errors:Is there a simple fix? The error for
test_mean
gives a suggestion to overload akernelmatrix
method, but that does not seem to be the issue since we are talking about the mean here. Why does the existingrrule
forRowVecs
not suffice?The text was updated successfully, but these errors were encountered: