-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from wgurecky/100d_gauss_test
100D multivariate normal distribution test
- Loading branch information
Showing
11 changed files
with
244 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from __future__ import print_function, division | ||
import numpy as np | ||
|
||
|
||
def var_ball(varepsilon, dim): | ||
"""! | ||
@brief Draw single sample from tight gaussian ball | ||
@param varepsilon float or 1d_array of len dim | ||
@param dim dimension of gaussian ball | ||
""" | ||
eps = 0. | ||
if np.all(np.asarray(varepsilon) > 0): | ||
eps = np.random.multivariate_normal(np.zeros(dim), | ||
np.eye(dim) * np.asarray(varepsilon), | ||
size=1)[0] | ||
return eps | ||
|
||
def var_box(varepsilon, dim): | ||
"""! | ||
@brief Draw single sample from tight uniform box | ||
@param varepsilon float or 1d_array of len dim | ||
@param dim dimension of uniform distribution | ||
""" | ||
eps = 0. | ||
if np.all(np.asarray(varepsilon) > 0): | ||
eps = np.random.uniform(low=-np.asarray(varepsilon) * np.ones(dim), | ||
high=np.asarray(varepsilon) * np.ones(dim)) | ||
return eps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/python | ||
## | ||
# Description: Implements 100d normal dist | ||
## | ||
import numpy as np | ||
from scipy.stats import multivariate_normal | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
class Gauss_100D(object): | ||
""" | ||
High dimension multivariate normal from LANL DREAM report | ||
""" | ||
def __init__(self, rho=0.5, dim=100): | ||
# rho is pairwise correlation between all rvs | ||
self.mu = np.zeros(dim) | ||
self.var = np.sqrt(np.arange(dim) + 1.0) | ||
self.cov = np.zeros((dim, dim)) | ||
for i in range(dim): | ||
for j in range(dim): | ||
if i == j: | ||
self.cov[i][j] = self.var[i] ** 2.0 | ||
else: | ||
self.cov[i][j] = self.var[i] * self.var[j] * rho | ||
self.dim = dim | ||
self.rho = rho | ||
self.rv_100d = multivariate_normal(self.mu, self.cov) | ||
|
||
def pdf(self, y): | ||
# return self.rv_100d.pdf(y) / 1e-100 | ||
return self.rv_100d.pdf(y) | ||
|
||
def ln_like(self, y): | ||
assert len(y) == self.dim | ||
return np.log(self.pdf(y)) | ||
|
||
def rvs(self, n_samples): | ||
rv_samples = self.rv_100d.rvs(size=n_samples) | ||
return rv_samples | ||
|
||
|
||
if __name__ == "__main__": | ||
d100_gauss = Gauss_100D() | ||
y = d100_gauss.rvs(10) | ||
print(y.shape) | ||
|
||
p = d100_gauss.pdf(np.zeros(100)) | ||
print(p) | ||
p = d100_gauss.pdf(np.ones(100)) | ||
print(p) | ||
|
||
ln_p = d100_gauss.ln_like(np.zeros(100)) | ||
print(ln_p) | ||
ln_p = d100_gauss.ln_like(np.ones(100)) | ||
print(ln_p) | ||
print(np.diag(d100_gauss.cov)) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
Copyright (c) 2018 William Gurecky | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
1. Redistributions of source code must retain the above copyright | ||
notice, this list of conditions and the following disclaimer. | ||
2. Redistributions in binary form must reproduce the above copyright | ||
notice, this list of conditions and the following disclaimer in the | ||
documentation and/or other materials provided with the distribution. | ||
3. Neither the name of the organization nor the | ||
names of its contributors may be used to endorse or promote products | ||
derived from this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY William Gurecky ''AS IS'' AND ANY | ||
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | ||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL William Gurecky BE LIABLE FOR ANY | ||
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES | ||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; | ||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND | ||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
#!/usr/bin/python | ||
## | ||
# Description: Tests samplers on a 100d gauss shaped distribution | ||
## | ||
from __future__ import print_function, division | ||
import unittest | ||
import numpy as np | ||
from six import iteritems | ||
from mpi4py import MPI | ||
import matplotlib.pyplot as plt | ||
import time | ||
import pytest | ||
# | ||
from bipymc.utils import banana_rv, dblgauss_rv, d100_gauss | ||
from bipymc.mc_plot import mc_plot | ||
from bipymc.demc import DeMcMpi | ||
from bipymc.dream import DreamMpi | ||
from bipymc.dram import DrMetropolis, Dram | ||
np.random.seed(42) | ||
n_samples = 500000 | ||
n_burn = 200000 | ||
|
||
|
||
@pytest.mark.heavy | ||
class TestMcmc100DGauss(unittest.TestCase): | ||
def setUp(self): | ||
""" | ||
Setup the 100D gauss distribution | ||
""" | ||
self.comm = MPI.COMM_WORLD | ||
self.gauss = d100_gauss.Gauss_100D() | ||
self.sampler_dict = { | ||
'demc': (self._setup_demc(self.gauss.ln_like), 200), | ||
'dream': (self._setup_dream(self.gauss.ln_like), 100), | ||
} | ||
|
||
if self.comm.rank == 0: | ||
# plot true pdf and true samples | ||
self._plot_gauss() | ||
|
||
def test_samplers(self): | ||
""" | ||
Test ability of each mcmc method to draw samples | ||
from the 100d gauss distribution. | ||
""" | ||
global n_burn | ||
global n_samples | ||
for sampler_name, (my_mcmc, n_chains) in iteritems(self.sampler_dict): | ||
t0 = time.time() | ||
my_mcmc.run_mcmc(n_samples) | ||
t1 = time.time() | ||
theta_est, sig_est, chain = my_mcmc.param_est(n_burn=n_burn) | ||
theta_est_, sig_est_, full_chain = my_mcmc.param_est(n_burn=0) | ||
|
||
if self.comm.rank == 0: | ||
print("=== " + str(sampler_name) + " ===") | ||
print("Sampler walltime: %d (s)" % int(t1 - t0)) | ||
print("Esimated params: %s" % str(theta_est)) | ||
print("Estimated params sigma: %s " % str(sig_est)) | ||
print("Acceptance fraction: %f" % my_mcmc.acceptance_fraction) | ||
try: | ||
print("P_cr: %s" % str(my_mcmc.p_cr)) | ||
except: | ||
pass | ||
y1, y2 = chain[:, 0], chain[:, 1] | ||
|
||
if sampler_name == 'demc' or sampler_name == 'dream': | ||
self.assertAlmostEqual(0.0, theta_est[0], delta=0.2) | ||
self.assertAlmostEqual(0.0, theta_est[1], delta=0.2) | ||
|
||
# plot mcmc samples | ||
plt.figure() | ||
plt.scatter(y1, y2, s=2, alpha=0.08) | ||
plt.grid(ls='--', alpha=0.5) | ||
plt.xlim(-4, 4) | ||
plt.ylim(-4, 4) | ||
plt.savefig(str(sampler_name) + "_100d_gauss_slice_sample.png") | ||
plt.close() | ||
|
||
# plot mcmc chains | ||
""" | ||
mc_plot.plot_mcmc_indep_chains(full_chain, n_chains, | ||
labels=["x1", "x2"], | ||
savefig=str(sampler_name) + "_chains.png", | ||
scatter=True) | ||
""" | ||
|
||
def _plot_gauss(self): | ||
n_samples = 10000 | ||
y = self.gauss.rvs(n_samples) | ||
y1, y2 = y[:, 0], y[:, 1] | ||
plt.figure() | ||
plt.scatter(y1, y2, s=2, alpha=0.3) | ||
plt.xlim(-4, 4) | ||
plt.ylim(-4, 4) | ||
plt.grid(ls='--', alpha=0.5) | ||
plt.savefig("true_100d_gauss_slice_samples.png") | ||
plt.close() | ||
|
||
def _setup_demc(self, log_like_fn, n_chains=200): | ||
theta_0 = np.zeros(100) | ||
my_mcmc = DeMcMpi(log_like_fn, theta_0, n_chains=n_chains, mpi_comm=self.comm) | ||
return my_mcmc | ||
|
||
def _setup_dream(self, log_like_fn, n_chains=100): | ||
global n_burn | ||
theta_0 = np.zeros(100) | ||
my_mcmc = DreamMpi(log_like_fn, theta_0, n_chains=n_chains, mpi_comm=self.comm, | ||
n_cr_gen=50, burnin_gen=int(n_burn / n_chains)) | ||
return my_mcmc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters