-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Bayesian VAR(2) example script #1658 #1915
Changes from 8 commits
2ac4c49
40ec61e
20eab26
c21c50d
426214d
a3256fe
91f01ce
c0cf4c4
c47ef14
7fb36ab
4be5513
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
r""" | ||
Example: VAR(2) process | ||
======================= | ||
|
||
In this example, we demonstrate how to implement and perform Bayesian inference for a | ||
Vector Autoregressive process of order 2 (VAR(2)). VAR models are widely used in | ||
time series analysis, especially for capturing the dynamics between multiple variables. | ||
|
||
A VAR(2) process for a multivariate time series :math:`y_t` with :math:`K` variables is defined as: | ||
|
||
.. math:: | ||
|
||
y_t = c + \Phi_1 y_{t-1} + \Phi_2 y_{t-2} + \epsilon_t | ||
|
||
Here, :math:`c` is a constant vector, :math:`\Phi_1` and :math:`\Phi_2` are coefficient matrices for lag 1 | ||
and lag 2, respectively, and :math:`\epsilon_t` is a Gaussian noise term with zero mean and a | ||
covariance matrix :math:`\Sigma`. | ||
|
||
This example uses NumPyro's `scan` utility to efficiently model the temporal dependencies without | ||
explicit Python loops. | ||
|
||
Reference | ||
--------- | ||
For more information on Vector Autoregressive models, see: | ||
https://otexts.com/fpp2/VAR.html | ||
|
||
.. image:: ../_static/img/examples/var2.png | ||
:align: center | ||
""" | ||
|
||
import argparse | ||
import os | ||
import time | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from jax import random | ||
import jax.numpy as jnp | ||
|
||
import numpyro | ||
from numpyro.contrib.control_flow import scan | ||
import numpyro.distributions as dist | ||
|
||
|
||
def var2_scan(y): | ||
T, K = y.shape # Number of time steps and number of variables | ||
|
||
# Priors for constants and coefficients | ||
c = numpyro.sample("c", dist.Normal(0, 1).expand([K])) # Constants vector of size K | ||
Phi1 = numpyro.sample( | ||
"Phi1", dist.Normal(0, 1).expand([K, K]) | ||
) # Coefficients for lag 1 | ||
Phi2 = numpyro.sample( | ||
"Phi2", dist.Normal(0, 1).expand([K, K]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
) # Coefficients for lag 2 | ||
|
||
# Priors for error terms | ||
with numpyro.plate("K", K): | ||
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0)) # Standard deviations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like above, prefer |
||
L_omega = numpyro.sample( | ||
"L_omega", dist.LKJCholesky(dimension=K, concentration=1.0) | ||
) | ||
L_Sigma = jnp.matmul(jnp.diag(sigma), L_omega) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe |
||
|
||
def transition(carry, t): | ||
y_prev1, y_prev2, y_obs = carry # Previous two observations and observed data | ||
m_t = c + jnp.dot(Phi1, y_prev1) + jnp.dot(Phi2, y_prev2) # Mean prediction | ||
# Conditioned on observed y | ||
y_t = numpyro.sample( | ||
f"y_{t}", | ||
dist.MultivariateNormal(loc=m_t, scale_tril=L_Sigma), | ||
obs=y_obs[t], | ||
) | ||
new_carry = (y_t, y_prev1, y_obs) | ||
return new_carry, m_t | ||
|
||
# Initial carry: observations at time steps 1 and 0 | ||
init_carry = (y[1], y[0], y[2:]) | ||
|
||
# Time indices starting from time step 2 | ||
time_indices = jnp.arange(T - 2) | ||
|
||
# Run the scan | ||
_, mu = scan(transition, init_carry, time_indices) | ||
|
||
# Store the mean trajectory as a deterministic variable | ||
numpyro.deterministic("mu", mu) | ||
|
||
|
||
def generate_var2_data(T, K, c, Phi1, Phi2, sigma): | ||
""" | ||
Generate time series data from a VAR(2) process. | ||
Args: | ||
T (int): Number of time steps. | ||
K (int): Number of variables in the time series. | ||
c (array): Constants (shape: (K,)). | ||
Phi1 (array): Coefficients for lag 1 (shape: (K, K)). | ||
Phi2 (array): Coefficients for lag 2 (shape: (K, K)). | ||
sigma (array): Covariance matrix for the noise (shape: (K, K)). | ||
Returns: | ||
np.ndarray: Generated time series data (shape: (T, K)). | ||
""" | ||
# Initialize time series with random values | ||
y = np.zeros((T, K)) | ||
y[:2] = np.random.multivariate_normal(mean=np.zeros(K), cov=sigma, size=2) | ||
|
||
# Generate the time series | ||
for t in range(2, T): | ||
y[t] = ( | ||
c | ||
+ Phi1 @ y[t - 1] | ||
+ Phi2 @ y[t - 2] | ||
+ np.random.multivariate_normal(mean=np.zeros(K), cov=sigma) | ||
) | ||
|
||
return y | ||
|
||
|
||
def run_inference(model, args, rng_key, y): | ||
""" | ||
Run MCMC inference for the given model. | ||
Args: | ||
model: The probabilistic model to infer. | ||
args: Command-line arguments. | ||
rng_key: PRNG key for randomness. | ||
y: Observed time series data. | ||
""" | ||
start = time.time() | ||
sampler = numpyro.infer.NUTS(model) | ||
mcmc = numpyro.infer.MCMC( | ||
sampler, | ||
num_warmup=args.num_warmup, | ||
num_samples=args.num_samples, | ||
num_chains=args.num_chains, | ||
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True, | ||
) | ||
mcmc.run(rng_key, y=y) | ||
mcmc.print_summary() | ||
print("\nMCMC elapsed time:", time.time() - start) | ||
return mcmc.get_samples() | ||
|
||
|
||
def main(args): | ||
# Generate artificial dataset | ||
T = args.num_data # Number of time steps | ||
K = 2 # Number of variables | ||
c_true = jnp.array([0.5, -0.3]) # Constants | ||
Phi1_true = jnp.array([[0.7, 0.1], [0.2, 0.5]]) # Coefficients for lag 1 | ||
Phi2_true = jnp.array([[0.2, -0.1], [-0.1, 0.2]]) # Coefficients for lag 2 | ||
sigma_true = jnp.array([[0.1, 0.02], [0.02, 0.1]]) # Covariance matrix | ||
|
||
rng_key = random.PRNGKey(0) | ||
y = generate_var2_data(T, K, c_true, Phi1_true, Phi2_true, sigma_true) | ||
|
||
# Perform inference | ||
samples = run_inference(var2_scan, args, rng_key, y) | ||
|
||
# Prediction | ||
mean_prediction = samples["mu"].mean(axis=0) | ||
lower_bound = jnp.percentile(samples["mu"], 2.5, axis=0) # 2.5th percentile | ||
upper_bound = jnp.percentile(samples["mu"], 97.5, axis=0) # 97.5th percentile | ||
|
||
# Plot results | ||
fig, axes = plt.subplots(K, 1, figsize=(10, 6), sharex=True) | ||
time_steps = jnp.arange(T) | ||
|
||
for i in range(K): | ||
# True values | ||
axes[i].plot(time_steps, y[:, i], label=f"True Variable {i + 1}", color="blue") | ||
# Posterior mean prediction | ||
axes[i].plot( | ||
time_steps[2:], | ||
mean_prediction[:, i], | ||
label=f"Predicted Mean Variable {i + 1}", | ||
color="orange", | ||
) | ||
# 95% confidence interval | ||
axes[i].fill_between( | ||
time_steps[2:], | ||
lower_bound[:, i], | ||
upper_bound[:, i], | ||
color="orange", | ||
alpha=0.2, | ||
label="95% CI", | ||
) | ||
axes[i].set_title(f"Variable {i + 1}") | ||
axes[i].legend() | ||
axes[i].grid(True) | ||
|
||
plt.xlabel("Time Steps") | ||
plt.tight_layout() | ||
plt.savefig("var2.png") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="VAR(2) example") | ||
parser.add_argument("--num-data", nargs="?", default=100, type=int) | ||
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int) | ||
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int) | ||
parser.add_argument("--num-chains", nargs="?", default=1, type=int) | ||
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".') | ||
args = parser.parse_args() | ||
|
||
numpyro.set_platform(args.device) | ||
numpyro.set_host_device_count(args.num_chains) | ||
|
||
main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add
.to_event(2)
at the end to make sure that no batch dimension appears here?