From 5e89e32e9d173126bcecc5970f9490fce0b8de98 Mon Sep 17 00:00:00 2001 From: Jeff Newman Date: Sun, 28 Jan 2024 21:56:41 -0600 Subject: [PATCH] InitializeToursSettings --- activitysim/abm/models/initialize_tours.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/activitysim/abm/models/initialize_tours.py b/activitysim/abm/models/initialize_tours.py index 7cd416a89..da69e8d22 100644 --- a/activitysim/abm/models/initialize_tours.py +++ b/activitysim/abm/models/initialize_tours.py @@ -8,6 +8,8 @@ from activitysim.abm.models.util import tour_frequency as tf from activitysim.core import expressions, tracing, workflow +from activitysim.core.configuration import PydanticReadable +from activitysim.core.configuration.base import PreprocessorSettings from activitysim.core.input import read_input_table logger = logging.getLogger(__name__) @@ -76,6 +78,14 @@ def set_tour_index(state: workflow.State, tours, parent_tour_num_col, is_joint): return patched_tours +class InitializeToursSettings(PydanticReadable): + annotate_tours: PreprocessorSettings | None = None + """Preprocessor settings to annotate tours""" + + skip_patch_tour_ids: bool = False + """Skip patching tour_ids""" + + @workflow.step def initialize_tours( state: workflow.State, @@ -96,17 +106,17 @@ def initialize_tours( tours = tours[tours.person_id.isin(persons.index)] # annotate before patching tour_id to allow addition of REQUIRED_TOUR_COLUMNS defined above - model_settings = state.filesystem.read_model_settings( - "initialize_tours.yaml", mandatory=True + model_settings = InitializeToursSettings.read_settings_file( + state.filesystem, "initialize_tours.yaml", mandatory=True ) expressions.assign_columns( state, df=tours, - model_settings=model_settings.get("annotate_tours"), + model_settings=model_settings.annotate_tours, trace_label=tracing.extend_trace_label(trace_label, "annotate_tours"), ) - skip_patch_tour_ids = model_settings.get("skip_patch_tour_ids", False) + skip_patch_tour_ids = model_settings.skip_patch_tour_ids if skip_patch_tour_ids: pass else: