diff --git a/backend/src/graph.rs b/backend/src/graph.rs index 2f5beca..4f6f74d 100644 --- a/backend/src/graph.rs +++ b/backend/src/graph.rs @@ -43,6 +43,7 @@ pub enum Direction { None, } +// TODO Justify why PublicTransit isn't captured here #[derive(Clone, Copy, Enum, Debug, Serialize, Deserialize)] pub enum Mode { Car, diff --git a/backend/src/isochrone.rs b/backend/src/isochrone.rs index 07151a1..ddfd79f 100644 --- a/backend/src/isochrone.rs +++ b/backend/src/isochrone.rs @@ -2,6 +2,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet}; use std::time::Duration; use anyhow::Result; +use chrono::NaiveTime; use geo::{Coord, Densify}; use utils::{Grid, PriorityQueueItem}; @@ -14,13 +15,14 @@ pub fn calculate( req: Coord, mode: Mode, contours: bool, + public_transit: bool, mut timer: Timer, ) -> Result { // 15 minutes let limit = Duration::from_secs(15 * 60); timer.step("get_costs"); - let cost_per_road = get_costs(graph, req, mode, limit); + let cost_per_road = get_costs(graph, req, mode, public_transit, limit); timer.push("render to GJ"); // Show cost per road @@ -51,7 +53,17 @@ pub fn calculate( Ok(x) } -fn get_costs(graph: &Graph, req: Coord, mode: Mode, limit: Duration) -> HashMap { +fn get_costs( + graph: &Graph, + req: Coord, + mode: Mode, + public_transit: bool, + limit: Duration, +) -> HashMap { + // TODO plumb in + let start_time = NaiveTime::from_hms_opt(7, 0, 0).unwrap(); + let end_time = start_time + limit; + let start = graph.closest_intersection[mode] .nearest_neighbor(&[req.x, req.y]) .unwrap() @@ -59,23 +71,25 @@ fn get_costs(graph: &Graph, req: Coord, mode: Mode, limit: Duration) -> HashMap< let mut visited: HashSet = HashSet::new(); let mut cost_per_road: HashMap = HashMap::new(); - let mut queue: BinaryHeap> = BinaryHeap::new(); + let mut queue: BinaryHeap> = BinaryHeap::new(); - queue.push(PriorityQueueItem::new(Duration::ZERO, start)); + queue.push(PriorityQueueItem::new(start_time, start)); while let Some(current) = queue.pop() { if visited.contains(¤t.value) { continue; } visited.insert(current.value); - if current.cost > limit { + if current.cost > end_time { continue; } for r in &graph.intersections[current.value.0].roads { let road = &graph.roads[r.0]; let total_cost = current.cost + cost(road, mode); - cost_per_road.entry(*r).or_insert(total_cost); + cost_per_road + .entry(*r) + .or_insert((total_cost - start_time).to_std().unwrap()); if road.src_i == current.value && road.allows_forwards(mode) { queue.push(PriorityQueueItem::new(total_cost, road.dst_i)); @@ -83,6 +97,23 @@ fn get_costs(graph: &Graph, req: Coord, mode: Mode, limit: Duration) -> HashMap< if road.dst_i == current.value && road.allows_backwards(mode) { queue.push(PriorityQueueItem::new(total_cost, road.src_i)); } + + if public_transit { + for stop1 in &road.stops { + // Find all trips leaving from this step before the end_time + for next_step in graph.gtfs.trips_from( + *stop1, + current.cost, + (end_time - current.cost).to_std().unwrap(), + ) { + // TODO Awkwardly, arrive at both intersections for the next stop's road + let stop2_road = &graph.roads[graph.gtfs.stops[next_step.stop2.0].road.0]; + for i in [stop2_road.src_i, stop2_road.dst_i] { + queue.push(PriorityQueueItem::new(next_step.time2, i)); + } + } + } + } } } diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 6c52637..79191ee 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -84,7 +84,7 @@ impl MapModel { "car" => Mode::Car, "bicycle" => Mode::Bicycle, "foot" => Mode::Foot, - // TODO Unimplemented + // Plumbed separately "transit" => Mode::Foot, // TODO error plumbing x => panic!("bad input {x}"), @@ -94,6 +94,7 @@ impl MapModel { start, mode, req.contours, + req.mode == "transit", Timer::new("isochrone request", None), ) .map_err(err_to_js)