Skip to content
This repository has been archived by the owner on Nov 15, 2024. It is now read-only.

Commit

Permalink
Further decouple street network handling from layer projects
Browse files Browse the repository at this point in the history
  • Loading branch information
nihar1024 committed Sep 6, 2024
1 parent 11cb495 commit 766b135
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 100 deletions.
6 changes: 6 additions & 0 deletions src/core/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional
from uuid import UUID

from pydantic import BaseSettings, PostgresDsn, validator

Expand Down Expand Up @@ -27,6 +28,11 @@ class Settings(BaseSettings):
CATCHMENT_AREA_CAR_BUFFER_DEFAULT_SPEED = 80 # km/h
CATCHMENT_AREA_HOLE_THRESHOLD_SQM = 200000 # 20 hectares, ~450m x 450m

BASE_STREET_NETWORK: Optional[UUID] = "903ecdca-b717-48db-bbce-0219e41439cf"
DEFAULT_STREET_NETWORK_NODE_LAYER_PROJECT_ID = (
37319 # Hardcoded until node layers are added to GOAT projects by default
)

DATA_INSERT_BATCH_SIZE = 800

CELERY_BROKER_URL: Optional[str] = "pyamqp://guest@rabbitmq//"
Expand Down
91 changes: 32 additions & 59 deletions src/core/street_network/street_network_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,12 @@ class StreetNetworkUtil:
def __init__(self, db_connection: AsyncSession):
self.db_connection = db_connection

async def _get_layer_and_user_id(self, layer_project_id: int):
"""Get the layer ID and user ID of the specified layer project ID."""
async def _get_user_id(self, layer_id: UUID):
"""Get the user ID of the specified layer ID."""

layer_id: UUID = None
user_id: UUID = None

try:
# Get the associated layer ID
result = await self.db_connection.execute(
text(
f"""SELECT layer_id
FROM {settings.CUSTOMER_SCHEMA}.layer_project
WHERE id = {layer_project_id};"""
)
)
layer_id = UUID(str(result.fetchone()[0]))

