From f5de226832014bed21478fdd83948fc762bc6a15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 12:23:48 +0200 Subject: [PATCH 1/6] Add r-hat diagonostic for the benchmark runs --- workflows/benchmark.smk | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/workflows/benchmark.smk b/workflows/benchmark.smk index 391b5e1..8735bff 100644 --- a/workflows/benchmark.smk +++ b/workflows/benchmark.smk @@ -10,7 +10,7 @@ import matplotlib.transforms as mtransforms matplotlib.use("Agg") import numpy as np - +from numpyro.diagnostics import summary import labelshift.algorithms.api as algo import labelshift.experiments.api as exp @@ -22,7 +22,7 @@ ESTIMATORS = { "BBS": algo.BlackBoxShiftEstimator(), "CC": algo.ClassifyAndCount(), "RIR": algo.InvariantRatioEstimator(restricted=True), - "BAY": algo.DiscreteCategoricalMeanEstimator(), + "BAY": algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)), } ESTIMATOR_COLORS = { "BBS": "orangered", @@ -146,6 +146,7 @@ _k_vals = [2, 3, 5, 7, 9] _quality = [0.55, 0.65, 0.75, 0.85, 0.95] _quality_prime = [0.45, 0.55, 0.65, 0.75, 0.80, 0.85, 0.90, 0.95] + BENCHMARKS = { "change_prevalence": BenchmarkSettings( param_name="Prevalence $\\pi'_1$", @@ -179,6 +180,7 @@ BENCHMARKS = { ), } + def get_data_setting(benchmark: str, param: int | str) -> DataSetting: return BENCHMARKS[str(benchmark)].settings[int(param)] @@ -234,6 +236,14 @@ rule apply_estimator: elapsed_time = timer.check() run_ok = True additional_info = {} + + if hasattr(estimator, "get_mcmc"): + samples = estimator.get_mcmc().get_samples(group_by_chain=True) + summ = summary(samples) + n_eff_list = [np.min(d["n_eff"]) for d in summ.values()] + r_hat_list = [np.max(d["r_hat"]) for d in summ.values()] + additional_info = additional_info | {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)} + except Exception as e: elapsed_time = float("nan") estimate = np.full_like(data.n_y_labeled, fill_value=float("nan")) @@ -267,9 +277,13 @@ def _get_paths_to_be_assembled(wildcards): rule assemble_results: output: csv = "results/benchmark-{benchmark}-metric-{metric}.csv", - err = "results/status/benchmark-{benchmark}-metric-{metric}.txt" + err = "results/status/benchmark-{benchmark}-metric-{metric}.txt", + convergence = "results/convergence/benchmark-{benchmark}-metric-{metric}.txt", input: _get_paths_to_be_assembled run: + max_r_hat = -1e9 + min_n_eff = 1e9 + results = [] for pth in input: res = joblib.load(pth) @@ -285,6 +299,10 @@ rule assemble_results: } results.append(nice) + if "max_r_hat" in res.additional_info: + max_r_hat = max(max_r_hat, res.additional_info["max_r_hat"]) + min_n_eff = min(min_n_eff, res.additional_info["min_n_eff"]) + results = pd.DataFrame(results) df_ok = results[results["run_ok"]] @@ -298,6 +316,9 @@ rule assemble_results: df_ok = df_ok.drop(columns=["run_ok", "additional_info"]) df_ok.to_csv(str(output.csv), index=False) + with open(output.convergence, "w") as f: + f.write(f"Max r_hat: {max_r_hat}\n") + f.write(f"Min n_eff: {min_n_eff}\n") def plot_results(ax, df, plot_std: bool = True, alpha: float = 0.5): From e1a4809f7be31d2712da814c9d9b459d36cf86da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 14:17:36 +0200 Subject: [PATCH 2/6] Add convergence checks to experiment with misspecified models --- workflows/misspecified.smk | 84 +++++++++++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk index 2e79b54..e0369cb 100644 --- a/workflows/misspecified.smk +++ b/workflows/misspecified.smk @@ -3,7 +3,7 @@ # ---------------------------------------------------------------------------------- from dataclasses import dataclass import numpy as np -import pandas as pd +import json import joblib import matplotlib @@ -14,7 +14,7 @@ matplotlib.use("agg") import jax import numpyro import numpyro.distributions as dist - +from numpyro.diagnostics import summary import labelshift.algorithms.api as algo from labelshift.datasets.discrete_categorical import SummaryStatistic @@ -65,10 +65,9 @@ COVERAGES = np.arange(0.05, 0.96, 0.05) rule all: - input: expand("plots/{n_points}.pdf", n_points=N_POINTS) - -# rule all: -# input: expand("figures/{setting}-{seed}.pdf", setting=SETTINGS.keys(), seed=SEEDS) + input: + plots = expand("plots/{n_points}.pdf", n_points=N_POINTS), + convergence = "convergence_overall.json", rule generate_data: @@ -82,7 +81,7 @@ rule generate_data: def gaussian_model(observed: Data, unobserved: np.ndarray): sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2))) - mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1)) + mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3)) pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2))) @@ -101,7 +100,7 @@ def gaussian_model(observed: Data, unobserved: np.ndarray): def student_model(observed: Data, unobserved: np.ndarray): df = numpyro.sample('df', dist.Gamma(np.ones(2), np.ones(2))) sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2))) - mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1)) + mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3)) pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2))) @@ -117,9 +116,17 @@ def student_model(observed: Data, unobserved: np.ndarray): numpyro.sample('x', mixture, obs=unobserved) +def generate_summary(samples): + summ = summary(samples) + n_eff_list = [float(np.min(d["n_eff"])) for d in summ.values()] + r_hat_list = [float(np.max(d["r_hat"])) for d in summ.values()] + return {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)} + rule run_gaussian_mcmc: input: "data/{n_points}/{seed}.npy" - output: "samples/{n_points}/Gaussian/{seed}.npy" + output: + samples = "samples/{n_points}/Gaussian/{seed}.npy", + convergence = "convergence/{n_points}/Gaussian/{seed}.joblib", run: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( @@ -130,12 +137,17 @@ rule run_gaussian_mcmc: rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) samples = mcmc.get_samples() - joblib.dump(samples, str(output)) + joblib.dump(samples, output.samples) + + summ = generate_summary(mcmc.get_samples(group_by_chain=True)) + joblib.dump(summ, output.convergence) rule run_student_mcmc: input: "data/{n_points}/{seed}.npy" - output: "samples/{n_points}/Student/{seed}.npy" + output: + samples = "samples/{n_points}/Student/{seed}.npy", + convergence = "convergence/{n_points}/Student/{seed}.joblib", run: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( @@ -146,7 +158,11 @@ rule run_student_mcmc: rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) samples = mcmc.get_samples() - joblib.dump(samples, str(output)) + joblib.dump(samples, output.samples) + + summ = generate_summary(mcmc.get_samples(group_by_chain=True)) + joblib.dump(summ, output.convergence) + def _calculate_bins(n: int): @@ -169,7 +185,9 @@ def generate_summary_statistic( rule run_discrete_mcmc: input: "data/{n_points}/{seed}.npy" - output: "samples/{n_points}/Discrete-{n_bins}/{seed}.npy" + output: + samples = "samples/{n_points}/Discrete-{n_bins}/{seed}.npy", + convergence = "convergence/{n_points}/Discrete-{n_bins}/{seed}.joblib", run: data_labeled, data_unlabeled = joblib.load(str(input)) estimator = algo.DiscreteCategoricalMeanEstimator( @@ -177,7 +195,10 @@ rule run_discrete_mcmc: params=algo.SamplingParams(warmup=N_MCMC_WARMUP, samples=N_MCMC_SAMPLES), ) samples = estimator.sample_posterior(generate_summary_statistic(data_labeled, data_unlabeled.xs, int(wildcards.n_bins))) - joblib.dump(samples, str(output)) + joblib.dump(samples, output.samples) + + summ = generate_summary(estimator.get_mcmc().get_samples(group_by_chain=True)) + joblib.dump(summ, output.convergence) def calculate_hdi(arr, prob: float) -> tuple[float, float]: @@ -231,6 +252,41 @@ rule calculate_coverages: coverages = results.mean(axis=0) np.save(str(output), coverages) + +def _input_paths_summarize_convergence(wildcards): + return [f"convergence/{wildcards.n_points}/{wildcards.algorithm}/{seed}.joblib" for seed in SEEDS] + + +rule summarize_convergence: + input: _input_paths_summarize_convergence + output: "convergence/{n_points}/{algorithm}.json" + run: + min_n_effs = [] + max_r_hats = [] + for pth in input: + res = joblib.load(pth) + min_n_effs.append(res["min_n_eff"]) + max_r_hats.append(res["max_r_hat"]) + + with open(str(output), "w") as fh: + json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh) + + +rule summarize_convergence_overall: + input: expand("convergence/{n_points}/{algorithm}.json", n_points=N_POINTS, algorithm=["Gaussian", "Student", "Discrete-5", "Discrete-10"]) + output: "convergence_overall.json" + run: + min_n_effs = [] + max_r_hats = [] + for pth in input: + with open(pth) as fh: + res = json.load(fh) + min_n_effs.append(res["min_n_eff"]) + max_r_hats.append(res["max_r_hat"]) + + with open(str(output), "w") as fh: + json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh) + rule plot_coverage: input: gaussian = "coverages/{n_points}/Gaussian.npy", From 02585b92c50d82aea87a6750ce8dc6994e737f8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 16:20:19 +0200 Subject: [PATCH 3/6] Use four chains --- workflows/misspecified.smk | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk index e0369cb..3fff50b 100644 --- a/workflows/misspecified.smk +++ b/workflows/misspecified.smk @@ -131,8 +131,9 @@ rule run_gaussian_mcmc: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(gaussian_model), - num_warmup=N_MCMC_WARMUP, - num_samples=N_MCMC_SAMPLES, + num_warmup=1500, + num_samples=2000, + num_chains=4, ) rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) @@ -152,8 +153,9 @@ rule run_student_mcmc: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(student_model), - num_warmup=N_MCMC_WARMUP, - num_samples=N_MCMC_SAMPLES, + num_warmup=1500, + num_samples=2000, + num_chains=4, ) rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) From 4463f595539f2389e8660ca974989d808fb51de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 16:24:36 +0200 Subject: [PATCH 4/6] Refactor code. --- workflows/misspecified.smk | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk index 3fff50b..f3d2ba6 100644 --- a/workflows/misspecified.smk +++ b/workflows/misspecified.smk @@ -57,8 +57,9 @@ N_POINTS = [100, 1000, 10_000] PI_LABELED = 0.5 PI_UNLABELED = 0.2 -N_MCMC_WARMUP = 500 -N_MCMC_SAMPLES = 1000 +N_MCMC_WARMUP = 1500 +N_MCMC_SAMPLES = 2000 +N_MCMC_CHAINS = 4 COVERAGES = np.arange(0.05, 0.96, 0.05) @@ -131,9 +132,9 @@ rule run_gaussian_mcmc: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(gaussian_model), - num_warmup=1500, - num_samples=2000, - num_chains=4, + num_warmup=N_MCMC_WARMUP, + num_samples=N_MCMC_SAMPLES, + num_chains=N_MCMC_CHAINS, ) rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) @@ -153,9 +154,9 @@ rule run_student_mcmc: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(student_model), - num_warmup=1500, - num_samples=2000, - num_chains=4, + num_warmup=N_MCMC_WARMUP, + num_samples=N_MCMC_SAMPLES, + num_chains=N_MCMC_CHAINS, ) rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101) mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs) @@ -194,7 +195,11 @@ rule run_discrete_mcmc: data_labeled, data_unlabeled = joblib.load(str(input)) estimator = algo.DiscreteCategoricalMeanEstimator( seed=int(wildcards.seed) + 101, - params=algo.SamplingParams(warmup=N_MCMC_WARMUP, samples=N_MCMC_SAMPLES), + params=algo.SamplingParams( + warmup=N_MCMC_WARMUP, + samples=N_MCMC_SAMPLES, + chains=N_MCMC_CHAINS, + ), ) samples = estimator.sample_posterior(generate_summary_statistic(data_labeled, data_unlabeled.xs, int(wildcards.n_bins))) joblib.dump(samples, output.samples) From 7fd0e420726a12116d7a99cc1b815c9bdcb7cf82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 16:58:17 +0200 Subject: [PATCH 5/6] Improve DPI --- workflows/misspecified.smk | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk index f3d2ba6..9e4aa8c 100644 --- a/workflows/misspecified.smk +++ b/workflows/misspecified.smk @@ -150,7 +150,7 @@ rule run_student_mcmc: output: samples = "samples/{n_points}/Student/{seed}.npy", convergence = "convergence/{n_points}/Student/{seed}.joblib", - run: + run: data_labeled, data_unlabeled = joblib.load(str(input)) mcmc = numpyro.infer.MCMC( numpyro.infer.NUTS(student_model), @@ -306,7 +306,7 @@ rule plot_coverage: sample_discrete10 = "samples/{n_points}/Discrete-10/1.npy", output: "plots/{n_points}.pdf" run: - fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=150, left=0.2, top=0.3, right=1.8) + fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=400, left=0.2, top=0.3, right=1.8) axs = axs.ravel() # Conditional distributions P(X|Y) From e754cd021d9a7ec02625fcb1e443fe2f7529f992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 27 May 2024 19:41:52 +0200 Subject: [PATCH 6/6] Filter out nonconverged runs --- workflows/misspecified.smk | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk index 9e4aa8c..b85f20a 100644 --- a/workflows/misspecified.smk +++ b/workflows/misspecified.smk @@ -224,12 +224,17 @@ def calculate_hdi(arr, prob: float) -> tuple[float, float]: rule contains_ground_truth: - input: "samples/{n_points}/{algorithm}/{seed}.npy" + input: + samples = "samples/{n_points}/{algorithm}/{seed}.npy", + convergence = "convergence/{n_points}/{algorithm}/{seed}.joblib", output: "contains/{n_points}/{algorithm}/{seed}.joblib" run: - samples = joblib.load(str(input)) + samples = joblib.load(input.samples) + convergence = joblib.load(input.convergence) + run_ok = True if convergence["max_r_hat"] < 1.02 else False + pi_samples = samples[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y][:, 1] - + results = [] intervals = [] for coverage in COVERAGES: @@ -240,7 +245,7 @@ rule contains_ground_truth: results = np.asarray(results, dtype=float) intervals = np.asarray(intervals, dtype=float) - joblib.dump((results, intervals), str(output)) + joblib.dump((results, intervals, run_ok), str(output)) def _input_paths_calculate_coverages(wildcards): @@ -249,16 +254,28 @@ def _input_paths_calculate_coverages(wildcards): rule calculate_coverages: input: _input_paths_calculate_coverages - output: "coverages/{n_points}/{algorithm}.npy" + output: + coverages = "coverages/{n_points}/{algorithm}.npy", + excluded_runs = "excluded/{n_points}-{algorithm}.json" run: results = [] + + ok_runs = 0 + excluded_runs = 0 for pth in input: - res, _ = joblib.load(pth) - results.append(res) + res, _, run_ok = joblib.load(pth) + if run_ok: + results.append(res) + ok_runs += 1 + else: + excluded_runs += 1 + results = np.asarray(results) coverages = results.mean(axis=0) - np.save(str(output), coverages) + np.save(output.coverages, coverages) + with open(output.excluded_runs, "w") as fh: + json.dump({"excluded_runs": excluded_runs, "ok_runs": ok_runs}, fh) def _input_paths_summarize_convergence(wildcards): return [f"convergence/{wildcards.n_points}/{wildcards.algorithm}/{seed}.joblib" for seed in SEEDS]