Skip to content

Commit

Permalink
Cleanup second (#128)
Browse files Browse the repository at this point in the history
Co-authored-by: MaxBlesch <[email protected]>
  • Loading branch information
segsell and MaxBlesch authored Oct 2, 2024
1 parent 1fe471e commit 66a15b7
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 429 deletions.
8 changes: 4 additions & 4 deletions src/dcegm/egm/solve_euler_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def calculate_candidate_solutions_from_euler_equation(
) = vmap(
vmap(
vmap(
wrapper_cont_optimal_policy_and_value,
compute_optimal_policy_and_value_wrapper,
in_axes=(1, 1, None, 0, None, None, None), # savings
),
in_axes=(1, 1, 0, None, None, None, None), # second continuous state
Expand Down Expand Up @@ -77,7 +77,7 @@ def calculate_candidate_solutions_from_euler_equation(
)


def wrapper_cont_optimal_policy_and_value(
def compute_optimal_policy_and_value_wrapper(
marg_util_next: np.ndarray,
emax_next: np.ndarray,
second_continuous_grid: np.ndarray,
Expand All @@ -86,9 +86,9 @@ def wrapper_cont_optimal_policy_and_value(
model_funcs: Dict[str, Callable],
params: Dict[str, float],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Simple wrapper function to write second continuous grid point into
state_choice_vec."""
"""Write second continuous grid point into state_choice_vec."""
state_choice_vec["second_continuous"] = second_continuous_grid

return compute_optimal_policy_and_value(
marg_util_next,
emax_next,
Expand Down
2 changes: 0 additions & 2 deletions src/dcegm/final_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def solve_last_two_periods(
value_solved,
policy_solved,
endog_grid_solved,
emax,
marg_util,
)


Expand Down
3 changes: 0 additions & 3 deletions src/dcegm/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def backward_induction(
"""
taste_shock_scale = params["lambda"]

# Calculate the continuous grids for the next period
cont_grids_next_period = calc_cont_grids_next_period(
state_space_dict=state_space_dict,
exog_grids=exog_grids,
Expand Down Expand Up @@ -245,8 +244,6 @@ def backward_induction(
value_solved,
policy_solved,
endog_grid_solved,
value_interp_final_period,
marginal_utility_final_last_period,
) = solve_last_two_periods(
cont_grids_next_period=cont_grids_next_period,
params=params,
Expand Down
4 changes: 2 additions & 2 deletions src/dcegm/wealth_correction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
from jax import vmap

from dcegm.budget import calculate_resources_for_each_grid_point
from dcegm.law_of_motion import calc_resources_for_each_savings_grid_point


def adjust_observed_wealth(observed_states_dict, wealth, params, model):
Expand All @@ -17,7 +17,7 @@ def adjust_observed_wealth(observed_states_dict, wealth, params, model):
savings_last_period = jnp.asarray(wealth / (1 + params["interest_rate"]))

adjusted_resources = vmap(
calculate_resources_for_each_grid_point, in_axes=(0, 0, None, None, None)
calc_resources_for_each_savings_grid_point, in_axes=(0, 0, None, None, None)
)(
observed_states_dict,
savings_last_period,
Expand Down
Loading

0 comments on commit 66a15b7

Please sign in to comment.