From 5916dd63bcfe149858874b4380aa305179bacb9a Mon Sep 17 00:00:00 2001 From: MaxBlesch Date: Mon, 13 Jan 2025 00:27:57 +0100 Subject: [PATCH] State specific draw functions. --- src/dcegm/final_periods.py | 7 +- src/dcegm/likelihood.py | 14 +++- src/dcegm/pre_processing/check_params.py | 2 - src/dcegm/pre_processing/model_functions.py | 66 +++++++++++++++++++ src/dcegm/pre_processing/setup_model.py | 2 + src/dcegm/pre_processing/shared.py | 11 +++- src/dcegm/simulation/sim_utils.py | 12 +++- src/dcegm/simulation/simulate.py | 12 +++- src/dcegm/solve.py | 13 +++- src/dcegm/solve_single_period.py | 3 + tests/test_pre_processing.py | 3 +- tests/test_simulate.py | 3 + .../test_two_period_continuous_experience.py | 2 +- 13 files changed, 138 insertions(+), 12 deletions(-) diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index 9beeb68b..32aaa8b3 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -14,7 +14,7 @@ def solve_last_two_periods( cont_grids_next_period: Dict[str, jnp.ndarray], params: Dict[str, float], - taste_shock_scale: float, + taste_shock_scale, income_shock_weights: jnp.ndarray, exog_grids: Dict[str, jnp.ndarray], model_funcs: Dict[str, Callable], @@ -75,6 +75,11 @@ def solve_last_two_periods( has_second_continuous_state=has_second_continuous_state, ) + if len(taste_shock_scale) > 1: + taste_shock_scale = taste_shock_scale[ + last_two_period_batch_info["idxs_parent_states_final_period"] + ] + endog_grid, policy, value = solve_for_interpolated_values( value_interpolated=value_interp_final_period, marginal_utility_interpolated=marginal_utility_final_last_period, diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index 1bdd4496..0db5a2c8 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -312,9 +312,21 @@ def calc_choice_probs_for_states( params, compute_utility, ) + + model_funcs = model["model_funcs"] + + # The following allows to specify a function to return taste shock scales for each + # state differently. + if model_funcs["draw_functions"]["draw_taste_shock_per_state"]: + taste_shock_scale = model_funcs["draw_functions"]["draw_taste_shock_per_state"]( + params=params, state_space_dict=observed_states + ) + else: + taste_shock_scale = model_funcs["draw_functions"]["taste_shock_scale"](params) + choice_prob_across_choices, _, _ = calculate_choice_probs_and_unsqueezed_logsum( choice_values_per_state=value_per_agent_interp, - taste_shock_scale=params["lambda"], + taste_shock_scale=taste_shock_scale, ) return choice_prob_across_choices diff --git a/src/dcegm/pre_processing/check_params.py b/src/dcegm/pre_processing/check_params.py index 8e5e0a91..2b5f1a2e 100644 --- a/src/dcegm/pre_processing/check_params.py +++ b/src/dcegm/pre_processing/check_params.py @@ -20,8 +20,6 @@ def process_params(params: Union[dict, pd.Series, pd.DataFrame]) -> Dict[str, fl if "interest_rate" not in params: params["interest_rate"] = 0 - if "lambda" not in params: - params["lambda"] = 0 if "sigma" not in params: params["sigma"] = 0 if "beta" not in params: diff --git a/src/dcegm/pre_processing/model_functions.py b/src/dcegm/pre_processing/model_functions.py index e800d655..1a440475 100644 --- a/src/dcegm/pre_processing/model_functions.py +++ b/src/dcegm/pre_processing/model_functions.py @@ -1,5 +1,6 @@ from typing import Callable, Dict +import jax import jax.numpy as jnp from upper_envelope.fues_jax.fues_jax import fues_jax @@ -15,6 +16,7 @@ def process_model_functions( utility_functions: Dict[str, Callable], utility_functions_final_period: Dict[str, Callable], budget_constraint: Callable, + draw_functions: Dict[str, Callable] = None, ): """Create wrapped functions from user supplied functions. @@ -125,6 +127,10 @@ def process_model_functions( continuous_state=continuous_state_name, ) + draw_functions_processed = process_draw_functions( + draw_functions, options, continuous_state_name + ) + model_funcs = { "compute_utility": compute_utility, "compute_marginal_utility": compute_marginal_utility, @@ -139,11 +145,71 @@ def process_model_functions( "state_specific_choice_set": state_specific_choice_set, "next_period_endogenous_state": next_period_endogenous_state, "compute_upper_envelope": compute_upper_envelope, + "draw_functions": draw_functions_processed, } return model_funcs +def process_draw_functions(draw_functions, options, continuous_state_name): + draw_functions_processed = {} if draw_functions is None else draw_functions + if "taste_shock_per_state" in draw_functions_processed.keys(): + draw_functions_processed["taste_shock_per_state"] = ( + get_taste_shock_over_all_states_func( + draw_function_taste_shocks=draw_functions_processed[ + "taste_shock_per_state" + ], + options=options, + continuous_state_name=continuous_state_name, + ) + ) + draw_taste_shock_per_state = True + else: + draw_taste_shock_per_state = False + if "lambda" in options["model_params"]: + # Check if lambda is a scalar + lambda_val = options["model_params"]["lambda"] + if not isinstance(lambda_val, (int, float)): + raise ValueError( + f"Lambda is not a scalar. If there is no draw function provided, " + f"lambda must be a scalar. Got {lambda_val}." + ) + read_function = lambda params_in: jnp.asarray( + [options["model_params"]["lambda"]] + ) + else: + read_function = lambda params_in: jnp.asarray([params_in["lambda"]]) + + draw_functions_processed["taste_shock_scale"] = read_function + + draw_functions_processed["draw_taste_shock_per_state"] = draw_taste_shock_per_state + return draw_functions_processed + + +def get_taste_shock_over_all_states_func( + draw_function_taste_shocks, options, continuous_state_name +): + not_allowed_states = ["wealth"] + if continuous_state_name is not None: + not_allowed_states += [continuous_state_name] + taste_shock_per_state_function = determine_function_arguments_and_partial_options( + func=draw_function_taste_shocks, + options=options["model_params"], + not_allowed_state_choices=not_allowed_states, + continuous_state_name=continuous_state_name, + ) + + def vectorized_taste_shock_per_state(state_dict_vec, params): + return taste_shock_per_state_function(params=params, **state_dict_vec) + + def taste_shock_over_all_states_func(state_dict, params): + return jax.vmap(vectorized_taste_shock_per_state, in_axes=(0, None))( + state_dict, params + ) + + return taste_shock_over_all_states_func + + def process_state_space_functions( state_space_functions, options, continuous_state_name ): diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index 3046deb0..63e1fd78 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -23,6 +23,7 @@ def setup_model( utility_functions_final_period: Dict[str, Callable], budget_constraint: Callable, state_space_functions: Dict[str, Callable] = None, + draw_functions: Dict[str, Callable] = None, debug_output: str = None, ): """Set up the model for dcegm. @@ -62,6 +63,7 @@ def setup_model( utility_functions=utility_functions, utility_functions_final_period=utility_functions_final_period, budget_constraint=budget_constraint, + draw_functions=draw_functions, ) model_structure = create_model_structure( diff --git a/src/dcegm/pre_processing/shared.py b/src/dcegm/pre_processing/shared.py index 01cbd0eb..557d0a86 100644 --- a/src/dcegm/pre_processing/shared.py +++ b/src/dcegm/pre_processing/shared.py @@ -7,15 +7,24 @@ def determine_function_arguments_and_partial_options( - func, options, continuous_state_name=None + func, options, not_allowed_state_choices=None, continuous_state_name=None ): signature = set(inspect.signature(func).parameters) + not_allowed_state_choices = ( + [] if not_allowed_state_choices is None else not_allowed_state_choices + ) partialed_func, signature = partial_options_and_update_signature( func=func, signature=signature, options=options, ) + if len(not_allowed_state_choices) > 0: + for var in signature: + if var in not_allowed_state_choices: + raise ValueError( + f"{func.__name__}() has a not allowed input variable: {var}" + ) @functools.wraps(func) def processed_func(**kwargs): diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index 9f31036a..9b7f7da4 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -240,8 +240,18 @@ def next_period_continuous_state_for_one_agent( ) -def draw_taste_shocks(n_agents, n_choices, taste_shock_scale, key): +def draw_taste_shocks(n_agents, n_choices, states, params, draw_functions, key): taste_shocks = jax.random.gumbel(key=key, shape=(n_agents, n_choices)) + # The following allows to specify a function to return taste shock scales for each + # state differently. + if draw_functions["draw_taste_shock_per_state"]: + taste_shock_scale = draw_functions["draw_taste_shock_per_state"]( + params=params, + state_space_dict=states, + ) + else: + taste_shock_scale = draw_functions["taste_shock_scale"](params) + taste_shocks = taste_shock_scale * (taste_shocks - jnp.euler_gamma) return taste_shocks diff --git a/src/dcegm/simulation/simulate.py b/src/dcegm/simulation/simulate.py index 4d6632f1..a170e6de 100644 --- a/src/dcegm/simulation/simulate.py +++ b/src/dcegm/simulation/simulate.py @@ -100,6 +100,7 @@ def simulate_all_periods( ], exog_state_mapping=model_funcs_sim["exog_state_mapping"], compute_next_period_states=compute_next_period_states, + draw_functions=model_funcs_sim["draw_functions"], second_continuous_state_dict=second_continuous_state_dict, ) @@ -123,6 +124,7 @@ def simulate_all_periods( choice_range=model_structure_solution["choice_range"], map_state_choice_to_index=model_structure_solution["map_state_choice_to_index"], compute_utility_final_period=model_funcs_sim["compute_utility_final"], + draw_functions=model_funcs_sim["draw_functions"], ) result = { @@ -151,6 +153,7 @@ def simulate_single_period( compute_beginning_of_period_wealth, exog_state_mapping, compute_next_period_states, + draw_functions, second_continuous_state_dict=None, ): ( @@ -195,7 +198,9 @@ def simulate_single_period( taste_shocks = draw_taste_shocks( n_agents=len(wealth_beginning_of_period), n_choices=len(choice_range), - taste_shock_scale=params["lambda"], + states=states_beginning_of_period, + params=params, + draw_functions=draw_functions, key=sim_specific_keys[0, :], ) values_across_choices = values_pre_taste_shock + taste_shocks @@ -268,6 +273,7 @@ def simulate_final_period( choice_range, map_state_choice_to_index, compute_utility_final_period, + draw_functions, ): invalid_number = np.iinfo(map_state_choice_to_index.dtype).max @@ -305,7 +311,9 @@ def simulate_final_period( taste_shocks = draw_taste_shocks( n_agents=n_agents, n_choices=n_choices, - taste_shock_scale=params["lambda"], + states=states_beginning_of_final_period, + params=params, + draw_functions=draw_functions, key=sim_specific_keys[0, :], ) values_across_choices = utilities_pre_taste_shock + taste_shocks diff --git a/src/dcegm/solve.py b/src/dcegm/solve.py index 0ee15972..7476a45a 100644 --- a/src/dcegm/solve.py +++ b/src/dcegm/solve.py @@ -24,6 +24,7 @@ def solve_dcegm( utility_functions_final_period: Dict[str, Callable], budget_constraint: Callable, state_space_functions: Dict[str, Callable] = None, + draw_functions: Dict[str, Callable] = None, ) -> Dict[int, np.ndarray]: """Solve a discrete-continuous life-cycle model using the DC-EGM algorithm. @@ -57,6 +58,7 @@ def solve_dcegm( utility_functions=utility_functions, budget_constraint=budget_constraint, utility_functions_final_period=utility_functions_final_period, + draw_functions=draw_functions, ) results = backward_jit(params=params) @@ -70,6 +72,7 @@ def get_solve_function( budget_constraint: Callable, utility_functions_final_period: Dict[str, Callable], state_space_functions: Dict[str, Callable] = None, + draw_functions: Dict[str, Callable] = None, ) -> Callable: """Create a solve function, which only takes params as input. @@ -101,6 +104,7 @@ def get_solve_function( utility_functions=utility_functions, utility_functions_final_period=utility_functions_final_period, budget_constraint=budget_constraint, + draw_functions=draw_functions, ) return get_solve_func_for_model(model=model) @@ -214,7 +218,14 @@ def backward_induction( from the backward induction. """ - taste_shock_scale = params["lambda"] + # The following allows to specify a function to return taste shock scales for each + # state differently. + if model_funcs["draw_functions"]["draw_taste_shock_per_state"]: + taste_shock_scale = model_funcs["draw_functions"]["draw_taste_shock_per_state"]( + params=params, state_space_dict=state_space_dict + ) + else: + taste_shock_scale = model_funcs["draw_functions"]["taste_shock_scale"](params) cont_grids_next_period = calc_cont_grids_next_period( state_space_dict=state_space_dict, diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index 3a14d7b8..b15258e7 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -48,6 +48,9 @@ def solve_single_period( params=params, ) + if len(taste_shock_scale) > 1: + taste_shock_scale = taste_shock_scale[child_state_idxs] + endog_grid_state_choice, policy_state_choice, value_state_choice = ( solve_for_interpolated_values( value_interpolated=value_interpolated, diff --git a/tests/test_pre_processing.py b/tests/test_pre_processing.py index b201c2bf..369f929f 100644 --- a/tests/test_pre_processing.py +++ b/tests/test_pre_processing.py @@ -107,11 +107,10 @@ def test_missing_parameter( params.pop("interest_rate") params.pop("sigma") - params.pop("lambda") params_dict = process_params(params) - for param in ["interest_rate", "sigma", "lambda"]: + for param in ["interest_rate", "sigma"]: assert param in params_dict.keys() params.pop("beta") diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 6708ae05..03b42775 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -120,6 +120,7 @@ def test_simulate_lax_scan(model_setup): "compute_beginning_of_period_wealth" ], exog_state_mapping=exog_state_mapping, + draw_functions=model_funcs["draw_functions"], compute_next_period_states={ "next_period_endogenous_state": next_period_endogenous_state }, @@ -149,6 +150,7 @@ def test_simulate_lax_scan(model_setup): choice_range=choice_range, map_state_choice_to_index=jnp.array(map_state_choice_to_index), compute_utility_final_period=model_funcs["compute_utility_final"], + draw_functions=model_funcs["draw_functions"], ) final_period_dict = simulate_final_period( states_and_wealth_beginning_of_final_period, @@ -158,6 +160,7 @@ def test_simulate_lax_scan(model_setup): choice_range=choice_range, map_state_choice_to_index=jnp.array(map_state_choice_to_index), compute_utility_final_period=model_funcs["compute_utility_final"], + draw_functions=model_funcs["draw_functions"], ) aaae(np.squeeze(lax_sim_dict_zero["taste_shocks"]), sim_dict_zero["taste_shocks"]) diff --git a/tests/test_two_period_continuous_experience.py b/tests/test_two_period_continuous_experience.py index 0e5b94a0..b7efb671 100644 --- a/tests/test_two_period_continuous_experience.py +++ b/tests/test_two_period_continuous_experience.py @@ -328,7 +328,7 @@ def create_test_inputs(): "state_to_choices_final_period" ], params=params, - taste_shock_scale=taste_shock_scale, + taste_shock_scale=jnp.array([taste_shock_scale]), income_shock_weights=income_shock_weights, exog_grids=exog_grids_cont, model_funcs=model_funcs_cont,