diff --git a/.gitignore b/.gitignore
index 90fae26f..509aea12 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,6 +14,7 @@
!tests/data/small-osm.pbf
!tests/data/chester-20230816-small_gtfs.zip
!tests/data/gtfs/newport-20230613_gtfs.zip
+!tests/data/gtfs/repeated_pair_gtfs_fixture.zip
!src/transport_performance/data/gtfs/route_lookup.pkl
!tests/data/gtfs/report/html_template.html
!tests/data/metrics/mock_centroid_gdf.pkl
diff --git a/src/transport_performance/gtfs/cleaners.py b/src/transport_performance/gtfs/cleaners.py
index 3d4d7157..863f67cf 100644
--- a/src/transport_performance/gtfs/cleaners.py
+++ b/src/transport_performance/gtfs/cleaners.py
@@ -1,9 +1,26 @@
"""A set of functions that clean the gtfs data."""
from typing import Union
+import warnings
import numpy as np
-
-from transport_performance.utils.defence import _gtfs_defence, _check_iterable
+import pandas as pd
+from gtfs_kit.cleaners import (
+ clean_ids as clean_ids_gk,
+ clean_route_short_names as clean_route_short_names_gk,
+ clean_times as clean_times_gk,
+ drop_zombies as drop_zombies_gk,
+)
+
+from transport_performance.gtfs.gtfs_utils import (
+ _get_validation_warnings,
+ _remove_validation_row,
+)
+from transport_performance.utils.defence import (
+ _gtfs_defence,
+ _check_iterable,
+ _type_defence,
+ _check_attribute,
+)
def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None:
@@ -32,19 +49,21 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None:
if isinstance(trip_id, str):
trip_id = [trip_id]
- # _check_iterable only takes lists, therefore convert numpy arrays
- if isinstance(trip_id, np.ndarray):
- trip_id = list(trip_id)
-
# ensure trip ids are string
_check_iterable(
iterable=trip_id,
param_nm="trip_id",
- iterable_type=list,
+ iterable_type=type(trip_id),
check_elements=True,
exp_type=str,
)
+ # warn users if passed one of the passed trip_id's is not present in the
+ # GTFS.
+ for _id in trip_id:
+ if _id not in gtfs.feed.trips.trip_id.unique():
+ warnings.warn(UserWarning(f"trip_id '{_id}' not found in GTFS"))
+
# drop relevant records from tables
gtfs.feed.trips = gtfs.feed.trips[
~gtfs.feed.trips["trip_id"].isin(trip_id)
@@ -61,44 +80,58 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None:
return None
-def clean_consecutive_stop_fast_travel_warnings(
- gtfs, validate: bool = False
-) -> None:
- """Clean 'Fast Travel Between Consecutive Stops' warnings from validity_df.
+def _clean_fast_travel_preparation(gtfs, warning_re: str) -> pd.DataFrame:
+ """Prepare to clean fast travel errors.
+
+ At the beggining of both of the fast travel cleaners, the gtfs is type
+ checked, attr checked and then warnings are obtained. Because of this, this
+ has been functionalised
Parameters
----------
- gtfs : GtfsInstance
- The GtfsInstance to clean warnings within
- validate : bool, optional
- Whether or not to validate the gtfs before carrying out this cleaning
- operation
+ gtfs : _type_
+ The GtfsInstance.
+ warning_re : str
+ Regex used to obtain warnings.
Returns
-------
- None
+ pd.DataFrame
+ A dataframe containing warnings.
"""
- # defences
_gtfs_defence(gtfs, "gtfs")
- if "validity_df" not in gtfs.__dict__.keys() and not validate:
- raise AttributeError(
+ _type_defence(warning_re, "warning_re", str)
+ _check_attribute(
+ gtfs,
+ "validity_df",
+ message=(
"The gtfs has not been validated, therefore no"
"warnings can be identified. You can pass "
"validate=True to this function to validate the "
"gtfs."
- )
+ ),
+ )
+ needed_warning = _get_validation_warnings(gtfs, warning_re)
+ return needed_warning
+
- if validate:
- gtfs.is_valid()
+def clean_consecutive_stop_fast_travel_warnings(gtfs) -> None:
+ """Clean 'Fast Travel Between Consecutive Stops' warnings from validity_df.
- needed_warning = (
- gtfs.validity_df[
- gtfs.validity_df["message"]
- == "Fast Travel Between Consecutive Stops"
- ]
- .copy()
- .values
+ Parameters
+ ----------
+ gtfs : GtfsInstance
+ The GtfsInstance to clean warnings within
+
+ Returns
+ -------
+ None
+
+ """
+ # defences
+ needed_warning = _clean_fast_travel_preparation(
+ gtfs, "Fast Travel Between Consecutive Stops"
)
if len(needed_warning) < 1:
@@ -116,45 +149,22 @@ def clean_consecutive_stop_fast_travel_warnings(
return None
-def clean_multiple_stop_fast_travel_warnings(
- gtfs, validate: bool = False
-) -> None:
+def clean_multiple_stop_fast_travel_warnings(gtfs) -> None:
"""Clean 'Fast Travel Over Multiple Stops' warnings from validity_df.
Parameters
----------
gtfs : GtfsInstance
The GtfsInstance to clean warnings within
- validate : bool, optional
- Whether or not to validate the gtfs before carrying out this cleaning
- operation
Returns
-------
None
"""
- # defences
- _gtfs_defence(gtfs, "gtfs")
- if "validity_df" not in gtfs.__dict__.keys() and not validate:
- raise AttributeError(
- "The gtfs has not been validated, therefore no"
- "warnings can be identified. You can pass "
- "validate=True to this function to validate the "
- "gtfs."
- )
-
- if validate:
- gtfs.is_valid()
-
- needed_warning = (
- gtfs.validity_df[
- gtfs.validity_df["message"] == "Fast Travel Over Multiple Stops"
- ]
- .copy()
- .values
+ needed_warning = _clean_fast_travel_preparation(
+ gtfs, "Fast Travel Over Multiple Stops"
)
-
if len(needed_warning) < 1:
return None
@@ -168,3 +178,122 @@ def clean_multiple_stop_fast_travel_warnings(
~gtfs.multiple_stops_invalid["trip_id"].isin(trip_ids)
]
return None
+
+
+def core_cleaners(
+ gtfs,
+ clean_ids: bool = True,
+ clean_times: bool = True,
+ clean_route_short_names: bool = True,
+ drop_zombies: bool = True,
+) -> None:
+ """Clean the gtfs with the core cleaners of gtfs-kit.
+
+ The source code for the cleaners, along with detailed descriptions of the
+ cleaning they are performing can be found here:
+ https://github.com/mrcagney/gtfs_kit/blob/master/gtfs_kit/cleaners.py
+
+ All credit for these cleaners goes to the creators of the gtfs_kit package.
+ HOMEPAGE: https://github.com/mrcagney/gtfs_kit
+
+ Parameters
+ ----------
+ gtfs : GtfsInstance
+ The gtfs to clean
+ clean_ids : bool, optional
+ Whether or not to use clean_ids, by default True
+ clean_times : bool, optional
+ Whether or not to use clean_times, by default True
+ clean_route_short_names : bool, optional
+ Whether or not to use clean_route_short_names, by default True
+ drop_zombies : bool, optional
+ Whether or not to use drop_zombies, by default True
+
+ Returns
+ -------
+ None
+
+ """
+ # defences
+ _gtfs_defence(gtfs, "gtfs")
+ _type_defence(clean_ids, "clean_ids", bool)
+ _type_defence(clean_times, "clean_times", bool)
+ _type_defence(clean_route_short_names, "clean_route_short_names", bool)
+ _type_defence(drop_zombies, "drop_zombies", bool)
+ # cleaning
+ if clean_ids:
+ clean_ids_gk(gtfs.feed)
+ if clean_times:
+ clean_times_gk(gtfs.feed)
+ if clean_route_short_names:
+ clean_route_short_names_gk(gtfs.feed)
+ if drop_zombies:
+ try:
+ drop_zombies_gk(gtfs.feed)
+ except KeyError:
+ warnings.warn(
+ UserWarning(
+ "The drop_zombies cleaner was unable to operate on "
+ "clean_feed as the trips table has no shape_id column"
+ )
+ )
+ return None
+
+
+def clean_unrecognised_column_warnings(gtfs) -> None:
+ """Clean warnings for unrecognised columns.
+
+ Parameters
+ ----------
+ gtfs : GtfsInstance
+ The GtfsInstance to clean warnings from
+
+ Returns
+ -------
+ None
+
+ """
+ _gtfs_defence(gtfs, "gtfs")
+ warnings = _get_validation_warnings(
+ gtfs=gtfs, message="Unrecognized column .*"
+ )
+ for warning in warnings:
+ tbl = gtfs.table_map[warning[2]]
+ # parse column from warning message
+ column = warning[1].split("column")[1].strip()
+ tbl.drop(column, inplace=True, axis=1)
+ _remove_validation_row(gtfs, warning[1])
+ return None
+
+
+def clean_duplicate_stop_times(gtfs) -> None:
+ """Clean duplicates from stop_times with repeated pair (trip_id, ...
+
+ departure_time.
+
+ Parameters
+ ----------
+ gtfs : GtfsInstance
+ The gtfs to clean
+
+ Returns
+ -------
+ None
+
+ """
+ _gtfs_defence(gtfs, "gtfs")
+ warning_re = r".* \(trip_id, departure_time\)"
+ # we are only expecting one warning here
+ warning = _get_validation_warnings(gtfs, warning_re)
+ if len(warning) == 0:
+ return None
+ warning = warning[0]
+ # drop from actual table
+ gtfs.table_map[warning[2]].drop_duplicates(
+ subset=["arrival_time", "departure_time", "trip_id", "stop_id"],
+ inplace=True,
+ )
+ _remove_validation_row(gtfs, message=warning_re)
+ # re-validate with gtfs-kit validator
+ gtfs.is_valid({"core_validation": None})
+ return None
diff --git a/src/transport_performance/gtfs/gtfs_utils.py b/src/transport_performance/gtfs/gtfs_utils.py
index 100d2924..60fffdad 100644
--- a/src/transport_performance/gtfs/gtfs_utils.py
+++ b/src/transport_performance/gtfs/gtfs_utils.py
@@ -5,11 +5,11 @@
from pyprojroot import here
import pandas as pd
import os
-import math
import plotly.graph_objects as go
from typing import Union, TYPE_CHECKING
import pathlib
from geopandas import GeoDataFrame
+import numpy as np
import warnings
if TYPE_CHECKING:
@@ -19,9 +19,12 @@
_is_expected_filetype,
_check_iterable,
_type_defence,
+ _check_attribute,
+ _gtfs_defence,
_validate_datestring,
_enforce_file_extension,
- _gtfs_defence,
+ _check_parent_dir_exists,
+ _check_item_in_iter,
)
from transport_performance.utils.constants import PKG_PATH
@@ -284,14 +287,20 @@ def _add_validation_row(
An error is raised if the validity df does not exist
"""
- # TODO: add dtype defences from defence.py once gtfs-html-new is merged
- if "validity_df" not in gtfs.__dict__.keys():
- raise AttributeError(
+ _gtfs_defence(gtfs, "gtfs")
+ _type_defence(_type, "_type", str)
+ _type_defence(message, "message", str)
+ _type_defence(rows, "rows", list)
+ _check_attribute(
+ gtfs,
+ "validity_df",
+ message=(
"The validity_df does not exist as an "
"attribute of your GtfsInstance object, \n"
"Did you forget to run the .is_valid() method?"
- )
-
+ ),
+ )
+ _check_item_in_iter(_type, ["warning", "error"], "_type")
temp_df = pd.DataFrame(
{
"type": [_type],
@@ -311,7 +320,6 @@ def filter_gtfs_around_trip(
gtfs,
trip_id: str,
buffer_dist: int = 10000,
- units: str = "m",
crs: str = "27700",
out_pth=os.path.join("data", "external", "trip_gtfs.zip"),
) -> None:
@@ -325,8 +333,6 @@ def filter_gtfs_around_trip(
The trip ID
buffer_dist : int, optional
The distance to create a buffer around the trip, by default 10000
- units : str, optional
- Distance units of the original GTFS, by default "m"
crs : str, optional
The CRS to use for adding a buffer, by default "27700"
out_pth : _type_, optional
@@ -343,21 +349,21 @@ def filter_gtfs_around_trip(
An error is raised if a shapeID is not available
"""
- # TODO: Add datatype defences once merged
+ _gtfs_defence(gtfs, "gtfs")
+ _type_defence(trip_id, "trip_id", str)
+ _type_defence(buffer_dist, "buffer_dist", int)
+ _type_defence(crs, "crs", str)
+ _check_parent_dir_exists(out_pth, "out_pth", create=True)
trips = gtfs.feed.trips
shapes = gtfs.feed.shapes
shape_id = list(trips[trips["trip_id"] == trip_id]["shape_id"])[0]
# defence
- # try/except for math.isnan() returning TypeError for strings
- try:
- if math.isnan(shape_id):
- raise ValueError(
- "'shape_id' not available for trip with trip_id: " f"{trip_id}"
- )
- except TypeError:
- pass
+ if pd.isna(shape_id):
+ raise ValueError(
+ "'shape_id' not available for trip with trip_id: " f"{trip_id}"
+ )
# create a buffer around the trip
trip_shape = shapes[shapes["shape_id"] == shape_id]
@@ -376,7 +382,7 @@ def filter_gtfs_around_trip(
in_pth=gtfs.gtfs_path,
bbox=list(bbox),
crs=crs,
- units=units,
+ units=gtfs.units,
out_pth=out_pth,
)
@@ -465,3 +471,143 @@ def convert_pandas_to_plotly(
if return_html:
return fig.to_html(full_html=False)
return fig
+
+
+def _get_validation_warnings(
+ gtfs, message: str, return_type: str = "values"
+) -> pd.DataFrame:
+ """Get warnings from the validity_df table based on a regex.
+
+ Parameters
+ ----------
+ gtfs : GtfsInstance()
+ The gtfs instance to obtain the warnings from.
+ message : str
+ The regex to use for filtering the warnings.
+ return_type : str, optional
+ The return type of the warnings. Can be either 'values' or 'dataframe',
+ by default 'values'
+
+ Returns
+ -------
+ pd.DataFrame
+ A dataframe containing all warnings matching the regex.
+
+ """
+ _gtfs_defence(gtfs, "gtfs")
+ _check_attribute(
+ gtfs,
+ "validity_df",
+ message=(
+ "The gtfs has not been validated, therefore no"
+ "warnings can be identified."
+ ),
+ )
+ _type_defence(message, "message", str)
+ return_type = return_type.lower().strip()
+ if return_type not in ["values", "dataframe"]:
+ raise ValueError(
+ "'return_type' expected one of ['values', 'dataframe]"
+ f". Got {return_type}"
+ )
+ needed_warnings = gtfs.validity_df[
+ gtfs.validity_df["message"].str.contains(message, regex=True, na=False)
+ ].copy()
+ if return_type == "dataframe":
+ return needed_warnings
+ return needed_warnings.values
+
+
+def _remove_validation_row(
+ gtfs, message: str = None, index: Union[list, np.array] = None
+) -> None:
+ """Remove rows from the 'validity_df' attr of a GtfsInstance().
+
+ Both a regex on messages and index locations can be used to drop drops. If
+ values are passed to both the 'index' and 'message' params, rows will be
+ removed using the regex passed to the 'message' param.
+
+ Parameters
+ ----------
+ gtfs : GtfsInstance
+ The GtfsInstance to remove the warnings/errors from.
+ message : str, optional
+ The regex to filter messages by, by default None.
+ index : Union[list, np.array], optional
+ The index locations of rows to be removed, by default None.
+
+ Returns
+ -------
+ None
+
+ Raises
+ ------
+ ValueError
+ Error raised if both 'message' and 'index' params are None.
+ UserWarning
+ A warning is raised if both 'message' and 'index' params are not None.
+
+ """
+ # defences
+ _gtfs_defence(gtfs, "gtfs")
+ _type_defence(message, "message", (str, type(None)))
+ _type_defence(index, "index", (list, np.ndarray, type(None)))
+ _check_attribute(gtfs, "validity_df")
+ if message is None and index is None:
+ raise ValueError(
+ "Both 'message' and 'index' are None, therefore no"
+ "warnings/errors are able to be cleaned."
+ )
+ if message is not None and index is not None:
+ warnings.warn(
+ UserWarning(
+ "Both 'index' and 'message' are not None. Warnings/"
+ "Errors have been cleaned on 'message'"
+ )
+ )
+ # remove row from validation table
+ if message is not None:
+ gtfs.validity_df = gtfs.validity_df[
+ ~gtfs.validity_df.message.str.contains(
+ message, regex=True, na=False
+ )
+ ]
+ return None
+
+ gtfs.validity_df = gtfs.validity_df.loc[
+ list(set(gtfs.validity_df.index) - set(index))
+ ]
+ return None
+
+
+def _function_pipeline(
+ gtfs, func_map: dict, operations: Union[dict, type(None)]
+) -> None:
+ """Iterate through and act on a functional pipeline."""
+ _gtfs_defence(gtfs, "gtfs")
+ _type_defence(func_map, "func_map", dict)
+ _type_defence(operations, "operations", (dict, type(None)))
+ if operations:
+ for key in operations.keys():
+ if key not in func_map.keys():
+ raise KeyError(
+ f"'{key}' function passed to 'operations' is not a "
+ "known operation. Known operation include: "
+ f"{func_map.keys()}"
+ )
+ for operation in operations:
+ # check value is dict or none (for kwargs)
+ _type_defence(
+ operations[operation],
+ f"operations[{operation}]",
+ (dict, type(None)),
+ )
+ operations[operation] = (
+ {} if operations[operation] is None else operations[operation]
+ )
+ func_map[operation](gtfs=gtfs, **operations[operation])
+ # if no operations passed, carry out all operations
+ else:
+ for operation in func_map:
+ func_map[operation](gtfs=gtfs)
+ return None
diff --git a/src/transport_performance/gtfs/validation.py b/src/transport_performance/gtfs/validation.py
index dcad1e72..3ce69e94 100644
--- a/src/transport_performance/gtfs/validation.py
+++ b/src/transport_performance/gtfs/validation.py
@@ -17,19 +17,15 @@
from plotly.graph_objects import Figure as PlotlyFigure
from geopandas import GeoDataFrame
-from transport_performance.gtfs.validators import (
- validate_travel_over_multiple_stops,
- validate_travel_between_consecutive_stops,
-)
+import transport_performance.gtfs.cleaners as cleaners
+import transport_performance.gtfs.validators as gtfs_validators
from transport_performance.gtfs.calendar import create_calendar_from_dates
-from transport_performance.gtfs.cleaners import (
- clean_consecutive_stop_fast_travel_warnings,
- clean_multiple_stop_fast_travel_warnings,
-)
from transport_performance.gtfs.routes import (
scrape_route_type_lookup,
get_saved_route_type_lookup,
)
+
+from transport_performance.gtfs.gtfs_utils import _function_pipeline
from transport_performance.utils.defence import (
_is_expected_filetype,
_check_namespace_export,
@@ -40,15 +36,44 @@
_check_attribute,
_enforce_file_extension,
)
-
from transport_performance.gtfs.report.report_utils import (
TemplateHTML,
_set_up_report_dir,
)
-
+from transport_performance.utils.constants import (
+ PKG_PATH,
+)
from transport_performance.gtfs.gtfs_utils import filter_gtfs
-from transport_performance.utils.constants import PKG_PATH
+# THESE MAPPINGS CAN NOT BE MOVED TO CONSTANTS AS THEY INTRODUCE DEPENDENCY
+# ISSUES.
+CLEAN_FEED_FUNCTION_MAP = {
+ "core_cleaners": cleaners.core_cleaners,
+ "clean_unrecognised_column_warnings": (
+ cleaners.clean_unrecognised_column_warnings
+ ),
+ "clean_duplicate_stop_times": (cleaners.clean_duplicate_stop_times),
+ "clean_consecutive_stop_fast_travel_warnings": (
+ cleaners.clean_consecutive_stop_fast_travel_warnings
+ ),
+ "clean_multiple_stop_fast_travel_warnings": (
+ cleaners.clean_multiple_stop_fast_travel_warnings
+ ),
+}
+
+VALIDATE_FEED_FUNC_MAP = {
+ "core_validation": gtfs_validators.core_validation,
+ "validate_gtfs_files": gtfs_validators.validate_gtfs_files,
+ "validate_travel_between_consecutive_stops": (
+ gtfs_validators.validate_travel_between_consecutive_stops
+ ),
+ "validate_travel_over_multiple_stops": (
+ gtfs_validators.validate_travel_over_multiple_stops
+ ),
+ "validate_route_type_warnings": (
+ gtfs_validators.validate_route_type_warnings
+ ),
+}
def _get_intermediate_dates(
@@ -182,6 +207,8 @@ class GtfsInstance:
A gtfs_kit feed produced using the files at `gtfs_pth` on init.
gtfs_path : Union[str, pathlib.Path]
The path to the GTFS archive.
+ units : str
+ Spatial units of the GTFS file, defaults to "km".
file_list: list
Files in the GTFS archive.
validity_df: pd.DataFrame
@@ -283,7 +310,7 @@ def __init__(
accepted_units = ["m", "km"]
_check_item_in_iter(units, accepted_units, "units")
-
+ self.units = units
self.feed = gk.read_feed(gtfs_pth, dist_units=units)
self.gtfs_path = gtfs_pth
if route_lookup_pth is not None:
@@ -329,6 +356,16 @@ def __init__(
"shapes": [],
}
+ # CONSTANT TO LINK TO GTFS TABLES USING STRINGS
+ self.table_map = {
+ "agency": self.feed.agency,
+ "routes": self.feed.routes,
+ "stop_times": self.feed.stop_times,
+ "stops": self.feed.stops,
+ "trips": self.feed.trips,
+ "calendar": self.feed.calendar,
+ }
+
def ensure_populated_calendar(self) -> None:
"""If calendar is absent, creates one from calendar_dates.
@@ -376,14 +413,13 @@ def get_gtfs_files(self) -> list:
self.file_list = file_list
return self.file_list
- def is_valid(self, far_stops: bool = True) -> pd.DataFrame:
+ def is_valid(self, validators: dict = None) -> pd.DataFrame:
"""Check a feed is valid with `gtfs_kit`.
Parameters
----------
- far_stops : bool, optional
- Whether or not to perform validation for far stops (both
- between consecutive stops and over multiple stops)
+ validators : dict, optional
+ A dictionary of function name to kwargs mappings.
Returns
-------
@@ -391,10 +427,14 @@ def is_valid(self, far_stops: bool = True) -> pd.DataFrame:
Table of errors, warnings & their descriptions.
"""
- self.validity_df = self.feed.validate()
- if far_stops:
- validate_travel_between_consecutive_stops(self)
- validate_travel_over_multiple_stops(self)
+ _type_defence(validators, "validators", (dict, type(None)))
+ # create validity df
+ self.validity_df = pd.DataFrame(
+ columns=["type", "message", "table", "rows"]
+ )
+ _function_pipeline(
+ gtfs=self, func_map=VALIDATE_FEED_FUNC_MAP, operations=validators
+ )
return self.validity_df
def print_alerts(self, alert_type: str = "error") -> None:
@@ -445,34 +485,27 @@ def print_alerts(self, alert_type: str = "error") -> None:
return None
- def clean_feed(
- self, validate: bool = False, fast_travel: bool = True
- ) -> None:
- """Attempt to clean feed using `gtfs_kit`.
+ def clean_feed(self, cleansers: dict = None) -> None:
+ """Clean the gtfs feed.
Parameters
----------
- validate: bool, optional
- Whether or not to validate the dataframe before cleaning
- fast_travel: bool, optional
- Whether or not to clean warnings related to fast travel.
+ cleansers : dict, optional
+ A mapping of cleansing functions and kwargs, by default None
+
+ Returns
+ -------
+ None
"""
- _type_defence(fast_travel, "fast_travel", bool)
- _type_defence(validate, "valiidate", bool)
- if validate:
- self.is_valid(far_stops=fast_travel)
- try:
- # In cases where shape_id is missing, keyerror is raised.
- # https://developers.google.com/transit/gtfs/reference#shapestxt
- # shows that shapes.txt is optional file.
- self.feed = self.feed.clean()
- if fast_travel:
- clean_consecutive_stop_fast_travel_warnings(self)
- clean_multiple_stop_fast_travel_warnings(self)
- except KeyError:
- # TODO: Issue 74 - Improve this to clean feed when KeyError raised
- print("KeyError. Feed was not cleaned.")
+ # DEV NOTE: Opting not to allow for validation in clean_feed().
+ # .is_valid() should be used before hand.
+ # DEV NOTE 2: Use of param name 'cleansers' is to avoid conflicts
+ _type_defence(cleansers, "cleansers", (dict, type(None)))
+ _function_pipeline(
+ gtfs=self, func_map=CLEAN_FEED_FUNCTION_MAP, operations=cleansers
+ )
+ return None
def _produce_stops_map(
self, what_geoms: str, is_filtered: bool, crs: Union[int, str]
@@ -1317,15 +1350,6 @@ def _extended_validation(
None
"""
- table_map = {
- "agency": self.feed.agency,
- "routes": self.feed.routes,
- "stop_times": self.feed.stop_times,
- "stops": self.feed.stops,
- "trips": self.feed.trips,
- "calendar": self.feed.calendar,
- }
-
# determine which errors/warnings have rows that can be located
validation_table = self.is_valid()
validation_table["valid_row"] = validation_table["rows"].apply(
@@ -1354,7 +1378,9 @@ def _extended_validation(
for col in self.GTFS_UNNEEDED_COLUMNS[table]
if col not in join_vars
]
- filtered_tbl = table_map[table].copy().drop(drop_cols, axis=1)
+ filtered_tbl = (
+ self.table_map[table].copy().drop(drop_cols, axis=1)
+ )
impacted_rows = self._create_extended_repeated_pair_table(
table=filtered_tbl,
join_vars=join_vars,
@@ -1372,7 +1398,7 @@ def _extended_validation(
== impacted_rows[f"{col}_duplicate"]
].shape[0]
else:
- impacted_rows = table_map[table].copy().iloc[rows]
+ impacted_rows = self.table_map[table].copy().iloc[rows]
# create the html to display the impacted rows (clean possibly)
table_html = f"""
@@ -1486,9 +1512,11 @@ def html_report(
date = datetime.datetime.strftime(datetime.datetime.now(), "%d-%m-%Y")
# feed evaluation
- self.clean_feed(validate=True, fast_travel=True)
+ # TODO: make this optional (and allow params)
+ self.is_valid()
+ self.clean_feed()
# re-validate to clean any newly raised errors/warnings
- validation_dataframe = self.is_valid(far_stops=True)
+ validation_dataframe = self.is_valid()
# create extended reports if requested
if extended_validation:
@@ -1503,7 +1531,7 @@ def html_report(
)
validation_dataframe["info"] = [
f""" Further Info"""
- if len(rows) > 1
+ if len(rows) > 0
else "Unavailable"
for href, rows in zip(info_href, validation_dataframe["rows"])
]
diff --git a/src/transport_performance/gtfs/validators.py b/src/transport_performance/gtfs/validators.py
index 6ad9f8e5..cc0c441e 100644
--- a/src/transport_performance/gtfs/validators.py
+++ b/src/transport_performance/gtfs/validators.py
@@ -1,12 +1,17 @@
"""A set of functions that validate the GTFS data."""
from typing import TYPE_CHECKING
+import os
import numpy as np
import pandas as pd
from haversine import Unit, haversine_vector
-from transport_performance.gtfs.gtfs_utils import _add_validation_row
-from transport_performance.utils.defence import _gtfs_defence
+from transport_performance.gtfs.gtfs_utils import (
+ _add_validation_row,
+ _get_validation_warnings,
+ _remove_validation_row,
+)
+from transport_performance.utils.defence import _gtfs_defence, _check_attribute
if TYPE_CHECKING:
from transport_performance.gtfs.validation import GtfsInstance
@@ -26,6 +31,28 @@
200: 120,
}
+# EXPECTED GTFS FILES
+# DESC: USED TO RAISE WARNINGS WHEN .TXT FILES INCLUDED IN THE ZIP PASSED TO
+# GTFSINSTANCE AREN'T A RECOGNISED GTFS TABLE#
+# NOTE: 'Levels' and 'Translation' tables are ignored by gtfs-kit
+
+ACCEPTED_GTFS_TABLES = [
+ "agency",
+ "attributions",
+ "calendar",
+ "calendar_dates",
+ "fare_attributes",
+ "fare_rules",
+ "feed_info",
+ "frequencies",
+ "routes",
+ "shapes",
+ "stops",
+ "stop_times",
+ "transfers",
+ "trips",
+]
+
def validate_travel_between_consecutive_stops(gtfs: "GtfsInstance"):
"""Validate the travel between consecutive stops in the GTFS data.
@@ -145,6 +172,8 @@ def _join_max_speed(r_type: int) -> int:
)
gtfs.full_stop_schedule = stop_sched
+ gtfs.table_map["full_stop_schedule"] = gtfs.full_stop_schedule
+
# find the stops that exceed the speed boundary
invalid_stops = stop_sched[stop_sched["speed"] > stop_sched["speed_bound"]]
@@ -153,7 +182,6 @@ def _join_max_speed(r_type: int) -> int:
return invalid_stops
# add the error to the validation table
- # TODO: After merge add full_stop_schedule to HTML output table keys
_add_validation_row(
gtfs=gtfs,
_type="warning",
@@ -252,9 +280,8 @@ def validate_travel_over_multiple_stops(gtfs: "GtfsInstance") -> None:
}
)
- # TODO: Add this table to the lookup once gtfs HTML is merged
gtfs.multiple_stops_invalid = far_stops_df
-
+ gtfs.table_map["multiple_stops_invalid"] = gtfs.multiple_stops_invalid
if len(gtfs.multiple_stops_invalid) > 0:
_add_validation_row(
gtfs=gtfs,
@@ -265,3 +292,98 @@ def validate_travel_over_multiple_stops(gtfs: "GtfsInstance") -> None:
)
return far_stops_df
+
+
+def validate_route_type_warnings(gtfs: "GtfsInstance") -> None:
+ """Valiidate that the route type warnings are reasonable and just.
+
+ Parameters
+ ----------
+ gtfs : GtfsInstance
+ The GtfsInstance to validate the warnings of.
+
+ Returns
+ -------
+ None
+
+ """
+ # defences
+ _gtfs_defence(gtfs, "gtfs")
+ _check_attribute(gtfs, "validity_df")
+ # identify and clean warnings
+ warnings = _get_validation_warnings(gtfs, "Invalid route_type.*")
+ if len(warnings) <= 0:
+ return None
+ route_rows = gtfs.feed.routes.loc[warnings[0][3]].copy()
+ route_rows = route_rows[
+ ~route_rows.route_type.astype("str").isin(
+ gtfs.ROUTE_LKP["route_type"].unique()
+ )
+ ]
+ _remove_validation_row(gtfs, "Invalid route_type.*")
+ if len(route_rows) > 0:
+ _add_validation_row(
+ gtfs,
+ _type="error",
+ message="Invalid route_type; maybe has extra space characters",
+ table="routes",
+ rows=list(route_rows.index),
+ )
+ return None
+
+
+def core_validation(gtfs: "GtfsInstance"):
+ """Carry out the main validators of gtfs-kit."""
+ _gtfs_defence(gtfs, "gtfs")
+ validation_df = gtfs.feed.validate()
+ gtfs.validity_df = pd.concat(
+ [validation_df, gtfs.validity_df], axis=0
+ ).reset_index(drop=True)
+
+
+def validate_gtfs_files(gtfs: "GtfsInstance") -> None:
+ """Validate to raise warnings if tables in GTFS zip aren't being read.
+
+ Parameters
+ ----------
+ gtfs : GtfsInstance
+ The gtfs instance to run the validation on.
+
+ Returns
+ -------
+ None
+
+ """
+ _gtfs_defence(gtfs, "gtfs")
+ files = gtfs.get_gtfs_files()
+ invalid_extensions = []
+ non_implemented_files = []
+ for fname in files:
+ root, ext = os.path.splitext(fname)
+ # check for instances of invalid file extensions (only .txt accepted)
+ if ext.lower() != ".txt":
+ invalid_extensions.append(fname)
+ # check that each file is within the GTFS specification.
+ # NOTE: The 'levels' and 'translation' tables are in the GTFS spec
+ # but aren't used by gtfs-kit.
+ # GTFS REFERENCE: https://developers.google.com/transit/gtfs/reference
+ if root not in ACCEPTED_GTFS_TABLES:
+ non_implemented_files.append(fname)
+ # raise warnings
+ if len(invalid_extensions) > 0:
+ msg = (
+ "GTFS zip includes files not of type '.txt'. These files "
+ f"include {invalid_extensions}"
+ )
+ _add_validation_row(
+ gtfs, _type="warning", message=msg, table="", rows=[]
+ )
+ if len(non_implemented_files) > 0:
+ msg = (
+ "GTFS zip includes files that aren't recognised by the GTFS "
+ f"spec. These include {non_implemented_files}"
+ )
+ _add_validation_row(
+ gtfs, _type="warning", message=msg, table="", rows=[]
+ )
+ return None
diff --git a/src/transport_performance/utils/constants.py b/src/transport_performance/utils/constants.py
index 97b98598..a1521f2b 100644
--- a/src/transport_performance/utils/constants.py
+++ b/src/transport_performance/utils/constants.py
@@ -1,5 +1,6 @@
"""Constants to be used throughout the transport-performance package."""
+from importlib import resources as pkg_resources #
+
import transport_performance
-from importlib import resources as pkg_resources
PKG_PATH = pkg_resources.files(transport_performance)
diff --git a/tests/data/gtfs/repeated_pair_gtfs_fixture.zip b/tests/data/gtfs/repeated_pair_gtfs_fixture.zip
new file mode 100644
index 00000000..8c723d01
Binary files /dev/null and b/tests/data/gtfs/repeated_pair_gtfs_fixture.zip differ
diff --git a/tests/gtfs/test_cleaners.py b/tests/gtfs/test_cleaners.py
index b94b1cf9..262aea4f 100644
--- a/tests/gtfs/test_cleaners.py
+++ b/tests/gtfs/test_cleaners.py
@@ -10,6 +10,13 @@
drop_trips,
clean_consecutive_stop_fast_travel_warnings,
clean_multiple_stop_fast_travel_warnings,
+ core_cleaners,
+ clean_unrecognised_column_warnings,
+ clean_duplicate_stop_times,
+)
+from transport_performance.gtfs.gtfs_utils import (
+ _get_validation_warnings,
+ _remove_validation_row,
)
@@ -64,6 +71,10 @@ def test_drop_trips_defences(self, gtfs_fixture):
]
assert len(found_df) == 0, "Failed to drop trip in format 'string'"
+ # test dropping non existent trip
+ with pytest.warns(UserWarning, match="trip_id .* not found in GTFS"):
+ drop_trips(gtfs_fixture, ["NOT_AT_TRIP_ID"])
+
def test_drop_trips_on_pass(self, gtfs_fixture):
"""General tests for drop_trips()."""
fixture = gtfs_fixture
@@ -117,7 +128,7 @@ class Test_CleanConsecutiveStopFastTravelWarnings(object):
def test_clean_consecutive_stop_fast_travel_warnings_defence(
self, gtfs_fixture
):
- """Defensive tests forclean_consecutive_stop_fast_travel_warnings()."""
+ """Defensive tests for clean_consecutive_stop_fast_travel_warnings."""
with pytest.raises(
AttributeError,
match=re.escape(
@@ -127,15 +138,13 @@ def test_clean_consecutive_stop_fast_travel_warnings_defence(
"gtfs."
),
):
- clean_consecutive_stop_fast_travel_warnings(
- gtfs=gtfs_fixture, validate=False
- )
+ clean_consecutive_stop_fast_travel_warnings(gtfs=gtfs_fixture)
def test_clean_consecutive_stop_fast_travel_warnings_on_pass(
self, gtfs_fixture
):
"""General tests for clean_consecutive_stop_fast_travel_warnings()."""
- gtfs_fixture.is_valid(far_stops=True)
+ gtfs_fixture.is_valid()
original_validation = {
"type": {
0: "warning",
@@ -180,18 +189,15 @@ def test_clean_consecutive_stop_fast_travel_warnings_on_pass(
assert (
original_validation == gtfs_fixture.validity_df.to_dict()
), "Original validity df is not as expected"
- clean_consecutive_stop_fast_travel_warnings(
- gtfs=gtfs_fixture, validate=False
- )
+ clean_consecutive_stop_fast_travel_warnings(gtfs=gtfs_fixture)
gtfs_fixture.is_valid()
assert expected_validation == gtfs_fixture.validity_df.to_dict(), (
"Validation table is not as expected after cleaning consecutive "
"stop fast travel warnings"
)
# test validation; test gtfs with no warnings
- clean_consecutive_stop_fast_travel_warnings(
- gtfs=gtfs_fixture, validate=True
- )
+ gtfs_fixture.is_valid()
+ clean_consecutive_stop_fast_travel_warnings(gtfs=gtfs_fixture)
class Test_CleanMultipleStopFastTravelWarnings(object):
@@ -210,15 +216,13 @@ def test_clean_multiple_stop_fast_travel_warnings_defence(
"gtfs."
),
):
- clean_multiple_stop_fast_travel_warnings(
- gtfs=gtfs_fixture, validate=False
- )
+ clean_multiple_stop_fast_travel_warnings(gtfs=gtfs_fixture)
def test_clean_multiple_stop_fast_travel_warnings_on_pass(
self, gtfs_fixture
):
"""General tests for clean_multiple_stop_fast_travel_warnings()."""
- gtfs_fixture.is_valid(far_stops=True)
+ gtfs_fixture.is_valid()
original_validation = {
"type": {
0: "warning",
@@ -263,15 +267,211 @@ def test_clean_multiple_stop_fast_travel_warnings_on_pass(
assert (
original_validation == gtfs_fixture.validity_df.to_dict()
), "Original validity df is not as expected"
- clean_multiple_stop_fast_travel_warnings(
- gtfs=gtfs_fixture, validate=False
- )
+ clean_multiple_stop_fast_travel_warnings(gtfs=gtfs_fixture)
gtfs_fixture.is_valid()
assert expected_validation == gtfs_fixture.validity_df.to_dict(), (
"Validation table is not as expected after cleaning consecutive "
"stop fast travel warnings"
)
# test validation; test gtfs with no warnings
- clean_multiple_stop_fast_travel_warnings(
- gtfs=gtfs_fixture, validate=True
- )
+ gtfs_fixture.is_valid()
+ clean_multiple_stop_fast_travel_warnings(gtfs=gtfs_fixture)
+
+
+class TestCoreCleaner(object):
+ """Tests for core_cleaners().
+
+ Notes
+ -----
+ There are no passing tests for this function as it relies on function from
+ gtfs-kit which have already been tested.
+
+ """
+
+ @pytest.mark.parametrize(
+ (
+ "clean_ids, clean_times, clean_route_short_names, drop_zombies, "
+ "raises, match"
+ ),
+ [
+ (
+ 1,
+ True,
+ True,
+ True,
+ TypeError,
+ r".*expected .*bool.* Got .*int.*",
+ ),
+ (
+ True,
+ dict(),
+ True,
+ True,
+ TypeError,
+ r".*expected .*bool.* Got .*dict.*",
+ ),
+ (
+ True,
+ True,
+ "test string",
+ True,
+ TypeError,
+ r".*expected .*bool.* Got .*str.*",
+ ),
+ (
+ True,
+ True,
+ True,
+ 2.12,
+ TypeError,
+ r".*expected .*bool.* Got .*float.*",
+ ),
+ ],
+ )
+ def test_core_claners_defence(
+ self,
+ gtfs_fixture,
+ clean_ids,
+ clean_times,
+ clean_route_short_names,
+ drop_zombies,
+ raises,
+ match,
+ ):
+ """Defensive tests for core_cleaners."""
+ with pytest.raises(raises, match=match):
+ gtfs_fixture.is_valid()
+ core_cleaners(
+ gtfs_fixture,
+ clean_ids,
+ clean_times,
+ clean_route_short_names,
+ drop_zombies,
+ )
+
+ def test_core_cleaners_drop_zombies_warns(self, gtfs_fixture):
+ """Test that warnings are emitted when shape_id isn't present in...
+
+ trips.
+ """
+ gtfs_fixture.feed.trips.drop("shape_id", axis=1, inplace=True)
+ with pytest.warns(
+ UserWarning,
+ match=r".*drop_zombies cleaner was unable to operate.*",
+ ):
+ gtfs_fixture.is_valid(validators={"core_validation": None})
+ gtfs_fixture.clean_feed()
+
+
+class TestCleanUnrecognisedColumnWarnings(object):
+ """Tests for clean_unrecognised_column_warnings."""
+
+ def test_clean_unrecognised_column_warnings(self, gtfs_fixture):
+ """Tests for clean_unrecognised_column_warnings."""
+ # initial assertions to ensure test data is correct
+ gtfs_fixture.is_valid(validators={"core_validation": None})
+ assert len(gtfs_fixture.validity_df) == 3, "validity_df wrong length"
+ assert np.array_equal(
+ gtfs_fixture.feed.trips.columns,
+ [
+ "route_id",
+ "service_id",
+ "trip_id",
+ "trip_headsign",
+ "block_id",
+ "shape_id",
+ "wheelchair_accessible",
+ "vehicle_journey_code",
+ ],
+ ), "Initial trips columns not as expected"
+ # clean warnings
+ clean_unrecognised_column_warnings(gtfs_fixture)
+ assert len(gtfs_fixture.validity_df) == 0, "Warnings no cleaned"
+ assert np.array_equal(
+ gtfs_fixture.feed.trips.columns,
+ [
+ "route_id",
+ "service_id",
+ "trip_id",
+ "trip_headsign",
+ "block_id",
+ "shape_id",
+ "wheelchair_accessible",
+ ],
+ ), "Failed to drop unrecognised columns"
+
+
+class TestCleanDuplicateStopTimes(object):
+ """Tests for clean_duplicate_stop_times."""
+
+ def test_clean_duplicate_stop_times_defence(self):
+ """Defensive functionality tests for clean_duplicate_stop_times."""
+ with pytest.raises(
+ TypeError, match=".*gtfs.* GtfsInstance object.*bool"
+ ):
+ clean_duplicate_stop_times(True)
+
+ def test_clean_duplicate_stop_times_on_pass(self):
+ """General functionality tests for clean_duplicate_stop_times."""
+ gtfs = GtfsInstance("tests/data/gtfs/repeated_pair_gtfs_fixture.zip")
+ gtfs.is_valid(validators={"core_validation": None})
+ _warnings = _get_validation_warnings(
+ gtfs, r"Repeated pair \(trip_id, departure_time\)"
+ )[0]
+ # check that the correct number of rows are impacted
+ assert (
+ len(_warnings[3]) == 24
+ ), "Unexpected number of rows originally impacted"
+ # test case for one of the instances that are removed (1)
+ assert 309 in _warnings[3], "Test case (1) missing from data"
+ assert np.array_equal(
+ list(gtfs.feed.stop_times.loc[309].values),
+ [
+ "VJf98a18e2c314219b4a98f75c5fa44e3f9618e72f",
+ "08:57:56",
+ "08:57:56",
+ "5540AWB33241",
+ 3,
+ np.nan,
+ 0,
+ 0,
+ 1.38726,
+ 0,
+ ],
+ ), "data for test case (1) invalid"
+ # test case for one of the instances that remain (2)
+ assert 18 in _warnings[3], "Test case (2) missing from data"
+ assert np.array_equal(
+ list(gtfs.feed.stop_times.loc[18].values),
+ [
+ "VJ1b6f3baad353b0ba8d42e0290bfd0b39cb1130bc",
+ "11:10:23",
+ "11:10:23",
+ "5540AWA17133",
+ 6,
+ np.nan,
+ 0,
+ 0,
+ np.nan,
+ 0,
+ ],
+ ), "data for test case (2) invalid"
+ # clean gtfs
+ clean_duplicate_stop_times(gtfs)
+ _warnings = _get_validation_warnings(
+ gtfs, r"Repeated pair \(trip_id, departure_time\)"
+ )[0]
+ assert (
+ len(_warnings[3]) == 6
+ ), "Unexpected number of rows impacted after cleaning"
+ # assert test case (1) is removed from the data
+ with pytest.raises(KeyError, match=".*309.*"):
+ gtfs.feed.stop_times.loc[309]
+ # assert test case (2) remains part of the data
+ assert (
+ len(gtfs.feed.stop_times.loc[18]) > 0
+ ), "Test case removed from data"
+
+ # test return of none when no warnings are found (for cov)
+ _remove_validation_row(gtfs, message=r".* \(trip_id, departure_time\)")
+ assert clean_duplicate_stop_times(gtfs) is None
diff --git a/tests/gtfs/test_gtfs_utils.py b/tests/gtfs/test_gtfs_utils.py
index e4aaffa1..dc1a1385 100644
--- a/tests/gtfs/test_gtfs_utils.py
+++ b/tests/gtfs/test_gtfs_utils.py
@@ -8,8 +8,12 @@
import geopandas as gpd
from shapely.geometry import box
from plotly.graph_objects import Figure as PlotlyFigure
+import numpy as np
-from transport_performance.gtfs.validation import GtfsInstance
+from transport_performance.gtfs.validation import (
+ GtfsInstance,
+ VALIDATE_FEED_FUNC_MAP,
+)
from transport_performance.gtfs.gtfs_utils import (
bbox_filter_gtfs,
filter_gtfs,
@@ -17,6 +21,9 @@
filter_gtfs_around_trip,
convert_pandas_to_plotly,
_validate_datestring,
+ _remove_validation_row,
+ _get_validation_warnings,
+ _function_pipeline,
)
# location of GTFS test fixture
@@ -276,7 +283,7 @@ class Test_AddValidationRow(object):
"""Tests for _add_validation_row()."""
def test__add_validation_row_defence(self):
- """Defensive tests for _add_test_validation_row()."""
+ """Defensive tests for _add_validation_row()."""
gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
with pytest.raises(
AttributeError,
@@ -291,9 +298,9 @@ def test__add_validation_row_defence(self):
)
def test__add_validation_row_on_pass(self):
- """General tests for _add_test_validation_row()."""
+ """General tests for _add_validation_row()."""
gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
- gtfs.is_valid(far_stops=False)
+ gtfs.is_valid(validators={"core_validation": None})
_add_validation_row(
gtfs=gtfs, _type="warning", message="test", table="stops"
@@ -335,12 +342,12 @@ def test_filter_gtfs_around_trip_on_pass(self, tmpdir):
trip_id="VJbedb4cfd0673348e017d42435abbdff3ddacbf82",
out_pth=out_pth,
)
- assert os.path.exists(out_pth), "Failed to filtere GTFS around trip."
+ assert os.path.exists(out_pth), "Failed to filter GTFS around trip."
# check the new gtfs can be read
feed = GtfsInstance(gtfs_pth=out_pth)
assert isinstance(
feed, GtfsInstance
- ), f"Expected class `Gtfs_Instance but found: {type(feed)}`"
+ ), f"Expected class `GtfsInstance` but found: {type(feed)}`"
@pytest.fixture(scope="function")
@@ -389,6 +396,171 @@ def test_convert_pandas_to_plotly_on_pass(self, test_df):
)
+class TestGetValidationWarnings(object):
+ """Tests for _get_validation_warnings."""
+
+ def test__get_validation_warnings_defence(self):
+ """Test thhe defences of _get_validation_warnings."""
+ with pytest.raises(
+ TypeError, match=".* expected a GtfsInstance object"
+ ):
+ _get_validation_warnings(True, "test_msg")
+ gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
+ with pytest.raises(
+ AttributeError, match="The gtfs has not been validated.*"
+ ):
+ _get_validation_warnings(gtfs, "test")
+ gtfs.is_valid()
+ with pytest.raises(
+ ValueError, match=r"'return_type' expected one of \[.*\]\. Got .*"
+ ):
+ _get_validation_warnings(gtfs, "tester", "tester")
+
+ def test__get_validation_warnings(self):
+ """Test _get_validation_warnings on pass."""
+ gtfs = GtfsInstance(GTFS_FIX_PTH)
+ gtfs.is_valid()
+ # test return types
+ df_exp = _get_validation_warnings(
+ gtfs, "test", return_type="dataframe"
+ )
+ assert isinstance(
+ df_exp, pd.DataFrame
+ ), f"Expected df, got {type(df_exp)}"
+ ndarray_exp = _get_validation_warnings(gtfs, "test")
+ assert isinstance(
+ ndarray_exp, np.ndarray
+ ), f"Expected np.ndarray, got {type(ndarray_exp)}"
+ # test with valld regex (assertions on DF data without DF)
+ regex_matches = _get_validation_warnings(
+ gtfs, "Unrecognized column *.", return_type="dataframe"
+ )
+ assert len(regex_matches) == 5, (
+ "Getting validaiton warnings returned"
+ "unexpected number of warnings"
+ )
+ assert list(regex_matches["type"].unique()) == [
+ "warning"
+ ], "Dataframe type column not asd expected"
+ assert list(regex_matches.table) == [
+ "agency",
+ "stop_times",
+ "stops",
+ "trips",
+ "trips",
+ ], "Dataframe table column not as expected"
+ # test with matching message (no regex)
+ exact_match = _get_validation_warnings(
+ gtfs, "Unrecognized column agency_noc", return_type="Dataframe"
+ )
+ assert list(exact_match.values[0]) == [
+ "warning",
+ "Unrecognized column agency_noc",
+ "agency",
+ [],
+ ], "Dataframe values not as expected"
+ assert (
+ len(exact_match) == 1
+ ), f"Expected one match, found {len(exact_match)}"
+ # test with no matches (regex)
+ regex_no_match = _get_validation_warnings(
+ gtfs, ".*This is a test.*", return_type="Dataframe"
+ )
+ assert len(regex_no_match) == 0, "No matches expected. Matches found"
+ # test with no match (no regex)
+ no_match = _get_validation_warnings(
+ gtfs, "This is a test!!!", return_type="Dataframe"
+ )
+ assert len(no_match) == 0, "No matches expected. Matched found"
+
+
+class TestRemoveValidationRow(object):
+ """Tests for _remove_validation_row."""
+
+ def test__remove_validation_row_defence(self):
+ """Tests the defences of _remove_validation_row."""
+ gtfs = GtfsInstance(GTFS_FIX_PTH)
+ gtfs.is_valid()
+ # no message or index provided
+ with pytest.raises(
+ ValueError, match=r"Both .* and .* are None, .* to be cleaned"
+ ):
+ _remove_validation_row(gtfs)
+ # both provided
+ with pytest.warns(
+ UserWarning,
+ match=r"Both .* and .* are not None.* cleaned on 'message'",
+ ):
+ _remove_validation_row(gtfs, message="test", index=[0, 1])
+
+ def test__remove_validation_row_on_pass(self):
+ """Tests for _remove_validation_row on pass."""
+ gtfs = GtfsInstance(GTFS_FIX_PTH)
+ gtfs.is_valid(validators={"core_validation": None})
+ # with message
+ msg = "Unrecognized column agency_noc"
+ _remove_validation_row(gtfs, message=msg)
+ assert len(gtfs.validity_df) == 6, "DF is incorrect size"
+ found_cols = _get_validation_warnings(
+ gtfs, message=msg, return_type="dataframe"
+ )
+ assert (
+ len(found_cols) == 0
+ ), "Invalid errors/warnings still in validity_df"
+ # with index (removing the same error)
+ gtfs = GtfsInstance(GTFS_FIX_PTH)
+ gtfs.is_valid(validators={"core_validation": None})
+ ind = [1]
+ _remove_validation_row(gtfs, index=ind)
+ assert len(gtfs.validity_df) == 6, "DF is incorrect size"
+ found_cols = _get_validation_warnings(
+ gtfs, message=msg, return_type="dataframe"
+ )
+ print(gtfs.validity_df)
+ assert (
+ len(found_cols) == 0
+ ), "Invalid errors/warnings still in validity_df"
+
+
+class TestFunctionPipeline(object):
+ """Tests for _function_pipeline.
+
+ Notes
+ -----
+ Not testing on pass here as better cases can be found in the tests for
+ GtfsInstance's is_valid() and clean_feed() methods.
+
+ """
+
+ @pytest.mark.parametrize(
+ "operations, raises, match",
+ [
+ # invalid type for 'validators'
+ (True, TypeError, ".*expected .*dict.*. Got .*bool.*"),
+ # invalid validator
+ (
+ {"not_a_valid_validator": None},
+ KeyError,
+ (
+ r"'not_a_valid_validator' function passed to 'operations'"
+ r" is not a known operation.*"
+ ),
+ ),
+ # invalid type for kwargs for validator
+ (
+ {"core_validation": pd.DataFrame()},
+ TypeError,
+ ".* expected .*dict.*NoneType.*",
+ ),
+ ],
+ )
+ def test_function_pipeline_defence(self, operations, raises, match):
+ """Defensive test for _function_pipeline."""
+ gtfs = GtfsInstance(GTFS_FIX_PTH)
+ with pytest.raises(raises, match=match):
+ _function_pipeline(gtfs, VALIDATE_FEED_FUNC_MAP, operations)
+
+
class Test_ValidateDatestring(object):
"""Tests for _validate_datestring."""
diff --git a/tests/gtfs/test_multi_validation.py b/tests/gtfs/test_multi_validation.py
index cabaa175..f2dd79bf 100644
--- a/tests/gtfs/test_multi_validation.py
+++ b/tests/gtfs/test_multi_validation.py
@@ -271,15 +271,15 @@ def test_clean_feeds_on_pass(self, multi_gtfs_fixture):
"""General tests for .clean_feeds()."""
# validate and do quick check on validity_df
valid_df = multi_gtfs_fixture.is_valid()
- assert len(valid_df) == 12, "validity_df not as expected"
+ assert len(valid_df) == 11, "validity_df not as expected"
# clean feed
multi_gtfs_fixture.clean_feeds()
# ensure cleaning has occured
new_valid = multi_gtfs_fixture.is_valid()
- assert len(new_valid) == 9
+ assert len(new_valid) == 1
assert np.array_equal(
- list(new_valid.iloc[3][["type", "table"]].values),
- ["error", "routes"],
+ list(new_valid.iloc[0][["type", "table"]].values),
+ ["warning", "routes"],
), "Validity df after cleaning not as expected"
def test_is_valid_defences(self, multi_gtfs_fixture):
@@ -290,7 +290,7 @@ def test_is_valid_defences(self, multi_gtfs_fixture):
def test_is_valid_on_pass(self, multi_gtfs_fixture):
"""General tests for is_valid()."""
valid_df = multi_gtfs_fixture.is_valid()
- assert len(valid_df) == 12, "Validation df not as expected"
+ assert len(valid_df) == 11, "Validation df not as expected"
assert np.array_equal(
list(valid_df.iloc[3][["type", "message"]].values),
(["warning", "Fast Travel Between Consecutive Stops"]),
diff --git a/tests/gtfs/test_validation.py b/tests/gtfs/test_validation.py
index 66d832a5..9997c141 100644
--- a/tests/gtfs/test_validation.py
+++ b/tests/gtfs/test_validation.py
@@ -27,12 +27,19 @@
@pytest.fixture(scope="function") # some funcs expect cleaned feed others dont
-def gtfs_fixture():
+def newp_gtfs_fixture():
"""Fixture for test funcs expecting a valid feed object."""
gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH)
return gtfs
+@pytest.fixture(scope="function")
+def chest_gtfs_fixture():
+ """Fixture for test funcs expecting a valid feed object."""
+ gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip"))
+ return gtfs
+
+
class TestGtfsInstance(object):
"""Tests related to the GtfsInstance class."""
@@ -119,7 +126,7 @@ def test_init_on_pass(self):
without_pth.to_dict() == with_pth.to_dict()
), "Failed to get route type lookup correctly"
- def test_get_gtfs_files(self, gtfs_fixture):
+ def test_get_gtfs_files(self, newp_gtfs_fixture):
"""Assert files that make up the GTFS."""
expected_files = [
# smaller filter has resulted in a GTFS with no calendar dates /
@@ -135,40 +142,77 @@ def test_get_gtfs_files(self, gtfs_fixture):
"calendar.txt",
"routes.txt",
]
- foundf = gtfs_fixture.get_gtfs_files()
+ foundf = newp_gtfs_fixture.get_gtfs_files()
assert (
foundf == expected_files
), f"GTFS files not as expected. Expected {expected_files},"
"found: {foundf}"
- def test_is_valid(self, gtfs_fixture):
- """Assertions about validity_df table."""
- gtfs_fixture.is_valid()
- assert isinstance(
- gtfs_fixture.validity_df, pd.core.frame.DataFrame
- ), f"Expected DataFrame. Found: {type(gtfs_fixture.validity_df)}"
- shp = gtfs_fixture.validity_df.shape
- assert shp == (
- 7,
- 4,
- ), f"Attribute `validity_df` expected a shape of (7,4). Found: {shp}"
- exp_cols = pd.Index(["type", "message", "table", "rows"])
- found_cols = gtfs_fixture.validity_df.columns
- assert (
- found_cols == exp_cols
- ).all(), f"Expected columns {exp_cols}. Found: {found_cols}"
+ @pytest.mark.parametrize(
+ "which, validators, shape",
+ [
+ # only core validation
+ ("n", {"core_validation": None}, (7, 4)),
+ # route type warning validation
+ (
+ "n",
+ {
+ "core_validation": None,
+ "validate_route_type_warnings": None,
+ },
+ (6, 4),
+ ),
+ # fast travel validators
+ (
+ "c",
+ {
+ "core_validation": None,
+ "validate_travel_between_consecutive_stops": None,
+ "validate_travel_over_multiple_stops": None,
+ },
+ (5, 4),
+ ),
+ # all validators
+ ("n", None, (6, 4)),
+ ],
+ )
+ def test_is_valid_on_pass(
+ self, newp_gtfs_fixture, chest_gtfs_fixture, which, validators, shape
+ ):
+ """Tests/assertions for is_valid() while passing.
+
+ Notes
+ -----
+ These tests are mostly to assure that the validators are being
+ identified and run, and that the validation df returned is as expected.
+
+ I will be refraining from over testing here as it would essentially be
+ testing the validators again, which occurs in test_validators.py.
+
+ Tests for validators with kwargs would be useful, once they are added.
+
+ """
+ # Bypassing any defensive checks for wich.
+ # Correct inputs are assured as tese are internal tests.
+ if which.lower().strip() == "n":
+ fixture = newp_gtfs_fixture
+ else:
+ fixture = chest_gtfs_fixture
+ df = fixture.is_valid(validators=validators)
+ assert isinstance(df, pd.DataFrame), "is_valid() failed to return df"
+ assert shape == df.shape, "validity_df not as expected"
@pytest.mark.sanitycheck
- def test_trips_unmatched_ids(self, gtfs_fixture):
+ def test_trips_unmatched_ids(self, newp_gtfs_fixture):
"""Tests to evaluate gtfs-klt's reaction to invalid IDs in trips.
Parameters
----------
- gtfs_fixture : GtfsInstance
+ newp_gtfs_fixture : GtfsInstance
a GtfsInstance test fixure
"""
- feed = gtfs_fixture.feed
+ feed = newp_gtfs_fixture.feed
# add row to tripas table with invald trip_id, route_id, service_id
feed.trips = pd.concat(
@@ -209,16 +253,16 @@ def test_trips_unmatched_ids(self, gtfs_fixture):
assert len(new_valid) == 10, "Validation table not expected size"
@pytest.mark.sanitycheck
- def test_routes_unmatched_ids(self, gtfs_fixture):
+ def test_routes_unmatched_ids(self, newp_gtfs_fixture):
"""Tests to evaluate gtfs-klt's reaction to invalid IDs in routes.
Parameters
----------
- gtfs_fixture : GtfsInstance
+ newp_gtfs_fixture : GtfsInstance
a GtfsInstance test fixure
"""
- feed = gtfs_fixture.feed
+ feed = newp_gtfs_fixture.feed
# add row to tripas table with invald trip_id, route_id, service_id
feed.routes = pd.concat(
@@ -248,12 +292,12 @@ def test_routes_unmatched_ids(self, gtfs_fixture):
assert len(new_valid) == 9, "Validation table not expected size"
@pytest.mark.sanitycheck
- def test_unmatched_service_id_behaviour(self, gtfs_fixture):
+ def test_unmatched_service_id_behaviour(self, newp_gtfs_fixture):
"""Tests to evaluate gtfs-klt's reaction to invalid IDs in calendar.
Parameters
----------
- gtfs_fixture : GtfsInstance
+ newp_gtfs_fixture : GtfsInstance
a GtfsInstance test fixure
Notes
@@ -264,7 +308,7 @@ def test_unmatched_service_id_behaviour(self, gtfs_fixture):
calendar table contains duplicate service_ids.
"""
- feed = gtfs_fixture.feed
+ feed = newp_gtfs_fixture.feed
original_error_count = len(feed.validate())
# introduce a dummy row with a non matching service_id
@@ -300,25 +344,25 @@ def test_unmatched_service_id_behaviour(self, gtfs_fixture):
len(new_valid[new_valid.message == "Undefined service_id"]) == 1
), "gtfs-kit failed to identify missing service_id"
- def test_print_alerts_defence(self, gtfs_fixture):
+ def test_print_alerts_defence(self, newp_gtfs_fixture):
"""Check defensive behaviour of print_alerts()."""
with pytest.raises(
AttributeError,
match=r"is None, did you forget to use `self.is_valid()`?",
):
- gtfs_fixture.print_alerts()
+ newp_gtfs_fixture.print_alerts()
- gtfs_fixture.is_valid()
+ newp_gtfs_fixture.is_valid()
with pytest.warns(
UserWarning, match="No alerts of type doesnt_exist were found."
):
- gtfs_fixture.print_alerts(alert_type="doesnt_exist")
+ newp_gtfs_fixture.print_alerts(alert_type="doesnt_exist")
@patch("builtins.print") # testing print statements
- def test_print_alerts_single_case(self, mocked_print, gtfs_fixture):
+ def test_print_alerts_single_case(self, mocked_print, newp_gtfs_fixture):
"""Check alerts print as expected without truncation."""
- gtfs_fixture.is_valid()
- gtfs_fixture.print_alerts()
+ newp_gtfs_fixture.is_valid(validators={"core_validation": None})
+ newp_gtfs_fixture.print_alerts()
# fixture contains single error
fun_out = mocked_print.mock_calls
assert fun_out == [
@@ -326,11 +370,11 @@ def test_print_alerts_single_case(self, mocked_print, gtfs_fixture):
], f"Expected a print about invalid route type. Found {fun_out}"
@patch("builtins.print")
- def test_print_alerts_multi_case(self, mocked_print, gtfs_fixture):
+ def test_print_alerts_multi_case(self, mocked_print, newp_gtfs_fixture):
"""Check multiple alerts are printed as expected."""
- gtfs_fixture.is_valid()
+ newp_gtfs_fixture.is_valid()
# fixture contains several warnings
- gtfs_fixture.print_alerts(alert_type="warning")
+ newp_gtfs_fixture.print_alerts(alert_type="warning")
fun_out = mocked_print.mock_calls
assert fun_out == [
call("Unrecognized column agency_noc"),
@@ -341,7 +385,7 @@ def test_print_alerts_multi_case(self, mocked_print, gtfs_fixture):
call("Unrecognized column vehicle_journey_code"),
], f"Expected print statements about GTFS warnings. Found: {fun_out}"
- def test_viz_stops_defence(self, tmpdir, gtfs_fixture):
+ def test_viz_stops_defence(self, tmpdir, newp_gtfs_fixture):
"""Check defensive behaviours of viz_stops()."""
tmp = os.path.join(tmpdir, "somefile.html")
with pytest.raises(
@@ -351,12 +395,12 @@ def test_viz_stops_defence(self, tmpdir, gtfs_fixture):
"Got "
),
):
- gtfs_fixture.viz_stops(out_pth=True)
+ newp_gtfs_fixture.viz_stops(out_pth=True)
with pytest.raises(
TypeError,
match="`geoms` expected . Got ",
):
- gtfs_fixture.viz_stops(out_pth=tmp, geoms=38)
+ newp_gtfs_fixture.viz_stops(out_pth=tmp, geoms=38)
with pytest.raises(
ValueError,
match=re.escape(
@@ -364,7 +408,7 @@ def test_viz_stops_defence(self, tmpdir, gtfs_fixture):
"['point', 'hull']. Got foobar: "
),
):
- gtfs_fixture.viz_stops(out_pth=tmp, geoms="foobar")
+ newp_gtfs_fixture.viz_stops(out_pth=tmp, geoms="foobar")
with pytest.raises(
TypeError,
match=re.escape(
@@ -372,28 +416,28 @@ def test_viz_stops_defence(self, tmpdir, gtfs_fixture):
""
),
):
- gtfs_fixture.viz_stops(out_pth=tmp, geom_crs=1.1)
+ newp_gtfs_fixture.viz_stops(out_pth=tmp, geom_crs=1.1)
# check missing stop_id results in an informative error message
- gtfs_fixture.feed.stops.drop("stop_id", axis=1, inplace=True)
+ newp_gtfs_fixture.feed.stops.drop("stop_id", axis=1, inplace=True)
with pytest.raises(
KeyError,
match="The stops table has no 'stop_code' column. While "
"this is an optional field in a GTFS file, it "
"raises an error through the gtfs-kit package.",
):
- gtfs_fixture.viz_stops(out_pth=tmp, filtered_only=False)
+ newp_gtfs_fixture.viz_stops(out_pth=tmp, filtered_only=False)
@patch("builtins.print")
- def test_viz_stops_point(self, mock_print, tmpdir, gtfs_fixture):
+ def test_viz_stops_point(self, mock_print, tmpdir, newp_gtfs_fixture):
"""Check behaviour of viz_stops when plotting point geom."""
tmp = os.path.join(tmpdir, "points.html")
- gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp))
+ newp_gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp))
assert os.path.exists(
tmp
), f"{tmp} was expected to exist but it was not found."
# check behaviour when parent directory doesn't exist
no_parent_pth = os.path.join(tmpdir, "notfound", "points1.html")
- gtfs_fixture.viz_stops(
+ newp_gtfs_fixture.viz_stops(
out_pth=pathlib.Path(no_parent_pth), create_out_parent=True
)
assert os.path.exists(
@@ -408,7 +452,7 @@ def test_viz_stops_point(self, mock_print, tmpdir, gtfs_fixture):
"to 'out_pth'. Path defaulted to .html"
),
):
- gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp1))
+ newp_gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp1))
# need to use regex for the first print statement, as tmpdir will
# change.
start_pat = re.compile(r"Creating parent directory:.*")
@@ -421,20 +465,22 @@ def test_viz_stops_point(self, mock_print, tmpdir, gtfs_fixture):
write_pth
), f"Map should have been written to {write_pth} but was not found."
- def test_viz_stops_hull(self, tmpdir, gtfs_fixture):
+ def test_viz_stops_hull(self, tmpdir, newp_gtfs_fixture):
"""Check viz_stops behaviour when plotting hull geom."""
tmp = os.path.join(tmpdir, "hull.html")
- gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp), geoms="hull")
+ newp_gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp), geoms="hull")
assert os.path.exists(tmp), f"Map file not found at {tmp}."
# assert file created when not filtering the hull
tmp1 = os.path.join(tmpdir, "filtered_hull.html")
- gtfs_fixture.viz_stops(out_pth=tmp1, geoms="hull", filtered_only=False)
+ newp_gtfs_fixture.viz_stops(
+ out_pth=tmp1, geoms="hull", filtered_only=False
+ )
assert os.path.exists(tmp1), f"Map file not found at {tmp1}."
- def test__create_map_title_text_defence(self, gtfs_fixture):
+ def test__create_map_title_text_defence(self, newp_gtfs_fixture):
"""Test the defences for _create_map_title_text()."""
# CRS without m or km units
- gtfs_hull = gtfs_fixture.feed.compute_convex_hull()
+ gtfs_hull = newp_gtfs_fixture.feed.compute_convex_hull()
gdf = GeoDataFrame({"geometry": gtfs_hull}, index=[0], crs="epsg:4326")
with pytest.raises(ValueError), pytest.warns(UserWarning):
_create_map_title_text(gdf=gdf, units="m", geom_crs=4326)
@@ -507,7 +553,7 @@ def test__convert_multi_index_to_single(self):
expected_cols.remove(col)
assert len(expected_cols) == 0, "Not all expected cols in output cols"
- def test__order_dataframe_by_day_defence(self, gtfs_fixture):
+ def test__order_dataframe_by_day_defence(self, newp_gtfs_fixture):
"""Test __order_dataframe_by_day defences."""
with pytest.raises(
TypeError,
@@ -516,7 +562,7 @@ def test__order_dataframe_by_day_defence(self, gtfs_fixture):
"Got "
),
):
- (gtfs_fixture._order_dataframe_by_day(df="test"))
+ (newp_gtfs_fixture._order_dataframe_by_day(df="test"))
with pytest.raises(
TypeError,
match=re.escape(
@@ -525,12 +571,12 @@ def test__order_dataframe_by_day_defence(self, gtfs_fixture):
),
):
(
- gtfs_fixture._order_dataframe_by_day(
+ newp_gtfs_fixture._order_dataframe_by_day(
df=pd.DataFrame(), day_column_name=5
)
)
- def test_get_route_modes(self, gtfs_fixture, mocker):
+ def test_get_route_modes(self, newp_gtfs_fixture, mocker):
"""Assertions about the table returned by get_route_modes()."""
patch_scrape_lookup = mocker.patch(
"transport_performance.gtfs.validation.scrape_route_type_lookup",
@@ -539,25 +585,28 @@ def test_get_route_modes(self, gtfs_fixture, mocker):
{"route_type": ["3"], "desc": ["Mocked bus"]}
),
)
- gtfs_fixture.get_route_modes()
+ newp_gtfs_fixture.get_route_modes()
# check mocker was called
assert (
patch_scrape_lookup.called
), "mocker.patch `patch_scrape_lookup` was not called."
- found = gtfs_fixture.route_mode_summary_df["desc"][0]
+ found = newp_gtfs_fixture.route_mode_summary_df["desc"][0]
assert found == "Mocked bus", f"Expected 'Mocked bus', found: {found}"
assert isinstance(
- gtfs_fixture.route_mode_summary_df, pd.core.frame.DataFrame
- ), f"Expected pd df. Found: {type(gtfs_fixture.route_mode_summary_df)}"
+ newp_gtfs_fixture.route_mode_summary_df, pd.core.frame.DataFrame
+ ), (
+ "Expected pd df. "
+ f"Found: {type(newp_gtfs_fixture.route_mode_summary_df)}"
+ )
exp_cols = pd.Index(["route_type", "desc", "n_routes", "prop_routes"])
- found_cols = gtfs_fixture.route_mode_summary_df.columns
+ found_cols = newp_gtfs_fixture.route_mode_summary_df.columns
assert (
found_cols == exp_cols
).all(), f"Expected columns are different. Found: {found_cols}"
- def test__preprocess_trips_and_routes(self, gtfs_fixture):
+ def test__preprocess_trips_and_routes(self, newp_gtfs_fixture):
"""Check the outputs of _pre_process_trips_and_route() (test data)."""
- returned_df = gtfs_fixture._preprocess_trips_and_routes()
+ returned_df = newp_gtfs_fixture._preprocess_trips_and_routes()
assert isinstance(returned_df, pd.core.frame.DataFrame), (
"Expected DF for _preprocess_trips_and_routes() return,"
f"found {type(returned_df)}"
@@ -591,13 +640,13 @@ def test__preprocess_trips_and_routes(self, gtfs_fixture):
f"Found {returned_df.shape}",
)
- def test_summarise_trips_defence(self, gtfs_fixture):
+ def test_summarise_trips_defence(self, newp_gtfs_fixture):
"""Defensive checks for summarise_trips()."""
with pytest.raises(
TypeError,
match="Each item in `summ_ops`.*. Found : np.mean",
):
- gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"])
+ newp_gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"])
# case where is function but not exported from numpy
def dummy_func():
@@ -611,18 +660,18 @@ def dummy_func():
" : dummy_func"
),
):
- gtfs_fixture.summarise_trips(summ_ops=[np.min, dummy_func])
+ newp_gtfs_fixture.summarise_trips(summ_ops=[np.min, dummy_func])
# case where a single non-numpy func is being passed
with pytest.raises(
NotImplementedError,
match="`summ_ops` expects numpy functions only.",
):
- gtfs_fixture.summarise_trips(summ_ops=dummy_func)
+ newp_gtfs_fixture.summarise_trips(summ_ops=dummy_func)
with pytest.raises(
TypeError,
match="`summ_ops` expects a numpy function.*. Found ",
):
- gtfs_fixture.summarise_trips(summ_ops=38)
+ newp_gtfs_fixture.summarise_trips(summ_ops=38)
# cases where return_summary are not of type boolean
with pytest.raises(
TypeError,
@@ -630,7 +679,7 @@ def dummy_func():
"`return_summary` expected . Got "
),
):
- gtfs_fixture.summarise_trips(return_summary=5)
+ newp_gtfs_fixture.summarise_trips(return_summary=5)
with pytest.raises(
TypeError,
match=re.escape(
@@ -638,15 +687,15 @@ def dummy_func():
"'str'>"
),
):
- gtfs_fixture.summarise_trips(return_summary="true")
+ newp_gtfs_fixture.summarise_trips(return_summary="true")
- def test_summarise_routes_defence(self, gtfs_fixture):
+ def test_summarise_routes_defence(self, newp_gtfs_fixture):
"""Defensive checks for summarise_routes()."""
with pytest.raises(
TypeError,
match="Each item in `summ_ops`.*. Found : np.mean",
):
- gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"])
+ newp_gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"])
# case where is function but not exported from numpy
def dummy_func():
@@ -660,18 +709,18 @@ def dummy_func():
" : dummy_func"
),
):
- gtfs_fixture.summarise_routes(summ_ops=[np.min, dummy_func])
+ newp_gtfs_fixture.summarise_routes(summ_ops=[np.min, dummy_func])
# case where a single non-numpy func is being passed
with pytest.raises(
NotImplementedError,
match="`summ_ops` expects numpy functions only.",
):
- gtfs_fixture.summarise_routes(summ_ops=dummy_func)
+ newp_gtfs_fixture.summarise_routes(summ_ops=dummy_func)
with pytest.raises(
TypeError,
match="`summ_ops` expects a numpy function.*. Found ",
):
- gtfs_fixture.summarise_routes(summ_ops=38)
+ newp_gtfs_fixture.summarise_routes(summ_ops=38)
# cases where return_summary are not of type boolean
with pytest.raises(
TypeError,
@@ -679,38 +728,36 @@ def dummy_func():
"`return_summary` expected . Got "
),
):
- gtfs_fixture.summarise_routes(return_summary=5)
+ newp_gtfs_fixture.summarise_routes(return_summary=5)
with pytest.raises(
TypeError,
match=re.escape(
"`return_summary` expected . Got "
),
):
- gtfs_fixture.summarise_routes(return_summary="true")
+ newp_gtfs_fixture.summarise_routes(return_summary="true")
- @patch("builtins.print")
- def test_clean_feed_defence(self, mock_print, gtfs_fixture):
+ def test_clean_feed_defence(self, newp_gtfs_fixture):
"""Check defensive behaviours of clean_feed()."""
- # Simulate condition where shapes.txt has no shape_id
- gtfs_fixture.feed.shapes.drop("shape_id", axis=1, inplace=True)
- gtfs_fixture.clean_feed()
- fun_out = mock_print.mock_calls
- assert fun_out == [
- call("KeyError. Feed was not cleaned.")
- ], f"Expected print statement about KeyError. Found: {fun_out}."
+ with pytest.raises(
+ TypeError, match=r".*expected .*dict.* Got .*int.*"
+ ):
+ fixt = newp_gtfs_fixture
+ fixt.is_valid(validators={"core_validation": None})
+ fixt.clean_feed(cleansers=1)
- def test_summarise_trips_on_pass(self, gtfs_fixture):
+ def test_summarise_trips_on_pass(self, newp_gtfs_fixture):
"""Assertions about the outputs from summarise_trips()."""
- gtfs_fixture.summarise_trips()
+ newp_gtfs_fixture.summarise_trips()
# tests the daily_routes_summary return schema
assert isinstance(
- gtfs_fixture.daily_trip_summary, pd.core.frame.DataFrame
+ newp_gtfs_fixture.daily_trip_summary, pd.core.frame.DataFrame
), (
"Expected DF for daily_summary,"
- f"found {type(gtfs_fixture.daily_trip_summary)}"
+ f"found {type(newp_gtfs_fixture.daily_trip_summary)}"
)
- found_ds = gtfs_fixture.daily_trip_summary.columns
+ found_ds = newp_gtfs_fixture.daily_trip_summary.columns
exp_cols_ds = pd.Index(
[
"day",
@@ -729,13 +776,13 @@ def test_summarise_trips_on_pass(self, gtfs_fixture):
# tests the self.dated_route_counts return schema
assert isinstance(
- gtfs_fixture.dated_trip_counts, pd.core.frame.DataFrame
+ newp_gtfs_fixture.dated_trip_counts, pd.core.frame.DataFrame
), (
"Expected DF for dated_route_counts,"
- f"found {type(gtfs_fixture.dated_trip_counts)}"
+ f"found {type(newp_gtfs_fixture.dated_trip_counts)}"
)
- found_drc = gtfs_fixture.dated_trip_counts.columns
+ found_drc = newp_gtfs_fixture.dated_trip_counts.columns
exp_cols_drc = pd.Index(["date", "route_type", "trip_count", "day"])
assert (
@@ -756,8 +803,8 @@ def test_summarise_trips_on_pass(self, gtfs_fixture):
)
found_df = (
- gtfs_fixture.daily_trip_summary[
- gtfs_fixture.daily_trip_summary["day"] == "friday"
+ newp_gtfs_fixture.daily_trip_summary[
+ newp_gtfs_fixture.daily_trip_summary["day"] == "friday"
]
.sort_values(by="route_type", ascending=True)
.reset_index(drop=True)
@@ -773,24 +820,26 @@ def test_summarise_trips_on_pass(self, gtfs_fixture):
# test that the dated_trip_counts can be returned
expected_size = (504, 4)
- found_size = gtfs_fixture.summarise_trips(return_summary=False).shape
+ found_size = newp_gtfs_fixture.summarise_trips(
+ return_summary=False
+ ).shape
assert expected_size == found_size, (
"Size of date_route_counts not as expected. "
"Expected {expected_size}"
)
- def test_summarise_routes_on_pass(self, gtfs_fixture):
+ def test_summarise_routes_on_pass(self, newp_gtfs_fixture):
"""Assertions about the outputs from summarise_routes()."""
- gtfs_fixture.summarise_routes()
+ newp_gtfs_fixture.summarise_routes()
# tests the daily_routes_summary return schema
assert isinstance(
- gtfs_fixture.daily_route_summary, pd.core.frame.DataFrame
+ newp_gtfs_fixture.daily_route_summary, pd.core.frame.DataFrame
), (
"Expected DF for daily_summary,"
- f"found {type(gtfs_fixture.daily_route_summary)}"
+ f"found {type(newp_gtfs_fixture.daily_route_summary)}"
)
- found_ds = gtfs_fixture.daily_route_summary.columns
+ found_ds = newp_gtfs_fixture.daily_route_summary.columns
exp_cols_ds = pd.Index(
[
"day",
@@ -809,13 +858,13 @@ def test_summarise_routes_on_pass(self, gtfs_fixture):
# tests the self.dated_route_counts return schema
assert isinstance(
- gtfs_fixture.dated_route_counts, pd.core.frame.DataFrame
+ newp_gtfs_fixture.dated_route_counts, pd.core.frame.DataFrame
), (
"Expected DF for dated_route_counts,"
- f"found {type(gtfs_fixture.dated_route_counts)}"
+ f"found {type(newp_gtfs_fixture.dated_route_counts)}"
)
- found_drc = gtfs_fixture.dated_route_counts.columns
+ found_drc = newp_gtfs_fixture.dated_route_counts.columns
exp_cols_drc = pd.Index(["date", "route_type", "day", "route_count"])
assert (
@@ -836,8 +885,8 @@ def test_summarise_routes_on_pass(self, gtfs_fixture):
)
found_df = (
- gtfs_fixture.daily_route_summary[
- gtfs_fixture.daily_route_summary["day"] == "friday"
+ newp_gtfs_fixture.daily_route_summary[
+ newp_gtfs_fixture.daily_route_summary["day"] == "friday"
]
.sort_values(by="route_type", ascending=True)
.reset_index(drop=True)
@@ -853,13 +902,15 @@ def test_summarise_routes_on_pass(self, gtfs_fixture):
# test that the dated_route_counts can be returned
expected_size = (504, 4)
- found_size = gtfs_fixture.summarise_routes(return_summary=False).shape
+ found_size = newp_gtfs_fixture.summarise_routes(
+ return_summary=False
+ ).shape
assert expected_size == found_size, (
"Size of date_route_counts not as expected. "
"Expected {expected_size}"
)
- def test__plot_summary_defences(self, tmp_path, gtfs_fixture):
+ def test__plot_summary_defences(self, tmp_path, newp_gtfs_fixture):
"""Test defences for _plot_summary()."""
# test defences for checks summaries exist
with pytest.raises(
@@ -869,7 +920,7 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture):
" Did you forget to call '.summarise_trips()' first?"
),
):
- gtfs_fixture._plot_summary(which="trip", target_column="mean")
+ newp_gtfs_fixture._plot_summary(which="trip", target_column="mean")
with pytest.raises(
AttributeError,
@@ -878,9 +929,11 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture):
" Did you forget to call '.summarise_routes()' first?"
),
):
- gtfs_fixture._plot_summary(which="route", target_column="mean")
+ newp_gtfs_fixture._plot_summary(
+ which="route", target_column="mean"
+ )
- gtfs_fixture.summarise_routes()
+ newp_gtfs_fixture.summarise_routes()
# test parameters that are yet to be tested
options = ["v", "h"]
@@ -891,7 +944,7 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture):
f"{options}. Got i: "
),
):
- gtfs_fixture._plot_summary(
+ newp_gtfs_fixture._plot_summary(
which="route",
target_column="route_count_mean",
orientation="i",
@@ -906,7 +959,7 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture):
" given to 'img_type'. Path defaulted to .png"
),
):
- gtfs_fixture._plot_summary(
+ newp_gtfs_fixture._plot_summary(
which="route",
target_column="route_count_mean",
save_image=True,
@@ -922,15 +975,17 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture):
"['trip', 'route']. Got tester: "
),
):
- gtfs_fixture._plot_summary(which="tester", target_column="tester")
+ newp_gtfs_fixture._plot_summary(
+ which="tester", target_column="tester"
+ )
- def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path):
+ def test__plot_summary_on_pass(self, newp_gtfs_fixture, tmp_path):
"""Test plotting a summary when defences are passed."""
- current_fixture = gtfs_fixture
+ current_fixture = newp_gtfs_fixture
current_fixture.summarise_routes()
# test returning a html string
- test_html = gtfs_fixture._plot_summary(
+ test_html = newp_gtfs_fixture._plot_summary(
which="route",
target_column="route_count_mean",
return_html=True,
@@ -938,7 +993,7 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path):
assert type(test_html) is str, "Failed to return HTML for the plot"
# test returning a plotly figure
- test_image = gtfs_fixture._plot_summary(
+ test_image = newp_gtfs_fixture._plot_summary(
which="route", target_column="route_count_mean"
)
assert (
@@ -946,8 +1001,8 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path):
), "Failed to return plotly.graph_objects.Figure type"
# test returning a plotly for trips
- gtfs_fixture.summarise_trips()
- test_image = gtfs_fixture._plot_summary(
+ newp_gtfs_fixture.summarise_trips()
+ test_image = newp_gtfs_fixture._plot_summary(
which="trip", target_column="trip_count_mean"
)
assert (
@@ -955,7 +1010,7 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path):
), "Failed to return plotly.graph_objects.Figure type"
# test saving plots in html and png format
- gtfs_fixture._plot_summary(
+ newp_gtfs_fixture._plot_summary(
which="route",
target_column="mean",
width=1200,
@@ -984,7 +1039,7 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path):
assert counts["html"] == 1, "Failed to save plot as HTML"
assert counts["png"] == 1, "Failed to save plot as png"
- def test__create_extended_repeated_pair_table(self, gtfs_fixture):
+ def test__create_extended_repeated_pair_table(self, newp_gtfs_fixture):
"""Test _create_extended_repeated_pair_table()."""
test_table = pd.DataFrame(
{
@@ -1003,17 +1058,19 @@ def test__create_extended_repeated_pair_table(self, gtfs_fixture):
}
).to_dict()
- returned_table = gtfs_fixture._create_extended_repeated_pair_table(
- table=test_table,
- join_vars=["trip_name", "trip_abbrev"],
- original_rows=[0],
- ).to_dict()
+ returned_table = (
+ newp_gtfs_fixture._create_extended_repeated_pair_table(
+ table=test_table,
+ join_vars=["trip_name", "trip_abbrev"],
+ original_rows=[0],
+ ).to_dict()
+ )
assert (
expected_table == returned_table
), "_create_extended_repeated_pair_table() failed"
- def test_html_report_defences(self, gtfs_fixture, tmp_path):
+ def test_html_report_defences(self, newp_gtfs_fixture, tmp_path):
"""Test the defences whilst generating a HTML report."""
with pytest.raises(
ValueError,
@@ -1022,15 +1079,15 @@ def test_html_report_defences(self, gtfs_fixture, tmp_path):
"['mean', 'min', 'max', 'median']. Got test_sum: "
),
):
- gtfs_fixture.html_report(
+ newp_gtfs_fixture.html_report(
report_dir=tmp_path,
overwrite=True,
summary_type="test_sum",
)
- def test_html_report_on_pass(self, gtfs_fixture, tmp_path):
+ def test_html_report_on_pass(self, newp_gtfs_fixture, tmp_path):
"""Test that a HTML report is generated if defences are passed."""
- gtfs_fixture.html_report(report_dir=pathlib.Path(tmp_path))
+ newp_gtfs_fixture.html_report(report_dir=pathlib.Path(tmp_path))
# assert that the report has been completely generated
assert os.path.exists(
@@ -1065,33 +1122,35 @@ def test_html_report_on_pass(self, gtfs_fixture, tmp_path):
("invalid_ext.txt", "invalid_ext.zip", True),
],
)
- def test_save(self, tmp_path, gtfs_fixture, path, final_path, warns):
+ def test_save(self, tmp_path, newp_gtfs_fixture, path, final_path, warns):
"""Test the .save() methohd of GtfsInstance()."""
complete_path = os.path.join(tmp_path, path)
expected_path = os.path.join(tmp_path, final_path)
if warns:
# catch UserWarning from invalid file extension
with pytest.warns(UserWarning):
- gtfs_fixture.save(complete_path)
+ newp_gtfs_fixture.save(complete_path)
else:
with does_not_raise():
- gtfs_fixture.save(complete_path, overwrite=True)
+ newp_gtfs_fixture.save(complete_path, overwrite=True)
assert os.path.exists(expected_path), "GTFS not saved correctly"
- def test_save_overwrite(self, tmp_path, gtfs_fixture):
+ def test_save_overwrite(self, tmp_path, newp_gtfs_fixture):
"""Test the .save()'s method of GtfsInstance overwrite feature."""
# original save
save_pth = f"{tmp_path}/test_save.zip"
- gtfs_fixture.save(save_pth, overwrite=True)
+ newp_gtfs_fixture.save(save_pth, overwrite=True)
assert os.path.exists(save_pth), "GTFS not saved at correct path"
# test saving without overwrite enabled
with pytest.raises(
FileExistsError, match="File already exists at path.*"
):
- gtfs_fixture.save(f"{tmp_path}/test_save.zip", overwrite=False)
+ newp_gtfs_fixture.save(
+ f"{tmp_path}/test_save.zip", overwrite=False
+ )
# test saving with overwrite enabled raises no errors
with does_not_raise():
- gtfs_fixture.save(f"{tmp_path}/test_save.zip", overwrite=True)
+ newp_gtfs_fixture.save(f"{tmp_path}/test_save.zip", overwrite=True)
assert os.path.exists(save_pth), "GTFS save not found"
@pytest.mark.parametrize(
@@ -1114,14 +1173,14 @@ def test_filter_to_date(self, date, expected_len):
len(gtfs.feed.stop_times) == expected_len
), "GTFS not filtered to singular date as expected"
- def test_filter_to_bbox(self, gtfs_fixture):
+ def test_filter_to_bbox(self, newp_gtfs_fixture):
"""Small tests for the shallow wrapper filter_to_bbox()."""
assert (
- len(gtfs_fixture.feed.stop_times) == 7765
+ len(newp_gtfs_fixture.feed.stop_times) == 7765
), "feed.stop_times is an unexpected size"
- gtfs_fixture.filter_to_bbox(
+ newp_gtfs_fixture.filter_to_bbox(
[-2.985535, 51.551459, -2.919617, 51.606077]
)
assert (
- len(gtfs_fixture.feed.stop_times) == 217
+ len(newp_gtfs_fixture.feed.stop_times) == 217
), "GTFS not filtered to bbox as expected"
diff --git a/tests/gtfs/test_validators.py b/tests/gtfs/test_validators.py
index 6d5cc627..9dd4903c 100644
--- a/tests/gtfs/test_validators.py
+++ b/tests/gtfs/test_validators.py
@@ -2,26 +2,42 @@
from pyprojroot import here
import pytest
import re
+import shutil
+import os
+import zipfile
+import pathlib
+
+import numpy as np
from transport_performance.gtfs.validation import GtfsInstance
from transport_performance.gtfs.validators import (
validate_travel_between_consecutive_stops,
validate_travel_over_multiple_stops,
+ validate_route_type_warnings,
+ validate_gtfs_files,
)
+from transport_performance.gtfs.gtfs_utils import _get_validation_warnings
@pytest.fixture(scope="function")
-def gtfs_fixture():
+def chest_gtfs_fixture():
"""Fixture for test funcs expecting a valid feed object."""
gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip"))
return gtfs
+@pytest.fixture(scope="function")
+def newp_gtfs_fixture():
+ """Fixture for test funcs expecting a valid feed object."""
+ gtfs = GtfsInstance(here("tests/data/gtfs/newport-20230613_gtfs.zip"))
+ return gtfs
+
+
class Test_ValidateTravelBetweenConsecutiveStops(object):
"""Tests for the validate_travel_between_consecutive_stops function()."""
def test_validate_travel_between_consecutive_stops_defences(
- self, gtfs_fixture
+ self, chest_gtfs_fixture
):
"""Defensive tests for validating travel between consecutive stops."""
with pytest.raises(
@@ -32,13 +48,15 @@ def test_validate_travel_between_consecutive_stops_defences(
"Did you forget to run the .is_valid() method?"
),
):
- validate_travel_between_consecutive_stops(gtfs_fixture)
+ validate_travel_between_consecutive_stops(chest_gtfs_fixture)
pass
- def test_validate_travel_between_consecutive_stops(self, gtfs_fixture):
+ def test_validate_travel_between_consecutive_stops(
+ self, chest_gtfs_fixture
+ ):
"""General tests for validating travel between consecutive stops."""
- gtfs_fixture.is_valid(far_stops=False)
- validate_travel_between_consecutive_stops(gtfs=gtfs_fixture)
+ chest_gtfs_fixture.is_valid(validators={"core_validation": None})
+ validate_travel_between_consecutive_stops(gtfs=chest_gtfs_fixture)
expected_validation = {
"type": {0: "warning", 1: "warning", 2: "warning", 3: "warning"},
@@ -62,20 +80,51 @@ def test_validate_travel_between_consecutive_stops(self, gtfs_fixture):
},
}
- found_dataframe = gtfs_fixture.validity_df
+ found_dataframe = chest_gtfs_fixture.validity_df
assert expected_validation == found_dataframe.to_dict(), (
"'_validate_travel_between_consecutive_stops()' failed to raise "
"warnings in the validity df"
)
+ def test__join_max_speed(self, newp_gtfs_fixture):
+ """Tests for the _join_max_speed function."""
+ newp_gtfs_fixture.is_valid(validators={"core_validation": None})
+ # assert route_type's beforehand
+ existing_types = newp_gtfs_fixture.feed.routes.route_type.unique()
+ assert np.array_equal(
+ existing_types, [3, 200]
+ ), "Existing route types not as expected."
+ # replace 3 with an invlid route_type
+ newp_gtfs_fixture.feed.routes.route_type = (
+ newp_gtfs_fixture.feed.routes.route_type.apply(
+ lambda x: 12345 if x == 200 else x
+ )
+ )
+ new_types = newp_gtfs_fixture.feed.routes.route_type.unique()
+ assert np.array_equal(
+ new_types, [3, 12345]
+ ), "Route types of 200 not replaced correctly"
+ # validate and assert a speed bound of 150 is set for these cases
+ validate_travel_between_consecutive_stops(newp_gtfs_fixture)
+ cases = newp_gtfs_fixture.full_stop_schedule[
+ newp_gtfs_fixture.full_stop_schedule.route_type == 12345
+ ]
+ print(cases.speed_bound)
+ assert np.array_equal(
+ cases.route_type.unique(), [12345]
+ ), "Dataframe filter to cases did not work"
+ assert np.array_equal(
+ cases.speed_bound.unique(), [200]
+ ), "Obtaining max speed for unrecognised route_type failed"
+
class Test_ValidateTravelOverMultipleStops(object):
"""Tests for validate_travel_over_multiple_stops()."""
- def test_validate_travel_over_multiple_stops(self, gtfs_fixture):
+ def test_validate_travel_over_multiple_stops(self, chest_gtfs_fixture):
"""General tests for validate_travel_over_multiple_stops()."""
- gtfs_fixture.is_valid(far_stops=False)
- validate_travel_over_multiple_stops(gtfs=gtfs_fixture)
+ chest_gtfs_fixture.is_valid(validators={"core_validation": None})
+ validate_travel_over_multiple_stops(gtfs=chest_gtfs_fixture)
expected_validation = {
"type": {
@@ -108,9 +157,124 @@ def test_validate_travel_over_multiple_stops(self, gtfs_fixture):
},
}
- found_dataframe = gtfs_fixture.validity_df
+ found_dataframe = chest_gtfs_fixture.validity_df
assert expected_validation == found_dataframe.to_dict(), (
"'_validate_travel_over_multiple_stops()' failed to raise "
"warnings in the validity df"
)
+
+
+class TestValidateRouteTypeWarnings(object):
+ """Tests for valdate_route_type_warnings."""
+
+ def test_validate_route_type_warnings_defence(self, newp_gtfs_fixture):
+ """Tests for validate_route_type_warnings on fail."""
+ with pytest.raises(
+ TypeError, match=r".* expected a GtfsInstance object. Got .*"
+ ):
+ validate_route_type_warnings(1)
+ with pytest.raises(
+ AttributeError, match=r".* has no attribute validity_df"
+ ):
+ validate_route_type_warnings(newp_gtfs_fixture)
+
+ def test_validate_route_type_warnings_on_pass(self, newp_gtfs_fixture):
+ """Tests for validate_route_type_warnings on pass."""
+ newp_gtfs_fixture.is_valid(validators={"core_validation": None})
+ route_errors = _get_validation_warnings(
+ newp_gtfs_fixture, message="Invalid route_type"
+ )
+ assert len(route_errors) == 1, "No route_type warnings found"
+ # clean the route_type errors
+ newp_gtfs_fixture.is_valid()
+ validate_route_type_warnings(newp_gtfs_fixture)
+ new_route_errors = _get_validation_warnings(
+ newp_gtfs_fixture, message="Invalid route_type"
+ )
+ assert (
+ len(new_route_errors) == 0
+ ), "Found route_type warnings after cleaning"
+
+ def test_validate_route_type_warnings_creates_warnings(
+ self, newp_gtfs_fixture
+ ):
+ """Test validate_route_type_warnings re-raises route_type warnings."""
+ newp_gtfs_fixture.feed.routes[
+ "route_type"
+ ] = newp_gtfs_fixture.feed.routes["route_type"].apply(
+ lambda x: 310030 if x == 200 else 200
+ )
+ newp_gtfs_fixture.is_valid(
+ {"core_validation": None, "validate_route_type_warnings": None}
+ )
+ new_route_errors = _get_validation_warnings(
+ newp_gtfs_fixture, message="Invalid route_type"
+ )
+ assert (
+ len(new_route_errors) == 1
+ ), "route_type warnings not found after cleaning"
+
+
+@pytest.fixture(scope="function")
+def create_test_zip(tmp_path) -> pathlib.Path:
+ """Create a gtfs zip with invalid files."""
+ gtfs_pth = here("tests/data/chester-20230816-small_gtfs.zip")
+ # create dir for unzipped gtfs contents
+ gtfs_contents_pth = os.path.join(tmp_path, "gtfs_contents")
+ os.mkdir(gtfs_contents_pth)
+ # extract unzipped gtfs file to new dir
+ with zipfile.ZipFile(gtfs_pth, "r") as gtfs_zip:
+ gtfs_zip.extractall(gtfs_contents_pth)
+ # write some dummy files with test cases
+ with open(os.path.join(gtfs_contents_pth, "not_in_spec.txt"), "w") as f:
+ f.write("test_date")
+ with open(os.path.join(gtfs_contents_pth, "routes.invalid"), "w") as f:
+ f.write("test_date")
+ # zip contents
+ new_zip_pth = os.path.join(tmp_path, "gtfs_zip")
+ shutil.make_archive(new_zip_pth, "zip", gtfs_contents_pth)
+ full_zip_pth = pathlib.Path(new_zip_pth + ".zip")
+ return full_zip_pth
+
+
+class TestValidateGtfsFile(object):
+ """Tests for validate_gtfs_files."""
+
+ def test_validate_gtfs_files_defence(self):
+ """Defensive tests for validate_gtfs_files."""
+ with pytest.raises(
+ TypeError, match="'gtfs' expected a GtfsInstance object.*"
+ ):
+ validate_gtfs_files(False)
+
+ def test_validate_gtfs_files_on_pass(self, create_test_zip):
+ """General tests for validte_gtfs_files."""
+ gtfs = GtfsInstance(create_test_zip)
+ gtfs.is_valid(validators={"core_validation": None})
+ validate_gtfs_files(gtfs)
+ # tests for invalid extensions
+ warnings = _get_validation_warnings(
+ gtfs, r".*files not of type .*txt.*", return_type="dataframe"
+ )
+ assert (
+ len(warnings) == 1
+ ), "More warnings than expected for invalid extension"
+ assert warnings.loc[3]["message"] == (
+ "GTFS zip includes files not of type '.txt'. These files include "
+ "['routes.invalid']"
+ ), "Warnings not appearing as expected"
+
+ # tests for unrecognised files
+ warnings = _get_validation_warnings(
+ gtfs,
+ r".*files that aren't recognised by the GTFS.*",
+ return_type="dataframe",
+ )
+ assert (
+ len(warnings) == 1
+ ), "More warnings than expected for not implemented tables"
+ assert warnings.loc[4]["message"] == (
+ "GTFS zip includes files that aren't recognised by the GTFS "
+ "spec. These include ['not_in_spec.txt']"
+ )
diff --git a/tests/utils/test_defence.py b/tests/utils/test_defence.py
index 3dd0d1fe..06dfb212 100644
--- a/tests/utils/test_defence.py
+++ b/tests/utils/test_defence.py
@@ -8,6 +8,7 @@
from _pytest.python_api import RaisesContext
import pandas as pd
from pyprojroot import here
+from contextlib import nullcontext as does_not_raise
from transport_performance.utils.defence import (
_check_iterable,
@@ -21,6 +22,7 @@
_is_expected_filetype,
_enforce_file_extension,
)
+from transport_performance.gtfs.validation import GtfsInstance
class Test_CheckIter(object):
@@ -244,6 +246,7 @@ def test_check_parents_dir_exists(self, tmp_path):
def test__gtfs_defence():
"""Tests for _gtfs_defence()."""
+ # defensive tests
with pytest.raises(
TypeError,
match=re.escape(
@@ -251,6 +254,10 @@ def test__gtfs_defence():
),
):
_gtfs_defence("tester", "test")
+ # passing test
+ with does_not_raise():
+ gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip"))
+ _gtfs_defence(gtfs=gtfs, param_nm="gtfs")
class Test_TypeDefence(object):