Skip to content

Commit

Permalink
More outputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed Oct 27, 2023
1 parent d07f83e commit 5bef46a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
20 changes: 20 additions & 0 deletions src/dcegm/budget.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,23 @@ def calculate_resources_for_each_grid_point(
params=params,
)
return out


def calculate_resources_for_all_agents(
states_beginning_of_period,
savings_end_of_last_period,
income_shocks_of_period,
params,
compute_beginning_of_period_wealth,
):
resources_beginning_of_next_period = vmap(
calculate_resources_for_each_grid_point,
in_axes=(0, 0, 0, None, None),
)(
states_beginning_of_period,
savings_end_of_last_period,
income_shocks_of_period,
params,
compute_beginning_of_period_wealth,
)
return resources_beginning_of_next_period
46 changes: 19 additions & 27 deletions src/dcegm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import jax.numpy as jnp
from dcegm.budget import calculate_resources_for_each_grid_point
from dcegm.budget import calculate_resources_for_all_agents
from dcegm.egm.interpolate_marginal_utility import interpolate_policy_and_check_value
from dcegm.interpolation import get_index_high_and_low
from jax import vmap
Expand Down Expand Up @@ -47,6 +47,8 @@ def simulate_all_periods(
states_and_wealth_last_period, sim_data = jax.lax.scan(
f=simulate_body, init=states_and_wealth_period_0, xs=jnp.arange(num_periods - 1)
)
# ToDo: Last period.

return states_and_wealth_last_period, sim_data


Expand Down Expand Up @@ -125,7 +127,11 @@ def simulate_single_period(
savings_this_period = wealth_beginning_of_period - consumption_period

# Transition to next period.
wealth_at_beginning_of_next_period, states_next_period = transition_to_next_period(
(
wealth_at_beginning_of_next_period,
states_next_period,
income_shocks_next_period,
) = transition_to_next_period(
states_beginning_of_period=states_beginning_of_period,
savings_this_period=savings_this_period,
choice_period=choice_period,
Expand All @@ -145,6 +151,8 @@ def simulate_single_period(
"taste_shocks": taste_shocks,
"value": value_period,
"savings": savings_this_period,
"income_shock": income_shocks_next_period,
**states_beginning_of_period,
}

return carry, result_data
Expand Down Expand Up @@ -234,45 +242,30 @@ def transition_to_next_period(
choice_period,
params,
)
# Draw income shocks.
income_shocks_next_period = draw_normal_shocks(
key=key, num_agents=num_agents, mean=0, std=params["sigma"]
)
# Generate states next period and apply budged constraint for wealth at the
# beginning of next period.
# Initialize states by copying
states_next_period = states_beginning_of_period.copy()
# Then update
states_to_update = {**endog_states_next_period, **exog_states_next_period}
states_next_period.update(states_to_update)

# Draw income shocks.
income_shocks_next_period = draw_normal_shocks(
key=key, num_agents=num_agents, mean=0, std=params["sigma"]
)
wealth_at_beginning_of_next_period = calculate_resources_for_all_agents(
states_beginning_of_period=states_next_period,
savings_end_of_last_period=savings_this_period,
income_shocks_of_period=income_shocks_next_period,
params=params,
compute_beginning_of_period_wealth=compute_beginning_of_period_wealth,
)
return wealth_at_beginning_of_next_period, states_next_period


def calculate_resources_for_all_agents(
states_beginning_of_period,
savings_end_of_last_period,
income_shocks_of_period,
params,
compute_beginning_of_period_wealth,
):
resources_beginning_of_next_period = vmap(
calculate_resources_for_each_grid_point,
in_axes=(0, 0, 0, None, None),
)(
states_beginning_of_period,
savings_end_of_last_period,
income_shocks_of_period,
params,
compute_beginning_of_period_wealth,
return (
wealth_at_beginning_of_next_period,
states_next_period,
income_shocks_next_period,
)
return resources_beginning_of_next_period


def draw_normal_shocks(key, num_agents, mean=0, std=1):
Expand Down Expand Up @@ -302,7 +295,6 @@ def realize_exog_process(state, choice, key, params, exog_func, exog_state_mappi
key=key, a=transition_vec.shape[0], p=transition_vec
)
exog_states_next_period = exog_state_mapping(exog_proc_next_period)

return exog_states_next_period


Expand Down

0 comments on commit 5bef46a

Please sign in to comment.