Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add r-hat calculation for expensive experiments #26

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions workflows/benchmark.smk
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import matplotlib.transforms as mtransforms
matplotlib.use("Agg")

import numpy as np

from numpyro.diagnostics import summary

import labelshift.algorithms.api as algo
import labelshift.experiments.api as exp
Expand All @@ -22,7 +22,7 @@ ESTIMATORS = {
"BBS": algo.BlackBoxShiftEstimator(),
"CC": algo.ClassifyAndCount(),
"RIR": algo.InvariantRatioEstimator(restricted=True),
"BAY": algo.DiscreteCategoricalMeanEstimator(),
"BAY": algo.DiscreteCategoricalMeanEstimator(params=algo.SamplingParams(chains=4)),
}
ESTIMATOR_COLORS = {
"BBS": "orangered",
Expand Down Expand Up @@ -146,6 +146,7 @@ _k_vals = [2, 3, 5, 7, 9]
_quality = [0.55, 0.65, 0.75, 0.85, 0.95]
_quality_prime = [0.45, 0.55, 0.65, 0.75, 0.80, 0.85, 0.90, 0.95]


BENCHMARKS = {
"change_prevalence": BenchmarkSettings(
param_name="Prevalence $\\pi'_1$",
Expand Down Expand Up @@ -179,6 +180,7 @@ BENCHMARKS = {
),
}


def get_data_setting(benchmark: str, param: int | str) -> DataSetting:
return BENCHMARKS[str(benchmark)].settings[int(param)]

Expand Down Expand Up @@ -234,6 +236,14 @@ rule apply_estimator:
elapsed_time = timer.check()
run_ok = True
additional_info = {}

if hasattr(estimator, "get_mcmc"):
samples = estimator.get_mcmc().get_samples(group_by_chain=True)
summ = summary(samples)
n_eff_list = [np.min(d["n_eff"]) for d in summ.values()]
r_hat_list = [np.max(d["r_hat"]) for d in summ.values()]
additional_info = additional_info | {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)}

except Exception as e:
elapsed_time = float("nan")
estimate = np.full_like(data.n_y_labeled, fill_value=float("nan"))
Expand Down Expand Up @@ -267,9 +277,13 @@ def _get_paths_to_be_assembled(wildcards):
rule assemble_results:
output:
csv = "results/benchmark-{benchmark}-metric-{metric}.csv",
err = "results/status/benchmark-{benchmark}-metric-{metric}.txt"
err = "results/status/benchmark-{benchmark}-metric-{metric}.txt",
convergence = "results/convergence/benchmark-{benchmark}-metric-{metric}.txt",
input: _get_paths_to_be_assembled
run:
max_r_hat = -1e9
min_n_eff = 1e9

results = []
for pth in input:
res = joblib.load(pth)
Expand All @@ -285,6 +299,10 @@ rule assemble_results:
}
results.append(nice)

if "max_r_hat" in res.additional_info:
max_r_hat = max(max_r_hat, res.additional_info["max_r_hat"])
min_n_eff = min(min_n_eff, res.additional_info["min_n_eff"])

results = pd.DataFrame(results)

df_ok = results[results["run_ok"]]
Expand All @@ -298,6 +316,9 @@ rule assemble_results:
df_ok = df_ok.drop(columns=["run_ok", "additional_info"])
df_ok.to_csv(str(output.csv), index=False)

with open(output.convergence, "w") as f:
f.write(f"Max r_hat: {max_r_hat}\n")
f.write(f"Min n_eff: {min_n_eff}\n")


def plot_results(ax, df, plot_std: bool = True, alpha: float = 0.5):
Expand Down
134 changes: 107 additions & 27 deletions workflows/misspecified.smk
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ----------------------------------------------------------------------------------
from dataclasses import dataclass
import numpy as np
import pandas as pd
import json
import joblib

import matplotlib
Expand All @@ -14,7 +14,7 @@ matplotlib.use("agg")
import jax
import numpyro
import numpyro.distributions as dist

from numpyro.diagnostics import summary

import labelshift.algorithms.api as algo
from labelshift.datasets.discrete_categorical import SummaryStatistic
Expand Down Expand Up @@ -57,18 +57,18 @@ N_POINTS = [100, 1000, 10_000]
PI_LABELED = 0.5
PI_UNLABELED = 0.2

