Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new section in HMM tutorial #508

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,16 @@ git-tree-sha1 = "401e4f3f30f43af2c8478fc008da50096ea5240f"
uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566"
version = "8.3.1+0"

[[deps.HiddenMarkovModels]]
deps = ["ArgCheck", "ChainRulesCore", "DensityInterface", "DocStringExtensions", "FillArrays", "LinearAlgebra", "Random", "SparseArrays", "StatsAPI", "StatsFuns"]
git-tree-sha1 = "f5f0f6e33b21487d39bcdfb6d67aa4c5e54faba3"
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
version = "0.5.3"
weakdeps = ["Distributions"]

[deps.HiddenMarkovModels.extensions]
HiddenMarkovModelsDistributionsExt = "Distributions"

[[deps.HostCPUFeatures]]
deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"]
git-tree-sha1 = "8e070b599339d622e9a081d17230d74a5c473293"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
HiddenMarkovModels = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down
100 changes: 93 additions & 7 deletions tutorials/04-hidden-markov-model/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ using Pkg;
Pkg.instantiate();
```

This tutorial illustrates training Bayesian [Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMM) using Turing. The main goals are learning the transition matrix, emission parameter, and hidden states. For a more rigorous academic overview on Hidden Markov Models, see [An introduction to Hidden Markov Models and Bayesian Networks](http://mlg.eng.cam.ac.uk/zoubin/papers/ijprai.pdf) (Ghahramani, 2001).
This tutorial illustrates training Bayesian [Hidden Markov Models](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMM) using Turing. The main goals are learning the transition matrix, emission parameter, and hidden states. For a more rigorous academic overview of Hidden Markov Models, see [An introduction to Hidden Markov Models and Bayesian Networks](http://mlg.eng.cam.ac.uk/zoubin/papers/ijprai.pdf) (Ghahramani, 2001).
JasonPekos marked this conversation as resolved.
Show resolved Hide resolved

In this tutorial, we assume there are $k$ discrete hidden states; the observations are continuous and normally distributed - centered around the hidden states. This assumption reduces the number of parameters to be estimated in the emission matrix.

Let's load the libraries we'll need. We also set a random seed (for reproducibility) and the automatic differentiation backend to forward mode (more [here](https://turinglang.org/dev/docs/using-turing/autodiff) on why this is useful).
Let's load the libraries we'll need, and set a random seed for reproducibility.

```{julia}
# Load libraries.
using Turing, StatsPlots, Random

# Set a random seed and use the forward_diff AD mode.
# Set a random seed
Random.seed!(12345678);
```

Expand All @@ -29,6 +29,9 @@ Random.seed!(12345678);
In this example, we'll use something where the states and emission parameters are straightforward.

```{julia}
#| code-fold: true
#| code-summary: "Load and plot data for this tutorial."

# Define the emission parameter.
y = [
1.0,
Expand Down Expand Up @@ -66,16 +69,17 @@ N = length(y);
K = 3;

# Plot the data we just made.
plot(y; xlim=(0, 30), ylim=(-1, 5), size=(500, 250))
plot(y; xlim=(0, 30), ylim=(-1, 5), size=(500, 250), legend = false)
scatter!(y, color = :blue; xlim=(0, 30), ylim=(-1, 5), size=(500, 250), legend = false)
```

We can see that we have three states, one for each height of the plot (1, 2, 3). This height is also our emission parameter, so state one produces a value of one, state two produces a value of two, and so on.

Ultimately, we would like to understand three major parameters:

1. The transition matrix. This is a matrix that assigns a probability of switching from one state to any other state, including the state that we are already in.
2. The emission matrix, which describes a typical value emitted by some state. In the plot above, the emission parameter for state one is simply one.
3. The state sequence is our understanding of what state we were actually in when we observed some data. This is very important in more sophisticated HMM models, where the emission value does not equal our state.
2. The emission parameters, which describes a typical value emitted by some state. In the plot above, the emission parameter for state one is simply one.
3. The state sequence is our understanding of what state we were actually in when we observed some data. This is very important in more sophisticated HMMs, where the emission value does not equal our state.

With this in mind, let's set up our model. We are going to use some of our knowledge as modelers to provide additional information about our system. This takes the form of the prior on our emission parameter.

Expand Down Expand Up @@ -131,6 +135,7 @@ Time to run our sampler.

```{julia}
#| output: false
#| echo: false
setprogress!(false)
```

Expand Down Expand Up @@ -190,4 +195,85 @@ stationary. We can use the diagnostic functions provided by [MCMCChains](https:/
heideldiag(MCMCChains.group(chn, :T))[1]
```

The p-values on the test suggest that we cannot reject the hypothesis that the observed sequence comes from a stationary distribution, so we can be reasonably confident that our transition matrix has converged to something reasonable.
The p-values on the test suggest that we cannot reject the hypothesis that the observed sequence comes from a stationary distribution, so we can be reasonably confident that our transition matrix has converged to something reasonable.

## Efficient Inference With The Forward Algorithm

JasonPekos marked this conversation as resolved.
Show resolved Hide resolved
While the above method works well for the simple example in this tutorial, some users may desire a more efficient method, especially when their model is more complicated. One simple way to improve inference is to marginalize out the hidden states of the model with an appropriate algorithm, calculating only the posterior over the continuous random variables. Not only does this allow more efficient inference via Rao-Blackwellization, but now we can sample our model with `NUTS()` alone, which is usually a much more performant MCMC kernel.

JasonPekos marked this conversation as resolved.
Show resolved Hide resolved
Thankfully, [HiddenMarkovModels.jl](https://github.com/gdalle/HiddenMarkovModels.jl) provides an extremely efficient implementation of many algorithms related to Hidden Markov Models. This allows us to re-write our model as:

```{julia}
#| output: false
using HiddenMarkovModels
Comment on lines +210 to +213
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do setprogress!(false) instead of using Quarto's output: false?

using FillArrays
using LinearAlgebra
using LogExpFunctions


@model function BayesHmm2(y, K)
m ~ Bijectors.ordered(MvNormal([1.0, 2.0, 3.0], 0.5I))
T ~ filldist(Dirichlet(Fill(1/K, K)), K)

hmm = HMM(softmax(ones(K)), copy(T'), [Normal(m[i], 0.1) for i in 1:K])
Turing.@addlogprob! logdensityof(hmm, y)
end

chn2 = sample(BayesHmm2(y, 3), NUTS(), 1000)
```


We can compare the chains of these two models, confirming the posterior estimate is similar (modulo label switching concerns with the Gibbs model):
```{julia}
#| code-fold: true
JasonPekos marked this conversation as resolved.
Show resolved Hide resolved
#| code-summary: "Plotting Chains"

plot(chn["m[1]"], label = "m[1], Model 1, Gibbs", color = :lightblue)
plot!(chn2["m[1]"], label = "m[1], Model 2, NUTS", color = :blue)
plot!(chn["m[2]"], label = "m[2], Model 1, Gibbs", color = :pink)
plot!(chn2["m[2]"], label = "m[2], Model 2, NUTS", color = :red)
plot!(chn["m[3]"], label = "m[3], Model 1, Gibbs", color = :yellow)
plot!(chn2["m[3]"], label = "m[3], Model 2, NUTS", color = :orange)
```


### Recovering Marginalized Trajectories

We can use the `viterbi()` algorithm, also from the `HiddenMarkovModels` package, to recover the most probable state for each parameter set in our posterior sample:
```{julia}
#| output: false
Comment on lines +247 to +249
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above – if we do setprogress!(false) above can we enable the output as usual?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could, but do we want to? I assume most users will want the progress updates in their REPL when working through the tutorial, but they just aren't a good fit for the published quarto format.

As long as we are not actually recommending users use setprogress!(false) (which I think we aren't?), it seems appropriate to just disable this in the chunk options for the rendered docs.

Copy link
Member

@penelopeysm penelopeysm Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact (I totally missed this earlier) but there is already a call to setprogress!(false) earlier in the page (which is hidden). Many other pages do call setprogress!(false) so it wouldn't be too out of place here, so I feel like it should be alright to just remove the output: false lines.

I'd be happy to add a clarification about setprogress! on the get started page in a separate PR :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(By the way, I'm assuming that the output you wanted to suppress was indeed the progress logging. Do correct me if that's not the case)

@model function BayesHmmRecover(y, K, IncludeGenerated = false)

m ~ Bijectors.ordered(MvNormal([1.0, 2.0, 3.0], 0.5I))
T ~ filldist(Dirichlet(Fill(1/K, K)), K)

hmm = HMM(softmax(ones(K)), copy(T'), [Normal(m[i], 0.1) for i in 1:K])
Turing.@addlogprob! logdensityof(hmm, y)

# Conditional generation of the hidden states.
JasonPekos marked this conversation as resolved.
Show resolved Hide resolved
if IncludeGenerated
seq, _ = viterbi(hmm, y)
s := [m[s] for s in seq]
else
return nothing
end
end

chn_recover = sample(BayesHmmRecover(y, 3, true), NUTS(), 1000)
```

Plotting the estimated states, we can see that the results align well with our expectations:

```{julia}
#| code-fold: true
#| code-summary: "HMM Plotting Functions"

Comment on lines +271 to +275
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```{julia}
#| code-fold: true
#| code-summary: "HMM Plotting Functions"
p = plot(xlim=(0, 30), ylim=(-1, 5), size=(500, 250))
```{julia}
p = plot(xlim=(0, 30), ylim=(-1, 5), size=(500, 250))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't like how this vector — which contains no information not present in the plot, and which is not part of anyones actual workflow (in practice, people just import data via csvs etc.) — takes up so much space on the page.

that is why I originally introduced the folder code blocks (I agree that the other use especially is excessive, but I am more torn here).

I am considering doing something like

y = [i for i in vcat(fill(1.0, 6), fill(2.0, 6), fill(3.0, 7), fill(2.0, 4), fill(1.0, 7))]

but ofc this is potentially more confusing — thoughts?

Copy link
Member

@penelopeysm penelopeysm Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the vector isn't this code block though – it's the one above. I do see your point though! I'm sorry if I sound too conservative here about code folding :) It's mainly because I'd like to preserve a consistent user experience across the docs. But if there are good places we can apply folding to in the docs I'm super happy to start using it.

For this particular vector, would this work maybe?

y = [fill(1.0, 6)..., fill(2.0, 6)..., fill(3.0, 7)..., fill(2.0, 4)..., fill(1.0, 7)...]

The other alternative would just be to put multiple values on one line:

y = [
    1.0, 1.0, 1.0, ...,
    ...,
]

though I recognise JuliaFormatter doesn't like that. Although we aren't explicitly checking code style in this repo, there is a page outlining the style guide so I guess we should try to practise what we preach 😄

p = plot(xlim=(0, 30), ylim=(-1, 5), size=(500, 250))
for i in 1:100
ind = rand(DiscreteUniform(1, 1000))
plot!(MCMCChains.group(chn_recover, :s).value[ind,:], color = :grey, opacity = 0.1, legend = :false)
end
scatter!(y, color = :blue)

p
```
Loading