Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How should the OILMM be parameterised when using AD? #50

Closed
wil-j-wil opened this issue Jun 8, 2022 · 13 comments
Closed

How should the OILMM be parameterised when using AD? #50

wil-j-wil opened this issue Jun 8, 2022 · 13 comments

Comments

@wil-j-wil
Copy link

wil-j-wil commented Jun 8, 2022

Hi,

How should I parameterise an OILMM if I want to optimise the hyperparameters whilst ensuring that the columns of U remain orthogonal?

I have the following setup which uses the orthogonal constraint from ParameterHandling:

using AbstractGPs
using KernelFunctions
using LinearAlgebra
using LinearMixingModels
using ParameterHandling

num_outputs = 11
num_latents = 3

x_train = KernelFunctions.MOInputIsotopicByOutputs(collect(1:100), num_outputs)
y_train = rand(100 * num_outputs)

H_init = rand(num_outputs, num_outputs)
U_, S_, V_ = svd(H_init)
U_init = U_[:, 1:num_latents]
S_init = S_[1:num_latents]

θ_oilmm = (;
    U = orthogonal(U_init),
    S = positive.(S_init),
)

function build_gp(θ)
    sogp = GP(Matern52Kernel())
    latent_gp = independent_mogp([sogp for _ in 1:num_latents])
    return ILMM(latent_gp, Orthogonal.U, Diagonal.S)))
end

function objective(θ)
    oilmm = build_gp(θ)
    return -logpdf(oilmm(x_train, 0.1), y_train)
end

but when I try to compute the gradient of the objective with this parameterisation,

using Zygote

flat_θ_oilmm, unflatten = flatten(θ_oilmm)
unpack = ParameterHandling.value  unflatten

Zygote.gradient(objective  unpack, flat_θ_oilmm)

the gradients of U are NaN (due to the orthogonal constraint).

What's the best way to set this up?

@wil-j-wil wil-j-wil changed the title How should the OILMM be parameterised when using AD How should the OILMM be parameterised when using AD? Jun 8, 2022
@wil-j-wil
Copy link
Author

Perhaps also worth mentioning that for a smallish example (800 time steps, 11 outputs, 3 latents) the Zygote.gradient(...) step takes a long time to compile and my laptop runs out of memory.

@willtebbutt
Copy link
Member

You actually don't need to pass an already-orthogonal matrix to orthogonal. Something like

θ_oilmm = (;
    U = orthogonal(randn(num_outputs, num_latents)),
    S = positive.(rand(num_latents) .+ 0.1),
)

Seems to be fine if you do this.

Perhaps also worth mentioning that for a smallish example (800 time steps, 11 outputs, 3 latents) the Zygote.gradient(...) step takes a long time to compile and my laptop runs out of memory.

Interestingly, I'm not seeing this kind of behaviour. Could you copy + paste your package versions?

My overall script is:

using AbstractGPs, KernelFunctions, LinearAlgebra, LinearMixingModels, ParameterHandling, Zygote

num_outputs = 11
num_latents = 3
N = 800

x_train = KernelFunctions.MOInputIsotopicByOutputs(collect(1:N), num_outputs)

θ_oilmm = (;
    U = orthogonal(randn(num_outputs, num_latents)),
    S = positive.(rand(num_latents) .+ 0.1),
)

flat_θ_oilmm, unpack = value_flatten(θ_oilmm)

function build_gp(θ)
    sogp = GP(SEKernel())
    latent_gp = independent_mogp([sogp for _ in 1:num_latents])
    return ILMM(latent_gp, Orthogonal.U, Diagonal.S)))
end

y_train = rand(build_gp(unpack(flat_θ_oilmm))(x_train, 0.1))

function objective(θ)
    oilmm = build_gp(θ)
    return -logpdf(oilmm(x_train, 0.1), y_train)
end

Zygote.gradient(objective  unpack, flat_θ_oilmm)