N_MCMC_WARMUP = 500
N_MCMC_SAMPLES = 1000
N_MCMC_WARMUP = 1500
N_MCMC_SAMPLES = 2000
N_MCMC_CHAINS = 4


COVERAGES = np.arange(0.05, 0.96, 0.05)


rule all:
input: expand("plots/{n_points}.pdf", n_points=N_POINTS)

# rule all:
# input: expand("figures/{setting}-{seed}.pdf", setting=SETTINGS.keys(), seed=SEEDS)
input:
plots = expand("plots/{n_points}.pdf", n_points=N_POINTS),
convergence = "convergence_overall.json",


rule generate_data:
Expand All @@ -82,7 +82,7 @@ rule generate_data:

def gaussian_model(observed: Data, unobserved: np.ndarray):
sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2)))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3))

pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2)))

Expand All @@ -101,7 +101,7 @@ def gaussian_model(observed: Data, unobserved: np.ndarray):
def student_model(observed: Data, unobserved: np.ndarray):
df = numpyro.sample('df', dist.Gamma(np.ones(2), np.ones(2)))
sigma = numpyro.sample('sigma', dist.HalfCauchy(np.ones(2)))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 1))
mu = numpyro.sample('mu', dist.Normal(np.zeros(2), 3))

pi = numpyro.sample(algo.DiscreteCategoricalMeanEstimator.P_TEST_Y, dist.Dirichlet(np.ones(2)))

Expand All @@ -117,36 +117,55 @@ def student_model(observed: Data, unobserved: np.ndarray):
numpyro.sample('x', mixture, obs=unobserved)


def generate_summary(samples):
summ = summary(samples)
n_eff_list = [float(np.min(d["n_eff"])) for d in summ.values()]
r_hat_list = [float(np.max(d["r_hat"])) for d in summ.values()]
return {"min_n_eff": min(n_eff_list), "max_r_hat": max(r_hat_list)}

rule run_gaussian_mcmc:
input: "data/{n_points}/{seed}.npy"
output: "samples/{n_points}/Gaussian/{seed}.npy"
output:
samples = "samples/{n_points}/Gaussian/{seed}.npy",
convergence = "convergence/{n_points}/Gaussian/{seed}.joblib",
run:
data_labeled, data_unlabeled = joblib.load(str(input))
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(gaussian_model),
num_warmup=N_MCMC_WARMUP,
num_samples=N_MCMC_SAMPLES,
num_chains=N_MCMC_CHAINS,
)
rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101)
mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs)
samples = mcmc.get_samples()
joblib.dump(samples, str(output))
joblib.dump(samples, output.samples)

summ = generate_summary(mcmc.get_samples(group_by_chain=True))
joblib.dump(summ, output.convergence)


rule run_student_mcmc:
input: "data/{n_points}/{seed}.npy"
output: "samples/{n_points}/Student/{seed}.npy"
run:
output:
samples = "samples/{n_points}/Student/{seed}.npy",
convergence = "convergence/{n_points}/Student/{seed}.joblib",
run:
data_labeled, data_unlabeled = joblib.load(str(input))
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(student_model),
num_warmup=N_MCMC_WARMUP,
num_samples=N_MCMC_SAMPLES,
num_chains=N_MCMC_CHAINS,
)
rng_key = jax.random.PRNGKey(int(wildcards.seed) + 101)
mcmc.run(rng_key, observed=data_labeled, unobserved=data_unlabeled.xs)
samples = mcmc.get_samples()
joblib.dump(samples, str(output))
joblib.dump(samples, output.samples)

summ = generate_summary(mcmc.get_samples(group_by_chain=True))
joblib.dump(summ, output.convergence)



