Skip to content

Commit

Permalink
conditions to avoid of raise issue if num_years_to_plot too high
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Apr 13, 2024
1 parent 8f25b15 commit 7e7ccdb
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions ogcore/output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def plot_aggregates(
"""
assert isinstance(start_year, (int, np.integer))
assert isinstance(num_years_to_plot, int)
assert num_years_to_plot <= base_params.T
# Make sure both runs cover same time period
if reform_tpi:
assert base_params.start_year == reform_params.start_year
Expand Down Expand Up @@ -221,6 +222,7 @@ def plot_industry_aggregates(
"""
assert isinstance(start_year, (int, np.integer))
assert isinstance(num_years_to_plot, int)
assert num_years_to_plot <= base_params.T
dims = base_tpi[var_list[0]].shape[1]
if ind_names_list:
assert len(ind_names_list) == dims
Expand Down Expand Up @@ -438,6 +440,7 @@ def plot_gdp_ratio(
"""
assert isinstance(start_year, (int, np.integer))
assert isinstance(num_years_to_plot, int)
assert num_years_to_plot <= base_params.T
if plot_type == "diff":
assert reform_tpi is not None
# Make sure both runs cover same time period
Expand Down Expand Up @@ -865,7 +868,7 @@ def plot_all(base_output_path, reform_output_path, save_path):
reform_params=reform_params,
var_list=["Y", "K", "L", "C"],
plot_type="pct_diff",
num_years_to_plot=150,
num_years_to_plot=min(base_params.T, 150),
start_year=base_params.start_year,
vertical_line_years=[
base_params.start_year + base_params.tG1,
Expand All @@ -883,7 +886,7 @@ def plot_all(base_output_path, reform_output_path, save_path):
reform_params=reform_params,
var_list=["D", "G", "TR", "total_tax_revenue"],
plot_type="pct_diff",
num_years_to_plot=150,
num_years_to_plot=min(base_params.T, 150),
start_year=base_params.start_year,
vertical_line_years=[
base_params.start_year + base_params.tG1,
Expand All @@ -901,7 +904,7 @@ def plot_all(base_output_path, reform_output_path, save_path):
reform_params=reform_params,
var_list=["r"],
plot_type="levels",
num_years_to_plot=150,
num_years_to_plot=min(base_params.T, 150),
start_year=base_params.start_year,
vertical_line_years=[
base_params.start_year + base_params.tG1,
Expand All @@ -918,7 +921,7 @@ def plot_all(base_output_path, reform_output_path, save_path):
reform_params=reform_params,
var_list=["w"],
plot_type="levels",
num_years_to_plot=150,
num_years_to_plot=min(base_params.T, 150),
start_year=base_params.start_year,
vertical_line_years=[
base_params.start_year + base_params.tG1,
Expand All @@ -935,7 +938,7 @@ def plot_all(base_output_path, reform_output_path, save_path):
reform_tpi,
reform_params,
var_list=["D"],
num_years_to_plot=150,
num_years_to_plot=min(base_params.T, 150),
start_year=base_params.start_year,
vertical_line_years=[
base_params.start_year + base_params.tG1,
Expand All @@ -952,7 +955,7 @@ def plot_all(base_output_path, reform_output_path, save_path):
reform_tpi,
reform_params,
var_list=["total_tax_revenue"],
num_years_to_plot=150,
num_years_to_plot=min(base_params.T, 150),
start_year=base_params.start_year,
vertical_line_years=[
base_params.start_year + base_params.tG1,
Expand Down Expand Up @@ -1096,6 +1099,7 @@ def inequality_plot(
"""
assert isinstance(start_year, (int, np.integer))
assert isinstance(num_years_to_plot, int)
assert num_years_to_plot <= base_params.T
# Make sure both runs cover same time period
if reform_tpi:
assert base_params.start_year == reform_params.start_year
Expand Down

0 comments on commit 7e7ccdb

Please sign in to comment.