Skip to content

Commit

Permalink
Implement data_types into remaining interface functions
Browse files Browse the repository at this point in the history
  • Loading branch information
stevebachmeier committed Nov 13, 2024
1 parent 7958298 commit 637b6a4
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 24 deletions.
10 changes: 0 additions & 10 deletions src/vivarium_inputs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,11 +723,6 @@ def get_age_bins(
data_type: utilities.DataType,
) -> pd.DataFrame:

if data_type.type != "draws":
raise utilities.DataTypeNotImplementedError(
f"Data type(s) {data_type.type} are not supported for this function."
)

age_bins = utility_data.get_age_bins()[["age_group_name", "age_start", "age_end"]]
return age_bins

Expand All @@ -739,11 +734,6 @@ def get_demographic_dimensions(
data_type: utilities.DataType,
) -> pd.DataFrame:

if data_type.type != "draws":
raise utilities.DataTypeNotImplementedError(
f"Data type(s) {data_type.type} are not supported for this function."
)

demographic_dimensions = utility_data.get_demographic_dimensions(location_id, years=years)
demographic_dimensions = utilities.normalize(
demographic_dimensions, data_type.value_columns
Expand Down
47 changes: 38 additions & 9 deletions src/vivarium_inputs/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_measure(
return utilities.sort_hierarchical_data(data)


# FIXME [mic-5573]: Add tests against this function
def get_population_structure(
location: int | str | list[int | str],
years: int | str | list[int] | None = None,
Expand All @@ -99,14 +100,20 @@ def get_population_structure(
"""
pop = Population()
data = core.get_data(pop, "structure", location, years)
# Hack: The data_type is set to "draws" to avoid NotImplementedErros, but the
# data is not actually draw-level data.
data_type = DataType("structure", "draws")
data = core.get_data(pop, "structure", location, years, data_type)
data = utilities.scrub_gbd_conventions(data, location)
validation.validate_for_simulation(data, pop, "structure", location, years)
validation.validate_for_simulation(
data, pop, "structure", location, years, data_type.value_columns
)
data = utilities.split_interval(data, interval_column="age", split_column_prefix="age")
data = utilities.split_interval(data, interval_column="year", split_column_prefix="year")
return utilities.sort_hierarchical_data(data)


# FIXME [mic-5573]: Add tests against this function
def get_theoretical_minimum_risk_life_expectancy() -> pd.DataFrame:
"""Pull GBD theoretical minimum risk life expectancy data and standardize
to the expected simulation input format, including binning age parameters
Expand All @@ -119,16 +126,31 @@ def get_theoretical_minimum_risk_life_expectancy() -> pd.DataFrame:
"""
pop = Population()
data = core.get_data(pop, "theoretical_minimum_risk_life_expectancy", "Global")
# Hack: The data_type is set to "draws" to avoid NotImplementedErros, but the
# data is not actually draw-level data.
data_type = DataType("theoretical_minimum_risk_life_expectancy", "draws")
data = core.get_data(
pop,
"theoretical_minimum_risk_life_expectancy",
"Global",
years=None,
data_type=data_type,
)
data = utilities.set_age_interval(data)
validation.validate_for_simulation(
data, pop, "theoretical_minimum_risk_life_expectancy", "Global"
data,
pop,
"theoretical_minimum_risk_life_expectancy",
"Global",
years=None,
value_columns=data_type.value_columns,
)
data = utilities.split_interval(data, interval_column="age", split_column_prefix="age")
data = utilities.split_interval(data, interval_column="year", split_column_prefix="year")
return utilities.sort_hierarchical_data(data)


# FIXME [mic-5573]: Add tests against this function
def get_age_bins() -> pd.DataFrame:
"""Pull GBD age bin data and standardize to the expected simulation input
format.
Expand All @@ -140,14 +162,18 @@ def get_age_bins() -> pd.DataFrame:
"""
pop = Population()
data = core.get_data(pop, "age_bins", "Global")
data_type = DataType("age_bins", None)
data = core.get_data(pop, "age_bins", "Global", years=None, data_type=data_type)
data = utilities.set_age_interval(data)
validation.validate_for_simulation(data, pop, "age_bins", "Global")
validation.validate_for_simulation(
data, pop, "age_bins", "Global", years=None, value_columns=data_type.value_columns
)
data = utilities.split_interval(data, interval_column="age", split_column_prefix="age")
data = utilities.split_interval(data, interval_column="year", split_column_prefix="year")
return utilities.sort_hierarchical_data(data)


# FIXME [mic-5573]: Add tests against this function
def get_demographic_dimensions(
location: int | str | list[int | str],
years: int | str | list[int] | None = None,
Expand All @@ -170,14 +196,18 @@ def get_demographic_dimensions(
"""
pop = Population()
data = core.get_data(pop, "demographic_dimensions", location, years=years)
data_type = DataType("demographic_dimensions", None)
data = core.get_data(pop, "demographic_dimensions", location, years, data_type=data_type)
data = utilities.scrub_gbd_conventions(data, location)
validation.validate_for_simulation(data, pop, "demographic_dimensions", location, years)
validation.validate_for_simulation(
data, pop, "demographic_dimensions", location, years, data_type.value_columns
)
data = utilities.split_interval(data, interval_column="age", split_column_prefix="age")
data = utilities.split_interval(data, interval_column="year", split_column_prefix="year")
return utilities.sort_hierarchical_data(data)


# FIXME [mic-5573]: Add tests against this function
def get_raw_data(
entity: ModelableEntity,
measure: str,
Expand Down Expand Up @@ -237,7 +267,6 @@ def get_raw_data(
Data for the entity-measure pair and specific location requested, with no
formatting or reshaping.
"""
# FIXME: Add tests against this function
data_type = DataType(measure, data_type)
if not isinstance(location, list):
location = [location]
Expand Down
12 changes: 9 additions & 3 deletions src/vivarium_inputs/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ class DataTypeNotImplementedError(NotImplementedError):
class DataType:
"""Class to handle data types and their corresponding differences."""

def __init__(self, measure: str, data_type: str | list[str]) -> None:
def __init__(self, measure: str, data_type: str | list[str] | None) -> None:

self._validate_data_type(data_type)

Expand All @@ -583,6 +583,7 @@ def __init__(self, measure: str, data_type: str | list[str]) -> None:
Supported values include:
- 'means' for getting mean data
- 'draws' for getting draw-level data
- None for measures that do not have meaningful value columns (e.g. age bins)
The data for the following measures do not adhere standard data_types
(i.e. they are not mean or draw-level data) and so this attribute
Expand All @@ -607,9 +608,12 @@ def __init__(self, measure: str, data_type: str | list[str]) -> None:
"""

@staticmethod
def _validate_data_type(data_type: str | list[str]) -> None:
def _validate_data_type(data_type: str | list[str] | None) -> None:
"""Validate that the provided data type is supported."""

if data_type is None:
return

# Temporarily raise for lists of data types
if isinstance(data_type, list):
raise DataTypeNotImplementedError("Lists of data types are not yet supported.")
Expand All @@ -627,13 +631,15 @@ def _validate_data_type(data_type: str | list[str]) -> None:
)

@staticmethod
def _get_value_columns(measure: str, data_type: str | list[str]) -> list[str]:
def _get_value_columns(measure: str, data_type: str | list[str] | None) -> list[str]:
"""Get the value columns corresponding to the provided data type(s).
If the measure is one of 'structure', 'theoretical_minimum_risk_life_expectancy',
'estimate', or 'exposure_distribution_weights', the value columns are always 'value'.
"""
value_cols = []
if data_type is None:
return value_cols
if isinstance(data_type, str):
data_type = [data_type]
for value in data_type:
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def test_validate_data_type(data_type, should_raise):

@pytest.mark.parametrize("measure", POSSIBLE_MEASURES, ids=lambda x: x)
@pytest.mark.parametrize(
"data_type", ("means", "draws", ["means", "draws"]), ids=["means", "draws", "means_draws"]
"data_type",
("means", "draws", None, ["means", "draws"]),
ids=["means", "draws", "None", "means_draws"],
)
def test_get_value_columns(measure, data_type):
if isinstance(data_type, list):
Expand All @@ -116,7 +118,9 @@ def test_get_value_columns(measure, data_type):
):
utilities.DataType(measure, data_type).value_columns
else:
if measure in NON_STANDARD_MEASURES:
if data_type == None: # Hacky: this goes first in the DataType class
expected_returned_cols = []
elif measure in NON_STANDARD_MEASURES:
expected_returned_cols = ["value"]
elif data_type == "means":
expected_returned_cols = MEAN_COLUMNS
Expand Down

0 comments on commit 637b6a4

Please sign in to comment.