diff --git a/workflows/nearly_nonidentifiable.smk b/workflows/nearly_nonidentifiable.smk index ded9b73..5545cda 100644 --- a/workflows/nearly_nonidentifiable.smk +++ b/workflows/nearly_nonidentifiable.smk @@ -1,6 +1,7 @@ # --------------------------------------------------- # - Experiment with a nearly non-identifiable model - # --------------------------------------------------- +from contextlib import redirect_stdout from dataclasses import dataclass import numpy as np import matplotlib.pyplot as plt @@ -72,12 +73,19 @@ rule generate_data: rule run_mcmc: input: "data/{setting}-{seed}.joblib" - output: "samples/MCMC/{setting}-{seed}.npy" + output: + array = "samples/MCMC/{setting}-{seed}.npy", + convergence = "samples/MCMC/convergence/{setting}-{seed}.txt" run: data = joblib.load(str(input)) - estimator = algo.DiscreteCategoricalMeanEstimator() + + estimator = algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)) samples = np.asarray(estimator.sample_posterior(data)[estimator.P_TEST_Y]) - np.save(str(output), samples) + with open(output.convergence, "w") as fh: + with redirect_stdout(fh): + estimator.get_mcmc().print_summary() + + np.save(output.array, samples) def _bootstrap(rng, stat: dc.SummaryStatistic) -> dc.SummaryStatistic: