Skip to content

Commit

Permalink
Abandon loop in exog process creation (#96)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastian Gsell <[email protected]>
  • Loading branch information
MaxBlesch and segsell authored Mar 13, 2024
1 parent 2146513 commit 1698f8f
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 105 deletions.
35 changes: 29 additions & 6 deletions src/dcegm/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def individual_likelihood(params):
endog_grid_in=endog_grid_solved,
params_in=params_initial,
).clip(min=1e-10)
log_value = jnp.sum(-jnp.log(choice_probs))
return log_value
likelihood_contributions = jnp.log(choice_probs)
log_value = jnp.sum(-likelihood_contributions)
return log_value, likelihood_contributions

return jax.jit(individual_likelihood)

Expand All @@ -153,7 +154,32 @@ def calc_choice_prob_for_observed_choices(
and then interpolates the wealth at the beginning of period on them.
"""
choice_prob_across_choices = calc_choice_probs_for_observed_states(
value_solved=value_solved,
endog_grid_solved=endog_grid_solved,
params=params,
observed_states=observed_states,
state_choice_indexes=state_choice_indexes,
oberseved_wealth=oberseved_wealth,
choice_range=choice_range,
compute_utility=compute_utility,
)
choice_probs = jnp.take_along_axis(
choice_prob_across_choices, observed_choices[:, None], axis=1
)[:, 0]
return choice_probs


def calc_choice_probs_for_observed_states(
value_solved,
endog_grid_solved,
params,
observed_states,
state_choice_indexes,
oberseved_wealth,
choice_range,
compute_utility,
):
value_grid_agent = jnp.take(
value_solved, state_choice_indexes, axis=0, mode="fill", fill_value=jnp.nan
)
Expand All @@ -179,10 +205,7 @@ def calc_choice_prob_for_observed_choices(
choice_values_per_state=value_per_agent_interp,
taste_shock_scale=params["lambda"],
)
choice_probs = jnp.take_along_axis(
choice_prob_across_choices, observed_choices[:, None], axis=1
)[:, 0]
return choice_probs
return choice_prob_across_choices


def interpolate_value_and_calc_choice_probabilities(
Expand Down
2 changes: 1 addition & 1 deletion src/dcegm/pre_processing/setup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_and_setup_model(
budget_constraint=budget_constraint,
)

create_exog_mapping(
model["exog_mapping"] = create_exog_mapping(
np.array(model["exog_state_space"], dtype=np.int16), model["exog_state_names"]
)

Expand Down
113 changes: 59 additions & 54 deletions src/dcegm/pre_processing/state_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def create_state_space_and_choice_objects(
exog_state_space=exog_state_space,
options=state_space_options,
period_specific_state_objects=out,
state_choice_space=state_choice_space,
map_state_choice_to_index=map_state_choice_to_index,
state_space=state_space,
map_state_to_index=map_state_to_state_space_index,
states_names_without_exog=states_names_without_exog,
get_next_period_state=get_next_period_state,
Expand Down Expand Up @@ -158,25 +161,15 @@ def create_state_space(options):
)

(
add_exog_state_func,
exog_states_names,
num_states_of_all_exog_states,
n_exog_states,
exog_state_space,
) = process_exog_model_specifications(state_space_options=state_space_options)

states_names_without_exog = ["period", "lagged_choice"] + endog_states_names

shape = (
[n_periods, n_choices]
+ num_states_of_all_endog_states
+ num_states_of_all_exog_states
)

map_state_to_index = np.full(shape, -9999, dtype=np.int64)
state_space_list = []
state_space_wo_exog_list = []

index = 0
for period in range(n_periods):
for lagged_choice in range(n_choices):
for endog_state_id in range(num_endog_states):
Expand All @@ -197,15 +190,20 @@ def create_state_space(options):
if not is_state_valid:
continue
else:
# If is valid, we continue to add all exogenous processes
for exog_state_id in range(n_exog_states):
exog_states = add_exog_state_func(exog_state_id)
state = state_without_exog + exog_states
state_space_list += [state]
map_state_to_index[tuple(state)] = index
index += 1
state_space_wo_exog_list += [state_without_exog]

state_space_wo_exog = np.array(state_space_wo_exog_list)
state_space_wo_exog_full = np.repeat(state_space_wo_exog, n_exog_states, axis=0)
exog_state_space_full = np.tile(exog_state_space, (state_space_wo_exog.shape[0], 1))
state_space = np.concatenate(
(state_space_wo_exog_full, exog_state_space_full), axis=1
)

state_space = np.array(state_space_list)
# Create indexer array that maps states to indexes
max_states = np.max(state_space, axis=0)
map_state_to_index = np.full(max_states + 1, fill_value=-9999, dtype=int)
state_space_tuple = tuple(state_space[:, i] for i in range(state_space.shape[1]))
map_state_to_index[state_space_tuple] = np.arange(state_space.shape[0], dtype=int)

return (
state_space,
Expand Down Expand Up @@ -241,24 +239,14 @@ def process_exog_model_specifications(state_space_options):

exog_state_space = np.array([[0]], dtype=np.int16)

exog_states_add_func = create_exog_state_add_function(exog_state_space)

return (
exog_states_add_func,
exog_state_names,
num_states_of_all_exog_states,
n_exog_states,
exog_state_space,
)


def create_exog_state_add_function(exog_state_space):
def add_exog_states(id_exog_state):
return list(exog_state_space[id_exog_state])

return add_exog_states


def span_subspace_and_read_information(subdict_of_space, states_names):
all_states_values = []

Expand Down Expand Up @@ -469,7 +457,10 @@ def create_map_from_state_to_child_nodes(
exog_state_space: np.ndarray,
options: Dict[str, int],
period_specific_state_objects: Dict,
state_choice_space: np.ndarray,
map_state_to_index: np.ndarray,
state_space,
map_state_choice_to_index: np.ndarray,
states_names_without_exog: list,
get_next_period_state: Callable,
):
Expand Down Expand Up @@ -512,7 +503,14 @@ def create_map_from_state_to_child_nodes(

n_exog_vars = exog_state_space.shape[1]

map_state_to_feasible_child_states = np.full(
(state_choice_space.shape[0], n_exog_states), fill_value=-9999, dtype=int
)

exog_states_tuple = tuple(exog_state_space[:, i] for i in range(n_exog_vars))
current_state_choice_idx = -1
for period in range(n_periods - 1):
end_of_prev_period_index = current_state_choice_idx + 1
period_dict = period_specific_state_objects[period]
idx_min_state_space_next_period = map_state_to_index[
tuple(period_specific_state_objects[period + 1]["state_choice_mat"][0, :-1])
Expand All @@ -527,6 +525,7 @@ def create_map_from_state_to_child_nodes(

# Loop over all state-choice combinations in period.
for idx, state_choice_vec in enumerate(state_choice_space_period):
current_state_choice_idx = end_of_prev_period_index + idx
current_state = state_choice_vec[:-1]
current_state_without_exog = current_state[:-n_exog_vars]

Expand All @@ -543,21 +542,23 @@ def create_map_from_state_to_child_nodes(

state_dict_without_exog.update(endog_state_update)

state_next_without_exog = np.array(
[state_dict_without_exog[key] for key in states_names_without_exog]
states_next_tuple = (
tuple(
np.full(
n_exog_states,
fill_value=state_dict_without_exog[key],
dtype=int,
)
for key in states_names_without_exog
)
+ exog_states_tuple
)

for exog_process in range(n_exog_states):
_state_vec_next = np.empty_like(current_state)
# Fill up the next state with the endogenous part.
_state_vec_next[:-n_exog_vars] = state_next_without_exog
# Then with the endogenous part.
_state_vec_next[-n_exog_vars:] = exog_state_space[exog_process]
# We want the index every period to start at 0.
map_state_to_feasible_child_nodes_period[idx, exog_process] = (
map_state_to_index[tuple(_state_vec_next)]
- idx_min_state_space_next_period
)
child_ixs = map_state_to_index[states_next_tuple]
map_state_to_feasible_child_nodes_period[idx, :] = (
child_ixs - idx_min_state_space_next_period
)
map_state_to_feasible_child_states[current_state_choice_idx, :] = child_ixs

period_specific_state_objects[period][
"idx_feasible_child_nodes"
Expand Down Expand Up @@ -587,18 +588,17 @@ def inspect_state_space(
)

(
add_exog_state_func,
exog_states_names,
_,
n_exog_states,
_,
exog_state_space,
) = process_exog_model_specifications(state_space_options=state_space_options)

states_names_without_exog = ["period", "lagged_choice"] + endog_states_names

state_space_list = []
state_space_wo_exog_list = []
is_feasible_list = []

idx = 0
for period in range(n_periods):
for lagged_choice in range(n_choices):
for endog_state_id in range(num_endog_states):
Expand All @@ -607,24 +607,29 @@ def inspect_state_space(

# Create the state vector without the exogenous processes
state_without_exog = [period, lagged_choice] + endog_states
state_space_wo_exog_list += [state_without_exog]

# Transform to dictionary to call sparsity function from user
state_dict_without_exog = {
states_names_without_exog[i]: state_value
for i, state_value in enumerate(state_without_exog)
}

is_state_valid = sparsity_func(**state_dict_without_exog)
for exog_state_id in range(n_exog_states):
exog_states = add_exog_state_func(exog_state_id)
state = state_without_exog + exog_states
is_state_feasible = sparsity_func(**state_dict_without_exog)
is_feasible_list += [is_state_feasible]

state_space_list += [state + [is_state_valid]]
idx += 1
state_space_wo_exog = np.array(state_space_wo_exog_list)
state_space_wo_exog_full = np.repeat(state_space_wo_exog, n_exog_states, axis=0)
exog_state_space_full = np.tile(exog_state_space, (state_space_wo_exog.shape[0], 1))

state_space = np.concatenate(
(state_space_wo_exog_full, exog_state_space_full), axis=1
)

state_space_df = pd.DataFrame(
state_space_list,
columns=states_names_without_exog + exog_states_names + ["is_feasible"],
state_space, columns=states_names_without_exog + exog_states_names
)

state_space_df["is_feasible"] = is_feasible_list

return state_space_df
37 changes: 37 additions & 0 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,43 @@
from jax import vmap


def simulate_all_periods_for_model(
states_initial,
resources_initial,
n_periods,
params,
seed,
endog_grid_solved,
value_solved,
policy_left_solved,
policy_right_solved,
choice_range,
model,
):
return simulate_all_periods(
states_initial=states_initial,
resources_initial=resources_initial,
n_periods=n_periods,
params=params,
seed=seed,
state_space_names=model["state_space_names"],
endog_grid_solved=endog_grid_solved,
value_solved=value_solved,
policy_left_solved=policy_left_solved,
policy_right_solved=policy_right_solved,
map_state_choice_to_index=jnp.array(model["map_state_choice_to_index"]),
choice_range=choice_range,
compute_exog_transition_vec=model["model_funcs"]["compute_exog_transition_vec"],
compute_utility=model["model_funcs"]["compute_utility"],
compute_beginning_of_period_resources=model["model_funcs"][
"compute_beginning_of_period_resources"
],
exog_state_mapping=model["exog_mapping"],
get_next_period_state=model["get_next_period_state"],
compute_utility_final_period=model["model_funcs"]["compute_utility_final"],
)


def simulate_all_periods(
states_initial,
resources_initial,
Expand Down
Loading

0 comments on commit 1698f8f

Please sign in to comment.