Skip to content

Commit

Permalink
Merge pull request PSLmodels#914 from jdebacker/param_plots
Browse files Browse the repository at this point in the history
Parameter plot fixes and extensions
  • Loading branch information
jdebacker authored Mar 19, 2024
2 parents 5a811f6 + 7de7a7c commit fdc00e2
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 61 deletions.
3 changes: 2 additions & 1 deletion docs/book/content/api/parameter_plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ogcore.parameter_plots
.. automodule:: ogcore.parameter_plots
:members: plot_imm_rates, plot_mort_rates, plot_pop_growth,
plot_ability_profiles, plot_elliptical_u, plot_chi_n,
plot_fert_rates, plot_mort_rates_data, plot_omega_fixed,
plot_fert_rates, plot_mort_rates_data, plot_g_n, plot_omega_fixed,
plot_imm_fixed, plot_population_path, gen_3Dscatters_hist,
txfunc_graph, txfunc_sse_plot, plot_income_data, plot_2D_taxfunc

23 changes: 18 additions & 5 deletions ogcore/output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def plot_aggregates(
reform_params=None,
var_list=["Y", "C", "K", "L"],
plot_type="pct_diff",
stationarized=True,
num_years_to_plot=50,
start_year=DEFAULT_START_YEAR,
forecast_data=None,
Expand All @@ -40,13 +41,15 @@ def plot_aggregates(
object
var_list (list): names of variable to plot
plot_type (string): type of plot, can be:
'pct_diff': plots percentage difference between baselien
'pct_diff': plots percentage difference between baseline
and reform ((reform-base)/base)
'diff': plots difference between baseline and reform
(reform-base)
'levels': plot variables in model units
'forecast': plots variables in levels relative to baseline
economic forecast
stationarized (bool): whether used stationarized variables (False
only affects pct_diff right now)
num_years_to_plot (integer): number of years to include in plot
start_year (integer): year to start plot
forecast_data (array_like): baseline economic forecast series,
Expand Down Expand Up @@ -78,11 +81,21 @@ def plot_aggregates(
# Compute just percentage point changes for rates
plot_var = reform_tpi[v] - base_tpi[v]
else:
plot_var = (reform_tpi[v] - base_tpi[v]) / base_tpi[v]
if stationarized:
plot_var = (reform_tpi[v] - base_tpi[v]) / base_tpi[v]
else:
pct_changes = utils.pct_change_unstationarized(
base_tpi,
base_params,
reform_tpi,
reform_params,
output_vars=[v],
)
plot_var = pct_changes[v]
ylabel = r"Pct. change"
plt.plot(
year_vec,
plot_var[start_index : start_index + num_years_to_plot],
plot_var[start_index : start_index + num_years_to_plot] * 100,
label=VAR_LABELS[v],
)
elif plot_type == "diff":
Expand Down Expand Up @@ -185,7 +198,7 @@ def plot_industry_aggregates(
object
var_list (list): names of variable to plot
plot_type (string): type of plot, can be:
'pct_diff': plots percentage difference between baselien
'pct_diff': plots percentage difference between baseline
and reform ((reform-base)/base)
'diff': plots difference between baseline and reform
(reform-base)
Expand Down Expand Up @@ -343,7 +356,7 @@ def ss_3Dplot(
reform_ss (dictionary): SS output from reform run
var (string): name of variable to plot
plot_type (string): type of plot, can be:
'pct_diff': plots percentage difference between baselien
'pct_diff': plots percentage difference between baseline
and reform ((reform-base)/base)
'diff': plots difference between baseline and reform (reform-base)
'levels': plot variables in model units
Expand Down
16 changes: 15 additions & 1 deletion ogcore/output_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ogcore.constants import VAR_LABELS, DEFAULT_START_YEAR
from ogcore import tax
from ogcore.utils import save_return_table, Inequality
from ogcore.utils import pct_change_unstationarized

cur_path = os.path.split(os.path.abspath(__file__))[0]

Expand All @@ -15,6 +16,7 @@ def macro_table(
reform_params=None,
var_list=["Y", "C", "K", "L", "r", "w"],
output_type="pct_diff",
stationarized=True,
num_years=10,
include_SS=True,
include_overall=True,
Expand All @@ -38,6 +40,8 @@ def macro_table(
and reform ((reform-base)/base)
'diff': plots difference between baseline and reform (reform-base)
'levels': variables in model units
stationarized (bool): whether used stationarized variables (False
only affects pct_diff right now)
num_years (integer): number of years to include in table
include_SS (bool): whether to include the steady-state results
in the table
Expand Down Expand Up @@ -72,7 +76,17 @@ def macro_table(
for i, v in enumerate(var_list):
if output_type == "pct_diff":
# multiple by 100 so in percentage points
results = ((reform_tpi[v] - base_tpi[v]) / base_tpi[v]) * 100
if stationarized:
results = ((reform_tpi[v] - base_tpi[v]) / base_tpi[v]) * 100
else:
pct_changes = pct_change_unstationarized(
base_tpi,
base_params,
reform_tpi,
reform_params,
output_vars=[v],
)
results = pct_changes[v] * 100
results_years = results[start_index : start_index + num_years]
results_overall = (
(
Expand Down
154 changes: 130 additions & 24 deletions ogcore/parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,36 +64,62 @@ def plot_imm_rates(


def plot_mort_rates(
p, years=[DEFAULT_START_YEAR], include_title=False, path=None
p_list,
labels=[""],
years=[DEFAULT_START_YEAR],
survival_rates=False,
include_title=False,
path=None,
):
"""
Create a plot of mortality rates from OG-Core parameterization.
Args:
p (OG-Core Specifications class): parameters object
p_list (list): list of parameters objects
labels (list): list of labels for the legend
survival_rates (bool): whether to plot survival rates instead
of mortality rates
include_title (bool): whether to include a title in the plot
path (string): path to save figure to
Returns:
fig (Matplotlib plot object): plot of immigration rates
fig (Matplotlib plot object): plot of mortality rates
"""
age_per = np.linspace(p.E, p.E + p.S, p.S)
years = np.array(years) - p.start_year
p0 = p_list[0]
age_per = np.linspace(p0.E, p0.E + p0.S, p0.S)
fig, ax = plt.subplots()
for y in years:
plt.plot(age_per, p.rho[y, :], label=str(y + p.start_year))
t = y - p0.start_year
for i, p in enumerate(p_list):
if survival_rates:
plt.plot(
age_per,
np.cumprod(1 - p.rho[t, :]),
label=labels[i] + " " + str(y),
)
else:
plt.plot(age_per, p.rho[t, :], label=labels[i] + " " + str(y))
plt.xlabel(r"Age $s$ (model periods)")
plt.ylabel(r"Mortality Rates $\rho_{s}$")
plt.legend(loc="upper right")
if survival_rates:
plt.ylabel(r"Cumulative Survival Rates")
plt.legend(loc="lower left")
title = "Survival Rates"
else:
plt.ylabel(r"Mortality Rates $\rho_{s}$")
plt.legend(loc="upper right")
title = "Mortality Rates"
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.0%}".format(x) for x in vals])
if include_title:
plt.title("Mortality Rates")
plt.title(title)
if path is None:
return fig
else:
fig_path = os.path.join(path, "mortality_rates")
if survival_rates:
fig_path = os.path.join(path, "survival_rates")
else:
fig_path = os.path.join(path, "mortality_rates")
plt.savefig(fig_path, dpi=300)


Expand Down Expand Up @@ -176,13 +202,17 @@ def plot_population(p, years_to_plot=["SS"], include_title=False, path=None):
plt.savefig(fig_path, dpi=300)


def plot_ability_profiles(p, t=None, include_title=False, path=None):
def plot_ability_profiles(
p, p2=None, t=None, log_scale=False, include_title=False, path=None
):
"""
Create a plot of earnings ability profiles.
Args:
p (OG-Core Specifications class): parameters object
t (int): model period for year, if None, then plot ability matrix for SS
log_scale (bool): whether to plot in log points
include_title (bool): whether to include a title in the plot
path (string): path to save figure to
Returns:
Expand All @@ -196,10 +226,32 @@ def plot_ability_profiles(p, t=None, include_title=False, path=None):
cm = plt.get_cmap("coolwarm")
ax.set_prop_cycle(color=[cm(1.0 * i / p.J) for i in range(p.J)])
for j in range(p.J):
plt.plot(age_vec, p.e[t, :, j], label=GROUP_LABELS[p.J][j])
if log_scale:
plt.plot(age_vec, np.log(p.e[t, :, j]), label=GROUP_LABELS[p.J][j])
else:
plt.plot(age_vec, p.e[t, :, j], label=GROUP_LABELS[p.J][j])
if p2 is not None:
for j in range(p.J):
if log_scale:
plt.plot(
age_vec,
np.log(p2.e[t, :, j]),
linestyle="--",
label=GROUP_LABELS[p.J][j],
)
else:
plt.plot(
age_vec,
p2.e[t, :, j],
linestyle="--",
label=GROUP_LABELS[p.J][j],
)
plt.xlabel(r"Age")
plt.ylabel(r"Earnings ability")
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if log_scale:
plt.ylabel(r"ln(Earnings ability)")
else:
plt.ylabel(r"Earnings ability")
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncols=5)
if include_title:
plt.title("Lifecycle Profiles of Effective Labor Units")
if path is None:
Expand Down Expand Up @@ -267,21 +319,37 @@ def plot_elliptical_u(p, plot_MU=True, include_title=False, path=None):
plt.savefig(fig_path, dpi=300)


def plot_chi_n(p, include_title=False, path=None):
def plot_chi_n(
p_list,
labels=[""],
years_to_plot=[DEFAULT_START_YEAR],
include_title=False,
path=None,
):
"""
Create a plot of showing the values of the chi_n parameters.
Args:
p (OG-Core Specifications class): parameters object
p_list (list): parameters objects
labels (list): labels for legend
years_to_plot (list): list of years to plot
include_title (boolean): whether to include a title in the plot
path (string): path to save figure to
Returns:
fig (Matplotlib plot object): plot of chi_n parameters
"""
age = np.linspace(p.starting_age, p.ending_age, p.S)
p0 = p_list[0]
age = np.linspace(p0.starting_age, p0.ending_age, p0.S)
fig, ax = plt.subplots()
plt.plot(age, p.chi_n)
for y in years_to_plot:
for i, p in enumerate(p_list):
plt.plot(
age,
p.chi_n[y - p.start_year, :],
label=labels[i] + " " + str(y),
)
if include_title:
plt.title("Utility Weight on the Disutility of Labor Supply")
plt.xlabel("Age, $s$")
Expand All @@ -294,20 +362,24 @@ def plot_chi_n(p, include_title=False, path=None):


def plot_fert_rates(
fert_rates,
fert_rates_list,
labels=[""],
start_year=DEFAULT_START_YEAR,
years_to_plot=[DEFAULT_START_YEAR],
include_title=False,
source="United Nations, World Population Prospects",
path=None,
):
"""
Plot fertility rates from the data
Args:
fert_rates (NumPy array): fertility rates for each of
totpers
fert_rates_list (list): list of Numpy arrays of fertility rates
for each model period and age
labels (list): list of labels for the legend
start_year (int): first year of data
years_to_plot (list): list of years to plot
include_title (bool): whether to include a title in the plot
source (str): data source for fertility rates
path (str): path to save figure to, if None then figure
is returned
Expand All @@ -321,9 +393,10 @@ def plot_fert_rates(
fig, ax = plt.subplots()
for y in years_to_plot:
i = start_year - y
plt.plot(fert_rates[i, :], c="blue", label="Year " + str(y))
# plt.title('Fertility rates by age ($f_{s}$)',
# fontsize=20)
for i, fert_rates in enumerate(fert_rates_list):
plt.plot(fert_rates[i, :], label=labels[i] + " " + str(y))
if include_title:
plt.title("Fertility rates by age ($f_{s}$)", fontsize=20)
plt.xlabel(r"Age $s$")
plt.ylabel(r"Fertility rate $f_{s}$")
plt.legend(loc="upper right")
Expand Down Expand Up @@ -395,6 +468,39 @@ def plot_mort_rates_data(
return fig


def plot_g_n(p_list, label_list=[""], include_title=False, path=None):
"""
Create a plot of population growth rates from OG-Core parameterization.
Args:
p_list (list): list of OG-Core Specifications objects
label_list (list): list of labels for the legend
include_title (bool): whether to include a title in the plot
path (string): path to save figure to
Returns:
fig (Matplotlib plot object): plot of immigration rates
"""
p0 = p_list[0]
years = np.arange(p0.start_year, p0.start_year + p0.T)
fig, ax = plt.subplots()
for i, p in enumerate(p_list):
plt.plot(years, p.g_n[: p.T], label=label_list[i])
plt.xlabel(r"Year $s$ (model periods)")
plt.ylabel(r"Population Growth Rate $g_{n,t}$")
plt.legend(loc="upper right")
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.0%}".format(x) for x in vals])
if include_title:
plt.title("Population Growth Rates")
if path is None:
return fig
else:
fig_path = os.path.join(path, "pop_growth_rates")
plt.savefig(fig_path, dpi=300)


def plot_omega_fixed(age_per_EpS, omega_SS_orig, omega_SSfx, E, S, path=None):
"""
Plot the steady-state population distribution implied by the data
Expand Down
Loading

0 comments on commit fdc00e2

Please sign in to comment.