diff --git a/src/core/config.py b/src/core/config.py index 2d12a99..7362207 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Optional +from uuid import UUID from pydantic import BaseSettings, PostgresDsn, validator @@ -27,6 +28,11 @@ class Settings(BaseSettings): CATCHMENT_AREA_CAR_BUFFER_DEFAULT_SPEED = 80 # km/h CATCHMENT_AREA_HOLE_THRESHOLD_SQM = 200000 # 20 hectares, ~450m x 450m + BASE_STREET_NETWORK: Optional[UUID] = "903ecdca-b717-48db-bbce-0219e41439cf" + DEFAULT_STREET_NETWORK_NODE_LAYER_PROJECT_ID = ( + 37319 # Hardcoded until node layers are added to GOAT projects by default + ) + DATA_INSERT_BATCH_SIZE = 800 CELERY_BROKER_URL: Optional[str] = "pyamqp://guest@rabbitmq//" diff --git a/src/core/street_network/street_network_util.py b/src/core/street_network/street_network_util.py index e7ef816..599c911 100644 --- a/src/core/street_network/street_network_util.py +++ b/src/core/street_network/street_network_util.py @@ -14,23 +14,12 @@ class StreetNetworkUtil: def __init__(self, db_connection: AsyncSession): self.db_connection = db_connection - async def _get_layer_and_user_id(self, layer_project_id: int): - """Get the layer ID and user ID of the specified layer project ID.""" + async def _get_user_id(self, layer_id: UUID): + """Get the user ID of the specified layer ID.""" - layer_id: UUID = None user_id: UUID = None try: - # Get the associated layer ID - result = await self.db_connection.execute( - text( - f"""SELECT layer_id - FROM {settings.CUSTOMER_SCHEMA}.layer_project - WHERE id = {layer_project_id};""" - ) - ) - layer_id = UUID(str(result.fetchone()[0])) - # Get the user ID of the layer result = await self.db_connection.execute( text( @@ -41,55 +30,47 @@ async def _get_layer_and_user_id(self, layer_project_id: int): ) user_id = UUID(str(result.fetchone()[0])) except Exception: - raise ValueError( - f"Could not fetch layer and user ID for layer project ID {layer_project_id}." - ) + raise ValueError(f"Could not fetch user ID for layer ID {layer_id}.") - return layer_id, user_id + return user_id async def _get_street_network_tables( self, - street_network_edge_layer_project_id: int, - street_network_node_layer_project_id: int, + edge_layer_id: UUID, + node_layer_id: UUID, ): """Get table names and layer IDs of the edge and node tables.""" edge_table: str = None - edge_layer_id: UUID = None node_table: str = None - node_layer_id: UUID = None - # Get edge table name if a layer project ID is specified - if street_network_edge_layer_project_id: + # Get edge table name if a layer ID is specified + if edge_layer_id: try: # Get the edge layer ID and associated user ID - edge_layer_id, user_id = await self._get_layer_and_user_id( - street_network_edge_layer_project_id - ) + user_id = await self._get_user_id(edge_layer_id) # Produce the edge table name edge_table = f"{settings.USER_DATA_SCHEMA}.street_network_line_{str(user_id).replace('-', '')}" except Exception: raise ValueError( - f"Could not fetch edge table name for layer project ID {street_network_edge_layer_project_id}." + f"Could not fetch edge table name for layer ID {edge_layer_id}." ) - # Get node table name if a layer project ID is specified - if street_network_node_layer_project_id: + # Get node table name if a layer ID is specified + if node_layer_id: try: # Get the node layer ID and associated user ID - node_layer_id, user_id = await self._get_layer_and_user_id( - street_network_node_layer_project_id - ) + user_id = await self._get_user_id(node_layer_id) # Produce the node table name node_table = f"{settings.USER_DATA_SCHEMA}.street_network_point_{str(user_id).replace('-', '')}" except Exception: raise ValueError( - f"Could not fetch node table name for layer project ID {street_network_node_layer_project_id}." + f"Could not fetch node table name for layer ID {node_layer_id}." ) - return edge_table, edge_layer_id, node_table, node_layer_id + return edge_table, node_table async def _get_street_network_region_h3_3_cells(self, region_geofence_table: str): """Get list of H3_3 cells covering the street network region.""" @@ -118,8 +99,8 @@ async def _get_street_network_region_h3_3_cells(self, region_geofence_table: str async def fetch( self, - edge_layer_project_id: int, - node_layer_project_id: int, + edge_layer_id: UUID, + node_layer_id: UUID, region_geofence_table: str, ): """Fetch street network from specified layer and load into Polars dataframes.""" @@ -139,25 +120,19 @@ async def fetch( # Get table names and layer IDs of the edge and node tables ( street_network_edge_table, - street_network_edge_layer_id, street_network_node_table, - street_network_node_layer_id, - ) = await self._get_street_network_tables( - edge_layer_project_id, node_layer_project_id - ) + ) = await self._get_street_network_tables(edge_layer_id, node_layer_id) # Initialize cache street_network_cache = StreetNetworkCache() try: for h3_short in street_network_region_h3_3_cells: - if street_network_edge_layer_id is not None: - if street_network_cache.edge_cache_exists( - street_network_edge_layer_id, h3_short - ): + if edge_layer_id is not None: + if street_network_cache.edge_cache_exists(edge_layer_id, h3_short): # Read edge data from cache edge_df = street_network_cache.read_edge_cache( - street_network_edge_layer_id, h3_short + edge_layer_id, h3_short ) else: if settings.DEBUG_MODE: @@ -174,7 +149,7 @@ async def fetch( maxspeed_backward, source, target, h3_3, h3_6 FROM {street_network_edge_table} WHERE h3_3 = {h3_short} - AND layer_id = '{str(street_network_edge_layer_id)}' + AND layer_id = '{str(edge_layer_id)}' """, uri=settings.POSTGRES_DATABASE_URI, schema_overrides=SEGMENT_DATA_SCHEMA, @@ -185,19 +160,17 @@ async def fetch( # Write edge data into cache street_network_cache.write_edge_cache( - street_network_edge_layer_id, h3_short, edge_df + edge_layer_id, h3_short, edge_df ) # Update street network edge dictionary and memory usage street_network_edge[h3_short] = edge_df street_network_size += edge_df.estimated_size("gb") - if street_network_node_layer_id is not None: - if street_network_cache.node_cache_exists( - street_network_node_layer_id, h3_short - ): + if node_layer_id is not None: + if street_network_cache.node_cache_exists(node_layer_id, h3_short): # Read node data from cache node_df = street_network_cache.read_node_cache( - street_network_node_layer_id, h3_short + node_layer_id, h3_short ) else: if settings.DEBUG_MODE: @@ -211,7 +184,7 @@ async def fetch( SELECT node_id AS id, h3_3, h3_6 FROM {street_network_node_table} WHERE h3_3 = {h3_short} - AND layer_id = '{str(street_network_node_layer_id)}' + AND layer_id = '{str(node_layer_id)}' """, uri=settings.POSTGRES_DATABASE_URI, schema_overrides=CONNECTOR_DATA_SCHEMA, @@ -219,7 +192,7 @@ async def fetch( # Write node data into cache street_network_cache.write_node_cache( - street_network_node_layer_id, h3_short, node_df + node_layer_id, h3_short, node_df ) # Update street network node dictionary and memory usage @@ -231,15 +204,15 @@ async def fetch( ) # Raise error if a edge layer project ID is specified but no edge data is fetched - if edge_layer_project_id is not None and len(street_network_edge) == 0: + if edge_layer_id is not None and len(street_network_edge) == 0: raise RuntimeError( - f"Failed to fetch street network edge data for layer project ID {edge_layer_project_id}." + f"Failed to fetch street network edge data for layer project ID {edge_layer_id}." ) # Raise error if a node layer project ID is specified but no node data is fetched - if node_layer_project_id is not None and len(street_network_node) == 0: + if node_layer_id is not None and len(street_network_node) == 0: raise RuntimeError( - f"Failed to fetch street network node data for layer project ID {node_layer_project_id}." + f"Failed to fetch street network node data for layer project ID {node_layer_id}." ) end_time = time.time() diff --git a/src/crud/crud_catchment_area.py b/src/crud/crud_catchment_area.py index 35c6bcd..7c4f690 100644 --- a/src/crud/crud_catchment_area.py +++ b/src/crud/crud_catchment_area.py @@ -27,6 +27,7 @@ ) from src.schemas.error import BufferExceedsNetworkError, DisconnectedOriginError from src.schemas.status import ProcessingStatus +from src.utils import format_value_null_sql class CRUDCatchmentArea: @@ -116,24 +117,22 @@ async def read_network( sub_network = sub_df # Produce all network modifications required to apply the specified scenario - scenario_id = ( - f"'{obj_in.scenario_id}'" if obj_in.scenario_id is not None else "NULL" - ) - sql_produce_network_modifications = text( - f""" - SELECT basic.produce_network_modifications( - {scenario_id}, - {obj_in.street_network_edge_layer_project_id}, - {obj_in.street_network_node_layer_project_id} - ); - """ - ) - network_modifications_table = ( - await self.db_connection.execute(sql_produce_network_modifications) - ).fetchone()[0] + network_modifications_table = None + if obj_in.scenario_id: + sql_produce_network_modifications = text( + f""" + SELECT basic.produce_network_modifications( + {format_value_null_sql(obj_in.scenario_id)}, + {obj_in.street_network.edge_layer_project_id}, + {obj_in.street_network.node_layer_project_id} + ); + """ + ) + network_modifications_table = ( + await self.db_connection.execute(sql_produce_network_modifications) + ).fetchone()[0] - # Apply network modifications to the sub-network - if network_modifications_table is not None: + # Apply network modifications to the sub-network segments_to_discard = [] sql_get_network_modifications = text( f""" @@ -198,9 +197,9 @@ async def read_network( maxspeed_forward, maxspeed_backward, source, target, h3_3, h3_6, point_cell_index, point_h3_3 FROM basic.get_artificial_segments( - {obj_in.street_network_edge_layer_project_id}, - {f"'{network_modifications_table}'" if network_modifications_table is not None else "NULL"}, - '{input_table}', + {format_value_null_sql(settings.BASE_STREET_NETWORK)}, + {format_value_null_sql(network_modifications_table)}, + {format_value_null_sql(input_table)}, {num_points}, '{",".join(valid_segment_classes)}', 10 @@ -593,8 +592,8 @@ async def run(self, obj_in: ICatchmentAreaActiveMobility | ICatchmentAreaCar): # Fetch routing network (processed segments) and load into memory if self.routing_network is None: self.routing_network, _ = await StreetNetworkUtil(self.db_connection).fetch( - edge_layer_project_id=obj_in.street_network_edge_layer_project_id, - node_layer_project_id=None, + edge_layer_id=settings.BASE_STREET_NETWORK, + node_layer_id=None, region_geofence_table=settings.NETWORK_REGION_TABLE, ) routing_network = self.routing_network diff --git a/src/crud/crud_catchment_area_sync.py b/src/crud/crud_catchment_area_sync.py index 29e9f90..a34793c 100644 --- a/src/crud/crud_catchment_area_sync.py +++ b/src/crud/crud_catchment_area_sync.py @@ -22,6 +22,10 @@ from src.schemas.error import BufferExceedsNetworkError, DisconnectedOriginError from src.utils import make_dir +#################################################################################################### +# TODO: Refactor and fix +#################################################################################################### + class FetchRoutingNetwork: def __init__(self, db_cursor): diff --git a/src/schemas/catchment_area.py b/src/schemas/catchment_area.py index f317520..28205ca 100644 --- a/src/schemas/catchment_area.py +++ b/src/schemas/catchment_area.py @@ -1,10 +1,12 @@ from enum import Enum -from typing import List +from typing import List, Optional from uuid import UUID import polars as pl from pydantic import BaseModel, Field, validator +from src.core.config import settings + SEGMENT_DATA_SCHEMA = { "id": pl.Int64, "length_m": pl.Float64, @@ -231,6 +233,19 @@ def valid_num_steps(cls, v): return v +class CatchmentAreaStreetNetwork(BaseModel): + edge_layer_project_id: int = Field( + ..., + title="Edge Layer Project ID", + description="The layer project ID of the street network edge layer.", + ) + node_layer_project_id: int = Field( + default=settings.DEFAULT_STREET_NETWORK_NODE_LAYER_PROJECT_ID, + title="Node Layer Project ID", + description="The layer project ID of the street network node layer.", + ) + + class ICatchmentAreaActiveMobility(BaseModel): """Model for the active mobility catchment area request.""" @@ -257,15 +272,10 @@ class ICatchmentAreaActiveMobility(BaseModel): title="Scenario ID", description="The ID of the scenario that is to be applied on the base network.", ) - street_network_edge_layer_project_id: int = Field( - ..., - title="Street Network Edge Layer Project ID", - description="The layer project ID of the street network edge layer.", - ) - street_network_node_layer_project_id: int = Field( - ..., - title="Street Network Node Layer Project ID", - description="The layer project ID of the street network node layer.", + street_network: Optional[CatchmentAreaStreetNetwork] = Field( + None, + title="Street Network Layer Config", + description="The configuration of the street network layers to use.", ) catchment_area_type: CatchmentAreaType = Field( ..., @@ -288,6 +298,15 @@ class ICatchmentAreaActiveMobility(BaseModel): description="The ID of the layer the results should be saved.", ) + # Ensure street network is specified if a scenario ID is provided + @validator("street_network", pre=True, always=True) + def check_street_network(cls, v, values): + if values["scenario_id"] is not None and v is None: + raise ValueError( + "The street network must be set if a scenario ID is provided." + ) + return v + # Check that polygon difference exists if catchment area type is polygon @validator("polygon_difference", pre=True, always=True) def check_polygon_difference(cls, v, values): @@ -339,15 +358,10 @@ class ICatchmentAreaCar(BaseModel): title="Scenario ID", description="The ID of the scenario that is used for the routing.", ) - street_network_edge_layer_project_id: int = Field( - ..., - title="Street Network Edge Layer Project ID", - description="The layer project ID of the street network edge layer.", - ) - street_network_node_layer_project_id: int = Field( - ..., - title="Street Network Node Layer Project ID", - description="The layer project ID of the street network node layer.", + street_network: Optional[CatchmentAreaStreetNetwork] = Field( + None, + title="Street Network Layer Config", + description="The configuration of the street network layers to use.", ) catchment_area_type: CatchmentAreaType = Field( ..., @@ -370,6 +384,15 @@ class ICatchmentAreaCar(BaseModel): description="The ID of the layer the results should be saved.", ) + # Ensure street network is specified if a scenario ID is provided + @validator("street_network", pre=True, always=True) + def check_street_network(cls, v, values): + if values["scenario_id"] is not None and v is None: + raise ValueError( + "The street network must be set if a scenario ID is provided." + ) + return v + # Check that polygon difference exists if catchment area type is polygon @validator("polygon_difference", pre=True, always=True) def check_polygon_difference(cls, v, values): diff --git a/src/utils.py b/src/utils.py index beba0e7..8787537 100644 --- a/src/utils.py +++ b/src/utils.py @@ -188,3 +188,10 @@ def make_dir(dir_path: str): """Creates a new directory if it doesn't already exist""" if not os.path.exists(dir_path): os.makedirs(dir_path) + + +def format_value_null_sql(value) -> str: + if value is None: + return "NULL" + else: + return f"'{value}'"