From 7e7ccdbe464314a801045bdc21da2260b62f4a70 Mon Sep 17 00:00:00 2001 From: jdebacker Date: Fri, 12 Apr 2024 21:20:53 -0400 Subject: [PATCH] conditions to avoid of raise issue if num_years_to_plot too high --- ogcore/output_plots.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/ogcore/output_plots.py b/ogcore/output_plots.py index 661044c3f..020948f44 100644 --- a/ogcore/output_plots.py +++ b/ogcore/output_plots.py @@ -66,6 +66,7 @@ def plot_aggregates( """ assert isinstance(start_year, (int, np.integer)) assert isinstance(num_years_to_plot, int) + assert num_years_to_plot <= base_params.T # Make sure both runs cover same time period if reform_tpi: assert base_params.start_year == reform_params.start_year @@ -221,6 +222,7 @@ def plot_industry_aggregates( """ assert isinstance(start_year, (int, np.integer)) assert isinstance(num_years_to_plot, int) + assert num_years_to_plot <= base_params.T dims = base_tpi[var_list[0]].shape[1] if ind_names_list: assert len(ind_names_list) == dims @@ -438,6 +440,7 @@ def plot_gdp_ratio( """ assert isinstance(start_year, (int, np.integer)) assert isinstance(num_years_to_plot, int) + assert num_years_to_plot <= base_params.T if plot_type == "diff": assert reform_tpi is not None # Make sure both runs cover same time period @@ -865,7 +868,7 @@ def plot_all(base_output_path, reform_output_path, save_path): reform_params=reform_params, var_list=["Y", "K", "L", "C"], plot_type="pct_diff", - num_years_to_plot=150, + num_years_to_plot=min(base_params.T, 150), start_year=base_params.start_year, vertical_line_years=[ base_params.start_year + base_params.tG1, @@ -883,7 +886,7 @@ def plot_all(base_output_path, reform_output_path, save_path): reform_params=reform_params, var_list=["D", "G", "TR", "total_tax_revenue"], plot_type="pct_diff", - num_years_to_plot=150, + num_years_to_plot=min(base_params.T, 150), start_year=base_params.start_year, vertical_line_years=[ base_params.start_year + base_params.tG1, @@ -901,7 +904,7 @@ def plot_all(base_output_path, reform_output_path, save_path): reform_params=reform_params, var_list=["r"], plot_type="levels", - num_years_to_plot=150, + num_years_to_plot=min(base_params.T, 150), start_year=base_params.start_year, vertical_line_years=[ base_params.start_year + base_params.tG1, @@ -918,7 +921,7 @@ def plot_all(base_output_path, reform_output_path, save_path): reform_params=reform_params, var_list=["w"], plot_type="levels", - num_years_to_plot=150, + num_years_to_plot=min(base_params.T, 150), start_year=base_params.start_year, vertical_line_years=[ base_params.start_year + base_params.tG1, @@ -935,7 +938,7 @@ def plot_all(base_output_path, reform_output_path, save_path): reform_tpi, reform_params, var_list=["D"], - num_years_to_plot=150, + num_years_to_plot=min(base_params.T, 150), start_year=base_params.start_year, vertical_line_years=[ base_params.start_year + base_params.tG1, @@ -952,7 +955,7 @@ def plot_all(base_output_path, reform_output_path, save_path): reform_tpi, reform_params, var_list=["total_tax_revenue"], - num_years_to_plot=150, + num_years_to_plot=min(base_params.T, 150), start_year=base_params.start_year, vertical_line_years=[ base_params.start_year + base_params.tG1, @@ -1096,6 +1099,7 @@ def inequality_plot( """ assert isinstance(start_year, (int, np.integer)) assert isinstance(num_years_to_plot, int) + assert num_years_to_plot <= base_params.T # Make sure both runs cover same time period if reform_tpi: assert base_params.start_year == reform_params.start_year