Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: '1'
version: '1.11'
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
fail-fast: false
matrix:
version:
- "1.9"
- "1"
- "lts"
- "1.11"
test_suite:
- "Standard"
- "HMMBase"
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HiddenMarkovModels"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
authors = ["Guillaume Dalle"]
version = "0.7.0"
version = "0.7.1"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -33,4 +33,4 @@ Random = "1"
SparseArrays = "1"
StatsAPI = "1.6"
StatsFuns = "1.3"
julia = "1.9"
julia = "1.10"
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Expand All @@ -12,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pages = [
joinpath("examples", "controlled.md"),
joinpath("examples", "autoregression.md"),
joinpath("examples", "autodiff.md"),
joinpath("examples", "gradientdescent.md"),
],
"API reference" => "api.md",
"Advanced" => ["alternatives.md", "debugging.md", "formulas.md"],
Expand Down
248 changes: 248 additions & 0 deletions examples/gradientdescent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# # Gradient Descent in HMMs

#=
In this tutorial we explore two ways to use gradient descent when fitting HMMs:

1. Fitting parameters of an observation model that do not have closed-form updates
(e.g., GLMs, neural networks, etc.), inside the EM algorithm.
2. Fitting the entire HMM with gradient-based optimization by leveraging automatic
differentiation.

We will explore both approaches below.
=#

using ADTypes
using ComponentArrays
using DensityInterface
using ForwardDiff
using HiddenMarkovModels
using HMMTest #src
using LinearAlgebra
using Optim
using Random
using StableRNGs
using StatsAPI
using Test #src

rng = StableRNG(42)

#=
For both parts of this tutorial we use a simple HMM with Gaussian observations.
Using gradient-based optimization here is overkill, but it keeps the tutorial
simple while illustrating the relevant methods.

We begin by defining a Normal observation model.
=#

mutable struct NormalModel{T}
μ::T
logσ::T # unconstrained parameterization; σ = exp(logσ)
end

model_mean(mod::NormalModel) = mod.μ
stddev(mod::NormalModel) = exp(mod.logσ)

#=
We have defined a simple probability model with two parameters: the mean and the
log of the standard deviation. Using `logσ` is intentional so we can optimize over
all real numbers without worrying about the positivity constraint on `σ`.

Next, we provide the minimal interface expected by HiddenMarkovModels.jl:
`(logdensityof, rand, fit!)`.
=#

function DensityInterface.logdensityof(mod::NormalModel, obs::T) where {T<:Real}
s = stddev(mod)
return - log(2π) / 2 - log(s) - ((obs - model_mean(mod)) / s)^2 / 2
end

DensityInterface.DensityKind(::NormalModel) = DensityInterface.HasDensity()

function Random.rand(rng::AbstractRNG, mod::NormalModel{T}) where {T}
return stddev(mod) * randn(rng, T) + model_mean(mod)
end

#=
Because we are fitting a Gaussian (and the variance can collapse to ~0), we add
weak priors to regularize the parameters. We use:
- A weak Normal prior on `μ`
- A moderate-strength Normal prior on `logσ` that pulls `σ` toward ~1
=#

const μ_prior = NormalModel(0.0, log(10.0))
const logσ_prior = NormalModel(log(1.0), log(0.5))

function neglogpost(
μ::T,
logσ::T,
data::AbstractVector{<:Real},
weights::AbstractVector{<:Real},
μ_prior::NormalModel,
logσ_prior::NormalModel,
) where {T<:Real}
tmp = NormalModel(μ, logσ)

nll = mapreduce(
i -> -weights[i] * logdensityof(tmp, data[i]), +, eachindex(data, weights)
)

nll += -logdensityof(μ_prior, μ)
nll += -logdensityof(logσ_prior, logσ)

return nll
end

function neglogpost(
θ::AbstractVector{T},
data::AbstractVector{<:Real},
weights::AbstractVector{<:Real},
μ_prior::NormalModel,
logσ_prior::NormalModel,
) where {T<:Real}
μ, logσ = θ
return neglogpost(μ, logσ, data, weights, μ_prior, logσ_prior)
end

function StatsAPI.fit!(
mod::NormalModel, data::AbstractVector{<:Real}, weights::AbstractVector{<:Real}
)
T = promote_type(typeof(mod.μ), typeof(mod.logσ))
θ0 = T[T(mod.μ), T(mod.logσ)]
obj = θ -> neglogpost(θ, data, weights, μ_prior, logσ_prior)
result = Optim.optimize(obj, θ0, BFGS(); autodiff=AutoForwardDiff())
mod.μ, mod.logσ = Optim.minimizer(result)
return mod
end

#=
Now that we have fully defined our observation model, we can create an HMM using it.
=#