@willtebbutt
Copy link
Member

I just tried timing stuff:

julia> @benchmark $objective($unpack($flat_θ_oilmm))
BenchmarkTools.Trial: 104 samples with 1 evaluation.
 Range (min  max):  32.383 ms  70.464 ms  ┊ GC (min  max):  0.00%  39.53%
 Time  (median):     49.666 ms              ┊ GC (median):    29.63%
 Time  (mean ± σ):   48.071 ms ±  8.882 ms  ┊ GC (mean ± σ):  26.62% ± 15.01%

  ▁▆▁▁ ▁                ▁▃▃▆  ▆ █ █▆▁▃▁ ▁   ▆
  ████▄█▇▇▇▁▁▁▁▁▁▁▁▇▄▁▇▇████▄▁█▇█▄█████▇█▁▁▁█▄▄▁▇▇▄▁▁▇▁▁▁▁▁▁▄ ▄
  32.4 ms         Histogram: frequency by time        67.7 ms <

 Memory estimate: 58.82 MiB, allocs estimate: 124.

julia> @benchmark Zygote.gradient($objective  $unpack, $flat_θ_oilmm)
BenchmarkTools.Trial: 24 samples with 1 evaluation.
 Range (min  max):  182.655 ms  380.152 ms  ┊ GC (min  max):  8.02%  55.93%
 Time  (median):     196.132 ms               ┊ GC (median):    21.91%
 Time  (mean ± σ):   211.914 ms ±  46.889 ms  ┊ GC (mean ± σ):  25.77% ±  9.87%

   ▃ █▁ ▃
  ▄█▁██▄█▁▁▄▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▄ ▁
  183 ms           Histogram: frequency by time          380 ms <

 Memory estimate: 279.88 MiB, allocs estimate: 46104.

Would be good to know if you're seeing different performance.

@thomasgudjonwright
Copy link
Member

If it isn't the construction of the initial Orthogonal (at least its orthogonality) that is causing the NaNs in hyperopt, we should figure out what is. @willtebbutt is there something obvious to you that could be causing this? If not I can look in to it. I am on a family trip until Friday but will have lots of time starting then to do that (as well as finally get to writing proper examples, including now one with hyperopt!).

My first thought is maybe the noise added to the Diagonal but I would have thought that obtaining initial values via svd should be fairly stable. Maybe that assumption is wrong.

@wil-j-wil
Copy link
Author

@willtebbutt your example works, thanks for that. Trying it out again now it seems like I just had a dodgy initialisation - sometimes it would result in NaNs and sometimes not. I should have scaled the inputs down as well as that was causing issues with the Matern kernel.

And the timing seems fine too now - I think I was using a session with a bunch of other stuff loaded.

Thanks both. And yes, an example notebook would be great!

@wil-j-wil wil-j-wil reopened this Jun 8, 2022
@wil-j-wil
Copy link
Author

Ah I just realised why I was no longer seeing the computational issue. I was using a RationalQuadraticKernel() previously. I switched to a MaternKernel() to create the example above, and things sped up. If I now switch back to RationalQuadraticKernel() then Zygote.gradient() crashes my session (I started a new environment and added the latest package versions).

I think that ForwardDiff would be much faster in this case, but it doesn't play nicely with ParameterHandling.jl (even using the PR here, the orthogonal constraint doesn't work with ForwardDiff).

This is likely a bit of a blocker in terms of using LinearMixingModels now, since for all our use cases / interesting examples we need to use more interesting kernels.

@willtebbutt
Copy link
Member

Ah interesting. I'll try to reproduce locally. This sounds like it's a problem with the RationalQuadraticKernel rather than anything in LinearMixingModels.jl though, so fixing this is a good thing to do more generally. Will take a look locally.

@willtebbutt
Copy link
Member

willtebbutt commented Jun 8, 2022

Short-term hack, add this to your code, and call Zygote.refresh() if you don't want to start a new session:

