From 66a15b7e9bf594b13034f6778f0d002f5065f617 Mon Sep 17 00:00:00 2001 From: Sebastian Gsell Date: Wed, 2 Oct 2024 16:23:12 +0200 Subject: [PATCH] Cleanup second (#128) Co-authored-by: MaxBlesch --- src/dcegm/egm/solve_euler_equation.py | 8 +- src/dcegm/final_periods.py | 2 - src/dcegm/solve.py | 3 - src/dcegm/wealth_correction.py | 4 +- tests/test_egm.py | 389 -------------------------- tests/test_exog_processes.py | 4 +- tests/test_law_of_motion.py | 31 +- 7 files changed, 12 insertions(+), 429 deletions(-) delete mode 100644 tests/test_egm.py diff --git a/src/dcegm/egm/solve_euler_equation.py b/src/dcegm/egm/solve_euler_equation.py index be29409b..c6baa997 100644 --- a/src/dcegm/egm/solve_euler_equation.py +++ b/src/dcegm/egm/solve_euler_equation.py @@ -33,7 +33,7 @@ def calculate_candidate_solutions_from_euler_equation( ) = vmap( vmap( vmap( - wrapper_cont_optimal_policy_and_value, + compute_optimal_policy_and_value_wrapper, in_axes=(1, 1, None, 0, None, None, None), # savings ), in_axes=(1, 1, 0, None, None, None, None), # second continuous state @@ -77,7 +77,7 @@ def calculate_candidate_solutions_from_euler_equation( ) -def wrapper_cont_optimal_policy_and_value( +def compute_optimal_policy_and_value_wrapper( marg_util_next: np.ndarray, emax_next: np.ndarray, second_continuous_grid: np.ndarray, @@ -86,9 +86,9 @@ def wrapper_cont_optimal_policy_and_value( model_funcs: Dict[str, Callable], params: Dict[str, float], ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Simple wrapper function to write second continuous grid point into - state_choice_vec.""" + """Write second continuous grid point into state_choice_vec.""" state_choice_vec["second_continuous"] = second_continuous_grid + return compute_optimal_policy_and_value( marg_util_next, emax_next, diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index 3cf7366f..227387ce 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -93,8 +93,6 @@ def solve_last_two_periods( value_solved, policy_solved, endog_grid_solved, - emax, - marg_util, ) diff --git a/src/dcegm/solve.py b/src/dcegm/solve.py index 8bc19f50..43996382 100644 --- a/src/dcegm/solve.py +++ b/src/dcegm/solve.py @@ -216,7 +216,6 @@ def backward_induction( """ taste_shock_scale = params["lambda"] - # Calculate the continuous grids for the next period cont_grids_next_period = calc_cont_grids_next_period( state_space_dict=state_space_dict, exog_grids=exog_grids, @@ -245,8 +244,6 @@ def backward_induction( value_solved, policy_solved, endog_grid_solved, - value_interp_final_period, - marginal_utility_final_last_period, ) = solve_last_two_periods( cont_grids_next_period=cont_grids_next_period, params=params, diff --git a/src/dcegm/wealth_correction.py b/src/dcegm/wealth_correction.py index 571ff23b..566ed568 100644 --- a/src/dcegm/wealth_correction.py +++ b/src/dcegm/wealth_correction.py @@ -1,7 +1,7 @@ import jax.numpy as jnp from jax import vmap -from dcegm.budget import calculate_resources_for_each_grid_point +from dcegm.law_of_motion import calc_resources_for_each_savings_grid_point def adjust_observed_wealth(observed_states_dict, wealth, params, model): @@ -17,7 +17,7 @@ def adjust_observed_wealth(observed_states_dict, wealth, params, model): savings_last_period = jnp.asarray(wealth / (1 + params["interest_rate"])) adjusted_resources = vmap( - calculate_resources_for_each_grid_point, in_axes=(0, 0, None, None, None) + calc_resources_for_each_savings_grid_point, in_axes=(0, 0, None, None, None) )( observed_states_dict, savings_last_period, diff --git a/tests/test_egm.py b/tests/test_egm.py deleted file mode 100644 index 15ad8fbe..00000000 --- a/tests/test_egm.py +++ /dev/null @@ -1,389 +0,0 @@ -"""Test module for EGM steps: - -- aggregate_marg_utils_and_exp_values -- calculate_candidate_solutions_from_euler_equation - -""" - -import jax.numpy as jnp -import jax.random as random -import numpy as np -import pytest - -from dcegm.egm.aggregate_marginal_utility import aggregate_marg_utils_and_exp_values -from dcegm.egm.solve_euler_equation import ( - calculate_candidate_solutions_from_euler_equation, -) -from dcegm.pre_processing.exog_processes import create_exog_transition_function -from dcegm.pre_processing.shared import determine_function_arguments_and_partial_options -from tests.two_period_models.model import prob_exog_ltc -from toy_models.consumption_retirement_model.utility_functions import ( - create_final_period_utility_function_dict, - create_utility_function_dict, -) - -WEALTH_GRID_POINTS = 100 - -PARAMS = { - "rho": 0.5, - "delta": 0.5, - "interest_rate": 0.02, - "ltc_cost": 5, - "wage_avg": 8, - "sigma": 1, - "lambda": 10, - "beta": 0.95, - # Exogenous parameters - "ltc_prob_constant": 0.3, - "ltc_prob_age": 0.1, - "job_offer_constant": 0.5, - "job_offer_age": 0, - "job_offer_educ": 0, - "job_offer_type_two": 0.4, -} - -OPTIONS = { - "model_params": { - "n_grid_points": WEALTH_GRID_POINTS, - "max_wealth": 50, - "quadrature_points_stochastic": 5, - "n_choices": 2, - }, - "state_space": { - "n_periods": 2, - "choices": np.arange(2), - "endogenous_states": { - "married": [0, 1], - }, - "exogenous_processes": { - "ltc": {"transition": prob_exog_ltc, "states": [0, 1]}, - }, - }, -} - - -@pytest.fixture() -def input_for_aggregation(): - - key = random.PRNGKey(0) - - taste_shock_scale = 1 - income_shock_weights = jnp.array( - [0.11846344, 0.23931434, 0.28444444, 0.23931434, 0.11846344] - ) - states_to_choices_child_states = jnp.array( - [ - [0, 1, 2], - [3, 4, 5], - [6, 7, 8], - [9, 10, 11], - [12, 13, 14], - [15, 16, 17], - [18, 19, 20], - [21, 22, 23], - [65311, 65311, 24], - [65311, 65311, 25], - [65311, 65311, 26], - [65311, 65311, 27], - [28, 29, 30], - [31, 32, 33], - [34, 35, 36], - [37, 38, 39], - [40, 41, 42], - [43, 44, 45], - [46, 47, 48], - [49, 50, 51], - [65311, 65311, 52], - [65311, 65311, 53], - [65311, 65311, 54], - [65311, 65311, 55], - [56, 57, 58], - [59, 60, 61], - [62, 63, 64], - [65, 66, 67], - [68, 69, 70], - [71, 72, 73], - [74, 75, 76], - [77, 78, 79], - [65311, 65311, 80], - [65311, 65311, 81], - [65311, 65311, 82], - [65311, 65311, 83], - [84, 85, 86], - [87, 88, 89], - [90, 91, 92], - [93, 94, 95], - [96, 97, 98], - [99, 100, 101], - [102, 103, 104], - [105, 106, 107], - [65311, 65311, 108], - [65311, 65311, 109], - [65311, 65311, 110], - [65311, 65311, 111], - [112, 113, 114], - [115, 116, 117], - [118, 119, 120], - [121, 122, 123], - [124, 125, 126], - [127, 128, 129], - [130, 131, 132], - [133, 134, 135], - [65311, 65311, 136], - [65311, 65311, 137], - [65311, 65311, 138], - [65311, 65311, 139], - ], - dtype=np.uint16, - ) - - base_array = jnp.linspace(0, 100, 100) # Shape (100,) - variation = random.uniform( - key, shape=(140, 1), minval=0, maxval=20 - ) # Shape (140, 1) - - varied_array = base_array + variation # Shape (140, 100) - - value_interp = jnp.expand_dims(varied_array, axis=-1) # Shape (140, 100, 1) - value_interp = jnp.tile(value_interp, (1, 1, 5)) # Shape (140, 100, 5) - - marg_util_interp = value_interp.copy() - - return ( - value_interp, - marg_util_interp, - states_to_choices_child_states, - taste_shock_scale, - income_shock_weights, - ) - - -@pytest.fixture() -def test_input_for_euler_equation(): - model_params_options = OPTIONS["model_params"] - compute_exog_transition_vec, _ = create_exog_transition_function(OPTIONS) - - utility_functions = create_utility_function_dict() - - compute_utility = determine_function_arguments_and_partial_options( - func=utility_functions["utility"], options=model_params_options - ) - compute_inverse_marginal_utility = determine_function_arguments_and_partial_options( - func=utility_functions["inverse_marginal_utility"], - options=model_params_options, - ) - model_funcs = { - "compute_utility": compute_utility, - "compute_inverse_marginal_utility": compute_inverse_marginal_utility, - "compute_exog_transition_vec": compute_exog_transition_vec, - } - - exog_savings_grid = jnp.linspace(0, 10_000, 100) - - state_choice_mat = { - "choice": jnp.array([0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1], dtype=np.uint8), - "lagged_choice": jnp.array( - [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1], dtype=np.uint8 - ), - "ltc": jnp.array([0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1], dtype=np.uint8), - "married": jnp.array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=np.uint8), - "period": jnp.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.uint8), - } - - child_state_idxs = jnp.array( - [ - [0, 1], - [2, 3], - [0, 1], - [2, 3], - [2, 3], - [2, 3], - [4, 5], - [6, 7], - [4, 5], - [6, 7], - [6, 7], - [6, 7], - ], - dtype=np.uint8, - ) - - return ( - exog_savings_grid, - state_choice_mat, - child_state_idxs, - model_funcs, - ) - - -def test_aggregation_1d(input_for_aggregation): - - ( - value_interp, - marg_util_interp, - states_to_choices_child_states, - taste_shock_scale, - income_shock_weights, - ) = input_for_aggregation - - marg_util, emax = aggregate_marg_utils_and_exp_values( - value_state_choice_specific=value_interp, - marg_util_state_choice_specific=marg_util_interp, - reshape_state_choice_vec_to_mat=states_to_choices_child_states, - taste_shock_scale=taste_shock_scale, - income_shock_weights=income_shock_weights, - ) - - np.testing.assert_equal(marg_util.shape, (60, 100)) - np.testing.assert_equal(emax.shape, (60, 100)) - - -def test_aggregation_2d(input_for_aggregation): - - ( - value_interp1d, - marg_util_interp1d, - states_to_choices_child_states, - taste_shock_scale, - income_shock_weights, - ) = input_for_aggregation - - _value_interp2d = jnp.expand_dims(value_interp1d, axis=1) # Shape (140, 1, 100, 5) - value_interp2d = jnp.tile(_value_interp2d, (1, 6, 1, 1)) # Shape (140, 6, 100, 5) - - _marg_util_interp2d = jnp.expand_dims( - marg_util_interp1d, axis=1 - ) # Shape (140, 1, 100, 5) - marg_util_interp2d = jnp.tile( - _marg_util_interp2d, (1, 6, 1, 1) - ) # Shape (140, 6, 100, 5) - - marg_util_2d, emax_2d = aggregate_marg_utils_and_exp_values( - value_state_choice_specific=value_interp2d, - marg_util_state_choice_specific=marg_util_interp2d, - reshape_state_choice_vec_to_mat=states_to_choices_child_states, - taste_shock_scale=taste_shock_scale, - income_shock_weights=income_shock_weights, - ) - - np.testing.assert_equal(marg_util_2d.shape, (60, 6, 100)) - np.testing.assert_equal(emax_2d.shape, (60, 6, 100)) - - -def test_euler_1d(test_input_for_euler_equation): - - ( - exog_savings_grid, - state_choice_mat, - child_state_idxs, - model_funcs, - ) = test_input_for_euler_equation - - key = random.PRNGKey(42) - - key, subkey = random.split(key) - marg_util = random.uniform(subkey, shape=(8, 100), minval=0.0, maxval=100.0) - - key, subkey = random.split(key) - emax = random.uniform(subkey, shape=(8, 100), minval=0.0, maxval=100.0) - - ( - endog_grid_candidate, - value_candidate, - policy_candidate, - expected_values, - ) = calculate_candidate_solutions_from_euler_equation( - exog_grids={"wealth": exog_savings_grid}, - marg_util_next=marg_util, - emax_next=emax, - state_choice_mat=state_choice_mat, - idx_post_decision_child_states=child_state_idxs, - model_funcs=model_funcs, - has_second_continuous_state=False, - params=PARAMS, - ) - - np.testing.assert_equal( - endog_grid_candidate.shape, - (child_state_idxs.shape[0], exog_savings_grid.shape[0]), - ) - np.testing.assert_equal( - value_candidate.shape, - (child_state_idxs.shape[0], exog_savings_grid.shape[0]), - ) - np.testing.assert_equal( - policy_candidate.shape, - (child_state_idxs.shape[0], exog_savings_grid.shape[0]), - ) - np.testing.assert_equal( - expected_values.shape, - (child_state_idxs.shape[0], exog_savings_grid.shape[0]), - ) - - -def test_euler_2d(test_input_for_euler_equation): - - n_continuous_state = 7 - - ( - exog_savings_grid, - state_choice_mat, - child_state_idxs, - model_funcs, - ) = test_input_for_euler_equation - - key = random.PRNGKey(42) - - key, subkey = random.split(key) - marg_util = random.uniform( - subkey, - shape=(8, n_continuous_state, exog_savings_grid.shape[0]), - minval=0.0, - maxval=100.0, - ) - - key, subkey = random.split(key) - emax = random.uniform( - subkey, - shape=(8, n_continuous_state, exog_savings_grid.shape[0]), - minval=0.0, - maxval=100.0, - ) - - exog_grids = { - "wealth": exog_savings_grid, - "second_continuous": jnp.linspace(0, 1, 7), - } - - ( - endog_grid_candidate, - value_candidate, - policy_candidate, - expected_values, - ) = calculate_candidate_solutions_from_euler_equation( - exog_grids=exog_grids, - marg_util_next=marg_util, - emax_next=emax, - state_choice_mat=state_choice_mat, - idx_post_decision_child_states=child_state_idxs, - model_funcs=model_funcs, - has_second_continuous_state=True, - params=PARAMS, - ) - - np.testing.assert_equal( - endog_grid_candidate.shape, - (child_state_idxs.shape[0], n_continuous_state, exog_savings_grid.shape[0]), - ) - np.testing.assert_equal( - value_candidate.shape, - (child_state_idxs.shape[0], n_continuous_state, exog_savings_grid.shape[0]), - ) - np.testing.assert_equal( - policy_candidate.shape, - (child_state_idxs.shape[0], n_continuous_state, exog_savings_grid.shape[0]), - ) - np.testing.assert_equal( - expected_values.shape, - (child_state_idxs.shape[0], n_continuous_state, exog_savings_grid.shape[0]), - ) diff --git a/tests/test_exog_processes.py b/tests/test_exog_processes.py index 9c107783..9fe48f85 100644 --- a/tests/test_exog_processes.py +++ b/tests/test_exog_processes.py @@ -9,7 +9,9 @@ from dcegm.pre_processing.exog_processes import create_exog_state_mapping from dcegm.pre_processing.model_functions import process_model_functions -from dcegm.pre_processing.state_space import create_state_space_and_choice_objects +from dcegm.pre_processing.state_space import ( + create_discrete_state_space_and_choice_objects, +) from tests.two_period_models.model import prob_exog_health from toy_models.consumption_retirement_model.budget_functions import budget_constraint from toy_models.consumption_retirement_model.state_space_objects import ( diff --git a/tests/test_law_of_motion.py b/tests/test_law_of_motion.py index f3d7a5f7..65e019ef 100644 --- a/tests/test_law_of_motion.py +++ b/tests/test_law_of_motion.py @@ -150,14 +150,14 @@ def test_get_beginning_of_period_wealth( aaae(wealth_beginning_of_period, max(consump_floor, budget_expected)) -# ===================================================================================== +TEST_CASES_SECOND_CONTINUOUS = list(product(model, max_wealth, n_grid_points)) @pytest.mark.parametrize( - "model, period, labor_choice, max_wealth, n_grid_points", TEST_CASES + "model, max_wealth, n_grid_points", TEST_CASES_SECOND_CONTINUOUS ) def test_wealth_and_second_continuous_state( - model, period, labor_choice, max_wealth, n_grid_points, load_example_model + model, max_wealth, n_grid_points, load_example_model ): # parametrize over number of experience points @@ -195,28 +195,3 @@ def test_wealth_and_second_continuous_state( ) aaae(exp_next, experience_next) - - # ======================================================================== - - _quad_points, _ = roots_sh_legendre(n_quad_points) - quad_points = norm.ppf(_quad_points) * sigma - - compute_beginning_of_period_resources = ( - determine_function_arguments_and_partial_options( - func=budget_constraint_based_on_experience, options={} - ) - ) - - wealth_next = calculate_resources_for_second_continuous_state( - discrete_states_beginning_of_next_period=child_state_dict, - continuous_state_beginning_of_next_period=experience_next, - savings_grid=savings_grid, - income_shocks=quad_points, - params=params, - compute_beginning_of_period_resources=compute_beginning_of_period_resources, - ) - - np.testing.assert_equal( - wealth_next.shape, - (len(child_state_dict["period"]), n_exp_points, n_grid_points, n_quad_points), - )