Skip to content

Commit

Permalink
Filter out nonconverged runs
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed May 27, 2024
1 parent 7fd0e42 commit e754cd0
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions workflows/misspecified.smk
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,17 @@ def calculate_hdi(arr, prob: float) -> tuple[float, float]:


rule contains_ground_truth:
input: "samples/{n_points}/{algorithm}/{seed}.npy"
input:
samples = "samples/{n_points}/{algorithm}/{seed}.npy",
convergence = "convergence/{n_points}/{algorithm}/{seed}.joblib",
output: "contains/{n_points}/{algorithm}/{seed}.joblib"
run:
samples = joblib.load(str(input))
samples = joblib.load(input.samples)
convergence = joblib.load(input.convergence)
run_ok = True if convergence["max_r_hat"] < 1.02 else False

pi_samples = samples[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y][:, 1]

results = []
intervals = []
for coverage in COVERAGES:
Expand All @@ -240,7 +245,7 @@ rule contains_ground_truth:

results = np.asarray(results, dtype=float)
intervals = np.asarray(intervals, dtype=float)
joblib.dump((results, intervals), str(output))
joblib.dump((results, intervals, run_ok), str(output))


def _input_paths_calculate_coverages(wildcards):
Expand All @@ -249,16 +254,28 @@ def _input_paths_calculate_coverages(wildcards):

rule calculate_coverages:
input: _input_paths_calculate_coverages
output: "coverages/{n_points}/{algorithm}.npy"
output:
coverages = "coverages/{n_points}/{algorithm}.npy",
excluded_runs = "excluded/{n_points}-{algorithm}.json"
run:
results = []

ok_runs = 0
excluded_runs = 0
for pth in input:
res, _ = joblib.load(pth)
results.append(res)
res, _, run_ok = joblib.load(pth)
if run_ok:
results.append(res)
ok_runs += 1
else:
excluded_runs += 1

results = np.asarray(results)
coverages = results.mean(axis=0)
np.save(str(output), coverages)
np.save(output.coverages, coverages)

with open(output.excluded_runs, "w") as fh:
json.dump({"excluded_runs": excluded_runs, "ok_runs": ok_runs}, fh)

def _input_paths_summarize_convergence(wildcards):
return [f"convergence/{wildcards.n_points}/{wildcards.algorithm}/{seed}.joblib" for seed in SEEDS]
Expand Down

0 comments on commit e754cd0

Please sign in to comment.