Skip to content

Commit

Permalink
Fix script (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
theogf authored Jul 20, 2024
1 parent 4133697 commit de7eb7f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
20 changes: 14 additions & 6 deletions examples/categorical/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ liks = [
AugmentedGPLikelihoods.nlatent(::CategoricalLikelihood{<:BijectiveSimplexLink}) = Nclass - 1
AugmentedGPLikelihoods.nlatent(::CategoricalLikelihood{<:LogisticSoftMaxLink}) = Nclass

# This is type piracy, but it makes life so much easier ;)
SplitApplyCombine.invert(x::ArrayOfSimilarArrays) = nestedview(flatview(x)')
# We define the models
kernel = 5.0 * with_lengthscale(SqExponentialKernel(), 2.0)
Expand Down Expand Up @@ -134,14 +135,21 @@ plot(p_plts...)
# How can one compute the Augmented ELBO?
# Again AugmentedGPLikelihoods provides helper functions
# to not have to compute everything yourself
function aug_elbo(lik, u_post, x, y)
qf = marginals(u_post(x))
= aux_posterior(lik, y, qf)
return expected_logtilt(lik, qΩ, y, qf) - aux_kldivergence(lik, qΩ, y) -
kldivergence(u_post.approx.q, u_post.approx.fz)
function aug_elbo(lik, u_posts, x, y)
qfs = marginals.([post(x) for post in u_posts])
= aux_posterior(lik, y, invert(qfs))
return expected_logtilt(lik, qΩ, y, invert(qfs)) - aux_kldivergence(lik, qΩ, y) -
sum(post -> kldivergence(post.approx.q, post.approx.fz), u_posts)
end

# aug_elbo(lik, u_posterior(fz, m, S), x, y)
# However the ELBO is non-valid to compute for the bijective likelihood,
# we only test it there.
let (lik, (; m, S), y) = (liks[1], ms_Ss[1], Ys[1])
u_posts = u_posterior.(Ref(fz), m, S)
qfs = marginals.([post(x) for post in u_posts])
= aux_posterior(lik, y, invert(qfs))
aug_elbo(lik, u_posterior.(Ref(fz), m, S), x, y)
end
# ## Gibbs Sampling
# We create our Gibbs sampling algorithm (we could do something fancier with
# AbstractMCMC)
Expand Down
2 changes: 1 addition & 1 deletion src/SpecialDistributions/negativemultinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct NegativeMultinomial{Tx₀<:Real,Tp<:AbstractVector} <:
x₀ > 0 || throw(ArgumentError("x₀ has to be positive"))
(all(>=(0), p) && sum(p) < 1) || throw(
ArgumentError(
"All p should be positive and their sum should be strictly smaller than 1",
"All p should be positive and their sum should be strictly smaller than 1, got $(p).",
),
)
return new{typeof(x₀),typeof(p)}(x₀, p)
Expand Down
7 changes: 7 additions & 0 deletions src/likelihoods/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ function aux_prior(lik::LogisticSoftMaxLikelihood, y::AbstractVector{<:Integer})
)
end

function aux_kldivergence(::LogisticSoftMaxLikelihood, qΩ::For, pΩ::For)
return error(
"due to the ill-conditioning of the augmented prior of the `LogisticSoftMaxLink` (non-bijective), the kl-divergence cannot be computed, use the `BijectiveLogisticSoftMaxLink` instead.
It might be possible to provide a more global workaround, but it is not listed as a priority for the package right now.",
)
end

function expected_logtilt(
::BijectiveLogisticSoftMaxLikelihood, qω, y, qf::AbstractVector{<:Normal}
)
Expand Down

2 comments on commit de7eb7f

@theogf
Copy link
Member Author

@theogf theogf commented on de7eb7f Jul 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.4.18 already exists

Please sign in to comment.