Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Will Tebbutt <[email protected]>
  • Loading branch information
samanklesaria and willtebbutt authored Apr 1, 2024
1 parent c906218 commit 8f5d288
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
3 changes: 1 addition & 2 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ include("LaplaceApproximationModule.jl")
build_laplace_objective, build_laplace_objective!

include("NearestNeighborsModule.jl")
@reexport using .NearestNeighborsModule:
NearestNeighbors
@reexport using .NearestNeighborsModule: NearestNeighbors

include("deprecations.jl")

Expand Down
26 changes: 11 additions & 15 deletions src/NearestNeighborsModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRulesCore
using KernelFunctions, LinearAlgebra, SparseArrays, AbstractGPs

"""
Constructs the matrix ``B`` for which ``f = Bf + \epsilon`` were ``f``
Constructs the matrix ``B`` for which ``f = Bf + \epsilon`` where ``f``
are the values of the GP and ``\epsilon`` is a vector of zero mean
independent Gaussian noise.
This matrix builds the conditional mean for function value ``f_i``
Expand All @@ -21,11 +21,11 @@ function make_B(pts::AbstractVector{T}, k::Int, kern::Kernel) where {T}
end

function make_rows(pts::AbstractVector{T}, k::Int, kern::Kernel) where {T}
[make_row(kern, pts[max(1, i-k):i-1], pts[i]) for i in 2:length(pts)]
return [make_row(kern, pts[max(1, i-k):i-1], pts[i]) for i in 2:length(pts)]
end

function make_row(kern::Kernel, ns::AbstractVector{T}, p::T) where {T}
kernelmatrix(kern,ns) \ kern.(ns, p)
return kernelmatrix(kern,ns) \ kern.(ns, p)
end

function make_js(rows, k)
Expand All @@ -35,17 +35,15 @@ function make_js(rows, k)
end for (row, i) in zip(rows, 2:(length(rows)+1))]
end

function make_is(js)
[fill(i, length(col_ix)) for (col_ix, i) in zip(js, 2:(length(js)+1))]
end
make_is(js) = [fill(i, length(col_ix)) for (col_ix, i) in zip(js, 2:(length(js)+1))]

"""
Constructs the diagonal covariance matrix for noise vector ``\epsilon``
for which ``f = Bf + \epsilon``.
See equation (10) of (Datta, A. Nearest neighbor sparse Cholesky
matrices in spatial statistics. 2022).
"""
function make_F(pts::AbstractVector{T}, k::Int, kern::Kernel) where {T}
function make_F(pts::AbstractVector, k::Int, kern::Kernel)
n = length(pts)
vals = [
begin
Expand Down Expand Up @@ -88,13 +86,11 @@ AbstractGPs.diag_Xt_invA_X(A::InvRoot, X::AbstractVecOrMat) = AbstractGPs.diag_A

AbstractGPs.Xt_invA_X(A::InvRoot, X::AbstractVecOrMat) = AbstractGPs.At_A(A.U' * X)

"""
Make a sparse approximation of the square root of the precision matrix
"""
# Make a sparse approximation of the square root of the precision matrix
function approx_root_prec(x::AbstractVector, k::Int, kern::Kernel)
F = make_F(x, k, kern)
B = make_B(x, k, kern)
UpperTriangular((I - B)' * inv(sqrt(F)))
return UpperTriangular((I - B)' * inv(sqrt(F)))
end

function AbstractGPs.posterior(nn::NearestNeighbors, fx::AbstractGPs.FiniteGP, y::AbstractVector)
Expand All @@ -107,10 +103,10 @@ function AbstractGPs.posterior(nn::NearestNeighbors, fx::AbstractGPs.FiniteGP, y
end

function API.approx_lml(nn::NearestNeighbors, fx::AbstractGPs.FiniteGP, y::AbstractVector)
post = posterior(nn, fx, y)
quadform = post.data.α' * post.data.δ
ld = logdet(post.data.C)
return -0.5 * ld -(length(y)/2) * log(2 * pi) - 0.5 * quadform
post = posterior(nn, fx, y)
quadform = post.data.α' * post.data.δ
ld = logdet(post.data.C)
return -(ld + length(y) * eltype(y)(log2π) + quadform) / 2
end

end
3 changes: 1 addition & 2 deletions test/NearestNeighborsModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
y = sin.(x)

@testset "Using all neighbors is the same as the exact GP" begin
opt_pred = mean_and_cov(posterior(NearestNeighbors(length(x) - 1),
fx, y)(x2))
opt_pred = mean_and_cov(posterior(NearestNeighbors(length(x) - 1), fx, y)(x2))
pred = mean_and_cov(posterior(fx, y)(x2))
for i in 1:2
@test all(isapprox.(opt_pred[i], pred[i]; atol=1e-4))
Expand Down

0 comments on commit 8f5d288

Please sign in to comment.