-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from rossviljoen/base_implementation
Base implementation of SVGP
- Loading branch information
Showing
14 changed files
with
925 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
style = "blue" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name = "SparseGPs" | ||
uuid = "298c2ebc-0411-48ad-af38-99e88101b606" | ||
authors = ["Ross Viljoen <[email protected]>"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" | ||
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Recreation of https://gpflow.readthedocs.io/en/master/notebooks/basics/classification.html | ||
|
||
# %% | ||
using SparseGPs | ||
using AbstractGPs | ||
using GPLikelihoods | ||
using StatsFuns | ||
using FastGaussQuadrature | ||
using Distributions | ||
using LinearAlgebra | ||
using DelimitedFiles | ||
using IterTools | ||
|
||
using Plots | ||
default(; legend=:outertopright, size=(700, 400)) | ||
|
||
using Random | ||
Random.seed!(1234) | ||
|
||
# %% | ||
# Read in the classification data | ||
data_file = pkgdir(SparseGPs) * "/examples/data/classif_1D.csv" | ||
x, y = eachcol(readdlm(data_file)) | ||
scatter(x, y) | ||
|
||
# %% | ||
# First, create the GP kernel from given parameters k | ||
function make_kernel(k) | ||
return softplus(k[1]) * (SqExponentialKernel() ∘ ScaleTransform(softplus(k[2]))) | ||
end | ||
|
||
k = [10, 0.1] | ||
|
||
kernel = make_kernel(k) | ||
f = LatentGP(GP(kernel), BernoulliLikelihood(), 0.1) | ||
fx = f(x) | ||
|
||
# %% | ||
# Then, plot some samples from the prior underlying GP | ||
x_plot = 0:0.02:6 | ||
prior_f_samples = rand(f.f(x_plot, 1e-6), 20) | ||
|
||
plt = plot(x_plot, prior_f_samples; seriescolor="red", linealpha=0.2, label="") | ||
scatter!(plt, x, y; seriescolor="blue", label="Data points") | ||
|
||
# %% | ||
# Plot the same samples, but pushed through a logistic sigmoid to constrain | ||
# them in (0, 1). | ||
prior_y_samples = mean.(f.lik.(prior_f_samples)) | ||
|
||
plt = plot(x_plot, prior_y_samples; seriescolor="red", linealpha=0.2, label="") | ||
scatter!(plt, x, y; seriescolor="blue", label="Data points") | ||
|
||
# %% | ||
# A simple Flux model | ||
using Flux | ||
|
||
struct SVGPModel | ||
k # kernel parameters | ||
m # variational mean | ||
A # variational covariance | ||
z # inducing points | ||
end | ||
|
||
Flux.@functor SVGPModel (k, m, A) # Don't train the inducing inputs | ||
|
||
lik = BernoulliLikelihood() | ||
jitter = 1e-4 | ||
|
||
function (m::SVGPModel)(x) | ||
kernel = make_kernel(m.k) | ||
f = LatentGP(GP(kernel), lik, jitter) | ||
q = MvNormal(m.m, m.A'm.A) | ||
fx = f(x) | ||
fu = f(m.z).fx | ||
return fx, fu, q | ||
end | ||
|
||
function flux_loss(x, y; n_data=length(y)) | ||
fx, fu, q = model(x) | ||
return -SparseGPs.elbo(fx, y, fu, q; n_data, method=MonteCarlo()) | ||
end | ||
|
||
# %% | ||
M = 15 # number of inducing points | ||
|
||
# Initialise the parameters | ||
k = [10, 0.1] | ||
m = zeros(M) | ||
A = Matrix{Float64}(I, M, M) | ||
z = x[1:M] | ||
|
||
model = SVGPModel(k, m, A, z) | ||
|
||
opt = ADAM(0.1) | ||
parameters = Flux.params(model) | ||
|
||
# %% | ||
# Negative ELBO before training | ||
println(flux_loss(x, y)) | ||
|
||
# %% | ||
# Train the model | ||
Flux.train!( | ||
(x, y) -> flux_loss(x, y), | ||
parameters, | ||
ncycle([(x, y)], 2000), # Train for 1000 epochs | ||
opt, | ||
) | ||
|
||
# %% | ||
# Negative ELBO after training | ||
println(flux_loss(x, y)) | ||
|
||
# %% | ||
# After optimisation, plot samples from the underlying posterior GP. | ||
fu = f(z).fx # want the underlying FiniteGP | ||
post = SparseGPs.approx_posterior(SVGP(), fu, MvNormal(m, A'A)) | ||
l_post = LatentGP(post, BernoulliLikelihood(), jitter) | ||
|
||
post_f_samples = rand(l_post.f(x_plot, 1e-6), 20) | ||
|
||
plt = plot(x_plot, post_f_samples; seriescolor="red", linealpha=0.2, legend=false) | ||
|
||
# %% | ||
# As above, push these samples through a logistic sigmoid to get posterior predictions. | ||
post_y_samples = mean.(l_post.lik.(post_f_samples)) | ||
|
||
plt = plot( | ||
x_plot, | ||
post_y_samples; | ||
seriescolor="red", | ||
linealpha=0.2, | ||
# legend=false, | ||
label="", | ||
) | ||
scatter!(plt, x, y; seriescolor="blue", label="Data points") | ||
vline!(z; label="Pseudo-points") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
5.668341708542713242e+00 0.000000000000000000e+00 | ||
5.758793969849246075e+00 0.000000000000000000e+00 | ||
5.517587939698492150e+00 0.000000000000000000e+00 | ||
2.954773869346733584e+00 1.000000000000000000e+00 | ||
3.648241206030150785e+00 1.000000000000000000e+00 | ||
2.110552763819095290e+00 1.000000000000000000e+00 | ||
4.613065326633165597e+00 0.000000000000000000e+00 | ||
4.793969849246231263e+00 0.000000000000000000e+00 | ||
4.703517587939698430e+00 0.000000000000000000e+00 | ||
6.030150753768843686e-01 1.000000000000000000e+00 | ||
3.015075376884421843e-01 0.000000000000000000e+00 | ||
3.979899497487437099e+00 0.000000000000000000e+00 | ||
3.226130653266331638e+00 1.000000000000000000e+00 | ||
1.899497487437185939e+00 1.000000000000000000e+00 | ||
1.145728643216080256e+00 1.000000000000000000e+00 | ||
3.316582914572864249e-01 0.000000000000000000e+00 | ||
6.030150753768843686e-01 1.000000000000000000e+00 | ||
2.231155778894472252e+00 1.000000000000000000e+00 | ||
3.256281407035175768e+00 1.000000000000000000e+00 | ||
1.085427135678391997e+00 1.000000000000000000e+00 | ||
1.809045226130653106e+00 1.000000000000000000e+00 | ||
4.492462311557789079e+00 0.000000000000000000e+00 | ||
1.959798994974874198e+00 1.000000000000000000e+00 | ||
0.000000000000000000e+00 0.000000000000000000e+00 | ||
3.346733668341708601e+00 1.000000000000000000e+00 | ||
1.507537688442210921e-01 0.000000000000000000e+00 | ||
1.809045226130653328e-01 1.000000000000000000e+00 | ||
5.517587939698492150e+00 0.000000000000000000e+00 | ||
2.201005025125628123e+00 1.000000000000000000e+00 | ||
5.577889447236180409e+00 0.000000000000000000e+00 | ||
1.809045226130653328e-01 0.000000000000000000e+00 | ||
1.688442211055276365e+00 1.000000000000000000e+00 | ||
4.160804020100502321e+00 0.000000000000000000e+00 | ||
2.170854271356783993e+00 1.000000000000000000e+00 | ||
4.311557788944723413e+00 0.000000000000000000e+00 | ||
3.075376884422110546e+00 1.000000000000000000e+00 | ||
5.125628140703517133e+00 0.000000000000000000e+00 | ||
1.989949748743718549e+00 1.000000000000000000e+00 | ||
5.366834170854271058e+00 0.000000000000000000e+00 | ||
4.100502512562814061e+00 0.000000000000000000e+00 | ||
7.236180904522613311e-01 1.000000000000000000e+00 | ||
2.261306532663316382e+00 1.000000000000000000e+00 | ||
3.467336683417085119e+00 1.000000000000000000e+00 | ||
1.085427135678391997e+00 1.000000000000000000e+00 | ||
5.095477386934673447e+00 0.000000000000000000e+00 | ||
5.185929648241205392e+00 0.000000000000000000e+00 | ||
2.743718592964823788e+00 1.000000000000000000e+00 | ||
2.773869346733668362e+00 1.000000000000000000e+00 | ||
1.417085427135678311e+00 1.000000000000000000e+00 | ||
1.989949748743718549e+00 1.000000000000000000e+00 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# A recreation of https://gpflow.readthedocs.io/en/master/notebooks/advanced/gps_for_big_data.html | ||
|
||
using AbstractGPs | ||
using SparseGPs | ||
using Distributions | ||
using LinearAlgebra | ||
using Optim | ||
using IterTools | ||
|
||
using Plots | ||
default(; legend=:outertopright, size=(700, 400)) | ||
|
||
using Random | ||
Random.seed!(1234) | ||
|
||
# %% | ||
# The data generating function | ||
function g(x) | ||
return sin(3π * x) + 0.3 * cos(9π * x) + 0.5 * sin(7π * x) | ||
end | ||
|
||
N = 10000 # Number of training points | ||
x = rand(Uniform(-1, 1), N) | ||
y = g.(x) + 0.3 * randn(N) | ||
|
||
scatter(x, y; xlabel="x", ylabel="y", legend=false) | ||
|
||
# %% | ||
# A simple Flux model | ||
using Flux | ||
|
||
lik_noise = 0.3 | ||
jitter = 1e-5 | ||
|
||
struct SVGPModel | ||
k # kernel parameters | ||
m # variational mean | ||
A # variational covariance | ||
z # inducing points | ||
end | ||
|
||
Flux.@functor SVGPModel (k, m, A) # Don't train the inducing inputs | ||
|
||
function make_kernel(k) | ||
return softplus(k[1]) * (SqExponentialKernel() ∘ ScaleTransform(softplus(k[2]))) | ||
end | ||
|
||
# Create the 'model' from the parameters - i.e. return the FiniteGP at inputs x, | ||
# the FiniteGP at inducing inputs z and the variational posterior over inducing | ||
# points - q(u). | ||
function (m::SVGPModel)(x) | ||
kernel = make_kernel(m.k) | ||
f = GP(kernel) | ||
q = MvNormal(m.m, m.A'm.A) | ||
fx = f(x, lik_noise) | ||
fu = f(m.z, jitter) | ||
return fx, fu, q | ||
end | ||
|
||
# Create the posterior GP from the model parameters. | ||
function posterior(m::SVGPModel) | ||
kernel = make_kernel(m.k) | ||
f = GP(kernel) | ||
fu = f(m.z, jitter) | ||
q = MvNormal(m.m, m.A'm.A) | ||
return SparseGPs.approx_posterior(SVGP(), fu, q) | ||
end | ||
|
||
# Return the loss given data - in this case the negative ELBO. | ||
function flux_loss(x, y; n_data=length(y)) | ||
fx, fu, q = model(x) | ||
return -SparseGPs.elbo(fx, y, fu, q; n_data) | ||
end | ||
|
||
# %% | ||
M = 50 # number of inducing points | ||
|
||
# Select the first M inputs as inducing inputs | ||
z = x[1:M] | ||
|
||
# Initialise the parameters | ||
k = [0.3, 10] | ||
m = zeros(M) | ||
A = Matrix{Float64}(I, M, M) | ||
|
||
model = SVGPModel(k, m, A, z) | ||
|
||
b = 100 # minibatch size | ||
opt = ADAM(0.001) | ||
parameters = Flux.params(model) | ||
data_loader = Flux.Data.DataLoader((x, y); batchsize=b) | ||
|
||
# %% | ||
# Negative ELBO before training | ||
println(flux_loss(x, y)) | ||
|
||
# %% | ||
# Train the model | ||
Flux.train!( | ||
(x, y) -> flux_loss(x, y; n_data=N), | ||
parameters, | ||
ncycle(data_loader, 300), # Train for 300 epochs | ||
opt, | ||
) | ||
|
||
# %% | ||
# Negative ELBO after training | ||
println(flux_loss(x, y)) | ||
|
||
# %% | ||
# Plot samples from the optmimised approximate posterior. | ||
post = posterior(model) | ||
|
||
scatter( | ||
x, | ||
y; | ||
markershape=:xcross, | ||
markeralpha=0.1, | ||
xlim=(-1, 1), | ||
xlabel="x", | ||
ylabel="y", | ||
title="posterior (VI with sparse grid)", | ||
label="Train Data", | ||
) | ||
plot!(-1:0.001:1, post; label="Posterior") | ||
vline!(z; label="Pseudo-points") | ||
|
||
# %% There is a closed form optimal solution for the variational posterior q(u) | ||
# (e.g. https://krasserm.github.io/2020/12/12/gaussian-processes-sparse/ | ||
# equations (11) & (12)). The SVGP posterior with this optimal q(u) should | ||
# therefore be equivalent to the 'exact' sparse GP (Titsias) posterior. | ||
|
||
function exact_q(fu, fx, y) | ||
σ² = fx.Σy[1] | ||
Kuf = cov(fu, fx) | ||
Kuu = Symmetric(cov(fu)) | ||
Σ = (Symmetric(cov(fu) + (1 / σ²) * Kuf * Kuf')) | ||
m = ((1 / σ²) * Kuu * (Σ \ Kuf)) * y | ||
S = Symmetric(Kuu * (Σ \ Kuu)) | ||
return MvNormal(m, S) | ||
end | ||
|
||
kernel = make_kernel([0.3, 10]) | ||
f = GP(kernel) | ||
fx = f(x, lik_noise) | ||
fu = f(z, jitter) | ||
q_ex = exact_q(fu, fx, y) | ||
|
||
scatter(x, y) | ||
scatter!(z, q_ex.μ) | ||
|
||
# These two should be the same - and they are, as the plot below shows | ||
ap_ex = SparseGPs.approx_posterior(SVGP(), fu, q_ex) # Hensman (2013) exact posterior | ||
ap_tits = AbstractGPs.approx_posterior(VFE(), fx, y, fu) # Titsias posterior | ||
|
||
# These are also approximately equal | ||
SparseGPs.elbo(fx, y, fu, q_ex) | ||
AbstractGPs.elbo(fx, y, fu) | ||
|
||
# %% | ||
scatter( | ||
x, | ||
y; | ||
markershape=:xcross, | ||
markeralpha=0.1, | ||
xlim=(-1, 1), | ||
xlabel="x", | ||
ylabel="y", | ||
title="posterior (VI with sparse grid)", | ||
label="Train Data", | ||
) | ||
plot!(-1:0.001:1, ap_ex; label="SVGP posterior") | ||
plot!(-1:0.001:1, ap_tits; label="Titsias posterior") | ||
vline!(z; label="Pseudo-points") |
Oops, something went wrong.