Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes. Working on second cond #123

Merged
merged 33 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4887b40
Starting to fix. LAst period working
MaxBlesch Sep 11, 2024
abf2cca
Draft.
MaxBlesch Sep 18, 2024
d8fc758
Fix typo
segsell Sep 20, 2024
3f8a9bd
Add two period test for second continuous state
segsell Sep 21, 2024
8801feb
Only pass at 3 decimal precision
segsell Sep 21, 2024
38ffd51
Some renaming
segsell Sep 21, 2024
617ca0e
Testing intermediate steps.
MaxBlesch Sep 23, 2024
7ef755f
WOrks?
MaxBlesch Sep 23, 2024
ce8d009
Trying.
MaxBlesch Sep 23, 2024
596df93
Fix
MaxBlesch Sep 23, 2024
3891eb1
.
segsell Sep 25, 2024
dca4359
Fix merge conflicts
segsell Sep 25, 2024
5519137
Adjust temporary interface
segsell Sep 25, 2024
eef5dd3
Some error in disc versus cont test
segsell Sep 25, 2024
eb00512
Fix test for final period
segsell Sep 25, 2024
651e4af
Clean up two period test
segsell Sep 25, 2024
58efaf2
Rename test arguments
segsell Sep 25, 2024
b37a184
Test passes when experience switched off
segsell Sep 25, 2024
e610d3b
Found bug (:
segsell Sep 25, 2024
c28a661
Add absorbing retirement again
segsell Sep 25, 2024
d515c25
For earlier periods still reasonably close
segsell Sep 25, 2024
54f0932
Save current
segsell Sep 25, 2024
0976127
Test fails for 1 discrete choice case
segsell Sep 26, 2024
4c658a3
Problems with sparsity in state space extends beyond two-period model
segsell Sep 26, 2024
29931b2
Found bug in batches
segsell Sep 26, 2024
b91f1b1
Fixed it.
MaxBlesch Sep 27, 2024
0438bd4
Add parametrize
segsell Sep 27, 2024
df72658
Merge branch 'max_cont' of https://github.com/OpenSourceEconomics/dce…
segsell Sep 27, 2024
be58acd
Add more test cases
segsell Sep 27, 2024
067472a
Remove unused code
segsell Sep 30, 2024
28978de
Setup test for utility with second continuous state
segsell Oct 1, 2024
0b8f850
Set utility parameter to zero
segsell Oct 1, 2024
96ec207
Adjust interface
segsell Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dcegm/egm/aggregate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def aggregate_marg_utils_and_exp_values(
)

log_sum_unsqueezed = max_value_per_state + taste_shock_scale * jnp.log(sum_exp)

# Because we kept the dimensions in the maximum and sum over choice specific objects
# to perform subtraction and division, we now need to squeeze the log_sum again
# to remove the redundant axis.
Expand Down
10 changes: 5 additions & 5 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,6 @@ def interp2d_value_and_marg_util_for_state_choice(

def interp_on_single_wealth_point(wealth_point, regular_point):

# To-Do: Add second vmap for regular grid point

policy_interp, value_interp = (
interp2d_policy_and_value_on_wealth_and_regular_grid(
regular_grid=regular_grid,
Expand All @@ -235,7 +233,10 @@ def interp_on_single_wealth_point(wealth_point, regular_point):
)
)
marg_util_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
consumption=policy_interp,
continuous_state=regular_point,
params=params,
**state_choice_vec
)

return value_interp, marg_util_interp
Expand All @@ -252,8 +253,7 @@ def interp_on_single_wealth_point(wealth_point, regular_point):
in_axes=(0, 0), # continuous state grid
)

