From 7a1e61631d9287dae7e9ca21e87f52da3a2610e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Fri, 19 Apr 2024 20:12:09 +0200 Subject: [PATCH] Report Rhat for the nearly-nonidentifiable experiment. --- workflows/nearly_nonidentifiable.smk | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/workflows/nearly_nonidentifiable.smk b/workflows/nearly_nonidentifiable.smk index ded9b73..5545cda 100644 --- a/workflows/nearly_nonidentifiable.smk +++ b/workflows/nearly_nonidentifiable.smk @@ -1,6 +1,7 @@ # --------------------------------------------------- # - Experiment with a nearly non-identifiable model - # --------------------------------------------------- +from contextlib import redirect_stdout from dataclasses import dataclass import numpy as np import matplotlib.pyplot as plt @@ -72,12 +73,19 @@ rule generate_data: rule run_mcmc: input: "data/{setting}-{seed}.joblib" - output: "samples/MCMC/{setting}-{seed}.npy" + output: + array = "samples/MCMC/{setting}-{seed}.npy", + convergence = "samples/MCMC/convergence/{setting}-{seed}.txt" run: data = joblib.load(str(input)) - estimator = algo.DiscreteCategoricalMeanEstimator() + + estimator = algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)) samples = np.asarray(estimator.sample_posterior(data)[estimator.P_TEST_Y]) - np.save(str(output), samples) + with open(output.convergence, "w") as fh: + with redirect_stdout(fh): + estimator.get_mcmc().print_summary() + + np.save(output.array, samples) def _bootstrap(rng, stat: dc.SummaryStatistic) -> dc.SummaryStatistic: