Skip to content

Commit

Permalink
Merge pull request #138 from fact-project/disp_plot_fixes
Browse files Browse the repository at this point in the history
Calculate errorbars for disp metrics, fix unit, drop bins with too few events
  • Loading branch information
maxnoe authored May 29, 2020
2 parents 393a621 + aa1dc7a commit 66daf64
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
36 changes: 32 additions & 4 deletions aict_tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import warnings

from sklearn import metrics
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.calibration import CalibratedClassifierCV

from .preprocessing import horizontal_to_camera
Expand Down Expand Up @@ -97,6 +99,7 @@ def plot_bias_resolution(
binned['upper_sigma'] = grouped['rel_error'].agg(lambda s: np.percentile(s, 85))
binned['resolution_quantiles'] = (binned.upper_sigma - binned.lower_sigma) / 2
binned['resolution'] = grouped['rel_error'].std()
binned = binned[grouped.count() > 5] # at least five events

for key in ('bias', 'resolution', 'resolution_quantiles'):
if matplotlib.get_backend() == 'pgf' or plt.rcParams['text.usetex']:
Expand Down Expand Up @@ -335,20 +338,45 @@ def r2(group):
'e_width': np.diff(edges),
}, index=pd.Series(np.arange(1, len(edges)), name='bin_idx'))

binned['accuracy'] = df.groupby('bin_idx').apply(accuracy)
binned['r2_score'] = df.groupby('bin_idx').apply(r2)
r2_scores = pd.DataFrame(index=binned.index)
accuracies = pd.DataFrame(index=binned.index)
counts = pd.DataFrame(index=binned.index)

with warnings.catch_warnings():
# warns when there are less than 2 events for calculating metrics,
# but we throw those away anyways
warnings.filterwarnings('ignore', category=UndefinedMetricWarning)
for cv_fold, cv in df.groupby('cv_fold'):
grouped = cv.groupby('bin_idx')
accuracies[cv_fold] = grouped.apply(accuracy)
r2_scores[cv_fold] = grouped.apply(r2)
counts[cv_fold] = grouped.size()

binned['r2_score'] = r2_scores.mean(axis=1)
binned['r2_std'] = r2_scores.std(axis=1)
binned['accuracy'] = accuracies.mean(axis=1)
binned['accuracy_std'] = accuracies.std(axis=1)
# at least 10 events in each crossval iteration
binned['valid'] = (counts > 10).any(axis=1)
binned = binned.query('valid')

fig = fig or plt.figure()

ax1 = fig.add_subplot(2, 1, 1)
ax2 = fig.add_subplot(2, 1, 2, sharex=ax1)

ax1.errorbar(
binned.e_center, binned.accuracy, xerr=binned.e_width / 2, ls='',
binned.e_center, binned.accuracy,
yerr=binned.accuracy_std, xerr=binned.e_width / 2,
ls='',
)
ax1.set_ylabel(r'Accuracy for $\mathrm{sgn} \mathtt{disp}$')

ax2.errorbar(binned.e_center, binned.r2_score, xerr=binned.e_width / 2, ls='')
ax2.errorbar(
binned.e_center, binned.r2_score,
yerr=binned.r2_std, xerr=binned.e_width / 2,
ls='',
)
ax2.set_ylabel(r'$r^2$ score for $|\mathtt{disp}|$')

ax2.set_xlabel(
Expand Down
6 changes: 4 additions & 2 deletions aict_tools/scripts/plot_disp_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(configuration_path, performance_path, data_path, disp_model_path, sign_
df = fact.io.read_data(performance_path, key=key)

columns = model_config.columns_to_read_train

if model_config.coordinate_transformation == 'CTA':
camera_unit = r'\mathrm{m}'
else:
Expand Down Expand Up @@ -123,7 +123,9 @@ def main(configuration_path, performance_path, data_path, disp_model_path, sign_
plot_true_delta_delta(df_data, model_config, ax)

if config.true_energy_column in df.columns:
fig = plot_energy_dependent_disp_metrics(df, config.true_energy_column)
fig = plot_energy_dependent_disp_metrics(
df, config.true_energy_column, energy_unit=config.energy_unit
)
figures.append(fig)

if output is None:
Expand Down

0 comments on commit 66daf64

Please sign in to comment.