Skip to content

Commit

Permalink
Clarify API for GP approximations (#361)
Browse files Browse the repository at this point in the history
* API docstrings for the base forms of `posterior` and `approx_log_evidence`

* DTC as separate type

* ExactInference for fallback forms

* deprecations

* bump version to 0.5.17
  • Loading branch information
st-- authored May 8, 2023
1 parent 3e5f0a5 commit 1b7c135
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 153 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractGPs"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
authors = ["JuliaGaussianProcesses Team"]
version = "0.5.16"
version = "0.5.17"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export rand!,
mean_vector,
marginals,
logpdf,
approx_log_evidence,
elbo,
dtc,
posterior,
Expand Down
27 changes: 27 additions & 0 deletions src/abstract_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,30 @@ for (m, f) in [
)
end
end

"""
approx_log_evidence(approx::<Approximation>, lfx::LatentFiniteGP, ys)
Compute an approximation to the log of the marginal likelihood (also known as
"evidence") under the given `approx`imation to the posterior. The return value
of `approx_log_evidence` can be used to optimise the hyperparameters of `lfx`.
"""
function approx_log_evidence end

"""
posterior(fx::FiniteGP, y::AbstractVector{<:Real})
posterior(approx::<Approximation>, fx::FiniteGP, y::AbstractVector{<:Real})
posterior(approx::<Approximation>, lfx::LatentFiniteGP, y::AbstractVector)
Construct the posterior distribution over the latent Gaussian process (`fx.f`
or `lfx.fx.f`), given the observations `y` corresponding to the process's
finite projection (`fx` or `lfx`).
In the two-argument form, this describes exact GP regression with `y` observed
under a Gaussian likelihood, and returns a `PosteriorGP`.
In the three-argument form, the first argument specifies the approximation to
be used (e.g. `VFE` or defined in other packages such as ApproximateGPs.jl),
and returns an `ApproxPosteriorGP`.
"""
function posterior end
3 changes: 3 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@
@deprecate sampleplot!(plt::RecipesBase.AbstractPlot, gp::FiniteGP, n::Int; kwargs...) sampleplot!(
plt, gp; samples=n, kwargs...
)

@deprecate elbo(dtc::DTC, fx, y) approx_log_evidence(dtc, fx, y)
@deprecate dtc(vfe::Union{VFE,DTC}, fx, y) approx_log_evidence(vfe, fx, y)
10 changes: 9 additions & 1 deletion src/exact_gpr_posterior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@ struct PosteriorGP{Tprior,Tdata} <: AbstractGP
data::Tdata
end

struct ExactInference end

posterior(::ExactInference, fx::FiniteGP, y::AbstractVector{<:Real}) = posterior(fx, y)

function approx_log_evidence(::ExactInference, fx::FiniteGP, y::AbstractVector{<:Real})
return logpdf(fx, y)
end

"""
posterior(fx::FiniteGP, y::AbstractVector{<:Real})
Construct the posterior distribution over `fx.f` given observations `y` at `x` made under
Construct the posterior distribution over `fx.f` given observations `y` at `fx.x` made under
noise `fx.Σy`. This is another `AbstractGP` object. See chapter 2 of [1] for a recap on
exact inference in GPs. This posterior process has mean function
```julia
Expand Down
69 changes: 44 additions & 25 deletions src/sparse_approximations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ struct VFE{Tfz<:FiniteGP}
fz::Tfz
end

const DTC = VFE
"""
DTC(fz::FiniteGP)
Similar to `VFE`, but uses a different objective for `approx_log_evidence`.
"""
struct DTC{Tfz<:FiniteGP}
fz::Tfz
end

struct ApproxPosteriorGP{Tapprox,Tprior,Tdata} <: AbstractGP
approx::Tapprox
Expand Down Expand Up @@ -48,7 +55,7 @@ true
processes". In: Proceedings of the Twelfth International Conference on Artificial
Intelligence and Statistics. 2009.
"""
function posterior(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
function posterior(vfe::Union{VFE,DTC}, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f

U_y = _cholesky(_symmetric(fx.Σy)).U
Expand All @@ -69,7 +76,7 @@ end

"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE},
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
fx::FiniteGP,
y::AbstractVector{<:Real}
)
Expand All @@ -78,7 +85,9 @@ Update the `ApproxPosteriorGP` given a new set of observations. Here, we retain
set of pseudo-points.
"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE}, fx::FiniteGP, y::AbstractVector{<:Real}
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
fx::FiniteGP,
y::AbstractVector{<:Real},
)
@assert f_post_approx.prior === fx.f

Expand Down Expand Up @@ -111,14 +120,14 @@ end

"""
function update_posterior(
f_post_approx::ApproxPosteriorGP{<:VFE},
f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}},
z::FiniteGP,
)
Update the `ApproxPosteriorGP` given a new set of pseudo-points to append to the existing
set of pseudo-points.
"""
function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
function update_posterior(f_post_approx::ApproxPosteriorGP{<:Union{VFE,DTC}}, fz::FiniteGP)
@assert f_post_approx.prior === fz.f

z_old = inducing_points(f_post_approx)
Expand Down Expand Up @@ -161,48 +170,56 @@ function update_posterior(f_post_approx::ApproxPosteriorGP{<:VFE}, fz::FiniteGP)
x=f_post_approx.data.x,
Σy=f_post_approx.data.Σy,
)
return ApproxPosteriorGP(VFE(fz_new), f_post_approx.prior, cache)
return ApproxPosteriorGP(
_update_approx(f_post_approx.approx, fz_new), f_post_approx.prior, cache
)
end

_update_approx(vfe::VFE, fz_new::FiniteGP) = VFE(fz_new)
_update_approx(dtc::DTC, fz_new::FiniteGP) = DTC(fz_new)

# AbstractGP interface implementation.

function Statistics.mean(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.mean(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α
end

function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
return cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
end

function Statistics.var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function Statistics.var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
return var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
end

function Statistics.cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector, y::AbstractVector)
function Statistics.cov(
f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector, y::AbstractVector
)
A_zx = f.data.U' \ cov(f.prior, inducing_points(f), x)
A_zy = f.data.U' \ cov(f.prior, inducing_points(f), y)
return cov(f.prior, x, y) - A_zx'A_zy + Xt_invA_Y(A_zx, f.data.Λ_ε, A_zy)
end

function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function StatsBase.mean_and_cov(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
m_post = mean(f.prior, x) + A' * f.data.m_ε
C_post = cov(f.prior, x) - At_A(A) + Xt_invA_X(f.data.Λ_ε, A)
return m_post, C_post
end

function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:VFE}, x::AbstractVector)
function StatsBase.mean_and_var(f::ApproxPosteriorGP{<:Union{VFE,DTC}}, x::AbstractVector)
A = f.data.U' \ cov(f.prior, inducing_points(f), x)
m_post = mean(f.prior, x) + A' * f.data.m_ε
c_post = var(f.prior, x) - diag_At_A(A) + diag_Xt_invA_X(f.data.Λ_ε, A)
return m_post, c_post
end

inducing_points(f::ApproxPosteriorGP{<:VFE}) = f.approx.fz.x
inducing_points(f::ApproxPosteriorGP{<:Union{VFE,DTC}}) = f.approx.fz.x

"""
approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
The Titsias Evidence Lower BOund (ELBO) [1]. `y` are observations of `fx`, and `v.z`
Expand All @@ -228,14 +245,16 @@ true
processes". In: Proceedings of the Twelfth International Conference on Artificial
Intelligence and Statistics. 2009.
"""
function elbo(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
function approx_log_evidence(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f
_dtc, A = _compute_intermediates(fx, y, vfe.fz)
return _dtc - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
dtc_objective, A = _compute_intermediates(fx, y, vfe.fz)
return dtc_objective - (tr_Cf_invΣy(fx, fx.Σy) - sum(abs2, A)) / 2
end

elbo(vfe::VFE, fx, y) = approx_log_evidence(vfe, fx, y)

"""
dtc(v::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})
The Deterministic Training Conditional (DTC) [1]. `y` are observations of `fx`, and `v.z`
are inducing points.
Expand All @@ -248,25 +267,25 @@ julia> x = randn(1000);
julia> z = range(-5.0, 5.0; length=256);
julia> v = VFE(f(z));
julia> d = DTC(f(z));
julia> y = rand(f(x, 0.1));
julia> isapprox(dtc(v, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6)
julia> isapprox(approx_log_evidence(d, f(x, 0.1), y), logpdf(f(x, 0.1), y); atol=1e-6, rtol=1e-6)
true
```
[1] - M. Seeger, C. K. I. Williams and N. D. Lawrence. "Fast Forward Selection to Speed Up
Sparse Gaussian Process Regression". In: Proceedings of the Ninth International Workshop on
Artificial Intelligence and Statistics. 2003
"""
function dtc(vfe::VFE, fx::FiniteGP, y::AbstractVector{<:Real})
@assert vfe.fz.f === fx.f
_dtc, _ = _compute_intermediates(fx, y, vfe.fz)
return _dtc
function approx_log_evidence(dtc::DTC, fx::FiniteGP, y::AbstractVector{<:Real})
@assert dtc.fz.f === fx.f
dtc_objective, _ = _compute_intermediates(fx, y, dtc.fz)
return dtc_objective
end

# Factor out computations common to the `elbo` and `dtc`.
# Factor out computations of `approx_log_evidence` common to `VFE` and `DTC`
function _compute_intermediates(fx::FiniteGP, y::AbstractVector{<:Real}, fz::FiniteGP)
length(fx) == length(y) || throw(
DimensionMismatch(
Expand Down
Loading

2 comments on commit 1b7c135

@st--
Copy link
Member Author

@st-- st-- commented on 1b7c135 May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/83109

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.17 -m "<description of version>" 1b7c13563107257640437c72587499e16fbf65af
git push origin v0.5.17

Please sign in to comment.