diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 1296ebe0..5dcad3d4 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -21,7 +21,7 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: '1' + version: '1.11' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-docdeploy@v1 env: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 975ccb9b..f6589c00 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,8 +18,8 @@ jobs: fail-fast: false matrix: version: - - "1.9" - - "1" + - "lts" + - "1.11" test_suite: - "Standard" - "HMMBase" diff --git a/Project.toml b/Project.toml index a840b69e..d02d6d6b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "HiddenMarkovModels" uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47" authors = ["Guillaume Dalle"] -version = "0.7.0" +version = "0.7.1" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -33,4 +33,4 @@ Random = "1" SparseArrays = "1" StatsAPI = "1.6" StatsFuns = "1.3" -julia = "1.9" +julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 917d9f78..b24ed4e6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" @@ -12,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/docs/make.jl b/docs/make.jl index e9d2ffb1..8f6a1fc0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -51,6 +51,7 @@ pages = [ joinpath("examples", "controlled.md"), joinpath("examples", "autoregression.md"), joinpath("examples", "autodiff.md"), + joinpath("examples", "gradientdescent.md"), ], "API reference" => "api.md", "Advanced" => ["alternatives.md", "debugging.md", "formulas.md"], diff --git a/examples/gradientdescent.jl b/examples/gradientdescent.jl new file mode 100644 index 00000000..ff7365bf --- /dev/null +++ b/examples/gradientdescent.jl @@ -0,0 +1,248 @@ +# # Gradient Descent in HMMs + +#= +In this tutorial we explore two ways to use gradient descent when fitting HMMs: + +1. Fitting parameters of an observation model that do not have closed-form updates + (e.g., GLMs, neural networks, etc.), inside the EM algorithm. +2. Fitting the entire HMM with gradient-based optimization by leveraging automatic + differentiation. + +We will explore both approaches below. +=# + +using ADTypes +using ComponentArrays +using DensityInterface +using ForwardDiff +using HiddenMarkovModels +using HMMTest #src +using LinearAlgebra +using Optim +using Random +using StableRNGs +using StatsAPI +using Test #src + +rng = StableRNG(42) + +#= +For both parts of this tutorial we use a simple HMM with Gaussian observations. +Using gradient-based optimization here is overkill, but it keeps the tutorial +simple while illustrating the relevant methods. + +We begin by defining a Normal observation model. +=# + +mutable struct NormalModel{T} + μ::T + logσ::T # unconstrained parameterization; σ = exp(logσ) +end + +model_mean(mod::NormalModel) = mod.μ +stddev(mod::NormalModel) = exp(mod.logσ) + +#= +We have defined a simple probability model with two parameters: the mean and the +log of the standard deviation. Using `logσ` is intentional so we can optimize over +all real numbers without worrying about the positivity constraint on `σ`. + +Next, we provide the minimal interface expected by HiddenMarkovModels.jl: +`(logdensityof, rand, fit!)`. +=# + +function DensityInterface.logdensityof(mod::NormalModel, obs::T) where {T<:Real} + s = stddev(mod) + return - log(2π) / 2 - log(s) - ((obs - model_mean(mod)) / s)^2 / 2 +end + +DensityInterface.DensityKind(::NormalModel) = DensityInterface.HasDensity() + +function Random.rand(rng::AbstractRNG, mod::NormalModel{T}) where {T} + return stddev(mod) * randn(rng, T) + model_mean(mod) +end + +#= +Because we are fitting a Gaussian (and the variance can collapse to ~0), we add +weak priors to regularize the parameters. We use: +- A weak Normal prior on `μ` +- A moderate-strength Normal prior on `logσ` that pulls `σ` toward ~1 +=# + +const μ_prior = NormalModel(0.0, log(10.0)) +const logσ_prior = NormalModel(log(1.0), log(0.5)) + +function neglogpost( + μ::T, + logσ::T, + data::AbstractVector{<:Real}, + weights::AbstractVector{<:Real}, + μ_prior::NormalModel, + logσ_prior::NormalModel, +) where {T<:Real} + tmp = NormalModel(μ, logσ) + + nll = mapreduce( + i -> -weights[i] * logdensityof(tmp, data[i]), +, eachindex(data, weights) + ) + + nll += -logdensityof(μ_prior, μ) + nll += -logdensityof(logσ_prior, logσ) + + return nll +end + +function neglogpost( + θ::AbstractVector{T}, + data::AbstractVector{<:Real}, + weights::AbstractVector{<:Real}, + μ_prior::NormalModel, + logσ_prior::NormalModel, +) where {T<:Real} + μ, logσ = θ + return neglogpost(μ, logσ, data, weights, μ_prior, logσ_prior) +end + +function StatsAPI.fit!( + mod::NormalModel, data::AbstractVector{<:Real}, weights::AbstractVector{<:Real} +) + T = promote_type(typeof(mod.μ), typeof(mod.logσ)) + θ0 = T[T(mod.μ), T(mod.logσ)] + obj = θ -> neglogpost(θ, data, weights, μ_prior, logσ_prior) + result = Optim.optimize(obj, θ0, BFGS(); autodiff=AutoForwardDiff()) + mod.μ, mod.logσ = Optim.minimizer(result) + return mod +end + +#= +Now that we have fully defined our observation model, we can create an HMM using it. +=# + +init_dist = [0.2, 0.7, 0.1] +init_trans = [ + 0.9 0.05 0.05; + 0.075 0.9 0.025; + 0.1 0.1 0.8 +] + +obs_dists = [ + NormalModel(-3.0, log(0.25)), NormalModel(0.0, log(0.5)), NormalModel(3.0, log(0.75)) +] + +hmm_true = HMM(init_dist, init_trans, obs_dists) + +#= +We can now generate data from this HMM. +Note: `rand(rng, hmm, T)` returns `(state_seq, obs_seq)`. +=# + +state_seq, obs_seq = rand(rng, hmm_true, 10_000) + +#= +Next we fit a new HMM to this data. Baum–Welch will perform EM updates for the +HMM parameters; during the M-step, our observation model parameters are fit via +gradient-based optimization (BFGS). +=# + +init_dist_guess = fill(1.0 / 3, 3) +init_trans_guess = [ + 0.98 0.01 0.01; + 0.01 0.98 0.01; + 0.01 0.01 0.98 +] + +obs_dist_guess = [ + NormalModel(-2.0, log(1.0)), NormalModel(2.0, log(1.0)), NormalModel(0.0, log(1.0)) +] + +hmm_guess = HMM(init_dist_guess, init_trans_guess, obs_dist_guess) + +hmm_est, lls = baum_welch(hmm_guess, obs_seq) + +#= +Great! We were able to fit the model using gradient descent inside EM. + +Now we will fit the entire HMM using gradient-based optimization by leveraging +automatic differentiation. The key idea is that the forward algorithm marginalizes +out the latent states, providing the likelihood of the observations directly as a +function of all model parameters. + +We can therefore optimize the negative log-likelihood returned by `forward`. +Each objective evaluation runs the forward algorithm, which can be expensive for +large datasets, but this approach allows end-to-end gradient-based fitting for +arbitrary parameterized HMMs. + +To respect HMM constraints, we optimize unconstrained parameters and map them to +valid probability distributions via softmax: +- `π = softmax(ηπ)` +- each row of `A` = `softmax(row_logits)` +=# + +function softmax(v::AbstractVector) + m = maximum(v) + ex = exp.(v .- m) + return ex ./ sum(ex) +end + +function rowsoftmax(M::AbstractMatrix) + A = similar(M) + for i in 1:size(M, 1) + A[i, :] .= softmax(view(M, i, :)) + end + return A +end + +function unpack_to_hmm(θ::ComponentVector) + K = length(θ.ηπ) + + π = softmax(θ.ηπ) + A = rowsoftmax(θ.ηA) + dists = [NormalModel(θ.μ[k], θ.logσ[k]) for k in 1:K] + + return HMM(π, A, dists) +end + +function hmm_to_θ0(hmm::HMM) + K = length(hmm.init) + + T = promote_type( + eltype(hmm.init), + eltype(hmm.trans), + eltype(hmm.dists[1].μ), + eltype(hmm.dists[1].logσ), + ) + + ηπ = log.(hmm.init .+ eps(T)) + ηA = log.(hmm.trans .+ eps(T)) + + μ = [hmm.dists[k].μ for k in 1:K] + logσ = [hmm.dists[k].logσ for k in 1:K] + + return ComponentVector(; ηπ=ηπ, ηA=ηA, μ=μ, logσ=logσ) +end + +function negloglik_from_θ(θ::ComponentVector, obs_seq) + hmm = unpack_to_hmm(θ) + _, loglik = forward(hmm, obs_seq; error_if_not_finite=false) + return -loglik[1] +end + +θ0 = hmm_to_θ0(hmm_guess) +ax = getaxes(θ0) + +obj(x) = negloglik_from_θ(ComponentVector(x, ax), obs_seq) + +result = Optim.optimize(obj, Vector(θ0), BFGS(); autodiff=AutoForwardDiff()) +hmm_est2 = unpack_to_hmm(ComponentVector(result.minimizer, ax)) + +#= +We have now trained an HMM using gradient-based optimization over *all* parameters! +=# + +@test isapprox(hmm_est.init, hmm_est2.init; atol=1e-3) #src +@test isapprox(hmm_est.trans, hmm_est2.trans; atol=1e-3) #src + +for k in 1:length(hmm_est.init) #src + @test isapprox(hmm_est.dists[k].μ, hmm_est2.dists[k].μ; atol=1e-3) #src + @test isapprox(stddev(hmm_est.dists[k]), stddev(hmm_est2.dists[k]); atol=1e-3) #src +end #src diff --git a/test/Project.toml b/test/Project.toml index 80cb0053..9cdbb314 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" @@ -12,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/runtests.jl b/test/runtests.jl index 6e8dc362..c6ebe4c5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,7 +48,7 @@ end using Distributions using Zygote if VERSION >= v"1.10" - JET.test_package(HiddenMarkovModels; target_defined_modules=true) + JET.test_package(HiddenMarkovModels; target_modules=(HiddenMarkovModels,)) end end