From e754cd021d9a7ec02625fcb1e443fe2f7529f992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 19:41:52 +0200 Subject: [PATCH] Filter out nonconverged runs --- workflows/misspecified.smk | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk index 9e4aa8c..b85f20a 100644 --- a/workflows/misspecified.smk +++ b/workflows/misspecified.smk @@ -224,12 +224,17 @@ def calculate_hdi(arr, prob: float) -> tuple[float, float]: rule contains_ground_truth: - input: "samples/{n_points}/{algorithm}/{seed}.npy" + input: + samples = "samples/{n_points}/{algorithm}/{seed}.npy", + convergence = "convergence/{n_points}/{algorithm}/{seed}.joblib", output: "contains/{n_points}/{algorithm}/{seed}.joblib" run: - samples = joblib.load(str(input)) + samples = joblib.load(input.samples) + convergence = joblib.load(input.convergence) + run_ok = True if convergence["max_r_hat"] < 1.02 else False + pi_samples = samples[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y][:, 1] - + results = [] intervals = [] for coverage in COVERAGES: @@ -240,7 +245,7 @@ rule contains_ground_truth: results = np.asarray(results, dtype=float) intervals = np.asarray(intervals, dtype=float) - joblib.dump((results, intervals), str(output)) + joblib.dump((results, intervals, run_ok), str(output)) def _input_paths_calculate_coverages(wildcards): @@ -249,16 +254,28 @@ def _input_paths_calculate_coverages(wildcards): rule calculate_coverages: input: _input_paths_calculate_coverages - output: "coverages/{n_points}/{algorithm}.npy" + output: + coverages = "coverages/{n_points}/{algorithm}.npy", + excluded_runs = "excluded/{n_points}-{algorithm}.json" run: results = [] + + ok_runs = 0 + excluded_runs = 0 for pth in input: - res, _ = joblib.load(pth) - results.append(res) + res, _, run_ok = joblib.load(pth) + if run_ok: + results.append(res) + ok_runs += 1 + else: + excluded_runs += 1 + results = np.asarray(results) coverages = results.mean(axis=0) - np.save(str(output), coverages) + np.save(output.coverages, coverages) + with open(output.excluded_runs, "w") as fh: + json.dump({"excluded_runs": excluded_runs, "ok_runs": ok_runs}, fh) def _input_paths_summarize_convergence(wildcards): return [f"convergence/{wildcards.n_points}/{wildcards.algorithm}/{seed}.joblib" for seed in SEEDS]