init_dist = [0.2, 0.7, 0.1]
init_trans = [
0.9 0.05 0.05;
0.075 0.9 0.025;
0.1 0.1 0.8
]

obs_dists = [
NormalModel(-3.0, log(0.25)), NormalModel(0.0, log(0.5)), NormalModel(3.0, log(0.75))
]

hmm_true = HMM(init_dist, init_trans, obs_dists)

#=
We can now generate data from this HMM.
Note: `rand(rng, hmm, T)` returns `(state_seq, obs_seq)`.
=#

state_seq, obs_seq = rand(rng, hmm_true, 10_000)

#=
Next we fit a new HMM to this data. Baum–Welch will perform EM updates for the
HMM parameters; during the M-step, our observation model parameters are fit via
gradient-based optimization (BFGS).
=#

init_dist_guess = fill(1.0 / 3, 3)
init_trans_guess = [
0.98 0.01 0.01;
0.01 0.98 0.01;
0.01 0.01 0.98
]

obs_dist_guess = [
NormalModel(-2.0, log(1.0)), NormalModel(2.0, log(1.0)), NormalModel(0.0, log(1.0))
]

hmm_guess = HMM(init_dist_guess, init_trans_guess, obs_dist_guess)

hmm_est, lls = baum_welch(hmm_guess, obs_seq)

#=
Great! We were able to fit the model using gradient descent inside EM.

Now we will fit the entire HMM using gradient-based optimization by leveraging
automatic differentiation. The key idea is that the forward algorithm marginalizes
out the latent states, providing the likelihood of the observations directly as a
function of all model parameters.

We can therefore optimize the negative log-likelihood returned by `forward`.
Each objective evaluation runs the forward algorithm, which can be expensive for
large datasets, but this approach allows end-to-end gradient-based fitting for
arbitrary parameterized HMMs.

To respect HMM constraints, we optimize unconstrained parameters and map them to
valid probability distributions via softmax:
- `π = softmax(ηπ)`
- each row of `A` = `softmax(row_logits)`
=#

function softmax(v::AbstractVector)
m = maximum(v)
ex = exp.(v .- m)
return ex ./ sum(ex)
end

function rowsoftmax(M::AbstractMatrix)
A = similar(M)
for i in 1:size(M, 1)
A[i, :] .= softmax(view(M, i, :))
end
return A
end

function unpack_to_hmm(θ::ComponentVector)
K = length(θ.ηπ)

π = softmax(θ.ηπ)
A = rowsoftmax(θ.ηA)
dists = [NormalModel(θ.μ[k], θ.logσ[k]) for k in 1:K]

return HMM(π, A, dists)
end

function hmm_to_θ0(hmm::HMM)
K = length(hmm.init)

T = promote_type(
eltype(hmm.init),
eltype(hmm.trans),
eltype(hmm.dists[1].μ),
eltype(hmm.dists[1].logσ),
)

ηπ = log.(hmm.init .+ eps(T))
ηA = log.(hmm.trans .+ eps(T))

μ = [hmm.dists[k].μ for k in 1:K]
logσ = [hmm.dists[k].logσ for k in 1:K]

return ComponentVector(; ηπ=ηπ, ηA=ηA, μ=μ, logσ=logσ)
end

function negloglik_from_θ(θ::ComponentVector, obs_seq)
hmm = unpack_to_hmm(θ)
_, loglik = forward(hmm, obs_seq; error_if_not_finite=false)
return -loglik[1]
end

θ0 = hmm_to_θ0(hmm_guess)
ax = getaxes(θ0)

obj(x) = negloglik_from_θ(ComponentVector(x, ax), obs_seq)

result = Optim.optimize(obj, Vector(θ0), BFGS(); autodiff=AutoForwardDiff())
hmm_est2 = unpack_to_hmm(ComponentVector(result.minimizer, ax))

#=
We have now trained an HMM using gradient-based optimization over *all* parameters!
=#

@test isapprox(hmm_est.init, hmm_est2.init; atol=1e-3) #src
@test isapprox(hmm_est.trans, hmm_est2.trans; atol=1e-3) #src

for k in 1:length(hmm_est.init) #src
@test isapprox(hmm_est.dists[k].μ, hmm_est2.dists[k].μ; atol=1e-3) #src
@test isapprox(stddev(hmm_est.dists[k]), stddev(hmm_est2.dists[k]); atol=1e-3) #src
end #src
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
Expand All @@ -12,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
using Distributions
using Zygote
if VERSION >= v"1.10"
JET.test_package(HiddenMarkovModels; target_defined_modules=true)
JET.test_package(HiddenMarkovModels; target_modules=(HiddenMarkovModels,))
end
end

Expand Down