Skip to content

Commit

Permalink
Merge pull request #9 from rossviljoen/base_implementation
Browse files Browse the repository at this point in the history
Base implementation of SVGP
  • Loading branch information
rossviljoen authored Jul 30, 2021
2 parents ff1df6c + ef3292c commit ffa3fa5
Show file tree
Hide file tree
Showing 14 changed files with 925 additions and 0 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "blue"
17 changes: 17 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name = "SparseGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["Ross Viljoen <[email protected]>"]
version = "0.1.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
138 changes: 138 additions & 0 deletions examples/classification.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Recreation of https://gpflow.readthedocs.io/en/master/notebooks/basics/classification.html

# %%
using SparseGPs
using AbstractGPs
using GPLikelihoods
using StatsFuns
using FastGaussQuadrature
using Distributions
using LinearAlgebra
using DelimitedFiles
using IterTools

using Plots
default(; legend=:outertopright, size=(700, 400))

using Random
Random.seed!(1234)

# %%
# Read in the classification data
data_file = pkgdir(SparseGPs) * "/examples/data/classif_1D.csv"
x, y = eachcol(readdlm(data_file))
scatter(x, y)

# %%
# First, create the GP kernel from given parameters k
function make_kernel(k)
return softplus(k[1]) * (SqExponentialKernel() ScaleTransform(softplus(k[2])))
end

k = [10, 0.1]

kernel = make_kernel(k)
f = LatentGP(GP(kernel), BernoulliLikelihood(), 0.1)
fx = f(x)

# %%
# Then, plot some samples from the prior underlying GP
x_plot = 0:0.02:6
prior_f_samples = rand(f.f(x_plot, 1e-6), 20)

plt = plot(x_plot, prior_f_samples; seriescolor="red", linealpha=0.2, label="")
scatter!(plt, x, y; seriescolor="blue", label="Data points")

# %%
# Plot the same samples, but pushed through a logistic sigmoid to constrain
# them in (0, 1).
prior_y_samples = mean.(f.lik.(prior_f_samples))

plt = plot(x_plot, prior_y_samples; seriescolor="red", linealpha=0.2, label="")
scatter!(plt, x, y; seriescolor="blue", label="Data points")

# %%
# A simple Flux model
using Flux

struct SVGPModel
k # kernel parameters
m # variational mean
A # variational covariance
z # inducing points
end

Flux.@functor SVGPModel (k, m, A) # Don't train the inducing inputs

lik = BernoulliLikelihood()
jitter = 1e-4

