diff --git a/.gitignore b/.gitignore index 90fae26f..509aea12 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ !tests/data/small-osm.pbf !tests/data/chester-20230816-small_gtfs.zip !tests/data/gtfs/newport-20230613_gtfs.zip +!tests/data/gtfs/repeated_pair_gtfs_fixture.zip !src/transport_performance/data/gtfs/route_lookup.pkl !tests/data/gtfs/report/html_template.html !tests/data/metrics/mock_centroid_gdf.pkl diff --git a/src/transport_performance/gtfs/cleaners.py b/src/transport_performance/gtfs/cleaners.py index 3d4d7157..863f67cf 100644 --- a/src/transport_performance/gtfs/cleaners.py +++ b/src/transport_performance/gtfs/cleaners.py @@ -1,9 +1,26 @@ """A set of functions that clean the gtfs data.""" from typing import Union +import warnings import numpy as np - -from transport_performance.utils.defence import _gtfs_defence, _check_iterable +import pandas as pd +from gtfs_kit.cleaners import ( + clean_ids as clean_ids_gk, + clean_route_short_names as clean_route_short_names_gk, + clean_times as clean_times_gk, + drop_zombies as drop_zombies_gk, +) + +from transport_performance.gtfs.gtfs_utils import ( + _get_validation_warnings, + _remove_validation_row, +) +from transport_performance.utils.defence import ( + _gtfs_defence, + _check_iterable, + _type_defence, + _check_attribute, +) def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None: @@ -32,19 +49,21 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None: if isinstance(trip_id, str): trip_id = [trip_id] - # _check_iterable only takes lists, therefore convert numpy arrays - if isinstance(trip_id, np.ndarray): - trip_id = list(trip_id) - # ensure trip ids are string _check_iterable( iterable=trip_id, param_nm="trip_id", - iterable_type=list, + iterable_type=type(trip_id), check_elements=True, exp_type=str, ) + # warn users if passed one of the passed trip_id's is not present in the + # GTFS. + for _id in trip_id: + if _id not in gtfs.feed.trips.trip_id.unique(): + warnings.warn(UserWarning(f"trip_id '{_id}' not found in GTFS")) + # drop relevant records from tables gtfs.feed.trips = gtfs.feed.trips[ ~gtfs.feed.trips["trip_id"].isin(trip_id) @@ -61,44 +80,58 @@ def drop_trips(gtfs, trip_id: Union[str, list, np.ndarray]) -> None: return None -def clean_consecutive_stop_fast_travel_warnings( - gtfs, validate: bool = False -) -> None: - """Clean 'Fast Travel Between Consecutive Stops' warnings from validity_df. +def _clean_fast_travel_preparation(gtfs, warning_re: str) -> pd.DataFrame: + """Prepare to clean fast travel errors. + + At the beggining of both of the fast travel cleaners, the gtfs is type + checked, attr checked and then warnings are obtained. Because of this, this + has been functionalised Parameters ---------- - gtfs : GtfsInstance - The GtfsInstance to clean warnings within - validate : bool, optional - Whether or not to validate the gtfs before carrying out this cleaning - operation + gtfs : _type_ + The GtfsInstance. + warning_re : str + Regex used to obtain warnings. Returns ------- - None + pd.DataFrame + A dataframe containing warnings. """ - # defences _gtfs_defence(gtfs, "gtfs") - if "validity_df" not in gtfs.__dict__.keys() and not validate: - raise AttributeError( + _type_defence(warning_re, "warning_re", str) + _check_attribute( + gtfs, + "validity_df", + message=( "The gtfs has not been validated, therefore no" "warnings can be identified. You can pass " "validate=True to this function to validate the " "gtfs." - ) + ), + ) + needed_warning = _get_validation_warnings(gtfs, warning_re) + return needed_warning + - if validate: - gtfs.is_valid() +def clean_consecutive_stop_fast_travel_warnings(gtfs) -> None: + """Clean 'Fast Travel Between Consecutive Stops' warnings from validity_df. - needed_warning = ( - gtfs.validity_df[ - gtfs.validity_df["message"] - == "Fast Travel Between Consecutive Stops" - ] - .copy() - .values + Parameters + ---------- + gtfs : GtfsInstance + The GtfsInstance to clean warnings within + + Returns + ------- + None + + """ + # defences + needed_warning = _clean_fast_travel_preparation( + gtfs, "Fast Travel Between Consecutive Stops" ) if len(needed_warning) < 1: @@ -116,45 +149,22 @@ def clean_consecutive_stop_fast_travel_warnings( return None -def clean_multiple_stop_fast_travel_warnings( - gtfs, validate: bool = False -) -> None: +def clean_multiple_stop_fast_travel_warnings(gtfs) -> None: """Clean 'Fast Travel Over Multiple Stops' warnings from validity_df. Parameters ---------- gtfs : GtfsInstance The GtfsInstance to clean warnings within - validate : bool, optional - Whether or not to validate the gtfs before carrying out this cleaning - operation Returns ------- None """ - # defences - _gtfs_defence(gtfs, "gtfs") - if "validity_df" not in gtfs.__dict__.keys() and not validate: - raise AttributeError( - "The gtfs has not been validated, therefore no" - "warnings can be identified. You can pass " - "validate=True to this function to validate the " - "gtfs." - ) - - if validate: - gtfs.is_valid() - - needed_warning = ( - gtfs.validity_df[ - gtfs.validity_df["message"] == "Fast Travel Over Multiple Stops" - ] - .copy() - .values + needed_warning = _clean_fast_travel_preparation( + gtfs, "Fast Travel Over Multiple Stops" ) - if len(needed_warning) < 1: return None @@ -168,3 +178,122 @@ def clean_multiple_stop_fast_travel_warnings( ~gtfs.multiple_stops_invalid["trip_id"].isin(trip_ids) ] return None + + +def core_cleaners( + gtfs, + clean_ids: bool = True, + clean_times: bool = True, + clean_route_short_names: bool = True, + drop_zombies: bool = True, +) -> None: + """Clean the gtfs with the core cleaners of gtfs-kit. + + The source code for the cleaners, along with detailed descriptions of the + cleaning they are performing can be found here: + https://github.com/mrcagney/gtfs_kit/blob/master/gtfs_kit/cleaners.py + + All credit for these cleaners goes to the creators of the gtfs_kit package. + HOMEPAGE: https://github.com/mrcagney/gtfs_kit + + Parameters + ---------- + gtfs : GtfsInstance + The gtfs to clean + clean_ids : bool, optional + Whether or not to use clean_ids, by default True + clean_times : bool, optional + Whether or not to use clean_times, by default True + clean_route_short_names : bool, optional + Whether or not to use clean_route_short_names, by default True + drop_zombies : bool, optional + Whether or not to use drop_zombies, by default True + + Returns + ------- + None + + """ + # defences + _gtfs_defence(gtfs, "gtfs") + _type_defence(clean_ids, "clean_ids", bool) + _type_defence(clean_times, "clean_times", bool) + _type_defence(clean_route_short_names, "clean_route_short_names", bool) + _type_defence(drop_zombies, "drop_zombies", bool) + # cleaning + if clean_ids: + clean_ids_gk(gtfs.feed) + if clean_times: + clean_times_gk(gtfs.feed) + if clean_route_short_names: + clean_route_short_names_gk(gtfs.feed) + if drop_zombies: + try: + drop_zombies_gk(gtfs.feed) + except KeyError: + warnings.warn( + UserWarning( + "The drop_zombies cleaner was unable to operate on " + "clean_feed as the trips table has no shape_id column" + ) + ) + return None + + +def clean_unrecognised_column_warnings(gtfs) -> None: + """Clean warnings for unrecognised columns. + + Parameters + ---------- + gtfs : GtfsInstance + The GtfsInstance to clean warnings from + + Returns + ------- + None + + """ + _gtfs_defence(gtfs, "gtfs") + warnings = _get_validation_warnings( + gtfs=gtfs, message="Unrecognized column .*" + ) + for warning in warnings: + tbl = gtfs.table_map[warning[2]] + # parse column from warning message + column = warning[1].split("column")[1].strip() + tbl.drop(column, inplace=True, axis=1) + _remove_validation_row(gtfs, warning[1]) + return None + + +def clean_duplicate_stop_times(gtfs) -> None: + """Clean duplicates from stop_times with repeated pair (trip_id, ... + + departure_time. + + Parameters + ---------- + gtfs : GtfsInstance + The gtfs to clean + + Returns + ------- + None + + """ + _gtfs_defence(gtfs, "gtfs") + warning_re = r".* \(trip_id, departure_time\)" + # we are only expecting one warning here + warning = _get_validation_warnings(gtfs, warning_re) + if len(warning) == 0: + return None + warning = warning[0] + # drop from actual table + gtfs.table_map[warning[2]].drop_duplicates( + subset=["arrival_time", "departure_time", "trip_id", "stop_id"], + inplace=True, + ) + _remove_validation_row(gtfs, message=warning_re) + # re-validate with gtfs-kit validator + gtfs.is_valid({"core_validation": None}) + return None diff --git a/src/transport_performance/gtfs/gtfs_utils.py b/src/transport_performance/gtfs/gtfs_utils.py index 100d2924..60fffdad 100644 --- a/src/transport_performance/gtfs/gtfs_utils.py +++ b/src/transport_performance/gtfs/gtfs_utils.py @@ -5,11 +5,11 @@ from pyprojroot import here import pandas as pd import os -import math import plotly.graph_objects as go from typing import Union, TYPE_CHECKING import pathlib from geopandas import GeoDataFrame +import numpy as np import warnings if TYPE_CHECKING: @@ -19,9 +19,12 @@ _is_expected_filetype, _check_iterable, _type_defence, + _check_attribute, + _gtfs_defence, _validate_datestring, _enforce_file_extension, - _gtfs_defence, + _check_parent_dir_exists, + _check_item_in_iter, ) from transport_performance.utils.constants import PKG_PATH @@ -284,14 +287,20 @@ def _add_validation_row( An error is raised if the validity df does not exist """ - # TODO: add dtype defences from defence.py once gtfs-html-new is merged - if "validity_df" not in gtfs.__dict__.keys(): - raise AttributeError( + _gtfs_defence(gtfs, "gtfs") + _type_defence(_type, "_type", str) + _type_defence(message, "message", str) + _type_defence(rows, "rows", list) + _check_attribute( + gtfs, + "validity_df", + message=( "The validity_df does not exist as an " "attribute of your GtfsInstance object, \n" "Did you forget to run the .is_valid() method?" - ) - + ), + ) + _check_item_in_iter(_type, ["warning", "error"], "_type") temp_df = pd.DataFrame( { "type": [_type], @@ -311,7 +320,6 @@ def filter_gtfs_around_trip( gtfs, trip_id: str, buffer_dist: int = 10000, - units: str = "m", crs: str = "27700", out_pth=os.path.join("data", "external", "trip_gtfs.zip"), ) -> None: @@ -325,8 +333,6 @@ def filter_gtfs_around_trip( The trip ID buffer_dist : int, optional The distance to create a buffer around the trip, by default 10000 - units : str, optional - Distance units of the original GTFS, by default "m" crs : str, optional The CRS to use for adding a buffer, by default "27700" out_pth : _type_, optional @@ -343,21 +349,21 @@ def filter_gtfs_around_trip( An error is raised if a shapeID is not available """ - # TODO: Add datatype defences once merged + _gtfs_defence(gtfs, "gtfs") + _type_defence(trip_id, "trip_id", str) + _type_defence(buffer_dist, "buffer_dist", int) + _type_defence(crs, "crs", str) + _check_parent_dir_exists(out_pth, "out_pth", create=True) trips = gtfs.feed.trips shapes = gtfs.feed.shapes shape_id = list(trips[trips["trip_id"] == trip_id]["shape_id"])[0] # defence - # try/except for math.isnan() returning TypeError for strings - try: - if math.isnan(shape_id): - raise ValueError( - "'shape_id' not available for trip with trip_id: " f"{trip_id}" - ) - except TypeError: - pass + if pd.isna(shape_id): + raise ValueError( + "'shape_id' not available for trip with trip_id: " f"{trip_id}" + ) # create a buffer around the trip trip_shape = shapes[shapes["shape_id"] == shape_id] @@ -376,7 +382,7 @@ def filter_gtfs_around_trip( in_pth=gtfs.gtfs_path, bbox=list(bbox), crs=crs, - units=units, + units=gtfs.units, out_pth=out_pth, ) @@ -465,3 +471,143 @@ def convert_pandas_to_plotly( if return_html: return fig.to_html(full_html=False) return fig + + +def _get_validation_warnings( + gtfs, message: str, return_type: str = "values" +) -> pd.DataFrame: + """Get warnings from the validity_df table based on a regex. + + Parameters + ---------- + gtfs : GtfsInstance() + The gtfs instance to obtain the warnings from. + message : str + The regex to use for filtering the warnings. + return_type : str, optional + The return type of the warnings. Can be either 'values' or 'dataframe', + by default 'values' + + Returns + ------- + pd.DataFrame + A dataframe containing all warnings matching the regex. + + """ + _gtfs_defence(gtfs, "gtfs") + _check_attribute( + gtfs, + "validity_df", + message=( + "The gtfs has not been validated, therefore no" + "warnings can be identified." + ), + ) + _type_defence(message, "message", str) + return_type = return_type.lower().strip() + if return_type not in ["values", "dataframe"]: + raise ValueError( + "'return_type' expected one of ['values', 'dataframe]" + f". Got {return_type}" + ) + needed_warnings = gtfs.validity_df[ + gtfs.validity_df["message"].str.contains(message, regex=True, na=False) + ].copy() + if return_type == "dataframe": + return needed_warnings + return needed_warnings.values + + +def _remove_validation_row( + gtfs, message: str = None, index: Union[list, np.array] = None +) -> None: + """Remove rows from the 'validity_df' attr of a GtfsInstance(). + + Both a regex on messages and index locations can be used to drop drops. If + values are passed to both the 'index' and 'message' params, rows will be + removed using the regex passed to the 'message' param. + + Parameters + ---------- + gtfs : GtfsInstance + The GtfsInstance to remove the warnings/errors from. + message : str, optional + The regex to filter messages by, by default None. + index : Union[list, np.array], optional + The index locations of rows to be removed, by default None. + + Returns + ------- + None + + Raises + ------ + ValueError + Error raised if both 'message' and 'index' params are None. + UserWarning + A warning is raised if both 'message' and 'index' params are not None. + + """ + # defences + _gtfs_defence(gtfs, "gtfs") + _type_defence(message, "message", (str, type(None))) + _type_defence(index, "index", (list, np.ndarray, type(None))) + _check_attribute(gtfs, "validity_df") + if message is None and index is None: + raise ValueError( + "Both 'message' and 'index' are None, therefore no" + "warnings/errors are able to be cleaned." + ) + if message is not None and index is not None: + warnings.warn( + UserWarning( + "Both 'index' and 'message' are not None. Warnings/" + "Errors have been cleaned on 'message'" + ) + ) + # remove row from validation table + if message is not None: + gtfs.validity_df = gtfs.validity_df[ + ~gtfs.validity_df.message.str.contains( + message, regex=True, na=False + ) + ] + return None + + gtfs.validity_df = gtfs.validity_df.loc[ + list(set(gtfs.validity_df.index) - set(index)) + ] + return None + + +def _function_pipeline( + gtfs, func_map: dict, operations: Union[dict, type(None)] +) -> None: + """Iterate through and act on a functional pipeline.""" + _gtfs_defence(gtfs, "gtfs") + _type_defence(func_map, "func_map", dict) + _type_defence(operations, "operations", (dict, type(None))) + if operations: + for key in operations.keys(): + if key not in func_map.keys(): + raise KeyError( + f"'{key}' function passed to 'operations' is not a " + "known operation. Known operation include: " + f"{func_map.keys()}" + ) + for operation in operations: + # check value is dict or none (for kwargs) + _type_defence( + operations[operation], + f"operations[{operation}]", + (dict, type(None)), + ) + operations[operation] = ( + {} if operations[operation] is None else operations[operation] + ) + func_map[operation](gtfs=gtfs, **operations[operation]) + # if no operations passed, carry out all operations + else: + for operation in func_map: + func_map[operation](gtfs=gtfs) + return None diff --git a/src/transport_performance/gtfs/validation.py b/src/transport_performance/gtfs/validation.py index dcad1e72..3ce69e94 100644 --- a/src/transport_performance/gtfs/validation.py +++ b/src/transport_performance/gtfs/validation.py @@ -17,19 +17,15 @@ from plotly.graph_objects import Figure as PlotlyFigure from geopandas import GeoDataFrame -from transport_performance.gtfs.validators import ( - validate_travel_over_multiple_stops, - validate_travel_between_consecutive_stops, -) +import transport_performance.gtfs.cleaners as cleaners +import transport_performance.gtfs.validators as gtfs_validators from transport_performance.gtfs.calendar import create_calendar_from_dates -from transport_performance.gtfs.cleaners import ( - clean_consecutive_stop_fast_travel_warnings, - clean_multiple_stop_fast_travel_warnings, -) from transport_performance.gtfs.routes import ( scrape_route_type_lookup, get_saved_route_type_lookup, ) + +from transport_performance.gtfs.gtfs_utils import _function_pipeline from transport_performance.utils.defence import ( _is_expected_filetype, _check_namespace_export, @@ -40,15 +36,44 @@ _check_attribute, _enforce_file_extension, ) - from transport_performance.gtfs.report.report_utils import ( TemplateHTML, _set_up_report_dir, ) - +from transport_performance.utils.constants import ( + PKG_PATH, +) from transport_performance.gtfs.gtfs_utils import filter_gtfs -from transport_performance.utils.constants import PKG_PATH +# THESE MAPPINGS CAN NOT BE MOVED TO CONSTANTS AS THEY INTRODUCE DEPENDENCY +# ISSUES. +CLEAN_FEED_FUNCTION_MAP = { + "core_cleaners": cleaners.core_cleaners, + "clean_unrecognised_column_warnings": ( + cleaners.clean_unrecognised_column_warnings + ), + "clean_duplicate_stop_times": (cleaners.clean_duplicate_stop_times), + "clean_consecutive_stop_fast_travel_warnings": ( + cleaners.clean_consecutive_stop_fast_travel_warnings + ), + "clean_multiple_stop_fast_travel_warnings": ( + cleaners.clean_multiple_stop_fast_travel_warnings + ), +} + +VALIDATE_FEED_FUNC_MAP = { + "core_validation": gtfs_validators.core_validation, + "validate_gtfs_files": gtfs_validators.validate_gtfs_files, + "validate_travel_between_consecutive_stops": ( + gtfs_validators.validate_travel_between_consecutive_stops + ), + "validate_travel_over_multiple_stops": ( + gtfs_validators.validate_travel_over_multiple_stops + ), + "validate_route_type_warnings": ( + gtfs_validators.validate_route_type_warnings + ), +} def _get_intermediate_dates( @@ -182,6 +207,8 @@ class GtfsInstance: A gtfs_kit feed produced using the files at `gtfs_pth` on init. gtfs_path : Union[str, pathlib.Path] The path to the GTFS archive. + units : str + Spatial units of the GTFS file, defaults to "km". file_list: list Files in the GTFS archive. validity_df: pd.DataFrame @@ -283,7 +310,7 @@ def __init__( accepted_units = ["m", "km"] _check_item_in_iter(units, accepted_units, "units") - + self.units = units self.feed = gk.read_feed(gtfs_pth, dist_units=units) self.gtfs_path = gtfs_pth if route_lookup_pth is not None: @@ -329,6 +356,16 @@ def __init__( "shapes": [], } + # CONSTANT TO LINK TO GTFS TABLES USING STRINGS + self.table_map = { + "agency": self.feed.agency, + "routes": self.feed.routes, + "stop_times": self.feed.stop_times, + "stops": self.feed.stops, + "trips": self.feed.trips, + "calendar": self.feed.calendar, + } + def ensure_populated_calendar(self) -> None: """If calendar is absent, creates one from calendar_dates. @@ -376,14 +413,13 @@ def get_gtfs_files(self) -> list: self.file_list = file_list return self.file_list - def is_valid(self, far_stops: bool = True) -> pd.DataFrame: + def is_valid(self, validators: dict = None) -> pd.DataFrame: """Check a feed is valid with `gtfs_kit`. Parameters ---------- - far_stops : bool, optional - Whether or not to perform validation for far stops (both - between consecutive stops and over multiple stops) + validators : dict, optional + A dictionary of function name to kwargs mappings. Returns ------- @@ -391,10 +427,14 @@ def is_valid(self, far_stops: bool = True) -> pd.DataFrame: Table of errors, warnings & their descriptions. """ - self.validity_df = self.feed.validate() - if far_stops: - validate_travel_between_consecutive_stops(self) - validate_travel_over_multiple_stops(self) + _type_defence(validators, "validators", (dict, type(None))) + # create validity df + self.validity_df = pd.DataFrame( + columns=["type", "message", "table", "rows"] + ) + _function_pipeline( + gtfs=self, func_map=VALIDATE_FEED_FUNC_MAP, operations=validators + ) return self.validity_df def print_alerts(self, alert_type: str = "error") -> None: @@ -445,34 +485,27 @@ def print_alerts(self, alert_type: str = "error") -> None: return None - def clean_feed( - self, validate: bool = False, fast_travel: bool = True - ) -> None: - """Attempt to clean feed using `gtfs_kit`. + def clean_feed(self, cleansers: dict = None) -> None: + """Clean the gtfs feed. Parameters ---------- - validate: bool, optional - Whether or not to validate the dataframe before cleaning - fast_travel: bool, optional - Whether or not to clean warnings related to fast travel. + cleansers : dict, optional + A mapping of cleansing functions and kwargs, by default None + + Returns + ------- + None """ - _type_defence(fast_travel, "fast_travel", bool) - _type_defence(validate, "valiidate", bool) - if validate: - self.is_valid(far_stops=fast_travel) - try: - # In cases where shape_id is missing, keyerror is raised. - # https://developers.google.com/transit/gtfs/reference#shapestxt - # shows that shapes.txt is optional file. - self.feed = self.feed.clean() - if fast_travel: - clean_consecutive_stop_fast_travel_warnings(self) - clean_multiple_stop_fast_travel_warnings(self) - except KeyError: - # TODO: Issue 74 - Improve this to clean feed when KeyError raised - print("KeyError. Feed was not cleaned.") + # DEV NOTE: Opting not to allow for validation in clean_feed(). + # .is_valid() should be used before hand. + # DEV NOTE 2: Use of param name 'cleansers' is to avoid conflicts + _type_defence(cleansers, "cleansers", (dict, type(None))) + _function_pipeline( + gtfs=self, func_map=CLEAN_FEED_FUNCTION_MAP, operations=cleansers + ) + return None def _produce_stops_map( self, what_geoms: str, is_filtered: bool, crs: Union[int, str] @@ -1317,15 +1350,6 @@ def _extended_validation( None """ - table_map = { - "agency": self.feed.agency, - "routes": self.feed.routes, - "stop_times": self.feed.stop_times, - "stops": self.feed.stops, - "trips": self.feed.trips, - "calendar": self.feed.calendar, - } - # determine which errors/warnings have rows that can be located validation_table = self.is_valid() validation_table["valid_row"] = validation_table["rows"].apply( @@ -1354,7 +1378,9 @@ def _extended_validation( for col in self.GTFS_UNNEEDED_COLUMNS[table] if col not in join_vars ] - filtered_tbl = table_map[table].copy().drop(drop_cols, axis=1) + filtered_tbl = ( + self.table_map[table].copy().drop(drop_cols, axis=1) + ) impacted_rows = self._create_extended_repeated_pair_table( table=filtered_tbl, join_vars=join_vars, @@ -1372,7 +1398,7 @@ def _extended_validation( == impacted_rows[f"{col}_duplicate"] ].shape[0] else: - impacted_rows = table_map[table].copy().iloc[rows] + impacted_rows = self.table_map[table].copy().iloc[rows] # create the html to display the impacted rows (clean possibly) table_html = f""" @@ -1486,9 +1512,11 @@ def html_report( date = datetime.datetime.strftime(datetime.datetime.now(), "%d-%m-%Y") # feed evaluation - self.clean_feed(validate=True, fast_travel=True) + # TODO: make this optional (and allow params) + self.is_valid() + self.clean_feed() # re-validate to clean any newly raised errors/warnings - validation_dataframe = self.is_valid(far_stops=True) + validation_dataframe = self.is_valid() # create extended reports if requested if extended_validation: @@ -1503,7 +1531,7 @@ def html_report( ) validation_dataframe["info"] = [ f""" Further Info""" - if len(rows) > 1 + if len(rows) > 0 else "Unavailable" for href, rows in zip(info_href, validation_dataframe["rows"]) ] diff --git a/src/transport_performance/gtfs/validators.py b/src/transport_performance/gtfs/validators.py index 6ad9f8e5..cc0c441e 100644 --- a/src/transport_performance/gtfs/validators.py +++ b/src/transport_performance/gtfs/validators.py @@ -1,12 +1,17 @@ """A set of functions that validate the GTFS data.""" from typing import TYPE_CHECKING +import os import numpy as np import pandas as pd from haversine import Unit, haversine_vector -from transport_performance.gtfs.gtfs_utils import _add_validation_row -from transport_performance.utils.defence import _gtfs_defence +from transport_performance.gtfs.gtfs_utils import ( + _add_validation_row, + _get_validation_warnings, + _remove_validation_row, +) +from transport_performance.utils.defence import _gtfs_defence, _check_attribute if TYPE_CHECKING: from transport_performance.gtfs.validation import GtfsInstance @@ -26,6 +31,28 @@ 200: 120, } +# EXPECTED GTFS FILES +# DESC: USED TO RAISE WARNINGS WHEN .TXT FILES INCLUDED IN THE ZIP PASSED TO +# GTFSINSTANCE AREN'T A RECOGNISED GTFS TABLE# +# NOTE: 'Levels' and 'Translation' tables are ignored by gtfs-kit + +ACCEPTED_GTFS_TABLES = [ + "agency", + "attributions", + "calendar", + "calendar_dates", + "fare_attributes", + "fare_rules", + "feed_info", + "frequencies", + "routes", + "shapes", + "stops", + "stop_times", + "transfers", + "trips", +] + def validate_travel_between_consecutive_stops(gtfs: "GtfsInstance"): """Validate the travel between consecutive stops in the GTFS data. @@ -145,6 +172,8 @@ def _join_max_speed(r_type: int) -> int: ) gtfs.full_stop_schedule = stop_sched + gtfs.table_map["full_stop_schedule"] = gtfs.full_stop_schedule + # find the stops that exceed the speed boundary invalid_stops = stop_sched[stop_sched["speed"] > stop_sched["speed_bound"]] @@ -153,7 +182,6 @@ def _join_max_speed(r_type: int) -> int: return invalid_stops # add the error to the validation table - # TODO: After merge add full_stop_schedule to HTML output table keys _add_validation_row( gtfs=gtfs, _type="warning", @@ -252,9 +280,8 @@ def validate_travel_over_multiple_stops(gtfs: "GtfsInstance") -> None: } ) - # TODO: Add this table to the lookup once gtfs HTML is merged gtfs.multiple_stops_invalid = far_stops_df - + gtfs.table_map["multiple_stops_invalid"] = gtfs.multiple_stops_invalid if len(gtfs.multiple_stops_invalid) > 0: _add_validation_row( gtfs=gtfs, @@ -265,3 +292,98 @@ def validate_travel_over_multiple_stops(gtfs: "GtfsInstance") -> None: ) return far_stops_df + + +def validate_route_type_warnings(gtfs: "GtfsInstance") -> None: + """Valiidate that the route type warnings are reasonable and just. + + Parameters + ---------- + gtfs : GtfsInstance + The GtfsInstance to validate the warnings of. + + Returns + ------- + None + + """ + # defences + _gtfs_defence(gtfs, "gtfs") + _check_attribute(gtfs, "validity_df") + # identify and clean warnings + warnings = _get_validation_warnings(gtfs, "Invalid route_type.*") + if len(warnings) <= 0: + return None + route_rows = gtfs.feed.routes.loc[warnings[0][3]].copy() + route_rows = route_rows[ + ~route_rows.route_type.astype("str").isin( + gtfs.ROUTE_LKP["route_type"].unique() + ) + ] + _remove_validation_row(gtfs, "Invalid route_type.*") + if len(route_rows) > 0: + _add_validation_row( + gtfs, + _type="error", + message="Invalid route_type; maybe has extra space characters", + table="routes", + rows=list(route_rows.index), + ) + return None + + +def core_validation(gtfs: "GtfsInstance"): + """Carry out the main validators of gtfs-kit.""" + _gtfs_defence(gtfs, "gtfs") + validation_df = gtfs.feed.validate() + gtfs.validity_df = pd.concat( + [validation_df, gtfs.validity_df], axis=0 + ).reset_index(drop=True) + + +def validate_gtfs_files(gtfs: "GtfsInstance") -> None: + """Validate to raise warnings if tables in GTFS zip aren't being read. + + Parameters + ---------- + gtfs : GtfsInstance + The gtfs instance to run the validation on. + + Returns + ------- + None + + """ + _gtfs_defence(gtfs, "gtfs") + files = gtfs.get_gtfs_files() + invalid_extensions = [] + non_implemented_files = [] + for fname in files: + root, ext = os.path.splitext(fname) + # check for instances of invalid file extensions (only .txt accepted) + if ext.lower() != ".txt": + invalid_extensions.append(fname) + # check that each file is within the GTFS specification. + # NOTE: The 'levels' and 'translation' tables are in the GTFS spec + # but aren't used by gtfs-kit. + # GTFS REFERENCE: https://developers.google.com/transit/gtfs/reference + if root not in ACCEPTED_GTFS_TABLES: + non_implemented_files.append(fname) + # raise warnings + if len(invalid_extensions) > 0: + msg = ( + "GTFS zip includes files not of type '.txt'. These files " + f"include {invalid_extensions}" + ) + _add_validation_row( + gtfs, _type="warning", message=msg, table="", rows=[] + ) + if len(non_implemented_files) > 0: + msg = ( + "GTFS zip includes files that aren't recognised by the GTFS " + f"spec. These include {non_implemented_files}" + ) + _add_validation_row( + gtfs, _type="warning", message=msg, table="", rows=[] + ) + return None diff --git a/src/transport_performance/utils/constants.py b/src/transport_performance/utils/constants.py index 97b98598..a1521f2b 100644 --- a/src/transport_performance/utils/constants.py +++ b/src/transport_performance/utils/constants.py @@ -1,5 +1,6 @@ """Constants to be used throughout the transport-performance package.""" +from importlib import resources as pkg_resources # + import transport_performance -from importlib import resources as pkg_resources PKG_PATH = pkg_resources.files(transport_performance) diff --git a/tests/data/gtfs/repeated_pair_gtfs_fixture.zip b/tests/data/gtfs/repeated_pair_gtfs_fixture.zip new file mode 100644 index 00000000..8c723d01 Binary files /dev/null and b/tests/data/gtfs/repeated_pair_gtfs_fixture.zip differ diff --git a/tests/gtfs/test_cleaners.py b/tests/gtfs/test_cleaners.py index b94b1cf9..262aea4f 100644 --- a/tests/gtfs/test_cleaners.py +++ b/tests/gtfs/test_cleaners.py @@ -10,6 +10,13 @@ drop_trips, clean_consecutive_stop_fast_travel_warnings, clean_multiple_stop_fast_travel_warnings, + core_cleaners, + clean_unrecognised_column_warnings, + clean_duplicate_stop_times, +) +from transport_performance.gtfs.gtfs_utils import ( + _get_validation_warnings, + _remove_validation_row, ) @@ -64,6 +71,10 @@ def test_drop_trips_defences(self, gtfs_fixture): ] assert len(found_df) == 0, "Failed to drop trip in format 'string'" + # test dropping non existent trip + with pytest.warns(UserWarning, match="trip_id .* not found in GTFS"): + drop_trips(gtfs_fixture, ["NOT_AT_TRIP_ID"]) + def test_drop_trips_on_pass(self, gtfs_fixture): """General tests for drop_trips().""" fixture = gtfs_fixture @@ -117,7 +128,7 @@ class Test_CleanConsecutiveStopFastTravelWarnings(object): def test_clean_consecutive_stop_fast_travel_warnings_defence( self, gtfs_fixture ): - """Defensive tests forclean_consecutive_stop_fast_travel_warnings().""" + """Defensive tests for clean_consecutive_stop_fast_travel_warnings.""" with pytest.raises( AttributeError, match=re.escape( @@ -127,15 +138,13 @@ def test_clean_consecutive_stop_fast_travel_warnings_defence( "gtfs." ), ): - clean_consecutive_stop_fast_travel_warnings( - gtfs=gtfs_fixture, validate=False - ) + clean_consecutive_stop_fast_travel_warnings(gtfs=gtfs_fixture) def test_clean_consecutive_stop_fast_travel_warnings_on_pass( self, gtfs_fixture ): """General tests for clean_consecutive_stop_fast_travel_warnings().""" - gtfs_fixture.is_valid(far_stops=True) + gtfs_fixture.is_valid() original_validation = { "type": { 0: "warning", @@ -180,18 +189,15 @@ def test_clean_consecutive_stop_fast_travel_warnings_on_pass( assert ( original_validation == gtfs_fixture.validity_df.to_dict() ), "Original validity df is not as expected" - clean_consecutive_stop_fast_travel_warnings( - gtfs=gtfs_fixture, validate=False - ) + clean_consecutive_stop_fast_travel_warnings(gtfs=gtfs_fixture) gtfs_fixture.is_valid() assert expected_validation == gtfs_fixture.validity_df.to_dict(), ( "Validation table is not as expected after cleaning consecutive " "stop fast travel warnings" ) # test validation; test gtfs with no warnings - clean_consecutive_stop_fast_travel_warnings( - gtfs=gtfs_fixture, validate=True - ) + gtfs_fixture.is_valid() + clean_consecutive_stop_fast_travel_warnings(gtfs=gtfs_fixture) class Test_CleanMultipleStopFastTravelWarnings(object): @@ -210,15 +216,13 @@ def test_clean_multiple_stop_fast_travel_warnings_defence( "gtfs." ), ): - clean_multiple_stop_fast_travel_warnings( - gtfs=gtfs_fixture, validate=False - ) + clean_multiple_stop_fast_travel_warnings(gtfs=gtfs_fixture) def test_clean_multiple_stop_fast_travel_warnings_on_pass( self, gtfs_fixture ): """General tests for clean_multiple_stop_fast_travel_warnings().""" - gtfs_fixture.is_valid(far_stops=True) + gtfs_fixture.is_valid() original_validation = { "type": { 0: "warning", @@ -263,15 +267,211 @@ def test_clean_multiple_stop_fast_travel_warnings_on_pass( assert ( original_validation == gtfs_fixture.validity_df.to_dict() ), "Original validity df is not as expected" - clean_multiple_stop_fast_travel_warnings( - gtfs=gtfs_fixture, validate=False - ) + clean_multiple_stop_fast_travel_warnings(gtfs=gtfs_fixture) gtfs_fixture.is_valid() assert expected_validation == gtfs_fixture.validity_df.to_dict(), ( "Validation table is not as expected after cleaning consecutive " "stop fast travel warnings" ) # test validation; test gtfs with no warnings - clean_multiple_stop_fast_travel_warnings( - gtfs=gtfs_fixture, validate=True - ) + gtfs_fixture.is_valid() + clean_multiple_stop_fast_travel_warnings(gtfs=gtfs_fixture) + + +class TestCoreCleaner(object): + """Tests for core_cleaners(). + + Notes + ----- + There are no passing tests for this function as it relies on function from + gtfs-kit which have already been tested. + + """ + + @pytest.mark.parametrize( + ( + "clean_ids, clean_times, clean_route_short_names, drop_zombies, " + "raises, match" + ), + [ + ( + 1, + True, + True, + True, + TypeError, + r".*expected .*bool.* Got .*int.*", + ), + ( + True, + dict(), + True, + True, + TypeError, + r".*expected .*bool.* Got .*dict.*", + ), + ( + True, + True, + "test string", + True, + TypeError, + r".*expected .*bool.* Got .*str.*", + ), + ( + True, + True, + True, + 2.12, + TypeError, + r".*expected .*bool.* Got .*float.*", + ), + ], + ) + def test_core_claners_defence( + self, + gtfs_fixture, + clean_ids, + clean_times, + clean_route_short_names, + drop_zombies, + raises, + match, + ): + """Defensive tests for core_cleaners.""" + with pytest.raises(raises, match=match): + gtfs_fixture.is_valid() + core_cleaners( + gtfs_fixture, + clean_ids, + clean_times, + clean_route_short_names, + drop_zombies, + ) + + def test_core_cleaners_drop_zombies_warns(self, gtfs_fixture): + """Test that warnings are emitted when shape_id isn't present in... + + trips. + """ + gtfs_fixture.feed.trips.drop("shape_id", axis=1, inplace=True) + with pytest.warns( + UserWarning, + match=r".*drop_zombies cleaner was unable to operate.*", + ): + gtfs_fixture.is_valid(validators={"core_validation": None}) + gtfs_fixture.clean_feed() + + +class TestCleanUnrecognisedColumnWarnings(object): + """Tests for clean_unrecognised_column_warnings.""" + + def test_clean_unrecognised_column_warnings(self, gtfs_fixture): + """Tests for clean_unrecognised_column_warnings.""" + # initial assertions to ensure test data is correct + gtfs_fixture.is_valid(validators={"core_validation": None}) + assert len(gtfs_fixture.validity_df) == 3, "validity_df wrong length" + assert np.array_equal( + gtfs_fixture.feed.trips.columns, + [ + "route_id", + "service_id", + "trip_id", + "trip_headsign", + "block_id", + "shape_id", + "wheelchair_accessible", + "vehicle_journey_code", + ], + ), "Initial trips columns not as expected" + # clean warnings + clean_unrecognised_column_warnings(gtfs_fixture) + assert len(gtfs_fixture.validity_df) == 0, "Warnings no cleaned" + assert np.array_equal( + gtfs_fixture.feed.trips.columns, + [ + "route_id", + "service_id", + "trip_id", + "trip_headsign", + "block_id", + "shape_id", + "wheelchair_accessible", + ], + ), "Failed to drop unrecognised columns" + + +class TestCleanDuplicateStopTimes(object): + """Tests for clean_duplicate_stop_times.""" + + def test_clean_duplicate_stop_times_defence(self): + """Defensive functionality tests for clean_duplicate_stop_times.""" + with pytest.raises( + TypeError, match=".*gtfs.* GtfsInstance object.*bool" + ): + clean_duplicate_stop_times(True) + + def test_clean_duplicate_stop_times_on_pass(self): + """General functionality tests for clean_duplicate_stop_times.""" + gtfs = GtfsInstance("tests/data/gtfs/repeated_pair_gtfs_fixture.zip") + gtfs.is_valid(validators={"core_validation": None}) + _warnings = _get_validation_warnings( + gtfs, r"Repeated pair \(trip_id, departure_time\)" + )[0] + # check that the correct number of rows are impacted + assert ( + len(_warnings[3]) == 24 + ), "Unexpected number of rows originally impacted" + # test case for one of the instances that are removed (1) + assert 309 in _warnings[3], "Test case (1) missing from data" + assert np.array_equal( + list(gtfs.feed.stop_times.loc[309].values), + [ + "VJf98a18e2c314219b4a98f75c5fa44e3f9618e72f", + "08:57:56", + "08:57:56", + "5540AWB33241", + 3, + np.nan, + 0, + 0, + 1.38726, + 0, + ], + ), "data for test case (1) invalid" + # test case for one of the instances that remain (2) + assert 18 in _warnings[3], "Test case (2) missing from data" + assert np.array_equal( + list(gtfs.feed.stop_times.loc[18].values), + [ + "VJ1b6f3baad353b0ba8d42e0290bfd0b39cb1130bc", + "11:10:23", + "11:10:23", + "5540AWA17133", + 6, + np.nan, + 0, + 0, + np.nan, + 0, + ], + ), "data for test case (2) invalid" + # clean gtfs + clean_duplicate_stop_times(gtfs) + _warnings = _get_validation_warnings( + gtfs, r"Repeated pair \(trip_id, departure_time\)" + )[0] + assert ( + len(_warnings[3]) == 6 + ), "Unexpected number of rows impacted after cleaning" + # assert test case (1) is removed from the data + with pytest.raises(KeyError, match=".*309.*"): + gtfs.feed.stop_times.loc[309] + # assert test case (2) remains part of the data + assert ( + len(gtfs.feed.stop_times.loc[18]) > 0 + ), "Test case removed from data" + + # test return of none when no warnings are found (for cov) + _remove_validation_row(gtfs, message=r".* \(trip_id, departure_time\)") + assert clean_duplicate_stop_times(gtfs) is None diff --git a/tests/gtfs/test_gtfs_utils.py b/tests/gtfs/test_gtfs_utils.py index e4aaffa1..dc1a1385 100644 --- a/tests/gtfs/test_gtfs_utils.py +++ b/tests/gtfs/test_gtfs_utils.py @@ -8,8 +8,12 @@ import geopandas as gpd from shapely.geometry import box from plotly.graph_objects import Figure as PlotlyFigure +import numpy as np -from transport_performance.gtfs.validation import GtfsInstance +from transport_performance.gtfs.validation import ( + GtfsInstance, + VALIDATE_FEED_FUNC_MAP, +) from transport_performance.gtfs.gtfs_utils import ( bbox_filter_gtfs, filter_gtfs, @@ -17,6 +21,9 @@ filter_gtfs_around_trip, convert_pandas_to_plotly, _validate_datestring, + _remove_validation_row, + _get_validation_warnings, + _function_pipeline, ) # location of GTFS test fixture @@ -276,7 +283,7 @@ class Test_AddValidationRow(object): """Tests for _add_validation_row().""" def test__add_validation_row_defence(self): - """Defensive tests for _add_test_validation_row().""" + """Defensive tests for _add_validation_row().""" gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH) with pytest.raises( AttributeError, @@ -291,9 +298,9 @@ def test__add_validation_row_defence(self): ) def test__add_validation_row_on_pass(self): - """General tests for _add_test_validation_row().""" + """General tests for _add_validation_row().""" gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH) - gtfs.is_valid(far_stops=False) + gtfs.is_valid(validators={"core_validation": None}) _add_validation_row( gtfs=gtfs, _type="warning", message="test", table="stops" @@ -335,12 +342,12 @@ def test_filter_gtfs_around_trip_on_pass(self, tmpdir): trip_id="VJbedb4cfd0673348e017d42435abbdff3ddacbf82", out_pth=out_pth, ) - assert os.path.exists(out_pth), "Failed to filtere GTFS around trip." + assert os.path.exists(out_pth), "Failed to filter GTFS around trip." # check the new gtfs can be read feed = GtfsInstance(gtfs_pth=out_pth) assert isinstance( feed, GtfsInstance - ), f"Expected class `Gtfs_Instance but found: {type(feed)}`" + ), f"Expected class `GtfsInstance` but found: {type(feed)}`" @pytest.fixture(scope="function") @@ -389,6 +396,171 @@ def test_convert_pandas_to_plotly_on_pass(self, test_df): ) +class TestGetValidationWarnings(object): + """Tests for _get_validation_warnings.""" + + def test__get_validation_warnings_defence(self): + """Test thhe defences of _get_validation_warnings.""" + with pytest.raises( + TypeError, match=".* expected a GtfsInstance object" + ): + _get_validation_warnings(True, "test_msg") + gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH) + with pytest.raises( + AttributeError, match="The gtfs has not been validated.*" + ): + _get_validation_warnings(gtfs, "test") + gtfs.is_valid() + with pytest.raises( + ValueError, match=r"'return_type' expected one of \[.*\]\. Got .*" + ): + _get_validation_warnings(gtfs, "tester", "tester") + + def test__get_validation_warnings(self): + """Test _get_validation_warnings on pass.""" + gtfs = GtfsInstance(GTFS_FIX_PTH) + gtfs.is_valid() + # test return types + df_exp = _get_validation_warnings( + gtfs, "test", return_type="dataframe" + ) + assert isinstance( + df_exp, pd.DataFrame + ), f"Expected df, got {type(df_exp)}" + ndarray_exp = _get_validation_warnings(gtfs, "test") + assert isinstance( + ndarray_exp, np.ndarray + ), f"Expected np.ndarray, got {type(ndarray_exp)}" + # test with valld regex (assertions on DF data without DF) + regex_matches = _get_validation_warnings( + gtfs, "Unrecognized column *.", return_type="dataframe" + ) + assert len(regex_matches) == 5, ( + "Getting validaiton warnings returned" + "unexpected number of warnings" + ) + assert list(regex_matches["type"].unique()) == [ + "warning" + ], "Dataframe type column not asd expected" + assert list(regex_matches.table) == [ + "agency", + "stop_times", + "stops", + "trips", + "trips", + ], "Dataframe table column not as expected" + # test with matching message (no regex) + exact_match = _get_validation_warnings( + gtfs, "Unrecognized column agency_noc", return_type="Dataframe" + ) + assert list(exact_match.values[0]) == [ + "warning", + "Unrecognized column agency_noc", + "agency", + [], + ], "Dataframe values not as expected" + assert ( + len(exact_match) == 1 + ), f"Expected one match, found {len(exact_match)}" + # test with no matches (regex) + regex_no_match = _get_validation_warnings( + gtfs, ".*This is a test.*", return_type="Dataframe" + ) + assert len(regex_no_match) == 0, "No matches expected. Matches found" + # test with no match (no regex) + no_match = _get_validation_warnings( + gtfs, "This is a test!!!", return_type="Dataframe" + ) + assert len(no_match) == 0, "No matches expected. Matched found" + + +class TestRemoveValidationRow(object): + """Tests for _remove_validation_row.""" + + def test__remove_validation_row_defence(self): + """Tests the defences of _remove_validation_row.""" + gtfs = GtfsInstance(GTFS_FIX_PTH) + gtfs.is_valid() + # no message or index provided + with pytest.raises( + ValueError, match=r"Both .* and .* are None, .* to be cleaned" + ): + _remove_validation_row(gtfs) + # both provided + with pytest.warns( + UserWarning, + match=r"Both .* and .* are not None.* cleaned on 'message'", + ): + _remove_validation_row(gtfs, message="test", index=[0, 1]) + + def test__remove_validation_row_on_pass(self): + """Tests for _remove_validation_row on pass.""" + gtfs = GtfsInstance(GTFS_FIX_PTH) + gtfs.is_valid(validators={"core_validation": None}) + # with message + msg = "Unrecognized column agency_noc" + _remove_validation_row(gtfs, message=msg) + assert len(gtfs.validity_df) == 6, "DF is incorrect size" + found_cols = _get_validation_warnings( + gtfs, message=msg, return_type="dataframe" + ) + assert ( + len(found_cols) == 0 + ), "Invalid errors/warnings still in validity_df" + # with index (removing the same error) + gtfs = GtfsInstance(GTFS_FIX_PTH) + gtfs.is_valid(validators={"core_validation": None}) + ind = [1] + _remove_validation_row(gtfs, index=ind) + assert len(gtfs.validity_df) == 6, "DF is incorrect size" + found_cols = _get_validation_warnings( + gtfs, message=msg, return_type="dataframe" + ) + print(gtfs.validity_df) + assert ( + len(found_cols) == 0 + ), "Invalid errors/warnings still in validity_df" + + +class TestFunctionPipeline(object): + """Tests for _function_pipeline. + + Notes + ----- + Not testing on pass here as better cases can be found in the tests for + GtfsInstance's is_valid() and clean_feed() methods. + + """ + + @pytest.mark.parametrize( + "operations, raises, match", + [ + # invalid type for 'validators' + (True, TypeError, ".*expected .*dict.*. Got .*bool.*"), + # invalid validator + ( + {"not_a_valid_validator": None}, + KeyError, + ( + r"'not_a_valid_validator' function passed to 'operations'" + r" is not a known operation.*" + ), + ), + # invalid type for kwargs for validator + ( + {"core_validation": pd.DataFrame()}, + TypeError, + ".* expected .*dict.*NoneType.*", + ), + ], + ) + def test_function_pipeline_defence(self, operations, raises, match): + """Defensive test for _function_pipeline.""" + gtfs = GtfsInstance(GTFS_FIX_PTH) + with pytest.raises(raises, match=match): + _function_pipeline(gtfs, VALIDATE_FEED_FUNC_MAP, operations) + + class Test_ValidateDatestring(object): """Tests for _validate_datestring.""" diff --git a/tests/gtfs/test_multi_validation.py b/tests/gtfs/test_multi_validation.py index cabaa175..f2dd79bf 100644 --- a/tests/gtfs/test_multi_validation.py +++ b/tests/gtfs/test_multi_validation.py @@ -271,15 +271,15 @@ def test_clean_feeds_on_pass(self, multi_gtfs_fixture): """General tests for .clean_feeds().""" # validate and do quick check on validity_df valid_df = multi_gtfs_fixture.is_valid() - assert len(valid_df) == 12, "validity_df not as expected" + assert len(valid_df) == 11, "validity_df not as expected" # clean feed multi_gtfs_fixture.clean_feeds() # ensure cleaning has occured new_valid = multi_gtfs_fixture.is_valid() - assert len(new_valid) == 9 + assert len(new_valid) == 1 assert np.array_equal( - list(new_valid.iloc[3][["type", "table"]].values), - ["error", "routes"], + list(new_valid.iloc[0][["type", "table"]].values), + ["warning", "routes"], ), "Validity df after cleaning not as expected" def test_is_valid_defences(self, multi_gtfs_fixture): @@ -290,7 +290,7 @@ def test_is_valid_defences(self, multi_gtfs_fixture): def test_is_valid_on_pass(self, multi_gtfs_fixture): """General tests for is_valid().""" valid_df = multi_gtfs_fixture.is_valid() - assert len(valid_df) == 12, "Validation df not as expected" + assert len(valid_df) == 11, "Validation df not as expected" assert np.array_equal( list(valid_df.iloc[3][["type", "message"]].values), (["warning", "Fast Travel Between Consecutive Stops"]), diff --git a/tests/gtfs/test_validation.py b/tests/gtfs/test_validation.py index 66d832a5..9997c141 100644 --- a/tests/gtfs/test_validation.py +++ b/tests/gtfs/test_validation.py @@ -27,12 +27,19 @@ @pytest.fixture(scope="function") # some funcs expect cleaned feed others dont -def gtfs_fixture(): +def newp_gtfs_fixture(): """Fixture for test funcs expecting a valid feed object.""" gtfs = GtfsInstance(gtfs_pth=GTFS_FIX_PTH) return gtfs +@pytest.fixture(scope="function") +def chest_gtfs_fixture(): + """Fixture for test funcs expecting a valid feed object.""" + gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip")) + return gtfs + + class TestGtfsInstance(object): """Tests related to the GtfsInstance class.""" @@ -119,7 +126,7 @@ def test_init_on_pass(self): without_pth.to_dict() == with_pth.to_dict() ), "Failed to get route type lookup correctly" - def test_get_gtfs_files(self, gtfs_fixture): + def test_get_gtfs_files(self, newp_gtfs_fixture): """Assert files that make up the GTFS.""" expected_files = [ # smaller filter has resulted in a GTFS with no calendar dates / @@ -135,40 +142,77 @@ def test_get_gtfs_files(self, gtfs_fixture): "calendar.txt", "routes.txt", ] - foundf = gtfs_fixture.get_gtfs_files() + foundf = newp_gtfs_fixture.get_gtfs_files() assert ( foundf == expected_files ), f"GTFS files not as expected. Expected {expected_files}," "found: {foundf}" - def test_is_valid(self, gtfs_fixture): - """Assertions about validity_df table.""" - gtfs_fixture.is_valid() - assert isinstance( - gtfs_fixture.validity_df, pd.core.frame.DataFrame - ), f"Expected DataFrame. Found: {type(gtfs_fixture.validity_df)}" - shp = gtfs_fixture.validity_df.shape - assert shp == ( - 7, - 4, - ), f"Attribute `validity_df` expected a shape of (7,4). Found: {shp}" - exp_cols = pd.Index(["type", "message", "table", "rows"]) - found_cols = gtfs_fixture.validity_df.columns - assert ( - found_cols == exp_cols - ).all(), f"Expected columns {exp_cols}. Found: {found_cols}" + @pytest.mark.parametrize( + "which, validators, shape", + [ + # only core validation + ("n", {"core_validation": None}, (7, 4)), + # route type warning validation + ( + "n", + { + "core_validation": None, + "validate_route_type_warnings": None, + }, + (6, 4), + ), + # fast travel validators + ( + "c", + { + "core_validation": None, + "validate_travel_between_consecutive_stops": None, + "validate_travel_over_multiple_stops": None, + }, + (5, 4), + ), + # all validators + ("n", None, (6, 4)), + ], + ) + def test_is_valid_on_pass( + self, newp_gtfs_fixture, chest_gtfs_fixture, which, validators, shape + ): + """Tests/assertions for is_valid() while passing. + + Notes + ----- + These tests are mostly to assure that the validators are being + identified and run, and that the validation df returned is as expected. + + I will be refraining from over testing here as it would essentially be + testing the validators again, which occurs in test_validators.py. + + Tests for validators with kwargs would be useful, once they are added. + + """ + # Bypassing any defensive checks for wich. + # Correct inputs are assured as tese are internal tests. + if which.lower().strip() == "n": + fixture = newp_gtfs_fixture + else: + fixture = chest_gtfs_fixture + df = fixture.is_valid(validators=validators) + assert isinstance(df, pd.DataFrame), "is_valid() failed to return df" + assert shape == df.shape, "validity_df not as expected" @pytest.mark.sanitycheck - def test_trips_unmatched_ids(self, gtfs_fixture): + def test_trips_unmatched_ids(self, newp_gtfs_fixture): """Tests to evaluate gtfs-klt's reaction to invalid IDs in trips. Parameters ---------- - gtfs_fixture : GtfsInstance + newp_gtfs_fixture : GtfsInstance a GtfsInstance test fixure """ - feed = gtfs_fixture.feed + feed = newp_gtfs_fixture.feed # add row to tripas table with invald trip_id, route_id, service_id feed.trips = pd.concat( @@ -209,16 +253,16 @@ def test_trips_unmatched_ids(self, gtfs_fixture): assert len(new_valid) == 10, "Validation table not expected size" @pytest.mark.sanitycheck - def test_routes_unmatched_ids(self, gtfs_fixture): + def test_routes_unmatched_ids(self, newp_gtfs_fixture): """Tests to evaluate gtfs-klt's reaction to invalid IDs in routes. Parameters ---------- - gtfs_fixture : GtfsInstance + newp_gtfs_fixture : GtfsInstance a GtfsInstance test fixure """ - feed = gtfs_fixture.feed + feed = newp_gtfs_fixture.feed # add row to tripas table with invald trip_id, route_id, service_id feed.routes = pd.concat( @@ -248,12 +292,12 @@ def test_routes_unmatched_ids(self, gtfs_fixture): assert len(new_valid) == 9, "Validation table not expected size" @pytest.mark.sanitycheck - def test_unmatched_service_id_behaviour(self, gtfs_fixture): + def test_unmatched_service_id_behaviour(self, newp_gtfs_fixture): """Tests to evaluate gtfs-klt's reaction to invalid IDs in calendar. Parameters ---------- - gtfs_fixture : GtfsInstance + newp_gtfs_fixture : GtfsInstance a GtfsInstance test fixure Notes @@ -264,7 +308,7 @@ def test_unmatched_service_id_behaviour(self, gtfs_fixture): calendar table contains duplicate service_ids. """ - feed = gtfs_fixture.feed + feed = newp_gtfs_fixture.feed original_error_count = len(feed.validate()) # introduce a dummy row with a non matching service_id @@ -300,25 +344,25 @@ def test_unmatched_service_id_behaviour(self, gtfs_fixture): len(new_valid[new_valid.message == "Undefined service_id"]) == 1 ), "gtfs-kit failed to identify missing service_id" - def test_print_alerts_defence(self, gtfs_fixture): + def test_print_alerts_defence(self, newp_gtfs_fixture): """Check defensive behaviour of print_alerts().""" with pytest.raises( AttributeError, match=r"is None, did you forget to use `self.is_valid()`?", ): - gtfs_fixture.print_alerts() + newp_gtfs_fixture.print_alerts() - gtfs_fixture.is_valid() + newp_gtfs_fixture.is_valid() with pytest.warns( UserWarning, match="No alerts of type doesnt_exist were found." ): - gtfs_fixture.print_alerts(alert_type="doesnt_exist") + newp_gtfs_fixture.print_alerts(alert_type="doesnt_exist") @patch("builtins.print") # testing print statements - def test_print_alerts_single_case(self, mocked_print, gtfs_fixture): + def test_print_alerts_single_case(self, mocked_print, newp_gtfs_fixture): """Check alerts print as expected without truncation.""" - gtfs_fixture.is_valid() - gtfs_fixture.print_alerts() + newp_gtfs_fixture.is_valid(validators={"core_validation": None}) + newp_gtfs_fixture.print_alerts() # fixture contains single error fun_out = mocked_print.mock_calls assert fun_out == [ @@ -326,11 +370,11 @@ def test_print_alerts_single_case(self, mocked_print, gtfs_fixture): ], f"Expected a print about invalid route type. Found {fun_out}" @patch("builtins.print") - def test_print_alerts_multi_case(self, mocked_print, gtfs_fixture): + def test_print_alerts_multi_case(self, mocked_print, newp_gtfs_fixture): """Check multiple alerts are printed as expected.""" - gtfs_fixture.is_valid() + newp_gtfs_fixture.is_valid() # fixture contains several warnings - gtfs_fixture.print_alerts(alert_type="warning") + newp_gtfs_fixture.print_alerts(alert_type="warning") fun_out = mocked_print.mock_calls assert fun_out == [ call("Unrecognized column agency_noc"), @@ -341,7 +385,7 @@ def test_print_alerts_multi_case(self, mocked_print, gtfs_fixture): call("Unrecognized column vehicle_journey_code"), ], f"Expected print statements about GTFS warnings. Found: {fun_out}" - def test_viz_stops_defence(self, tmpdir, gtfs_fixture): + def test_viz_stops_defence(self, tmpdir, newp_gtfs_fixture): """Check defensive behaviours of viz_stops().""" tmp = os.path.join(tmpdir, "somefile.html") with pytest.raises( @@ -351,12 +395,12 @@ def test_viz_stops_defence(self, tmpdir, gtfs_fixture): "Got " ), ): - gtfs_fixture.viz_stops(out_pth=True) + newp_gtfs_fixture.viz_stops(out_pth=True) with pytest.raises( TypeError, match="`geoms` expected . Got ", ): - gtfs_fixture.viz_stops(out_pth=tmp, geoms=38) + newp_gtfs_fixture.viz_stops(out_pth=tmp, geoms=38) with pytest.raises( ValueError, match=re.escape( @@ -364,7 +408,7 @@ def test_viz_stops_defence(self, tmpdir, gtfs_fixture): "['point', 'hull']. Got foobar: " ), ): - gtfs_fixture.viz_stops(out_pth=tmp, geoms="foobar") + newp_gtfs_fixture.viz_stops(out_pth=tmp, geoms="foobar") with pytest.raises( TypeError, match=re.escape( @@ -372,28 +416,28 @@ def test_viz_stops_defence(self, tmpdir, gtfs_fixture): "" ), ): - gtfs_fixture.viz_stops(out_pth=tmp, geom_crs=1.1) + newp_gtfs_fixture.viz_stops(out_pth=tmp, geom_crs=1.1) # check missing stop_id results in an informative error message - gtfs_fixture.feed.stops.drop("stop_id", axis=1, inplace=True) + newp_gtfs_fixture.feed.stops.drop("stop_id", axis=1, inplace=True) with pytest.raises( KeyError, match="The stops table has no 'stop_code' column. While " "this is an optional field in a GTFS file, it " "raises an error through the gtfs-kit package.", ): - gtfs_fixture.viz_stops(out_pth=tmp, filtered_only=False) + newp_gtfs_fixture.viz_stops(out_pth=tmp, filtered_only=False) @patch("builtins.print") - def test_viz_stops_point(self, mock_print, tmpdir, gtfs_fixture): + def test_viz_stops_point(self, mock_print, tmpdir, newp_gtfs_fixture): """Check behaviour of viz_stops when plotting point geom.""" tmp = os.path.join(tmpdir, "points.html") - gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp)) + newp_gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp)) assert os.path.exists( tmp ), f"{tmp} was expected to exist but it was not found." # check behaviour when parent directory doesn't exist no_parent_pth = os.path.join(tmpdir, "notfound", "points1.html") - gtfs_fixture.viz_stops( + newp_gtfs_fixture.viz_stops( out_pth=pathlib.Path(no_parent_pth), create_out_parent=True ) assert os.path.exists( @@ -408,7 +452,7 @@ def test_viz_stops_point(self, mock_print, tmpdir, gtfs_fixture): "to 'out_pth'. Path defaulted to .html" ), ): - gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp1)) + newp_gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp1)) # need to use regex for the first print statement, as tmpdir will # change. start_pat = re.compile(r"Creating parent directory:.*") @@ -421,20 +465,22 @@ def test_viz_stops_point(self, mock_print, tmpdir, gtfs_fixture): write_pth ), f"Map should have been written to {write_pth} but was not found." - def test_viz_stops_hull(self, tmpdir, gtfs_fixture): + def test_viz_stops_hull(self, tmpdir, newp_gtfs_fixture): """Check viz_stops behaviour when plotting hull geom.""" tmp = os.path.join(tmpdir, "hull.html") - gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp), geoms="hull") + newp_gtfs_fixture.viz_stops(out_pth=pathlib.Path(tmp), geoms="hull") assert os.path.exists(tmp), f"Map file not found at {tmp}." # assert file created when not filtering the hull tmp1 = os.path.join(tmpdir, "filtered_hull.html") - gtfs_fixture.viz_stops(out_pth=tmp1, geoms="hull", filtered_only=False) + newp_gtfs_fixture.viz_stops( + out_pth=tmp1, geoms="hull", filtered_only=False + ) assert os.path.exists(tmp1), f"Map file not found at {tmp1}." - def test__create_map_title_text_defence(self, gtfs_fixture): + def test__create_map_title_text_defence(self, newp_gtfs_fixture): """Test the defences for _create_map_title_text().""" # CRS without m or km units - gtfs_hull = gtfs_fixture.feed.compute_convex_hull() + gtfs_hull = newp_gtfs_fixture.feed.compute_convex_hull() gdf = GeoDataFrame({"geometry": gtfs_hull}, index=[0], crs="epsg:4326") with pytest.raises(ValueError), pytest.warns(UserWarning): _create_map_title_text(gdf=gdf, units="m", geom_crs=4326) @@ -507,7 +553,7 @@ def test__convert_multi_index_to_single(self): expected_cols.remove(col) assert len(expected_cols) == 0, "Not all expected cols in output cols" - def test__order_dataframe_by_day_defence(self, gtfs_fixture): + def test__order_dataframe_by_day_defence(self, newp_gtfs_fixture): """Test __order_dataframe_by_day defences.""" with pytest.raises( TypeError, @@ -516,7 +562,7 @@ def test__order_dataframe_by_day_defence(self, gtfs_fixture): "Got " ), ): - (gtfs_fixture._order_dataframe_by_day(df="test")) + (newp_gtfs_fixture._order_dataframe_by_day(df="test")) with pytest.raises( TypeError, match=re.escape( @@ -525,12 +571,12 @@ def test__order_dataframe_by_day_defence(self, gtfs_fixture): ), ): ( - gtfs_fixture._order_dataframe_by_day( + newp_gtfs_fixture._order_dataframe_by_day( df=pd.DataFrame(), day_column_name=5 ) ) - def test_get_route_modes(self, gtfs_fixture, mocker): + def test_get_route_modes(self, newp_gtfs_fixture, mocker): """Assertions about the table returned by get_route_modes().""" patch_scrape_lookup = mocker.patch( "transport_performance.gtfs.validation.scrape_route_type_lookup", @@ -539,25 +585,28 @@ def test_get_route_modes(self, gtfs_fixture, mocker): {"route_type": ["3"], "desc": ["Mocked bus"]} ), ) - gtfs_fixture.get_route_modes() + newp_gtfs_fixture.get_route_modes() # check mocker was called assert ( patch_scrape_lookup.called ), "mocker.patch `patch_scrape_lookup` was not called." - found = gtfs_fixture.route_mode_summary_df["desc"][0] + found = newp_gtfs_fixture.route_mode_summary_df["desc"][0] assert found == "Mocked bus", f"Expected 'Mocked bus', found: {found}" assert isinstance( - gtfs_fixture.route_mode_summary_df, pd.core.frame.DataFrame - ), f"Expected pd df. Found: {type(gtfs_fixture.route_mode_summary_df)}" + newp_gtfs_fixture.route_mode_summary_df, pd.core.frame.DataFrame + ), ( + "Expected pd df. " + f"Found: {type(newp_gtfs_fixture.route_mode_summary_df)}" + ) exp_cols = pd.Index(["route_type", "desc", "n_routes", "prop_routes"]) - found_cols = gtfs_fixture.route_mode_summary_df.columns + found_cols = newp_gtfs_fixture.route_mode_summary_df.columns assert ( found_cols == exp_cols ).all(), f"Expected columns are different. Found: {found_cols}" - def test__preprocess_trips_and_routes(self, gtfs_fixture): + def test__preprocess_trips_and_routes(self, newp_gtfs_fixture): """Check the outputs of _pre_process_trips_and_route() (test data).""" - returned_df = gtfs_fixture._preprocess_trips_and_routes() + returned_df = newp_gtfs_fixture._preprocess_trips_and_routes() assert isinstance(returned_df, pd.core.frame.DataFrame), ( "Expected DF for _preprocess_trips_and_routes() return," f"found {type(returned_df)}" @@ -591,13 +640,13 @@ def test__preprocess_trips_and_routes(self, gtfs_fixture): f"Found {returned_df.shape}", ) - def test_summarise_trips_defence(self, gtfs_fixture): + def test_summarise_trips_defence(self, newp_gtfs_fixture): """Defensive checks for summarise_trips().""" with pytest.raises( TypeError, match="Each item in `summ_ops`.*. Found : np.mean", ): - gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"]) + newp_gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"]) # case where is function but not exported from numpy def dummy_func(): @@ -611,18 +660,18 @@ def dummy_func(): " : dummy_func" ), ): - gtfs_fixture.summarise_trips(summ_ops=[np.min, dummy_func]) + newp_gtfs_fixture.summarise_trips(summ_ops=[np.min, dummy_func]) # case where a single non-numpy func is being passed with pytest.raises( NotImplementedError, match="`summ_ops` expects numpy functions only.", ): - gtfs_fixture.summarise_trips(summ_ops=dummy_func) + newp_gtfs_fixture.summarise_trips(summ_ops=dummy_func) with pytest.raises( TypeError, match="`summ_ops` expects a numpy function.*. Found ", ): - gtfs_fixture.summarise_trips(summ_ops=38) + newp_gtfs_fixture.summarise_trips(summ_ops=38) # cases where return_summary are not of type boolean with pytest.raises( TypeError, @@ -630,7 +679,7 @@ def dummy_func(): "`return_summary` expected . Got " ), ): - gtfs_fixture.summarise_trips(return_summary=5) + newp_gtfs_fixture.summarise_trips(return_summary=5) with pytest.raises( TypeError, match=re.escape( @@ -638,15 +687,15 @@ def dummy_func(): "'str'>" ), ): - gtfs_fixture.summarise_trips(return_summary="true") + newp_gtfs_fixture.summarise_trips(return_summary="true") - def test_summarise_routes_defence(self, gtfs_fixture): + def test_summarise_routes_defence(self, newp_gtfs_fixture): """Defensive checks for summarise_routes().""" with pytest.raises( TypeError, match="Each item in `summ_ops`.*. Found : np.mean", ): - gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"]) + newp_gtfs_fixture.summarise_trips(summ_ops=[np.mean, "np.mean"]) # case where is function but not exported from numpy def dummy_func(): @@ -660,18 +709,18 @@ def dummy_func(): " : dummy_func" ), ): - gtfs_fixture.summarise_routes(summ_ops=[np.min, dummy_func]) + newp_gtfs_fixture.summarise_routes(summ_ops=[np.min, dummy_func]) # case where a single non-numpy func is being passed with pytest.raises( NotImplementedError, match="`summ_ops` expects numpy functions only.", ): - gtfs_fixture.summarise_routes(summ_ops=dummy_func) + newp_gtfs_fixture.summarise_routes(summ_ops=dummy_func) with pytest.raises( TypeError, match="`summ_ops` expects a numpy function.*. Found ", ): - gtfs_fixture.summarise_routes(summ_ops=38) + newp_gtfs_fixture.summarise_routes(summ_ops=38) # cases where return_summary are not of type boolean with pytest.raises( TypeError, @@ -679,38 +728,36 @@ def dummy_func(): "`return_summary` expected . Got " ), ): - gtfs_fixture.summarise_routes(return_summary=5) + newp_gtfs_fixture.summarise_routes(return_summary=5) with pytest.raises( TypeError, match=re.escape( "`return_summary` expected . Got " ), ): - gtfs_fixture.summarise_routes(return_summary="true") + newp_gtfs_fixture.summarise_routes(return_summary="true") - @patch("builtins.print") - def test_clean_feed_defence(self, mock_print, gtfs_fixture): + def test_clean_feed_defence(self, newp_gtfs_fixture): """Check defensive behaviours of clean_feed().""" - # Simulate condition where shapes.txt has no shape_id - gtfs_fixture.feed.shapes.drop("shape_id", axis=1, inplace=True) - gtfs_fixture.clean_feed() - fun_out = mock_print.mock_calls - assert fun_out == [ - call("KeyError. Feed was not cleaned.") - ], f"Expected print statement about KeyError. Found: {fun_out}." + with pytest.raises( + TypeError, match=r".*expected .*dict.* Got .*int.*" + ): + fixt = newp_gtfs_fixture + fixt.is_valid(validators={"core_validation": None}) + fixt.clean_feed(cleansers=1) - def test_summarise_trips_on_pass(self, gtfs_fixture): + def test_summarise_trips_on_pass(self, newp_gtfs_fixture): """Assertions about the outputs from summarise_trips().""" - gtfs_fixture.summarise_trips() + newp_gtfs_fixture.summarise_trips() # tests the daily_routes_summary return schema assert isinstance( - gtfs_fixture.daily_trip_summary, pd.core.frame.DataFrame + newp_gtfs_fixture.daily_trip_summary, pd.core.frame.DataFrame ), ( "Expected DF for daily_summary," - f"found {type(gtfs_fixture.daily_trip_summary)}" + f"found {type(newp_gtfs_fixture.daily_trip_summary)}" ) - found_ds = gtfs_fixture.daily_trip_summary.columns + found_ds = newp_gtfs_fixture.daily_trip_summary.columns exp_cols_ds = pd.Index( [ "day", @@ -729,13 +776,13 @@ def test_summarise_trips_on_pass(self, gtfs_fixture): # tests the self.dated_route_counts return schema assert isinstance( - gtfs_fixture.dated_trip_counts, pd.core.frame.DataFrame + newp_gtfs_fixture.dated_trip_counts, pd.core.frame.DataFrame ), ( "Expected DF for dated_route_counts," - f"found {type(gtfs_fixture.dated_trip_counts)}" + f"found {type(newp_gtfs_fixture.dated_trip_counts)}" ) - found_drc = gtfs_fixture.dated_trip_counts.columns + found_drc = newp_gtfs_fixture.dated_trip_counts.columns exp_cols_drc = pd.Index(["date", "route_type", "trip_count", "day"]) assert ( @@ -756,8 +803,8 @@ def test_summarise_trips_on_pass(self, gtfs_fixture): ) found_df = ( - gtfs_fixture.daily_trip_summary[ - gtfs_fixture.daily_trip_summary["day"] == "friday" + newp_gtfs_fixture.daily_trip_summary[ + newp_gtfs_fixture.daily_trip_summary["day"] == "friday" ] .sort_values(by="route_type", ascending=True) .reset_index(drop=True) @@ -773,24 +820,26 @@ def test_summarise_trips_on_pass(self, gtfs_fixture): # test that the dated_trip_counts can be returned expected_size = (504, 4) - found_size = gtfs_fixture.summarise_trips(return_summary=False).shape + found_size = newp_gtfs_fixture.summarise_trips( + return_summary=False + ).shape assert expected_size == found_size, ( "Size of date_route_counts not as expected. " "Expected {expected_size}" ) - def test_summarise_routes_on_pass(self, gtfs_fixture): + def test_summarise_routes_on_pass(self, newp_gtfs_fixture): """Assertions about the outputs from summarise_routes().""" - gtfs_fixture.summarise_routes() + newp_gtfs_fixture.summarise_routes() # tests the daily_routes_summary return schema assert isinstance( - gtfs_fixture.daily_route_summary, pd.core.frame.DataFrame + newp_gtfs_fixture.daily_route_summary, pd.core.frame.DataFrame ), ( "Expected DF for daily_summary," - f"found {type(gtfs_fixture.daily_route_summary)}" + f"found {type(newp_gtfs_fixture.daily_route_summary)}" ) - found_ds = gtfs_fixture.daily_route_summary.columns + found_ds = newp_gtfs_fixture.daily_route_summary.columns exp_cols_ds = pd.Index( [ "day", @@ -809,13 +858,13 @@ def test_summarise_routes_on_pass(self, gtfs_fixture): # tests the self.dated_route_counts return schema assert isinstance( - gtfs_fixture.dated_route_counts, pd.core.frame.DataFrame + newp_gtfs_fixture.dated_route_counts, pd.core.frame.DataFrame ), ( "Expected DF for dated_route_counts," - f"found {type(gtfs_fixture.dated_route_counts)}" + f"found {type(newp_gtfs_fixture.dated_route_counts)}" ) - found_drc = gtfs_fixture.dated_route_counts.columns + found_drc = newp_gtfs_fixture.dated_route_counts.columns exp_cols_drc = pd.Index(["date", "route_type", "day", "route_count"]) assert ( @@ -836,8 +885,8 @@ def test_summarise_routes_on_pass(self, gtfs_fixture): ) found_df = ( - gtfs_fixture.daily_route_summary[ - gtfs_fixture.daily_route_summary["day"] == "friday" + newp_gtfs_fixture.daily_route_summary[ + newp_gtfs_fixture.daily_route_summary["day"] == "friday" ] .sort_values(by="route_type", ascending=True) .reset_index(drop=True) @@ -853,13 +902,15 @@ def test_summarise_routes_on_pass(self, gtfs_fixture): # test that the dated_route_counts can be returned expected_size = (504, 4) - found_size = gtfs_fixture.summarise_routes(return_summary=False).shape + found_size = newp_gtfs_fixture.summarise_routes( + return_summary=False + ).shape assert expected_size == found_size, ( "Size of date_route_counts not as expected. " "Expected {expected_size}" ) - def test__plot_summary_defences(self, tmp_path, gtfs_fixture): + def test__plot_summary_defences(self, tmp_path, newp_gtfs_fixture): """Test defences for _plot_summary().""" # test defences for checks summaries exist with pytest.raises( @@ -869,7 +920,7 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture): " Did you forget to call '.summarise_trips()' first?" ), ): - gtfs_fixture._plot_summary(which="trip", target_column="mean") + newp_gtfs_fixture._plot_summary(which="trip", target_column="mean") with pytest.raises( AttributeError, @@ -878,9 +929,11 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture): " Did you forget to call '.summarise_routes()' first?" ), ): - gtfs_fixture._plot_summary(which="route", target_column="mean") + newp_gtfs_fixture._plot_summary( + which="route", target_column="mean" + ) - gtfs_fixture.summarise_routes() + newp_gtfs_fixture.summarise_routes() # test parameters that are yet to be tested options = ["v", "h"] @@ -891,7 +944,7 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture): f"{options}. Got i: " ), ): - gtfs_fixture._plot_summary( + newp_gtfs_fixture._plot_summary( which="route", target_column="route_count_mean", orientation="i", @@ -906,7 +959,7 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture): " given to 'img_type'. Path defaulted to .png" ), ): - gtfs_fixture._plot_summary( + newp_gtfs_fixture._plot_summary( which="route", target_column="route_count_mean", save_image=True, @@ -922,15 +975,17 @@ def test__plot_summary_defences(self, tmp_path, gtfs_fixture): "['trip', 'route']. Got tester: " ), ): - gtfs_fixture._plot_summary(which="tester", target_column="tester") + newp_gtfs_fixture._plot_summary( + which="tester", target_column="tester" + ) - def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path): + def test__plot_summary_on_pass(self, newp_gtfs_fixture, tmp_path): """Test plotting a summary when defences are passed.""" - current_fixture = gtfs_fixture + current_fixture = newp_gtfs_fixture current_fixture.summarise_routes() # test returning a html string - test_html = gtfs_fixture._plot_summary( + test_html = newp_gtfs_fixture._plot_summary( which="route", target_column="route_count_mean", return_html=True, @@ -938,7 +993,7 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path): assert type(test_html) is str, "Failed to return HTML for the plot" # test returning a plotly figure - test_image = gtfs_fixture._plot_summary( + test_image = newp_gtfs_fixture._plot_summary( which="route", target_column="route_count_mean" ) assert ( @@ -946,8 +1001,8 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path): ), "Failed to return plotly.graph_objects.Figure type" # test returning a plotly for trips - gtfs_fixture.summarise_trips() - test_image = gtfs_fixture._plot_summary( + newp_gtfs_fixture.summarise_trips() + test_image = newp_gtfs_fixture._plot_summary( which="trip", target_column="trip_count_mean" ) assert ( @@ -955,7 +1010,7 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path): ), "Failed to return plotly.graph_objects.Figure type" # test saving plots in html and png format - gtfs_fixture._plot_summary( + newp_gtfs_fixture._plot_summary( which="route", target_column="mean", width=1200, @@ -984,7 +1039,7 @@ def test__plot_summary_on_pass(self, gtfs_fixture, tmp_path): assert counts["html"] == 1, "Failed to save plot as HTML" assert counts["png"] == 1, "Failed to save plot as png" - def test__create_extended_repeated_pair_table(self, gtfs_fixture): + def test__create_extended_repeated_pair_table(self, newp_gtfs_fixture): """Test _create_extended_repeated_pair_table().""" test_table = pd.DataFrame( { @@ -1003,17 +1058,19 @@ def test__create_extended_repeated_pair_table(self, gtfs_fixture): } ).to_dict() - returned_table = gtfs_fixture._create_extended_repeated_pair_table( - table=test_table, - join_vars=["trip_name", "trip_abbrev"], - original_rows=[0], - ).to_dict() + returned_table = ( + newp_gtfs_fixture._create_extended_repeated_pair_table( + table=test_table, + join_vars=["trip_name", "trip_abbrev"], + original_rows=[0], + ).to_dict() + ) assert ( expected_table == returned_table ), "_create_extended_repeated_pair_table() failed" - def test_html_report_defences(self, gtfs_fixture, tmp_path): + def test_html_report_defences(self, newp_gtfs_fixture, tmp_path): """Test the defences whilst generating a HTML report.""" with pytest.raises( ValueError, @@ -1022,15 +1079,15 @@ def test_html_report_defences(self, gtfs_fixture, tmp_path): "['mean', 'min', 'max', 'median']. Got test_sum: " ), ): - gtfs_fixture.html_report( + newp_gtfs_fixture.html_report( report_dir=tmp_path, overwrite=True, summary_type="test_sum", ) - def test_html_report_on_pass(self, gtfs_fixture, tmp_path): + def test_html_report_on_pass(self, newp_gtfs_fixture, tmp_path): """Test that a HTML report is generated if defences are passed.""" - gtfs_fixture.html_report(report_dir=pathlib.Path(tmp_path)) + newp_gtfs_fixture.html_report(report_dir=pathlib.Path(tmp_path)) # assert that the report has been completely generated assert os.path.exists( @@ -1065,33 +1122,35 @@ def test_html_report_on_pass(self, gtfs_fixture, tmp_path): ("invalid_ext.txt", "invalid_ext.zip", True), ], ) - def test_save(self, tmp_path, gtfs_fixture, path, final_path, warns): + def test_save(self, tmp_path, newp_gtfs_fixture, path, final_path, warns): """Test the .save() methohd of GtfsInstance().""" complete_path = os.path.join(tmp_path, path) expected_path = os.path.join(tmp_path, final_path) if warns: # catch UserWarning from invalid file extension with pytest.warns(UserWarning): - gtfs_fixture.save(complete_path) + newp_gtfs_fixture.save(complete_path) else: with does_not_raise(): - gtfs_fixture.save(complete_path, overwrite=True) + newp_gtfs_fixture.save(complete_path, overwrite=True) assert os.path.exists(expected_path), "GTFS not saved correctly" - def test_save_overwrite(self, tmp_path, gtfs_fixture): + def test_save_overwrite(self, tmp_path, newp_gtfs_fixture): """Test the .save()'s method of GtfsInstance overwrite feature.""" # original save save_pth = f"{tmp_path}/test_save.zip" - gtfs_fixture.save(save_pth, overwrite=True) + newp_gtfs_fixture.save(save_pth, overwrite=True) assert os.path.exists(save_pth), "GTFS not saved at correct path" # test saving without overwrite enabled with pytest.raises( FileExistsError, match="File already exists at path.*" ): - gtfs_fixture.save(f"{tmp_path}/test_save.zip", overwrite=False) + newp_gtfs_fixture.save( + f"{tmp_path}/test_save.zip", overwrite=False + ) # test saving with overwrite enabled raises no errors with does_not_raise(): - gtfs_fixture.save(f"{tmp_path}/test_save.zip", overwrite=True) + newp_gtfs_fixture.save(f"{tmp_path}/test_save.zip", overwrite=True) assert os.path.exists(save_pth), "GTFS save not found" @pytest.mark.parametrize( @@ -1114,14 +1173,14 @@ def test_filter_to_date(self, date, expected_len): len(gtfs.feed.stop_times) == expected_len ), "GTFS not filtered to singular date as expected" - def test_filter_to_bbox(self, gtfs_fixture): + def test_filter_to_bbox(self, newp_gtfs_fixture): """Small tests for the shallow wrapper filter_to_bbox().""" assert ( - len(gtfs_fixture.feed.stop_times) == 7765 + len(newp_gtfs_fixture.feed.stop_times) == 7765 ), "feed.stop_times is an unexpected size" - gtfs_fixture.filter_to_bbox( + newp_gtfs_fixture.filter_to_bbox( [-2.985535, 51.551459, -2.919617, 51.606077] ) assert ( - len(gtfs_fixture.feed.stop_times) == 217 + len(newp_gtfs_fixture.feed.stop_times) == 217 ), "GTFS not filtered to bbox as expected" diff --git a/tests/gtfs/test_validators.py b/tests/gtfs/test_validators.py index 6d5cc627..9dd4903c 100644 --- a/tests/gtfs/test_validators.py +++ b/tests/gtfs/test_validators.py @@ -2,26 +2,42 @@ from pyprojroot import here import pytest import re +import shutil +import os +import zipfile +import pathlib + +import numpy as np from transport_performance.gtfs.validation import GtfsInstance from transport_performance.gtfs.validators import ( validate_travel_between_consecutive_stops, validate_travel_over_multiple_stops, + validate_route_type_warnings, + validate_gtfs_files, ) +from transport_performance.gtfs.gtfs_utils import _get_validation_warnings @pytest.fixture(scope="function") -def gtfs_fixture(): +def chest_gtfs_fixture(): """Fixture for test funcs expecting a valid feed object.""" gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip")) return gtfs +@pytest.fixture(scope="function") +def newp_gtfs_fixture(): + """Fixture for test funcs expecting a valid feed object.""" + gtfs = GtfsInstance(here("tests/data/gtfs/newport-20230613_gtfs.zip")) + return gtfs + + class Test_ValidateTravelBetweenConsecutiveStops(object): """Tests for the validate_travel_between_consecutive_stops function().""" def test_validate_travel_between_consecutive_stops_defences( - self, gtfs_fixture + self, chest_gtfs_fixture ): """Defensive tests for validating travel between consecutive stops.""" with pytest.raises( @@ -32,13 +48,15 @@ def test_validate_travel_between_consecutive_stops_defences( "Did you forget to run the .is_valid() method?" ), ): - validate_travel_between_consecutive_stops(gtfs_fixture) + validate_travel_between_consecutive_stops(chest_gtfs_fixture) pass - def test_validate_travel_between_consecutive_stops(self, gtfs_fixture): + def test_validate_travel_between_consecutive_stops( + self, chest_gtfs_fixture + ): """General tests for validating travel between consecutive stops.""" - gtfs_fixture.is_valid(far_stops=False) - validate_travel_between_consecutive_stops(gtfs=gtfs_fixture) + chest_gtfs_fixture.is_valid(validators={"core_validation": None}) + validate_travel_between_consecutive_stops(gtfs=chest_gtfs_fixture) expected_validation = { "type": {0: "warning", 1: "warning", 2: "warning", 3: "warning"}, @@ -62,20 +80,51 @@ def test_validate_travel_between_consecutive_stops(self, gtfs_fixture): }, } - found_dataframe = gtfs_fixture.validity_df + found_dataframe = chest_gtfs_fixture.validity_df assert expected_validation == found_dataframe.to_dict(), ( "'_validate_travel_between_consecutive_stops()' failed to raise " "warnings in the validity df" ) + def test__join_max_speed(self, newp_gtfs_fixture): + """Tests for the _join_max_speed function.""" + newp_gtfs_fixture.is_valid(validators={"core_validation": None}) + # assert route_type's beforehand + existing_types = newp_gtfs_fixture.feed.routes.route_type.unique() + assert np.array_equal( + existing_types, [3, 200] + ), "Existing route types not as expected." + # replace 3 with an invlid route_type + newp_gtfs_fixture.feed.routes.route_type = ( + newp_gtfs_fixture.feed.routes.route_type.apply( + lambda x: 12345 if x == 200 else x + ) + ) + new_types = newp_gtfs_fixture.feed.routes.route_type.unique() + assert np.array_equal( + new_types, [3, 12345] + ), "Route types of 200 not replaced correctly" + # validate and assert a speed bound of 150 is set for these cases + validate_travel_between_consecutive_stops(newp_gtfs_fixture) + cases = newp_gtfs_fixture.full_stop_schedule[ + newp_gtfs_fixture.full_stop_schedule.route_type == 12345 + ] + print(cases.speed_bound) + assert np.array_equal( + cases.route_type.unique(), [12345] + ), "Dataframe filter to cases did not work" + assert np.array_equal( + cases.speed_bound.unique(), [200] + ), "Obtaining max speed for unrecognised route_type failed" + class Test_ValidateTravelOverMultipleStops(object): """Tests for validate_travel_over_multiple_stops().""" - def test_validate_travel_over_multiple_stops(self, gtfs_fixture): + def test_validate_travel_over_multiple_stops(self, chest_gtfs_fixture): """General tests for validate_travel_over_multiple_stops().""" - gtfs_fixture.is_valid(far_stops=False) - validate_travel_over_multiple_stops(gtfs=gtfs_fixture) + chest_gtfs_fixture.is_valid(validators={"core_validation": None}) + validate_travel_over_multiple_stops(gtfs=chest_gtfs_fixture) expected_validation = { "type": { @@ -108,9 +157,124 @@ def test_validate_travel_over_multiple_stops(self, gtfs_fixture): }, } - found_dataframe = gtfs_fixture.validity_df + found_dataframe = chest_gtfs_fixture.validity_df assert expected_validation == found_dataframe.to_dict(), ( "'_validate_travel_over_multiple_stops()' failed to raise " "warnings in the validity df" ) + + +class TestValidateRouteTypeWarnings(object): + """Tests for valdate_route_type_warnings.""" + + def test_validate_route_type_warnings_defence(self, newp_gtfs_fixture): + """Tests for validate_route_type_warnings on fail.""" + with pytest.raises( + TypeError, match=r".* expected a GtfsInstance object. Got .*" + ): + validate_route_type_warnings(1) + with pytest.raises( + AttributeError, match=r".* has no attribute validity_df" + ): + validate_route_type_warnings(newp_gtfs_fixture) + + def test_validate_route_type_warnings_on_pass(self, newp_gtfs_fixture): + """Tests for validate_route_type_warnings on pass.""" + newp_gtfs_fixture.is_valid(validators={"core_validation": None}) + route_errors = _get_validation_warnings( + newp_gtfs_fixture, message="Invalid route_type" + ) + assert len(route_errors) == 1, "No route_type warnings found" + # clean the route_type errors + newp_gtfs_fixture.is_valid() + validate_route_type_warnings(newp_gtfs_fixture) + new_route_errors = _get_validation_warnings( + newp_gtfs_fixture, message="Invalid route_type" + ) + assert ( + len(new_route_errors) == 0 + ), "Found route_type warnings after cleaning" + + def test_validate_route_type_warnings_creates_warnings( + self, newp_gtfs_fixture + ): + """Test validate_route_type_warnings re-raises route_type warnings.""" + newp_gtfs_fixture.feed.routes[ + "route_type" + ] = newp_gtfs_fixture.feed.routes["route_type"].apply( + lambda x: 310030 if x == 200 else 200 + ) + newp_gtfs_fixture.is_valid( + {"core_validation": None, "validate_route_type_warnings": None} + ) + new_route_errors = _get_validation_warnings( + newp_gtfs_fixture, message="Invalid route_type" + ) + assert ( + len(new_route_errors) == 1 + ), "route_type warnings not found after cleaning" + + +@pytest.fixture(scope="function") +def create_test_zip(tmp_path) -> pathlib.Path: + """Create a gtfs zip with invalid files.""" + gtfs_pth = here("tests/data/chester-20230816-small_gtfs.zip") + # create dir for unzipped gtfs contents + gtfs_contents_pth = os.path.join(tmp_path, "gtfs_contents") + os.mkdir(gtfs_contents_pth) + # extract unzipped gtfs file to new dir + with zipfile.ZipFile(gtfs_pth, "r") as gtfs_zip: + gtfs_zip.extractall(gtfs_contents_pth) + # write some dummy files with test cases + with open(os.path.join(gtfs_contents_pth, "not_in_spec.txt"), "w") as f: + f.write("test_date") + with open(os.path.join(gtfs_contents_pth, "routes.invalid"), "w") as f: + f.write("test_date") + # zip contents + new_zip_pth = os.path.join(tmp_path, "gtfs_zip") + shutil.make_archive(new_zip_pth, "zip", gtfs_contents_pth) + full_zip_pth = pathlib.Path(new_zip_pth + ".zip") + return full_zip_pth + + +class TestValidateGtfsFile(object): + """Tests for validate_gtfs_files.""" + + def test_validate_gtfs_files_defence(self): + """Defensive tests for validate_gtfs_files.""" + with pytest.raises( + TypeError, match="'gtfs' expected a GtfsInstance object.*" + ): + validate_gtfs_files(False) + + def test_validate_gtfs_files_on_pass(self, create_test_zip): + """General tests for validte_gtfs_files.""" + gtfs = GtfsInstance(create_test_zip) + gtfs.is_valid(validators={"core_validation": None}) + validate_gtfs_files(gtfs) + # tests for invalid extensions + warnings = _get_validation_warnings( + gtfs, r".*files not of type .*txt.*", return_type="dataframe" + ) + assert ( + len(warnings) == 1 + ), "More warnings than expected for invalid extension" + assert warnings.loc[3]["message"] == ( + "GTFS zip includes files not of type '.txt'. These files include " + "['routes.invalid']" + ), "Warnings not appearing as expected" + + # tests for unrecognised files + warnings = _get_validation_warnings( + gtfs, + r".*files that aren't recognised by the GTFS.*", + return_type="dataframe", + ) + assert ( + len(warnings) == 1 + ), "More warnings than expected for not implemented tables" + assert warnings.loc[4]["message"] == ( + "GTFS zip includes files that aren't recognised by the GTFS " + "spec. These include ['not_in_spec.txt']" + ) diff --git a/tests/utils/test_defence.py b/tests/utils/test_defence.py index 3dd0d1fe..06dfb212 100644 --- a/tests/utils/test_defence.py +++ b/tests/utils/test_defence.py @@ -8,6 +8,7 @@ from _pytest.python_api import RaisesContext import pandas as pd from pyprojroot import here +from contextlib import nullcontext as does_not_raise from transport_performance.utils.defence import ( _check_iterable, @@ -21,6 +22,7 @@ _is_expected_filetype, _enforce_file_extension, ) +from transport_performance.gtfs.validation import GtfsInstance class Test_CheckIter(object): @@ -244,6 +246,7 @@ def test_check_parents_dir_exists(self, tmp_path): def test__gtfs_defence(): """Tests for _gtfs_defence().""" + # defensive tests with pytest.raises( TypeError, match=re.escape( @@ -251,6 +254,10 @@ def test__gtfs_defence(): ), ): _gtfs_defence("tester", "test") + # passing test + with does_not_raise(): + gtfs = GtfsInstance(here("tests/data/chester-20230816-small_gtfs.zip")) + _gtfs_defence(gtfs=gtfs, param_nm="gtfs") class Test_TypeDefence(object):