Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aux objects in likelihood #138

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 63 additions & 125 deletions src/dcegm/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def create_individual_likelihood_function_for_model(
observed_choices: np.array,
params_all,
unobserved_state_specs=None,
return_model_solution=False,
):

solve_func = get_solve_func_for_model(
Expand All @@ -44,7 +45,6 @@ def create_individual_likelihood_function_for_model(
observed_states=observed_states,
observed_choices=observed_choices,
unobserved_state_specs=unobserved_state_specs,
weight_full_states=True,
)

def individual_likelihood(params):
Expand All @@ -64,7 +64,16 @@ def individual_likelihood(params):
# Negative ll contributions are positive numbers. The smaller the better the fit
# Add high fixed punishment for not explained choices
neg_likelihood_contributions = (-jnp.log(choice_probs)).clip(max=999)
return neg_likelihood_contributions

if return_model_solution:
solution = {
"value": value_solved,
"policy": policy_solved,
"endog_grid": endog_grid_solved,
}
return neg_likelihood_contributions, solution
else:
return neg_likelihood_contributions

return jax.jit(individual_likelihood)

Expand All @@ -74,46 +83,32 @@ def create_choice_prob_func_unobserved_states(
observed_states: Dict[str, int],
observed_choices: np.array,
unobserved_state_specs,
weight_full_states=True,
):
# First prepare full observed states, choices and pre period states for weighting
full_mask = unobserved_state_specs["observed_bool"]
state_space_names = model["model_structure"]["discrete_states_names"] + ["wealth"]
if len(model["options"]["exog_grids"]) == 2:
second_cont_state_name = model["options"]["second_continuous_state_name"]
state_space_names = model["model_structure"]["discrete_states_names"] + [
"wealth",
second_cont_state_name,
]
else:
state_space_names = model["model_structure"]["discrete_states_names"] + [
"wealth"
]

full_observed_states = {
name: observed_states[name][full_mask] for name in state_space_names
}
full_observed_choices = observed_choices[full_mask]
# Now the states of last period for weighting and also the unobserved states
# for this period
pre_period_full_observed_states = {
name: unobserved_state_specs["pre_period_states"][name][full_mask]
for name in unobserved_state_specs["pre_period_states"].keys()
}
for state_name in unobserved_state_specs["states"]:
pre_period_full_observed_states[state_name + "_new"] = full_observed_states[
state_name
]

# Finish with partial prob function for full observed states
partial_choice_probs_full_observed_states = create_partial_choice_prob_calculation(
observed_states=full_observed_states,
observed_choices=full_observed_choices,
model=model,
)
state_space_names += [second_cont_state_name]

unobserved_state_names = unobserved_state_specs["observed_bools_states"].keys()
observed_bools = unobserved_state_specs["observed_bools_states"]

# Create weighting vars by extracting states and choices
weighting_vars = unobserved_state_specs["state_choices_weighing"]["states"]
weighting_vars["choice"] = unobserved_state_specs["state_choices_weighing"][
"choices"
]

# Add unobserved states with appendix new and bools indicating if state is observed
for state_name in unobserved_state_names:
weighting_vars[state_name + "_new"] = observed_states[state_name]
weighting_vars[state_name + "_observed_bool"] = unobserved_state_specs[
"observed_bools_states"
][state_name]

# Read out possible values for unobserved states
unobserved_state_values = {}
for state_name in unobserved_state_specs["states"]:
for state_name in unobserved_state_specs["observed_bools_states"].keys():
if state_name in model["model_structure"]["exog_states_names"]:
state_values = model["options"]["state_space"]["exogenous_processes"][
state_name
Expand All @@ -124,131 +119,74 @@ def create_choice_prob_func_unobserved_states(
]
unobserved_state_values[state_name] = state_values

# Read out the observed states of the unobserved states
unobserved_states = {
name: observed_states[name][~full_mask] for name in state_space_names
}
# Also pre period states
pre_period_unobserved_states = {
name: unobserved_state_specs["pre_period_states"][name][~full_mask]
for name in unobserved_state_specs["pre_period_states"].keys()
}
# Now add the new states which correspond to the states of this period
for state_name in unobserved_state_specs["states"]:
pre_period_unobserved_states[state_name + "_new"] = unobserved_states[
state_name
]

# Now create a list which contains dictionaries with ach dictionary
# containing a unique combination of unobserved states. Note that this is
# only tested for one state with two values.
possible_unobserved_states = [unobserved_states]
possible_pre_period_unobserved_states = [pre_period_unobserved_states]
for state_name in unobserved_state_specs["states"]:
new_possible_unobserved_states = []
new_possible_pre_period_unobserved_states = []
possible_states = [observed_states]
weighting_vars_for_possible_states = [weighting_vars]
for state_name in unobserved_state_names:
# Create bool indicating if state is unobserved
unobserved_state_bool = ~observed_bools[state_name]

new_possible_states = []
new_weighting_vars_for_possible_states = []
for state_value in unobserved_state_values[state_name]:
for possible_state in possible_unobserved_states:
possible_state[state_name][:] = state_value
new_possible_unobserved_states.append(copy.deepcopy(possible_state))
for possible_state in possible_states:
possible_state[state_name][unobserved_state_bool] = state_value
new_possible_states.append(copy.deepcopy(possible_state))
# Same for pre period states
for pre_period_state in possible_pre_period_unobserved_states:
pre_period_state[state_name + "_new"][:] = state_value
new_possible_pre_period_unobserved_states.append(
copy.deepcopy(pre_period_state)
for weighting_vars in weighting_vars_for_possible_states:
weighting_vars[state_name + "_new"][unobserved_state_bool] = state_value
new_weighting_vars_for_possible_states.append(
copy.deepcopy(weighting_vars)
)
# Now overwrite existing lists
possible_unobserved_states = new_possible_unobserved_states
possible_pre_period_unobserved_states = (
new_possible_pre_period_unobserved_states
)
possible_states = new_possible_states
weighting_vars_for_possible_states = new_weighting_vars_for_possible_states

# Create a list of partial choice probability functions for each unique
# combination of unobserved states.
partial_choice_probs_unobserved_states = []
for unobserved_state in possible_unobserved_states:
for states in possible_states:
partial_choice_probs_unobserved_states.append(
create_partial_choice_prob_calculation(
observed_states=unobserved_state,
observed_choices=observed_choices[~full_mask],
observed_states=states,
observed_choices=observed_choices,
model=model,
)
)
partial_weight_func = (
lambda params_in, states, choices: calculate_weights_for_each_state(
lambda params_in, weight_vars: calculate_weights_for_each_state(
params=params_in,
state_vec=states,
choice=choices,
options=model["options"],
weight_vars=weight_vars,
options=model["options"]["model_params"],
weight_func=unobserved_state_specs["weight_func"],
)
)

unobserved_states_index = jnp.where(~full_mask)[0]
observed_states_index = jnp.where(full_mask)[0]
n_obs = len(observed_choices)

def choice_prob_func(value_in, endog_grid_in, params_in):
choice_probs_final = jnp.empty_like(observed_choices, dtype=jnp.float64)
unobserved_probs = jnp.zeros_like(
observed_choices[~full_mask], dtype=jnp.float64
)
objects = {}
i = 0
for partial_choice_prob, unobserved_state, pre_period_unobserved_states in zip(
choice_probs_final = jnp.zeros(n_obs, dtype=jnp.float64)
for partial_choice_prob, unobserved_state, weighting_vars in zip(
partial_choice_probs_unobserved_states,
possible_unobserved_states,
possible_pre_period_unobserved_states,
possible_states,
weighting_vars_for_possible_states,
):
weights = jax.vmap(
partial_weight_func,
in_axes=(None, 0, 0),
in_axes=(None, 0),
)(
params_in,
pre_period_unobserved_states,
unobserved_state_specs["pre_period_choices"][~full_mask],
weighting_vars,
)

unweighted_choice_probs = partial_choice_prob(
value_in=value_in,
endog_grid_in=endog_grid_in,
params_in=params_in,
)
objects[i] = {}
objects[i]["unweighted_choice_probs"] = unweighted_choice_probs
objects[i]["weights"] = weights

i += 1
unobserved_probs += jnp.nan_to_num(
weights * unweighted_choice_probs, nan=0.0
)

choice_probs_final = choice_probs_final.at[unobserved_states_index].set(
unobserved_probs
)

choice_probs_full = partial_choice_probs_full_observed_states(
value_in=value_in,
endog_grid_in=endog_grid_in,
params_in=params_in,
)

if weight_full_states:
weight_choice_probs_full = jax.vmap(
partial_weight_func,
in_axes=(None, 0, 0),
)(
params_in,
pre_period_full_observed_states,
unobserved_state_specs["pre_period_choices"][full_mask],
)

choice_probs_final = choice_probs_final.at[observed_states_index].set(
choice_probs_full * weight_choice_probs_full
)
else:
choice_probs_final = choice_probs_final.at[observed_states_index].set(
choice_probs_full
)
choice_probs_final += unweighted_choice_probs * weights

return choice_probs_final

Expand Down Expand Up @@ -422,7 +360,7 @@ def interp1d_value_for_state_in_each_choice(
return value_interp


def calculate_weights_for_each_state(params, state_vec, choice, options, weight_func):
def calculate_weights_for_each_state(params, weight_vars, options, weight_func):
"""Calculate the weights for each state.

Args:
Expand All @@ -436,4 +374,4 @@ def calculate_weights_for_each_state(params, state_vec, choice, options, weight_
float: Weight.

"""
return weight_func(**state_vec, params=params, choice=choice, options=options)
return weight_func(**weight_vars, params=params, options=options)
Loading