using ChainRulesCore

function ChainRulesCore.rrule(::typeof(only), x::Vector{<:Real})
    only_pullback(Ω) = NoTangent(), [Ω]
    return only(x), only_pullback
end

Long story short, Zygote appears to not be very good at differentiating through only, which sucks. Probably we need to put this rule upstream in ChainRules at some point.

edit: hmmm I may have spoken too soon -- stil taking over 1s to compute the gradient of the log marginal likleihood.

@willtebbutt
Copy link
Member

Re-implementing kernelmatrix for RationalQuadraticKernel in a way that is more AD-friendly seems to do the job. With

using KernelFunctions: pairwise, metric

function KernelFunctions.kernelmatrix(
    k::RationalQuadraticKernel, x::AbstractVector, y::AbstractVector
)
    D² = pairwise(metric(k), x, y)
    α = only(k.α)
    return (one(eltype(D²)) .+./ (2 * α)).^(-α)
end

function KernelFunctions.kernelmatrix(
    k::RationalQuadraticKernel, x::AbstractVector
)
    D² = pairwise(metric(k), x)
    α = only(k.α)
    return (one(eltype(D²)) .+./ (2 * α)).^(-α)
end

I'm seeing timings along the lines of:

julia> @benchmark $objective($unpack($flat_θ_oilmm))
BenchmarkTools.Trial: 57 samples with 1 evaluation.
 Range (min  max):  78.129 ms  114.812 ms  ┊ GC (min  max):  0.00%  31.26%
 Time  (median):     79.065 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   87.947 ms ±  11.295 ms  ┊ GC (mean ± σ):  10.17% ± 10.79%

  █▆
  ██▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▃▃▄▄▆▄▁▃▁▁▁▁▄▁▁▃▄▁▁▁▃▃▃▁▁▁▁▁▁▁▁▁▁▃ ▁
  78.1 ms         Histogram: frequency by time          114 ms <

 Memory estimate: 58.82 MiB, allocs estimate: 125.

julia> @benchmark Zygote.gradient($objective  $unpack, $flat_θ_oilmm)
BenchmarkTools.Trial: 22 samples with 1 evaluation.
 Range (min  max):  207.531 ms  256.803 ms  ┊ GC (min  max):  7.72%  23.06%
 Time  (median):     226.406 ms               ┊ GC (median):    15.23%
 Time  (mean ± σ):   227.271 ms ±  11.943 ms  ┊ GC (mean ± σ):  14.32% ±  4.29%

  ▁     █   █  ▁ █ ▁ ▁  ▁▁    ▁ ▁█▁▁    ▁   ▁  ▁              ▁
  █▁▁▁▁▁█▁▁▁█▁▁█▁█▁█▁█▁▁██▁▁▁▁█▁████▁▁▁▁█▁▁▁█▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  208 ms           Histogram: frequency by time          257 ms <

 Memory estimate: 294.53 MiB, allocs estimate: 46033.

@wil-j-wil
Copy link
Author

That's awesome, thanks a lot. Is this issue quite pervasive across many kernels from KernelFunctions.jl?

As another example, using a PeriodicTransform on a kernel with a lengthscale also slows things down a fair bit:
with_lengthscale(Matern12Kernel(), θ.len) ∘ PeriodicTransform(1 / period)

@willtebbutt
Copy link
Member

That's awesome, thanks a lot. Is this issue quite pervasive across many kernels from KernelFunctions.jl?

Hopefully not? We don't have good systematic performance testing at the minute unfortunately, so it's hard to say much.
I haven't been able to get code to run with the above kernel -- we really need to enable a jitter term inside some of the computations.

@willtebbutt
Copy link
Member

@wil-j-wil is this resolved now that we've improved KernelFunctions' performance?

@wil-j-wil
Copy link
Author

yes this is all sorted now (noting the slightly awkward SVD issue in the orthogonality constraint)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants