Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bayesian VAR(2) example script #1658 #1915

Merged
merged 11 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/_static/img/examples/var2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ NumPyro documentation
examples/zero_inflated_poisson
examples/cvae
tutorials/tbip
examples/var2

.. nbgallery::
:maxdepth: 1
Expand Down
211 changes: 211 additions & 0 deletions examples/var2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

r"""
Example: VAR(2) process
=======================

In this example, we demonstrate how to implement and perform Bayesian inference for a
Vector Autoregressive process of order 2 (VAR(2)). VAR models are widely used in
time series analysis, especially for capturing the dynamics between multiple variables.

A VAR(2) process for a multivariate time series :math:`y_t` with :math:`K` variables is defined as:

.. math::

y_t = c + \Phi_1 y_{t-1} + \Phi_2 y_{t-2} + \epsilon_t

Here, :math:`c` is a constant vector, :math:`\Phi_1` and :math:`\Phi_2` are coefficient matrices for lag 1
and lag 2, respectively, and :math:`\epsilon_t` is a Gaussian noise term with zero mean and a
covariance matrix :math:`\Sigma`.

This example uses NumPyro's `scan` utility to efficiently model the temporal dependencies without
explicit Python loops.

Reference
---------
For more information on Vector Autoregressive models, see:
https://otexts.com/fpp2/VAR.html

.. image:: ../_static/img/examples/var2.png
:align: center
"""

import argparse
import os
import time

import matplotlib.pyplot as plt
import numpy as np

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist


def var2_scan(y):
T, K = y.shape # Number of time steps and number of variables

# Priors for constants and coefficients
c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K
Phi1 = numpyro.sample(
"Phi1", dist.Normal(0, 1).expand([K, K])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add .to_event(2) at the end to make sure that no batch dimension appears here?

) # Coefficients for lag 1
Phi2 = numpyro.sample(
"Phi2", dist.Normal(0, 1).expand([K, K])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

) # Coefficients for lag 2

# Priors for error terms
with numpyro.plate("K", K):
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0)) # Standard deviations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like above, prefer .expand([K]).to_event(1) than using plate

L_omega = numpyro.sample(
"L_omega", dist.LKJCholesky(dimension=K, concentration=1.0)
)
L_Sigma = jnp.matmul(jnp.diag(sigma), L_omega)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe jnp.einsum("...i,...ij->...ij", sigma, L_omega) or sigma[..., None] * L_omega to improve the performance


def transition(carry, t):
y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data
m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction
# Conditioned on observed y
y_t = numpyro.sample(
f"y_{t}",
dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma),
obs=y_obs[t],
)
new_carry = (y_t, y_prev1, y_obs)
return new_carry, m_t

# Initial carry: observations at time steps 1 and 0
init_carry = (y[1], y[0], y[2:])

# Time indices starting from time step 2
time_indices = jnp.arange(T - 2)

# Run the scan
_, mu = scan(transition, init_carry, time_indices)

# Store the mean trajectory as a deterministic variable
numpyro.deterministic("mu", mu)


def generate_var2_data(T, K, c, Phi1, Phi2, sigma):
"""
Generate time series data from a VAR(2) process.
Args:
T (int): Number of time steps.
K (int): Number of variables in the time series.
c (array): Constants (shape: (K,)).
Phi1 (array): Coefficients for lag 1 (shape: (K, K)).
Phi2 (array): Coefficients for lag 2 (shape: (K, K)).
sigma (array): Covariance matrix for the noise (shape: (K, K)).
Returns:
np.ndarray: Generated time series data (shape: (T, K)).
"""
# Initialize time series with random values
y = np.zeros((T, K))
y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2)

# Generate the time series
for t in range(2, T):
y[t] = (
c
+ Phi1 @ y[t - 1]
+ Phi2 @ y[t - 2]
+ np.random.multivariate_normal(mean=np.zeros(K), cov=sigma)
)

return y


def run_inference(model, args, rng_key, y):
"""
Run MCMC inference for the given model.
Args:
model: The probabilistic model to infer.
args: Command-line arguments.
rng_key: PRNG key for randomness.
y: Observed time series data.
"""
start = time.time()
sampler = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(
sampler,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key, y=y)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()


def main(args):
# Generate artificial dataset
T = args.num_data # Number of time steps
K = 2 # Number of variables
c_true = jnp.array([0.5, -0.3]) # Constants
Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1
Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2
sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix

rng_key = random.PRNGKey(0)
y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true)

# Perform inference
samples = run_inference(var2_scan, args, rng_key, y)

# Prediction
mean_prediction = samples["mu"].mean(axis=0)
lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile
upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile

# Plot results
fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True)
time_steps = jnp.arange(T)

for i in range(K):
# True values
axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue")
# Posterior mean prediction
axes[i].plot(
time_steps[2:],
mean_prediction[:, i],
label=f"Predicted Mean Variable {i + 1}",
color="orange",
)
# 95% confidence interval
axes[i].fill_between(
time_steps[2:],
lower_bound[:, i],
upper_bound[:, i],
color="orange",
alpha=0.2,
label="95% CI",
)
axes[i].set_title(f"Variable {i + 1}")
axes[i].legend()
axes[i].grid(True)

plt.xlabel("Time Steps")
plt.tight_layout()
plt.savefig("var2.png")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="VAR(2) example")
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

main(args)
Loading