Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote errors with parameterized mean functions and multidimensional input #344

Closed
simsurace opened this issue Dec 6, 2022 · 13 comments
Closed

Comments

@simsurace
Copy link
Member

When trying to differentiate logpdf or other scalar functions with a parameterized mean function and multidimensional input, there are errors:

using AbstractGPs
using Zygote

pars = [1., 0.]

function build_model(pars)
    a, b = pars
    return GP(x -> a * first(x) + b, SEKernel())
end

rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)

function test_logpdf(pars)
    f = build_model(pars)
    x, y = rand_data(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf(pars)
Zygote.gradient(test_logpdf, pars) # works

function test_logpdf2(pars)
    f = build_model(pars)
    x, y = rand_data_2d(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf2(pars)
Zygote.gradient(test_logpdf2, pars)
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})

function test_mean(pars)
    f = build_model(pars)
    x, _ = rand_data_2d(10)
    return sum(mean(f(x, 1e-3)))
end

test_mean(pars)
Zygote.gradient(test_mean, pars) # ERROR: Pullback on AbstractVector{<:AbstractVector}.

function test_post_mean(pars)
    f = build_model(pars)
    x, y = rand_data_2d(10)
    fp = posterior(f(x, 1e-3), y)
    return sum(mean(fp(x, 1e-3)))
end

test_post_mean(pars)
Zygote.gradient(test_post_mean, pars) 
# ERROR: MethodError: no method matching +(::NamedTuple{(:X,), Tuple{LinearAlgebra.Transpose{Float64, Matrix{Float64}}}}, ::Vector{Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})

Is there a simple fix? The error for test_mean gives a suggestion to overload a kernelmatrix method, but that does not seem to be the issue since we are talking about the mean here. Why does the existing rrule for RowVecs not suffice?

@willtebbutt
Copy link
Member

willtebbutt commented Dec 6, 2022

Hmmm this issue has come up repeatedly recently. All a problem in this Stheno.jl issue, I suspect for the same reasons.

To be honest, the simplest solution is going to be to implement AbstractGPs._map_meanfunction for your custom mean function and a RowVecs / ColVecs input, so that you can be sure that it's differentiable. So, something like

function AbstractGPs._map_meanfunction(f::CustomMean{typeof(your_mean_function)}, x::RowVecs)

end

etc. @simsurace could you let me know if this solves the problem?

I can't see this issue surrounding the mean function getting fixed any time soon (because it's AD-related), so I'm wondering whether we should change our approach to documenting it. e.g. making it clear that if you might well need to implement _map_meanfunction if you're using a custom mean function.

@simsurace
Copy link
Member Author

Thanks for the suggestion. Maybe I misunderstood, but this did not solve the issue:

struct LinearMean{T}
    a::T
    b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b

using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return [f.f.a * first(xi) + f.f.b for xi in x]
end

function build_model(pars)
    a, b = pars
    return GP(CustomMean(LinearMean(a, b)), SEKernel())
end

@willtebbutt
Copy link
Member

Ah, sorry, I more mean something like

function AbstractGPs._map_meanfunction(f::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return f.f.a * x.X[1, :] .+ f.f.b
end

so that you're interacting with the underlying matrix.

@simsurace
Copy link
Member Author

simsurace commented Dec 6, 2022

Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?

@simsurace
Copy link
Member Author

P.S. CustomMean currently does not seem to be exported or documented, for that matter.

@willtebbutt
Copy link
Member

Oh, I get it. Thanks, this seems to do the trick. Actually, wrapping LinearMean in CustomMean seems overly complicated. I could make LinearMean <: MeanFunction and then define _map_meanfunction accordingly, right?

That should indeed work!

@simsurace
Copy link
Member Author

I ended up with a general struct FunctionOfTime{Tf} <: MeanFunction with overloads that map its field over slices. This works. Thanks for the tips!

@torfjelde
Copy link

I also ran into this issue recently, and because you end up hitting the def in KernelFunctions, debugging is somewhat confusing 😕

Probably worth an entry in the docs + maybe changing the error in KernelFunctions?

@willtebbutt
Copy link
Member

I agree that the docs should probably be improved here

@willtebbutt
Copy link
Member

willtebbutt commented Sep 12, 2023

We should probably add a note about when you the need to implement mean_vector yourself for CustomMean here and here, and provide an example.

@kjrathore
Copy link

kjrathore commented Jul 30, 2024

pars = [1., 0.]

rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)

struct LinearMean{T}
    a::T
    b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b

using AbstractGPs: CustomMean
function AbstractGPs._map_meanfunction(m::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return vec(sum(m.f.(x.X); dims=2))
end


function build_model(pars)
    a, b = pars
    return GP(CustomMean(LinearMean(a, b)), SEKernel())
end

function test_logpdf(pars)
    f = build_model(pars)
    x, y = rand_data(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf(pars)
Zygote.gradient(test_logpdf, pars)

Hi, @simsurace I did use your code to check working with Zygote.

However I am getting error. Could you please guide me how did you solve this issue?
image

@simsurace
Copy link
Member Author

simsurace commented Jul 31, 2024

Hi @kjrathore, thanks for reaching out. _map_meanfunction has been removed since then. There is now a public mean_vector that can be overloaded. This should work on the current release of AbstractGPs:

using AbstractGPs, Zygote

pars = [1., 0.]

rand_data(n::Integer) = rand(n), randn(n)
rand_data_2d(n::Integer) = RowVecs(rand(n, 2)), randn(n)

struct LinearMean{T}
    a::T
    b::T
end
(f::LinearMean)(x) = f.a * first(x) + f.b

using AbstractGPs: CustomMean
function AbstractGPs.mean_vector(m::CustomMean{<:LinearMean}, x::RowVecs)
    @info "Calling specialized function"
    return vec(sum(m.f.(x.X); dims=2))
end


function build_model(pars)
    a, b = pars
    return GP(CustomMean(LinearMean(a, b)), SEKernel())
end

function test_logpdf(pars)
    f = build_model(pars)
    x, y = rand_data(10)
    return logpdf(f(x, 1e-3), y)
end

test_logpdf(pars)
Zygote.gradient(test_logpdf, pars)

@kjrathore
Copy link

kjrathore commented Jul 31, 2024

Thanks @simsurace !
Slight update to this code. need to do
import AbstractGPs: mean_vector

Note : "using" and "import" in Julia are not the same. While using brings all exported names from a module into the current namespace, import allows you to extend a function without prefixing it with its module name (https://stackoverflow.com/questions/42888911/function-base-must-be-explicitly-imported-to-be-extended)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants