Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for different model in simulation #142

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
args:
- --profile=black
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.5.0
rev: v2.7.0
hooks:
- id: setup-cfg-fmt
- repo: https://github.com/psf/black
Expand Down Expand Up @@ -90,7 +90,7 @@ repos:
- id: nbqa-ruff
exclude: tests/sandbox/
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.17
rev: 0.7.18
hooks:
- id: mdformat
additional_dependencies:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ project_urls =

[options]
packages = find:
python_requires = >=3.8
python_requires = >=3.9
include_package_data = True
package_dir =
=src
Expand Down
40 changes: 22 additions & 18 deletions src/dcegm/simulation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def simulate_all_periods(
policy_solved,
value_solved,
model,
model_sim=None,
):
model_sim = model if model_sim is None else model_sim

second_continuous_state_dict = next(
(
Expand All @@ -45,7 +47,8 @@ def simulate_all_periods(
else None
)

discrete_state_space = model["model_structure"]["state_space_dict"]
model_structure_solution = model["model_structure"]
discrete_state_space = model_structure_solution["state_space_dict"]

# Set initial states to internal dtype
states_initial_dtype = {
Expand All @@ -54,7 +57,7 @@ def simulate_all_periods(
if key in discrete_state_space
}

if "dummy_exog" in model["model_structure"]["exog_states_names"]:
if "dummy_exog" in model_structure_solution["exog_states_names"]:
states_initial_dtype["dummy_exog"] = np.zeros_like(
states_initial_dtype["period"]
)
Expand All @@ -72,32 +75,30 @@ def simulate_all_periods(
for period in range(n_periods)
]
)

model_structure = model["model_structure"]
model_funcs = model["model_funcs"]
model_funcs_sim = model_sim["model_funcs"]

compute_next_period_states = {
"get_next_period_state": model_funcs["get_next_period_state"],
"update_continuous_state": model_funcs["update_continuous_state"],
"get_next_period_state": model_funcs_sim["get_next_period_state"],
"update_continuous_state": model_funcs_sim["update_continuous_state"],
}

simulate_body = partial(
simulate_single_period,
params=params,
discrete_states_names=model_structure["discrete_states_names"],
discrete_states_names=model_structure_solution["discrete_states_names"],
endog_grid_solved=endog_grid_solved,
value_solved=value_solved,
policy_solved=policy_solved,
map_state_choice_to_index=jnp.asarray(
model_structure["map_state_choice_to_index"]
model_structure_solution["map_state_choice_to_index"]
),
choice_range=model_structure["choice_range"],
compute_exog_transition_vec=model_funcs["compute_exog_transition_vec"],
compute_utility=model_funcs["compute_utility"],
compute_beginning_of_period_wealth=model_funcs[
choice_range=model_structure_solution["choice_range"],
compute_exog_transition_vec=model_funcs_sim["compute_exog_transition_vec"],
compute_utility=model_funcs_sim["compute_utility"],
compute_beginning_of_period_wealth=model_funcs_sim[
"compute_beginning_of_period_wealth"
],
exog_state_mapping=model_funcs["exog_state_mapping"],
exog_state_mapping=model_funcs_sim["exog_state_mapping"],
compute_next_period_states=compute_next_period_states,
second_continuous_state_dict=second_continuous_state_dict,
)
Expand All @@ -118,16 +119,19 @@ def simulate_all_periods(
states_and_wealth_beginning_of_final_period,
sim_specific_keys=sim_specific_keys[-1],
params=params,
discrete_states_names=model_structure["discrete_states_names"],
choice_range=model_structure["choice_range"],
map_state_choice_to_index=model_structure["map_state_choice_to_index"],
compute_utility_final_period=model_funcs["compute_utility_final"],
discrete_states_names=model_structure_solution["discrete_states_names"],
choice_range=model_structure_solution["choice_range"],
map_state_choice_to_index=model_structure_solution["map_state_choice_to_index"],
compute_utility_final_period=model_funcs_sim["compute_utility_final"],
)

result = {
key: np.vstack([sim_dict[key], final_period_dict[key]])
for key in sim_dict.keys()
}
if "dummy_exog" in model_structure_solution["exog_states_names"]:
if "dummy_exog" not in model_sim["model_structure"]["exog_states_names"]:
result.pop("dummy_exog")

return result

Expand Down
52 changes: 0 additions & 52 deletions src/toy_models/cons_ret_model_dcegm_paper/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,6 @@ def create_utility_function_dict():
}


def utiility_log_crra(
consumption: jnp.array,
choice: int,
params: Dict[str, float], # delta: float
) -> jnp.array:
"""Compute the agent's utility in case of theta equal to 1.

Args:
consumption (jnp.array): Level of the agent's consumption.
Array of shape (i) (n_quad_stochastic * n_grid_wealth,)
when called by :func:`~dcgm.call_egm_step.map_exog_to_endog_grid`
and :func:`~dcgm.call_egm_step.get_next_period_value`, or
(ii) of shape (n_grid_wealth,) when called by
:func:`~dcgm.call_egm_step.get_current_period_value`.
choice (int): Choice of the agent, e.g. 0 = "retirement", 1 = "working".
params (dict): Dictionary containing model parameters.
Relevant here is the CRRA coefficient theta.

Returns:
utility (jnp.array): Agent's utility . Array of shape
(n_quad_stochastic * n_grid_wealth,) or (n_grid_wealth,).

"""
return jnp.log(consumption) - (1 - choice) * params["delta"]


def utiility_log_crra_final_consume_all(
wealth: jnp.array,
choice: int,
params: Dict[str, float], # delta: float
) -> jnp.array:
"""Compute the agent's utility in case of theta equal to 1.

Args:
consumption (jnp.array): Level of the agent's consumption.
Array of shape (i) (n_quad_stochastic * n_grid_wealth,)
when called by :func:`~dcgm.call_egm_step.map_exog_to_endog_grid`
and :func:`~dcgm.call_egm_step.get_next_period_value`, or
(ii) of shape (n_grid_wealth,) when called by
:func:`~dcgm.call_egm_step.get_current_period_value`.
choice (int): Choice of the agent, e.g. 0 = "retirement", 1 = "working".
params (dict): Dictionary containing model parameters.
Relevant here is the CRRA coefficient theta.

Returns:
utility (jnp.array): Agent's utility . Array of shape
(n_quad_stochastic * n_grid_wealth,) or (n_grid_wealth,).

"""
return jnp.log(wealth) - (1 - choice) * params["delta"]


def utility_crra(
consumption: jnp.array,
choice: int,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Dict

from jax import numpy as jnp


def utiility_log_crra(
consumption: jnp.array,
choice: int,
params: Dict[str, float], # delta: float
) -> jnp.array:
"""Compute the agent's utility in case of theta equal to 1.

Args:
consumption (jnp.array): Level of the agent's consumption.
Array of shape (i) (n_quad_stochastic * n_grid_wealth,)
when called by :func:`~dcgm.call_egm_step.map_exog_to_endog_grid`
and :func:`~dcgm.call_egm_step.get_next_period_value`, or
(ii) of shape (n_grid_wealth,) when called by
:func:`~dcgm.call_egm_step.get_current_period_value`.
choice (int): Choice of the agent, e.g. 0 = "retirement", 1 = "working".
params (dict): Dictionary containing model parameters.
Relevant here is the CRRA coefficient theta.

Returns:
utility (jnp.array): Agent's utility . Array of shape
(n_quad_stochastic * n_grid_wealth,) or (n_grid_wealth,).

"""
return jnp.log(consumption) - (1 - choice) * params["delta"]


def utiility_log_crra_final_consume_all(
wealth: jnp.array,
choice: int,
params: Dict[str, float], # delta: float
) -> jnp.array:
"""Compute the agent's utility in case of theta equal to 1.

Args:
consumption (jnp.array): Level of the agent's consumption.
Array of shape (i) (n_quad_stochastic * n_grid_wealth,)
when called by :func:`~dcgm.call_egm_step.map_exog_to_endog_grid`
and :func:`~dcgm.call_egm_step.get_next_period_value`, or
(ii) of shape (n_grid_wealth,) when called by
:func:`~dcgm.call_egm_step.get_current_period_value`.
choice (int): Choice of the agent, e.g. 0 = "retirement", 1 = "working".
params (dict): Dictionary containing model parameters.
Relevant here is the CRRA coefficient theta.

Returns:
utility (jnp.array): Agent's utility . Array of shape
(n_quad_stochastic * n_grid_wealth,) or (n_grid_wealth,).

"""
return jnp.log(wealth) - (1 - choice) * params["delta"]
122 changes: 116 additions & 6 deletions tests/sandbox/time_functions_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "35ab80e3",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-20T14:03:27.083891148Z",
"start_time": "2024-06-20T14:03:26.534171812Z"
"end_time": "2024-11-25T15:44:09.704943Z",
"start_time": "2024-11-25T15:44:09.292524Z"
}
},
"outputs": [],
"source": [
"from jax import vmap, jit\n",
"import pickle\n",
Expand All @@ -19,8 +17,120 @@
"import yaml\n",
"from functools import partial\n",
"import jax.numpy as jnp\n",
"import numpy as np"
]
"import numpy as np\n",
"from tests.utils.markov_simulator import markov_simulator"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-25T16:03:54.707014Z",
"start_time": "2024-11-25T16:03:54.700726Z"
}
},
"cell_type": "code",
"source": [
"n_periods = 10\n",
"init_dist = np.array([0.5, 0.5])\n",
"trans_mat = np.array([[0.8, 0.2], [0.1, 0.9]])\n",
"\n",
"markov_simulator(n_periods, init_dist, trans_mat)"
],
"id": "c2b7c16010b9ba85",
"outputs": [
{
"data": {
"text/plain": [
"array([[0.5 , 0.5 ],\n",
" [0.45 , 0.55 ],\n",
" [0.415 , 0.585 ],\n",
" [0.3905 , 0.6095 ],\n",
" [0.37335 , 0.62665 ],\n",
" [0.361345 , 0.638655 ],\n",
" [0.3529415 , 0.6470585 ],\n",
" [0.34705905, 0.65294095],\n",
" [0.34294134, 0.65705866],\n",
" [0.34005893, 0.65994107]])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 34
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-25T16:08:11.928721Z",
"start_time": "2024-11-25T16:08:11.908425Z"
}
},
"cell_type": "code",
"source": [
"n_agents = 100_000\n",
"current_agents_in_states = (np.ones(2) * n_agents / 2).astype(int)\n",
"for period in range(n_periods):\n",
" print(current_agents_in_states / n_agents)\n",
" next_period_agents_states = np.zeros(2, dtype=int)\n",
" for state in range(2):\n",
" agents_in_state = current_agents_in_states[state]\n",
" transition_draws = np.random.choice(\n",
" a=[0, 1], size=agents_in_state, p=trans_mat[state, :]\n",
" )\n",
" next_period_agents_states[1] += transition_draws.sum()\n",
" next_period_agents_states[0] += agents_in_state - transition_draws.sum()\n",
" current_agents_in_states = next_period_agents_states"
],
"id": "ae676759dd2627d2",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.5 0.5]\n",
"[0.4502 0.5498]\n",
"[0.4189 0.5811]\n",
"[0.39164 0.60836]\n",
"[0.37405 0.62595]\n",
"[0.35994 0.64006]\n",
"[0.35166 0.64834]\n",
"[0.34544 0.65456]\n",
"[0.34263 0.65737]\n",
"[0.34015 0.65985]\n"
]
}
],
"execution_count": 47
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-11-25T16:02:05.105098Z",
"start_time": "2024-11-25T16:02:05.100262Z"
}
},
"cell_type": "code",
"source": [
"trans_mat[0, :]"
],
"id": "28dac7ec90b5d015",
"outputs": [
{
"data": {
"text/plain": [
"Array([0.8, 0.2], dtype=float32)"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 29
},
{
"cell_type": "code",
Expand Down
Loading
Loading