# Get the user ID of the layer
result = await self.db_connection.execute(
text(
Expand All @@ -41,55 +30,47 @@ async def _get_layer_and_user_id(self, layer_project_id: int):
)
user_id = UUID(str(result.fetchone()[0]))
except Exception:
raise ValueError(
f"Could not fetch layer and user ID for layer project ID {layer_project_id}."
)
raise ValueError(f"Could not fetch user ID for layer ID {layer_id}.")

return layer_id, user_id
return user_id

async def _get_street_network_tables(
self,
street_network_edge_layer_project_id: int,
street_network_node_layer_project_id: int,
edge_layer_id: UUID,
node_layer_id: UUID,
):
"""Get table names and layer IDs of the edge and node tables."""

edge_table: str = None
edge_layer_id: UUID = None
node_table: str = None
node_layer_id: UUID = None

# Get edge table name if a layer project ID is specified
if street_network_edge_layer_project_id:
# Get edge table name if a layer ID is specified
if edge_layer_id:
try:
# Get the edge layer ID and associated user ID
edge_layer_id, user_id = await self._get_layer_and_user_id(
street_network_edge_layer_project_id
)
user_id = await self._get_user_id(edge_layer_id)

# Produce the edge table name
edge_table = f"{settings.USER_DATA_SCHEMA}.street_network_line_{str(user_id).replace('-', '')}"
except Exception:
raise ValueError(
f"Could not fetch edge table name for layer project ID {street_network_edge_layer_project_id}."
f"Could not fetch edge table name for layer ID {edge_layer_id}."
)

# Get node table name if a layer project ID is specified
if street_network_node_layer_project_id:
# Get node table name if a layer ID is specified
if node_layer_id:
try:
# Get the node layer ID and associated user ID
node_layer_id, user_id = await self._get_layer_and_user_id(
street_network_node_layer_project_id
)
user_id = await self._get_user_id(node_layer_id)

# Produce the node table name
node_table = f"{settings.USER_DATA_SCHEMA}.street_network_point_{str(user_id).replace('-', '')}"
except Exception:
raise ValueError(
f"Could not fetch node table name for layer project ID {street_network_node_layer_project_id}."
f"Could not fetch node table name for layer ID {node_layer_id}."
)

return edge_table, edge_layer_id, node_table, node_layer_id
return edge_table, node_table

async def _get_street_network_region_h3_3_cells(self, region_geofence_table: str):
"""Get list of H3_3 cells covering the street network region."""
Expand Down Expand Up @@ -118,8 +99,8 @@ async def _get_street_network_region_h3_3_cells(self, region_geofence_table: str

async def fetch(
self,
edge_layer_project_id: int,
node_layer_project_id: int,
edge_layer_id: UUID,
node_layer_id: UUID,
region_geofence_table: str,
):
"""Fetch street network from specified layer and load into Polars dataframes."""
Expand All @@ -139,25 +120,19 @@ async def fetch(
# Get table names and layer IDs of the edge and node tables
(
street_network_edge_table,
street_network_edge_layer_id,
street_network_node_table,
street_network_node_layer_id,
) = await self._get_street_network_tables(
edge_layer_project_id, node_layer_project_id
)
) = await self._get_street_network_tables(edge_layer_id, node_layer_id)

# Initialize cache
street_network_cache = StreetNetworkCache()

try:
for h3_short in street_network_region_h3_3_cells:
if street_network_edge_layer_id is not None:
if street_network_cache.edge_cache_exists(
street_network_edge_layer_id, h3_short
):
if edge_layer_id is not None:
if street_network_cache.edge_cache_exists(edge_layer_id, h3_short):
# Read edge data from cache
edge_df = street_network_cache.read_edge_cache(
street_network_edge_layer_id, h3_short
edge_layer_id, h3_short
)
else:
if settings.DEBUG_MODE:
Expand All @@ -174,7 +149,7 @@ async def fetch(
maxspeed_backward, source, target, h3_3, h3_6
FROM {street_network_edge_table}
WHERE h3_3 = {h3_short}
AND layer_id = '{str(street_network_edge_layer_id)}'
AND layer_id = '{str(edge_layer_id)}'
""",
uri=settings.POSTGRES_DATABASE_URI,
schema_overrides=SEGMENT_DATA_SCHEMA,
Expand All @@ -185,19 +160,17 @@ async def fetch(

# Write edge data into cache
street_network_cache.write_edge_cache(
street_network_edge_layer_id, h3_short, edge_df
edge_layer_id, h3_short, edge_df
)
# Update street network edge dictionary and memory usage
street_network_edge[h3_short] = edge_df
street_network_size += edge_df.estimated_size("gb")

if street_network_node_layer_id is not None:
if street_network_cache.node_cache_exists(
street_network_node_layer_id, h3_short
):
if node_layer_id is not None:
if street_network_cache.node_cache_exists(node_layer_id, h3_short):
# Read node data from cache
node_df = street_network_cache.read_node_cache(
street_network_node_layer_id, h3_short
node_layer_id, h3_short
)
else:
if settings.DEBUG_MODE:
Expand All @@ -211,15 +184,15 @@ async def fetch(
SELECT node_id AS id, h3_3, h3_6
FROM {street_network_node_table}
WHERE h3_3 = {h3_short}
AND layer_id = '{str(street_network_node_layer_id)}'
AND layer_id = '{str(node_layer_id)}'
""",
uri=settings.POSTGRES_DATABASE_URI,
schema_overrides=CONNECTOR_DATA_SCHEMA,
)

# Write node data into cache
street_network_cache.write_node_cache(
street_network_node_layer_id, h3_short, node_df
node_layer_id, h3_short, node_df
)

# Update street network node dictionary and memory usage
Expand All @@ -231,15 +204,15 @@ async def fetch(
)

# Raise error if a edge layer project ID is specified but no edge data is fetched
if edge_layer_project_id is not None and len(street_network_edge) == 0:
if edge_layer_id is not None and len(street_network_edge) == 0:
raise RuntimeError(
f"Failed to fetch street network edge data for layer project ID {edge_layer_project_id}."
f"Failed to fetch street network edge data for layer project ID {edge_layer_id}."
)

# Raise error if a node layer project ID is specified but no node data is fetched
if node_layer_project_id is not None and len(street_network_node) == 0:
if node_layer_id is not None and len(street_network_node) == 0:
raise RuntimeError(
f"Failed to fetch street network node data for layer project ID {node_layer_project_id}."
f"Failed to fetch street network node data for layer project ID {node_layer_id}."
)

end_time = time.time()
Expand Down
43 changes: 21 additions & 22 deletions src/crud/crud_catchment_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from src.schemas.error import BufferExceedsNetworkError, DisconnectedOriginError
from src.schemas.status import ProcessingStatus
from src.utils import format_value_null_sql


class CRUDCatchmentArea:
Expand Down Expand Up @@ -116,24 +117,22 @@ async def read_network(
sub_network = sub_df

# Produce all network modifications required to apply the specified scenario
scenario_id = (
f"'{obj_in.scenario_id}'" if obj_in.scenario_id is not None else "NULL"
)
sql_produce_network_modifications = text(
f"""
SELECT basic.produce_network_modifications(
{scenario_id},
{obj_in.street_network_edge_layer_project_id},
{obj_in.street_network_node_layer_project_id}
);
"""
)
network_modifications_table = (
await self.db_connection.execute(sql_produce_network_modifications)
).fetchone()[0]
network_modifications_table = None
if obj_in.scenario_id:
sql_produce_network_modifications = text(
f"""
SELECT basic.produce_network_modifications(
{format_value_null_sql(obj_in.scenario_id)},
{obj_in.street_network.edge_layer_project_id},
{obj_in.street_network.node_layer_project_id}
);
"""
)
network_modifications_table = (
await self.db_connection.execute(sql_produce_network_modifications)
).fetchone()[0]

# Apply network modifications to the sub-network
if network_modifications_table is not None:
# Apply network modifications to the sub-network
segments_to_discard = []
sql_get_network_modifications = text(
f"""
Expand Down Expand Up @@ -198,9 +197,9 @@ async def read_network(
maxspeed_forward, maxspeed_backward, source, target,
h3_3, h3_6, point_cell_index, point_h3_3
FROM basic.get_artificial_segments(
{obj_in.street_network_edge_layer_project_id},
{f"'{network_modifications_table}'" if network_modifications_table is not None else "NULL"},
'{input_table}',
{format_value_null_sql(settings.BASE_STREET_NETWORK)},
{format_value_null_sql(network_modifications_table)},
{format_value_null_sql(input_table)},
{num_points},
'{",".join(valid_segment_classes)}',
10
Expand Down Expand Up @@ -593,8 +592,8 @@ async def run(self, obj_in: ICatchmentAreaActiveMobility | ICatchmentAreaCar):
# Fetch routing network (processed segments) and load into memory
if self.routing_network is None:
self.routing_network, _ = await StreetNetworkUtil(self.db_connection).fetch(
edge_layer_project_id=obj_in.street_network_edge_layer_project_id,
node_layer_project_id=None,
edge_layer_id=settings.BASE_STREET_NETWORK,
node_layer_id=None,
region_geofence_table=settings.NETWORK_REGION_TABLE,
)
routing_network = self.routing_network
Expand Down
4 changes: 4 additions & 0 deletions src/crud/crud_catchment_area_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
from src.schemas.error import BufferExceedsNetworkError, DisconnectedOriginError
from src.utils import make_dir

####################################################################################################
# TODO: Refactor and fix
####################################################################################################


class FetchRoutingNetwork:
def __init__(self, db_cursor):
Expand Down
Loading

0 comments on commit 766b135

Please sign in to comment.