def _calculate_bins(n: int):
Expand All @@ -169,15 +188,24 @@ def generate_summary_statistic(

rule run_discrete_mcmc:
input: "data/{n_points}/{seed}.npy"
output: "samples/{n_points}/Discrete-{n_bins}/{seed}.npy"
output:
samples = "samples/{n_points}/Discrete-{n_bins}/{seed}.npy",
convergence = "convergence/{n_points}/Discrete-{n_bins}/{seed}.joblib",
run:
data_labeled, data_unlabeled = joblib.load(str(input))
estimator = algo.DiscreteCategoricalMeanEstimator(
seed=int(wildcards.seed) + 101,
params=algo.SamplingParams(warmup=N_MCMC_WARMUP, samples=N_MCMC_SAMPLES),
params=algo.SamplingParams(
warmup=N_MCMC_WARMUP,
samples=N_MCMC_SAMPLES,
chains=N_MCMC_CHAINS,
),
)
samples = estimator.sample_posterior(generate_summary_statistic(data_labeled, data_unlabeled.xs, int(wildcards.n_bins)))
joblib.dump(samples, str(output))
joblib.dump(samples, output.samples)

summ = generate_summary(estimator.get_mcmc().get_samples(group_by_chain=True))
joblib.dump(summ, output.convergence)


def calculate_hdi(arr, prob: float) -> tuple[float, float]:
Expand All @@ -196,12 +224,17 @@ def calculate_hdi(arr, prob: float) -> tuple[float, float]:


rule contains_ground_truth:
input: "samples/{n_points}/{algorithm}/{seed}.npy"
input:
samples = "samples/{n_points}/{algorithm}/{seed}.npy",
convergence = "convergence/{n_points}/{algorithm}/{seed}.joblib",
output: "contains/{n_points}/{algorithm}/{seed}.joblib"
run:
samples = joblib.load(str(input))
samples = joblib.load(input.samples)
convergence = joblib.load(input.convergence)
run_ok = True if convergence["max_r_hat"] < 1.02 else False

pi_samples = samples[algo.DiscreteCategoricalMeanEstimator.P_TEST_Y][:, 1]

results = []
intervals = []
for coverage in COVERAGES:
Expand All @@ -212,7 +245,7 @@ rule contains_ground_truth:

results = np.asarray(results, dtype=float)
intervals = np.asarray(intervals, dtype=float)
joblib.dump((results, intervals), str(output))
joblib.dump((results, intervals, run_ok), str(output))


def _input_paths_calculate_coverages(wildcards):
Expand All @@ -221,15 +254,62 @@ def _input_paths_calculate_coverages(wildcards):

rule calculate_coverages:
input: _input_paths_calculate_coverages
output: "coverages/{n_points}/{algorithm}.npy"
output:
coverages = "coverages/{n_points}/{algorithm}.npy",
excluded_runs = "excluded/{n_points}-{algorithm}.json"
run:
results = []

ok_runs = 0
excluded_runs = 0
for pth in input:
res, _ = joblib.load(pth)
results.append(res)
res, _, run_ok = joblib.load(pth)
if run_ok:
results.append(res)
ok_runs += 1
else:
excluded_runs += 1

results = np.asarray(results)
coverages = results.mean(axis=0)
np.save(str(output), coverages)
np.save(output.coverages, coverages)

with open(output.excluded_runs, "w") as fh:
json.dump({"excluded_runs": excluded_runs, "ok_runs": ok_runs}, fh)

def _input_paths_summarize_convergence(wildcards):
return [f"convergence/{wildcards.n_points}/{wildcards.algorithm}/{seed}.joblib" for seed in SEEDS]


rule summarize_convergence:
input: _input_paths_summarize_convergence
output: "convergence/{n_points}/{algorithm}.json"
run:
min_n_effs = []
max_r_hats = []
for pth in input:
res = joblib.load(pth)
min_n_effs.append(res["min_n_eff"])
max_r_hats.append(res["max_r_hat"])

with open(str(output), "w") as fh:
json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh)


rule summarize_convergence_overall:
input: expand("convergence/{n_points}/{algorithm}.json", n_points=N_POINTS, algorithm=["Gaussian", "Student", "Discrete-5", "Discrete-10"])
output: "convergence_overall.json"
run:
min_n_effs = []
max_r_hats = []
for pth in input:
with open(pth) as fh:
res = json.load(fh)
min_n_effs.append(res["min_n_eff"])
max_r_hats.append(res["max_r_hat"])

with open(str(output), "w") as fh:
json.dump({"min_n_eff": min(min_n_effs), "max_r_hat": max(max_r_hats)}, fh)

rule plot_coverage:
input:
Expand All @@ -243,7 +323,7 @@ rule plot_coverage:
sample_discrete10 = "samples/{n_points}/Discrete-10/1.npy",
output: "plots/{n_points}.pdf"
run:
fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=150, left=0.2, top=0.3, right=1.8)
fig, axs = subplots_from_axsize(axsize=(2, 1), wspace=[0.2, 0.3, 0.6], dpi=400, left=0.2, top=0.3, right=1.8)
axs = axs.ravel()

# Conditional distributions P(X|Y)
Expand Down
Loading