# To-Do: Interpolate over next period regular and wealth point
# Old points regular grid and endog grid
# Old points: regular grid and endog grid
# New points: continuous state next period and wealth next period
value_interp, marg_util_interp = interp_over_single_wealth_and_income_shock_draw(
wealth_beginning_of_next_period, continuous_state_beginning_of_next_period
Expand Down
150 changes: 125 additions & 25 deletions src/dcegm/egm/solve_euler_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@


def calculate_candidate_solutions_from_euler_equation(
exog_savings_grid: np.ndarray,
marg_util: jnp.ndarray,
emax: jnp.ndarray,
exog_grids: np.ndarray,
marg_util_next: jnp.ndarray,
emax_next: jnp.ndarray,
state_choice_mat: np.ndarray,
idx_post_decision_child_states: np.ndarray,
compute_utility: Callable,
Expand All @@ -20,13 +20,11 @@ def calculate_candidate_solutions_from_euler_equation(
params: Dict[str, float],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Calculate candidates for the optimal policy and value function."""
feasible_marg_utils_child, feasible_emax_child = (
_get_post_decision_marg_utils_and_emax(
marg_util_next=marg_util,
emax_next=emax,
idx_post_decision_child_states=idx_post_decision_child_states,
)

feasible_marg_utils_child = jnp.take(
marg_util_next, idx_post_decision_child_states, axis=0
)
feasible_emax_child = jnp.take(emax_next, idx_post_decision_child_states, axis=0)

# transform exog_transition_mat to matrix with same shape as state_choice_vec

Expand All @@ -39,16 +37,47 @@ def calculate_candidate_solutions_from_euler_equation(
) = vmap(
vmap(
vmap(
compute_optimal_policy_and_value,
in_axes=(1, 1, 0, None, None, None, None, None), # savings grid
compute_optimal_policy_and_value_second_continuous,
in_axes=(
1,
1,
None,
0,
None,
None,
None,
None,
None,
), # savings
),
in_axes=(1, 1, None, None, None, None, None, None), # continuous state
in_axes=(
1,
1,
0,
None,
None,
None,
None,
None,
None,
), # second continuous state
),
in_axes=(0, 0, None, 0, None, None, None, None), # discrete states choices
in_axes=(
0,
0,
None,
None,
0,
None,
None,
None,
None,
), # discrete states choices
)(
feasible_marg_utils_child,
feasible_emax_child,
exog_savings_grid,
exog_grids["second_continuous"],
exog_grids["wealth"],
state_choice_mat,
compute_inverse_marginal_utility,
compute_utility,
Expand All @@ -70,7 +99,7 @@ def calculate_candidate_solutions_from_euler_equation(
)(
feasible_marg_utils_child,
feasible_emax_child,
exog_savings_grid,
exog_grids["wealth"],
state_choice_mat,
compute_inverse_marginal_utility,
compute_utility,
Expand All @@ -87,8 +116,8 @@ def calculate_candidate_solutions_from_euler_equation(


def compute_optimal_policy_and_value(
marg_utils: np.ndarray,
emax: np.ndarray,
marg_util_next: np.ndarray,
emax_next: np.ndarray,
exogenous_savings_grid: np.ndarray,
state_choice_vec: Dict,
compute_inverse_marginal_utility: Callable,
Expand Down Expand Up @@ -138,8 +167,8 @@ def compute_optimal_policy_and_value(

policy, expected_value = solve_euler_equation(
state_choice_vec=state_choice_vec,
marg_utils=marg_utils,
emax=emax,
marg_util_next=marg_util_next,
emax_next=emax_next,
compute_inverse_marginal_utility=compute_inverse_marginal_utility,
compute_exog_transition_vec=compute_exog_transition_vec,
params=params,
Expand All @@ -152,10 +181,82 @@ def compute_optimal_policy_and_value(
return endog_grid, policy, value, expected_value


def compute_optimal_policy_and_value_second_continuous(
marg_util_next: np.ndarray,
emax_next: np.ndarray,
second_continuous_grid: np.ndarray,
exogenous_savings_grid: np.ndarray,
state_choice_vec: Dict,
compute_inverse_marginal_utility: Callable,
compute_utility: Callable,
compute_exog_transition_vec: Callable,
params: Dict[str, float],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Compute optimal child-state- and choice-specific policy and value function.

Given the marginal utilities of possible child states and next period wealth, we
compute the optimal policy and value functions by solving the euler equation
and using the optimal consumption level in the bellman equation.

Args:
marg_utils (np.ndarray): 1d array of shape (n_exog_processes,) containing
the state-choice specific marginal utilities for a given point on
the savings grid.
emax (np.ndarray): 1d array of shape (n_exog_processes,) containing
the state-choice specific expected maximum value for a given point on
the savings grid.
exogenous_savings_grid (np.ndarray): 1d array of shape (n_grid_wealth,)
containing the exogenous savings grid.
trans_vec_state (np.ndarray): 1d array of shape (n_exog_processes,) containing
for each exogenous process state the corresponding transition probability.
state_choice_vec (np.ndarray): A dictionary containing the states and a
corresponding admissible choice of a particular state choice vector.
compute_inverse_marginal_utility (Callable): Function for calculating the
inverse marginal utility, which takes the marginal utility as only input.
compute_value (callable): Function for calculating the value from consumption
level, discrete choice and expected value. The inputs ```discount_rate```
and ```compute_utility``` are already partialled in.
params (dict): Dictionary of model parameters.

Returns:
tuple:

- endog_grid (np.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific endogenous grid.
- policy (np.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific policy function.
- value (np.ndarray): 1d array of shape (n_grid_wealth + 1,)
containing the current state- and choice-specific value function.
- expected_value_zero_savings (float): The agent's expected value given that
she saves nothing.

"""

policy, expected_value = solve_euler_equation(
state_choice_vec=state_choice_vec,
marg_util_next=marg_util_next,
emax_next=emax_next,
compute_inverse_marginal_utility=compute_inverse_marginal_utility,
compute_exog_transition_vec=compute_exog_transition_vec,
params=params,
)
endog_grid = exogenous_savings_grid + policy

utility = compute_utility(
consumption=policy,
continuous_state=second_continuous_grid,
params=params,
**state_choice_vec,
)
value = utility + params["beta"] * expected_value

return endog_grid, policy, value, expected_value


def solve_euler_equation(
state_choice_vec: dict,
marg_utils: np.ndarray,
emax: np.ndarray,
marg_util_next: np.ndarray,
emax_next: np.ndarray,
compute_inverse_marginal_utility: Callable,
compute_exog_transition_vec: Callable,
params: Dict[str, float],
Expand Down Expand Up @@ -192,11 +293,11 @@ def solve_euler_equation(
transition_vec = compute_exog_transition_vec(params=params, **state_choice_vec)

# Integrate out uncertainty over exogenous processes
marginal_utility = jnp.nansum(transition_vec * marg_utils)
expected_value = jnp.nansum(transition_vec * emax)
marginal_utility_next = jnp.nansum(transition_vec * marg_util_next)
expected_value = jnp.nansum(transition_vec * emax_next)

# RHS of Euler Eq., p. 337 IJRS (2017) by multiplying with marginal wealth
rhs_euler = marginal_utility * (1 + params["interest_rate"]) * params["beta"]
rhs_euler = marginal_utility_next * (1 + params["interest_rate"]) * params["beta"]

policy = compute_inverse_marginal_utility(
marginal_utility=rhs_euler,
Expand Down Expand Up @@ -243,7 +344,6 @@ def _get_post_decision_marg_utils_and_emax(
child states in the current period t.

"""

# state-choice specific
marg_utils_child = jnp.take(marg_util_next, idx_post_decision_child_states, axis=0)
emax_child = jnp.take(emax_next, idx_post_decision_child_states, axis=0)
Expand Down
Loading
Loading