From 50721bcfdc47ab261d486506aa6f92ac2c6dc84b Mon Sep 17 00:00:00 2001 From: Ross Viljoen Date: Sat, 30 Apr 2022 13:56:30 +0100 Subject: [PATCH 1/3] Relax elbo type signature and factor homoscadasticity check --- src/SparseVariationalApproximationModule.jl | 22 ++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 2a134b4b..914cfbc6 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -306,26 +306,17 @@ Statistics. PMLR, 2015. """ function AbstractGPs.elbo( sva::SparseVariationalApproximation, - fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Union{Diagonal{<:Real,<:Fill},ScalMat}}, + fx::FiniteGP, y::AbstractVector{<:Real}; num_data=length(y), quadrature=DefaultExpectationMethod(), ) + σ² = _get_homoscedastic_noise(fx.Σy) σ² = fx.Σy[1] lik = GaussianLikelihood(σ²) return elbo(sva, LatentFiniteGP(fx, lik), y; num_data, quadrature) end -function AbstractGPs.elbo( - ::SparseVariationalApproximation, ::FiniteGP, ::AbstractVector; kwargs... -) - return error( - "The observation noise fx.Σy must be homoscedastic.\n", - "To avoid this error, construct fx using: f = GP(kernel); fx = f(x, σ²)", - ", where σ² is a positive Real.", - ) -end - """ elbo( sva::SparseVariationalApproximation, @@ -372,4 +363,13 @@ function _prior_kl(sva::SparseVariationalApproximation{NonCentered}) return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2 end +_get_homoscedastic_noise(Σy::Union{Diagonal{<:Real,<:Fill},ScalMat}) = Σy[1] +function _get_homoscedastic_noise(_) + return error( + "The observation noise fx.Σy must be homoscedastic.\n", + "To avoid this error, construct fx using: f = GP(kernel); fx = f(x, σ²)", + ", where σ² is a positive Real.\n" + ) +end + end From 40195b45c973d34ef0b2ce6a2e426d97bfe00a2d Mon Sep 17 00:00:00 2001 From: Ross Viljoen Date: Sat, 30 Apr 2022 14:01:07 +0100 Subject: [PATCH 2/3] Remove old line --- src/SparseVariationalApproximationModule.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 914cfbc6..5494f969 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -312,7 +312,6 @@ function AbstractGPs.elbo( quadrature=DefaultExpectationMethod(), ) σ² = _get_homoscedastic_noise(fx.Σy) - σ² = fx.Σy[1] lik = GaussianLikelihood(σ²) return elbo(sva, LatentFiniteGP(fx, lik), y; num_data, quadrature) end From 894c51f5c8242bb572a9e2476dcc8ec239ecb0ab Mon Sep 17 00:00:00 2001 From: Ross Viljoen Date: Sat, 30 Apr 2022 14:11:53 +0100 Subject: [PATCH 3/3] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/SparseVariationalApproximationModule.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 5494f969..bde2daf3 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -367,7 +367,7 @@ function _get_homoscedastic_noise(_) return error( "The observation noise fx.Σy must be homoscedastic.\n", "To avoid this error, construct fx using: f = GP(kernel); fx = f(x, σ²)", - ", where σ² is a positive Real.\n" + ", where σ² is a positive Real.\n", ) end