function (m::SVGPModel)(x)
kernel = make_kernel(m.k)
f = LatentGP(GP(kernel), lik, jitter)
q = MvNormal(m.m, m.A'm.A)
fx = f(x)
fu = f(m.z).fx
return fx, fu, q
end

function flux_loss(x, y; n_data=length(y))
fx, fu, q = model(x)
return -SparseGPs.elbo(fx, y, fu, q; n_data, method=MonteCarlo())
end

# %%
M = 15 # number of inducing points

# Initialise the parameters
k = [10, 0.1]
m = zeros(M)
A = Matrix{Float64}(I, M, M)
z = x[1:M]

model = SVGPModel(k, m, A, z)

opt = ADAM(0.1)
parameters = Flux.params(model)

# %%
# Negative ELBO before training
println(flux_loss(x, y))

# %%
# Train the model
Flux.train!(
(x, y) -> flux_loss(x, y),
parameters,
ncycle([(x, y)], 2000), # Train for 1000 epochs
opt,
)

# %%
# Negative ELBO after training
println(flux_loss(x, y))

# %%
# After optimisation, plot samples from the underlying posterior GP.
fu = f(z).fx # want the underlying FiniteGP
post = SparseGPs.approx_posterior(SVGP(), fu, MvNormal(m, A'A))
l_post = LatentGP(post, BernoulliLikelihood(), jitter)

post_f_samples = rand(l_post.f(x_plot, 1e-6), 20)

plt = plot(x_plot, post_f_samples; seriescolor="red", linealpha=0.2, legend=false)

# %%
# As above, push these samples through a logistic sigmoid to get posterior predictions.
post_y_samples = mean.(l_post.lik.(post_f_samples))

plt = plot(
x_plot,
post_y_samples;
seriescolor="red",
linealpha=0.2,
# legend=false,
label="",
)
scatter!(plt, x, y; seriescolor="blue", label="Data points")
vline!(z; label="Pseudo-points")
50 changes: 50 additions & 0 deletions examples/data/classif_1D.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
5.668341708542713242e+00 0.000000000000000000e+00
5.758793969849246075e+00 0.000000000000000000e+00
5.517587939698492150e+00 0.000000000000000000e+00
2.954773869346733584e+00 1.000000000000000000e+00
3.648241206030150785e+00 1.000000000000000000e+00
2.110552763819095290e+00 1.000000000000000000e+00
4.613065326633165597e+00 0.000000000000000000e+00
4.793969849246231263e+00 0.000000000000000000e+00
4.703517587939698430e+00 0.000000000000000000e+00
6.030150753768843686e-01 1.000000000000000000e+00
3.015075376884421843e-01 0.000000000000000000e+00
3.979899497487437099e+00 0.000000000000000000e+00
3.226130653266331638e+00 1.000000000000000000e+00
1.899497487437185939e+00 1.000000000000000000e+00
1.145728643216080256e+00 1.000000000000000000e+00
3.316582914572864249e-01 0.000000000000000000e+00
6.030150753768843686e-01 1.000000000000000000e+00
2.231155778894472252e+00 1.000000000000000000e+00
3.256281407035175768e+00 1.000000000000000000e+00
1.085427135678391997e+00 1.000000000000000000e+00
1.809045226130653106e+00 1.000000000000000000e+00
4.492462311557789079e+00 0.000000000000000000e+00
1.959798994974874198e+00 1.000000000000000000e+00
0.000000000000000000e+00 0.000000000000000000e+00
3.346733668341708601e+00 1.000000000000000000e+00
1.507537688442210921e-01 0.000000000000000000e+00
1.809045226130653328e-01 1.000000000000000000e+00
5.517587939698492150e+00 0.000000000000000000e+00
2.201005025125628123e+00 1.000000000000000000e+00
5.577889447236180409e+00 0.000000000000000000e+00
1.809045226130653328e-01 0.000000000000000000e+00
1.688442211055276365e+00 1.000000000000000000e+00
4.160804020100502321e+00 0.000000000000000000e+00
2.170854271356783993e+00 1.000000000000000000e+00
4.311557788944723413e+00 0.000000000000000000e+00
3.075376884422110546e+00 1.000000000000000000e+00
5.125628140703517133e+00 0.000000000000000000e+00
1.989949748743718549e+00 1.000000000000000000e+00
5.366834170854271058e+00 0.000000000000000000e+00
4.100502512562814061e+00 0.000000000000000000e+00
7.236180904522613311e-01 1.000000000000000000e+00
2.261306532663316382e+00 1.000000000000000000e+00
3.467336683417085119e+00 1.000000000000000000e+00
1.085427135678391997e+00 1.000000000000000000e+00
5.095477386934673447e+00 0.000000000000000000e+00
5.185929648241205392e+00 0.000000000000000000e+00
2.743718592964823788e+00 1.000000000000000000e+00
2.773869346733668362e+00 1.000000000000000000e+00
1.417085427135678311e+00 1.000000000000000000e+00
1.989949748743718549e+00 1.000000000000000000e+00
174 changes: 174 additions & 0 deletions examples/regression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# A recreation of https://gpflow.readthedocs.io/en/master/notebooks/advanced/gps_for_big_data.html

using AbstractGPs
using SparseGPs
using Distributions
using LinearAlgebra
using Optim
using IterTools

using Plots
default(; legend=:outertopright, size=(700, 400))

using Random
Random.seed!(1234)

# %%
# The data generating function
function g(x)
return sin(3π * x) + 0.3 * cos(9π * x) + 0.5 * sin(7π * x)
end

N = 10000 # Number of training points
x = rand(Uniform(-1, 1), N)
y = g.(x) + 0.3 * randn(N)

scatter(x, y; xlabel="x", ylabel="y", legend=false)

# %%
# A simple Flux model
using Flux

lik_noise = 0.3
jitter = 1e-5

struct SVGPModel
k # kernel parameters
m # variational mean
A # variational covariance
z # inducing points
end

Flux.@functor SVGPModel (k, m, A) # Don't train the inducing inputs

function make_kernel(k)
return softplus(k[1]) * (SqExponentialKernel() ScaleTransform(softplus(k[2])))
end

# Create the 'model' from the parameters - i.e. return the FiniteGP at inputs x,
# the FiniteGP at inducing inputs z and the variational posterior over inducing
# points - q(u).
function (m::SVGPModel)(x)
kernel = make_kernel(m.k)
f = GP(kernel)
q = MvNormal(m.m, m.A'm.A)
fx = f(x, lik_noise)
fu = f(m.z, jitter)
return fx, fu, q
end

# Create the posterior GP from the model parameters.
function posterior(m::SVGPModel)
kernel = make_kernel(m.k)
f = GP(kernel)
fu = f(m.z, jitter)
q = MvNormal(m.m, m.A'm.A)
return SparseGPs.approx_posterior(SVGP(), fu, q)
end

# Return the loss given data - in this case the negative ELBO.
function flux_loss(x, y; n_data=length(y))
fx, fu, q = model(x)
return -SparseGPs.elbo(fx, y, fu, q; n_data)
end

# %%
M = 50 # number of inducing points

# Select the first M inputs as inducing inputs
z = x[1:M]

# Initialise the parameters
k = [0.3, 10]
m = zeros(M)
A = Matrix{Float64}(I, M, M)

model = SVGPModel(k, m, A, z)

b = 100 # minibatch size
opt = ADAM(0.001)
parameters = Flux.params(model)
data_loader = Flux.Data.DataLoader((x, y); batchsize=b)

# %%
# Negative ELBO before training
println(flux_loss(x, y))

# %%
# Train the model
Flux.train!(
(x, y) -> flux_loss(x, y; n_data=N),
parameters,
ncycle(data_loader, 300), # Train for 300 epochs
opt,
)

# %%
# Negative ELBO after training
println(flux_loss(x, y))

# %%
# Plot samples from the optmimised approximate posterior.
post = posterior(model)

scatter(
x,
y;
markershape=:xcross,
markeralpha=0.1,
xlim=(-1, 1),
xlabel="x",
ylabel="y",
title="posterior (VI with sparse grid)",
label="Train Data",
)
plot!(-1:0.001:1, post; label="Posterior")
vline!(z; label="Pseudo-points")

# %% There is a closed form optimal solution for the variational posterior q(u)
# (e.g. https://krasserm.github.io/2020/12/12/gaussian-processes-sparse/
# equations (11) & (12)). The SVGP posterior with this optimal q(u) should
# therefore be equivalent to the 'exact' sparse GP (Titsias) posterior.

function exact_q(fu, fx, y)
σ² = fx.Σy[1]
Kuf = cov(fu, fx)
Kuu = Symmetric(cov(fu))
Σ = (Symmetric(cov(fu) + (1 / σ²) * Kuf * Kuf'))
m = ((1 / σ²) * Kuu *\ Kuf)) * y
S = Symmetric(Kuu *\ Kuu))
return MvNormal(m, S)
end

kernel = make_kernel([0.3, 10])
f = GP(kernel)
fx = f(x, lik_noise)
fu = f(z, jitter)
q_ex = exact_q(fu, fx, y)

scatter(x, y)
scatter!(z, q_ex.μ)

# These two should be the same - and they are, as the plot below shows
ap_ex = SparseGPs.approx_posterior(SVGP(), fu, q_ex) # Hensman (2013) exact posterior
ap_tits = AbstractGPs.approx_posterior(VFE(), fx, y, fu) # Titsias posterior

# These are also approximately equal
SparseGPs.elbo(fx, y, fu, q_ex)
AbstractGPs.elbo(fx, y, fu)

# %%
scatter(
x,
y;
markershape=:xcross,
markeralpha=0.1,
xlim=(-1, 1),
xlabel="x",
ylabel="y",
title="posterior (VI with sparse grid)",
label="Train Data",
)
plot!(-1:0.001:1, ap_ex; label="SVGP posterior")
plot!(-1:0.001:1, ap_tits; label="Titsias posterior")
vline!(z; label="Pseudo-points")
Loading

0 comments on commit ffa3fa5

Please sign in to comment.