Skip to content

Commit

Permalink
State specific draw functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed Jan 12, 2025
1 parent 4243ddf commit 5916dd6
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 12 deletions.
7 changes: 6 additions & 1 deletion src/dcegm/final_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def solve_last_two_periods(
cont_grids_next_period: Dict[str, jnp.ndarray],
params: Dict[str, float],
taste_shock_scale: float,
taste_shock_scale,
income_shock_weights: jnp.ndarray,
exog_grids: Dict[str, jnp.ndarray],
model_funcs: Dict[str, Callable],
Expand Down Expand Up @@ -75,6 +75,11 @@ def solve_last_two_periods(
has_second_continuous_state=has_second_continuous_state,
)

if len(taste_shock_scale) > 1:
taste_shock_scale = taste_shock_scale[

Check warning on line 79 in src/dcegm/final_periods.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/final_periods.py#L79

Added line #L79 was not covered by tests
last_two_period_batch_info["idxs_parent_states_final_period"]
]

endog_grid, policy, value = solve_for_interpolated_values(
value_interpolated=value_interp_final_period,
marginal_utility_interpolated=marginal_utility_final_last_period,
Expand Down
14 changes: 13 additions & 1 deletion src/dcegm/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,21 @@ def calc_choice_probs_for_states(
params,
compute_utility,
)

model_funcs = model["model_funcs"]

# The following allows to specify a function to return taste shock scales for each
# state differently.
if model_funcs["draw_functions"]["draw_taste_shock_per_state"]:
taste_shock_scale = model_funcs["draw_functions"]["draw_taste_shock_per_state"](
params=params, state_space_dict=observed_states
)
else:
taste_shock_scale = model_funcs["draw_functions"]["taste_shock_scale"](params)

choice_prob_across_choices, _, _ = calculate_choice_probs_and_unsqueezed_logsum(
choice_values_per_state=value_per_agent_interp,
taste_shock_scale=params["lambda"],
taste_shock_scale=taste_shock_scale,
)
return choice_prob_across_choices

Expand Down
2 changes: 0 additions & 2 deletions src/dcegm/pre_processing/check_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ def process_params(params: Union[dict, pd.Series, pd.DataFrame]) -> Dict[str, fl

if "interest_rate" not in params:
params["interest_rate"] = 0
if "lambda" not in params:
params["lambda"] = 0
if "sigma" not in params:
params["sigma"] = 0
if "beta" not in params:
Expand Down
66 changes: 66 additions & 0 deletions src/dcegm/pre_processing/model_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Dict

import jax
import jax.numpy as jnp
from upper_envelope.fues_jax.fues_jax import fues_jax

Expand All @@ -15,6 +16,7 @@ def process_model_functions(
utility_functions: Dict[str, Callable],
utility_functions_final_period: Dict[str, Callable],
budget_constraint: Callable,
draw_functions: Dict[str, Callable] = None,
):
"""Create wrapped functions from user supplied functions.
Expand Down Expand Up @@ -125,6 +127,10 @@ def process_model_functions(
continuous_state=continuous_state_name,
)

draw_functions_processed = process_draw_functions(
draw_functions, options, continuous_state_name
)

model_funcs = {
"compute_utility": compute_utility,
"compute_marginal_utility": compute_marginal_utility,
Expand All @@ -139,11 +145,71 @@ def process_model_functions(
"state_specific_choice_set": state_specific_choice_set,
"next_period_endogenous_state": next_period_endogenous_state,
"compute_upper_envelope": compute_upper_envelope,
"draw_functions": draw_functions_processed,
}

return model_funcs


def process_draw_functions(draw_functions, options, continuous_state_name):
draw_functions_processed = {} if draw_functions is None else draw_functions
if "taste_shock_per_state" in draw_functions_processed.keys():
draw_functions_processed["taste_shock_per_state"] = (

Check warning on line 157 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L157

Added line #L157 was not covered by tests
get_taste_shock_over_all_states_func(
draw_function_taste_shocks=draw_functions_processed[
"taste_shock_per_state"
],
options=options,
continuous_state_name=continuous_state_name,
)
)
draw_taste_shock_per_state = True

Check warning on line 166 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L166

Added line #L166 was not covered by tests
else:
draw_taste_shock_per_state = False
if "lambda" in options["model_params"]:
# Check if lambda is a scalar
lambda_val = options["model_params"]["lambda"]
if not isinstance(lambda_val, (int, float)):
raise ValueError(

Check warning on line 173 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L171-L173

Added lines #L171 - L173 were not covered by tests
f"Lambda is not a scalar. If there is no draw function provided, "
f"lambda must be a scalar. Got {lambda_val}."
)
read_function = lambda params_in: jnp.asarray(

Check warning on line 177 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L177

Added line #L177 was not covered by tests
[options["model_params"]["lambda"]]
)
else:
read_function = lambda params_in: jnp.asarray([params_in["lambda"]])

draw_functions_processed["taste_shock_scale"] = read_function

draw_functions_processed["draw_taste_shock_per_state"] = draw_taste_shock_per_state
return draw_functions_processed


def get_taste_shock_over_all_states_func(
draw_function_taste_shocks, options, continuous_state_name
):
not_allowed_states = ["wealth"]
if continuous_state_name is not None:
not_allowed_states += [continuous_state_name]
taste_shock_per_state_function = determine_function_arguments_and_partial_options(

Check warning on line 195 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L192-L195

Added lines #L192 - L195 were not covered by tests
func=draw_function_taste_shocks,
options=options["model_params"],
not_allowed_state_choices=not_allowed_states,
continuous_state_name=continuous_state_name,
)

def vectorized_taste_shock_per_state(state_dict_vec, params):
return taste_shock_per_state_function(params=params, **state_dict_vec)

Check warning on line 203 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L202-L203

Added lines #L202 - L203 were not covered by tests

def taste_shock_over_all_states_func(state_dict, params):
return jax.vmap(vectorized_taste_shock_per_state, in_axes=(0, None))(

Check warning on line 206 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L205-L206

Added lines #L205 - L206 were not covered by tests
state_dict, params
)

return taste_shock_over_all_states_func

Check warning on line 210 in src/dcegm/pre_processing/model_functions.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/model_functions.py#L210

Added line #L210 was not covered by tests


def process_state_space_functions(
state_space_functions, options, continuous_state_name
):
Expand Down
2 changes: 2 additions & 0 deletions src/dcegm/pre_processing/setup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def setup_model(
utility_functions_final_period: Dict[str, Callable],
budget_constraint: Callable,
state_space_functions: Dict[str, Callable] = None,
draw_functions: Dict[str, Callable] = None,
debug_output: str = None,
):
"""Set up the model for dcegm.
Expand Down Expand Up @@ -62,6 +63,7 @@ def setup_model(
utility_functions=utility_functions,
utility_functions_final_period=utility_functions_final_period,
budget_constraint=budget_constraint,
draw_functions=draw_functions,
)

model_structure = create_model_structure(
Expand Down
11 changes: 10 additions & 1 deletion src/dcegm/pre_processing/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,24 @@


def determine_function_arguments_and_partial_options(
func, options, continuous_state_name=None
func, options, not_allowed_state_choices=None, continuous_state_name=None
):
signature = set(inspect.signature(func).parameters)
not_allowed_state_choices = (
[] if not_allowed_state_choices is None else not_allowed_state_choices
)

partialed_func, signature = partial_options_and_update_signature(
func=func,
signature=signature,
options=options,
)
if len(not_allowed_state_choices) > 0:
for var in signature:
if var in not_allowed_state_choices:
raise ValueError(

Check warning on line 25 in src/dcegm/pre_processing/shared.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/pre_processing/shared.py#L23-L25

Added lines #L23 - L25 were not covered by tests
f"{func.__name__}() has a not allowed input variable: {var}"
)

@functools.wraps(func)
def processed_func(**kwargs):
Expand Down
12 changes: 11 additions & 1 deletion src/dcegm/simulation/sim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,18 @@ def next_period_continuous_state_for_one_agent(
)


def draw_taste_shocks(n_agents, n_choices, taste_shock_scale, key):
def draw_taste_shocks(n_agents, n_choices, states, params, draw_functions, key):
taste_shocks = jax.random.gumbel(key=key, shape=(n_agents, n_choices))
# The following allows to specify a function to return taste shock scales for each
# state differently.
if draw_functions["draw_taste_shock_per_state"]:
taste_shock_scale = draw_functions["draw_taste_shock_per_state"](

Check warning on line 248 in src/dcegm/simulation/sim_utils.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/simulation/sim_utils.py#L248

Added line #L248 was not covered by tests
params=params,
state_space_dict=states,
)
else:
taste_shock_scale = draw_functions["taste_shock_scale"](params)

taste_shocks = taste_shock_scale * (taste_shocks - jnp.euler_gamma)
return taste_shocks

Expand Down
12 changes: 10 additions & 2 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def simulate_all_periods(
],
exog_state_mapping=model_funcs_sim["exog_state_mapping"],
compute_next_period_states=compute_next_period_states,
draw_functions=model_funcs_sim["draw_functions"],
second_continuous_state_dict=second_continuous_state_dict,
)

Expand All @@ -123,6 +124,7 @@ def simulate_all_periods(
choice_range=model_structure_solution["choice_range"],
map_state_choice_to_index=model_structure_solution["map_state_choice_to_index"],
compute_utility_final_period=model_funcs_sim["compute_utility_final"],
draw_functions=model_funcs_sim["draw_functions"],
)

result = {
Expand Down Expand Up @@ -151,6 +153,7 @@ def simulate_single_period(
compute_beginning_of_period_wealth,
exog_state_mapping,
compute_next_period_states,
draw_functions,
second_continuous_state_dict=None,
):
(
Expand Down Expand Up @@ -195,7 +198,9 @@ def simulate_single_period(
taste_shocks = draw_taste_shocks(
n_agents=len(wealth_beginning_of_period),
n_choices=len(choice_range),
taste_shock_scale=params["lambda"],
states=states_beginning_of_period,
params=params,
draw_functions=draw_functions,
key=sim_specific_keys[0, :],
)
values_across_choices = values_pre_taste_shock + taste_shocks
Expand Down Expand Up @@ -268,6 +273,7 @@ def simulate_final_period(
choice_range,
map_state_choice_to_index,
compute_utility_final_period,
draw_functions,
):
invalid_number = np.iinfo(map_state_choice_to_index.dtype).max

Expand Down Expand Up @@ -305,7 +311,9 @@ def simulate_final_period(
taste_shocks = draw_taste_shocks(
n_agents=n_agents,
n_choices=n_choices,
taste_shock_scale=params["lambda"],
states=states_beginning_of_final_period,
params=params,
draw_functions=draw_functions,
key=sim_specific_keys[0, :],
)
values_across_choices = utilities_pre_taste_shock + taste_shocks
Expand Down
13 changes: 12 additions & 1 deletion src/dcegm/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def solve_dcegm(
utility_functions_final_period: Dict[str, Callable],
budget_constraint: Callable,
state_space_functions: Dict[str, Callable] = None,
draw_functions: Dict[str, Callable] = None,
) -> Dict[int, np.ndarray]:
"""Solve a discrete-continuous life-cycle model using the DC-EGM algorithm.
Expand Down Expand Up @@ -57,6 +58,7 @@ def solve_dcegm(
utility_functions=utility_functions,
budget_constraint=budget_constraint,
utility_functions_final_period=utility_functions_final_period,
draw_functions=draw_functions,
)

results = backward_jit(params=params)
Expand All @@ -70,6 +72,7 @@ def get_solve_function(
budget_constraint: Callable,
utility_functions_final_period: Dict[str, Callable],
state_space_functions: Dict[str, Callable] = None,
draw_functions: Dict[str, Callable] = None,
) -> Callable:
"""Create a solve function, which only takes params as input.
Expand Down Expand Up @@ -101,6 +104,7 @@ def get_solve_function(
utility_functions=utility_functions,
utility_functions_final_period=utility_functions_final_period,
budget_constraint=budget_constraint,
draw_functions=draw_functions,
)

return get_solve_func_for_model(model=model)
Expand Down Expand Up @@ -214,7 +218,14 @@ def backward_induction(
from the backward induction.
"""
taste_shock_scale = params["lambda"]
# The following allows to specify a function to return taste shock scales for each
# state differently.
if model_funcs["draw_functions"]["draw_taste_shock_per_state"]:
taste_shock_scale = model_funcs["draw_functions"]["draw_taste_shock_per_state"](

Check warning on line 224 in src/dcegm/solve.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/solve.py#L224

Added line #L224 was not covered by tests
params=params, state_space_dict=state_space_dict
)
else:
taste_shock_scale = model_funcs["draw_functions"]["taste_shock_scale"](params)

cont_grids_next_period = calc_cont_grids_next_period(
state_space_dict=state_space_dict,
Expand Down
3 changes: 3 additions & 0 deletions src/dcegm/solve_single_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def solve_single_period(
params=params,
)

if len(taste_shock_scale) > 1:
taste_shock_scale = taste_shock_scale[child_state_idxs]

Check warning on line 52 in src/dcegm/solve_single_period.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/solve_single_period.py#L52

Added line #L52 was not covered by tests

endog_grid_state_choice, policy_state_choice, value_state_choice = (
solve_for_interpolated_values(
value_interpolated=value_interpolated,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_pre_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ def test_missing_parameter(

params.pop("interest_rate")
params.pop("sigma")
params.pop("lambda")

params_dict = process_params(params)

for param in ["interest_rate", "sigma", "lambda"]:
for param in ["interest_rate", "sigma"]:
assert param in params_dict.keys()

params.pop("beta")
Expand Down
3 changes: 3 additions & 0 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_simulate_lax_scan(model_setup):
"compute_beginning_of_period_wealth"
],
exog_state_mapping=exog_state_mapping,
draw_functions=model_funcs["draw_functions"],
compute_next_period_states={
"next_period_endogenous_state": next_period_endogenous_state
},
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_simulate_lax_scan(model_setup):
choice_range=choice_range,
map_state_choice_to_index=jnp.array(map_state_choice_to_index),
compute_utility_final_period=model_funcs["compute_utility_final"],
draw_functions=model_funcs["draw_functions"],
)
final_period_dict = simulate_final_period(
states_and_wealth_beginning_of_final_period,
Expand All @@ -158,6 +160,7 @@ def test_simulate_lax_scan(model_setup):
choice_range=choice_range,
map_state_choice_to_index=jnp.array(map_state_choice_to_index),
compute_utility_final_period=model_funcs["compute_utility_final"],
draw_functions=model_funcs["draw_functions"],
)

aaae(np.squeeze(lax_sim_dict_zero["taste_shocks"]), sim_dict_zero["taste_shocks"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_two_period_continuous_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def create_test_inputs():
"state_to_choices_final_period"
],
params=params,
taste_shock_scale=taste_shock_scale,
taste_shock_scale=jnp.array([taste_shock_scale]),
income_shock_weights=income_shock_weights,
exog_grids=exog_grids_cont,
model_funcs=model_funcs_cont,
Expand Down

0 comments on commit 5916dd6

Please sign in to comment.