diff --git a/requirements.txt b/requirements.txt index dc618750..3c777179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,6 @@ pytest pyvelox pyyaml snowflake +ruamel.yaml +deepdiff +pytz diff --git a/tools/convert_tests/check_test_format_conversion_roundtrip.py b/tools/convert_tests/check_test_format_conversion_roundtrip.py new file mode 100644 index 00000000..c2b55de6 --- /dev/null +++ b/tools/convert_tests/check_test_format_conversion_roundtrip.py @@ -0,0 +1,97 @@ +import os +import re +import shutil + +from ruamel.yaml import YAML +from deepdiff import DeepDiff + +from convert_tests_to_new_format import convert_directory, load_test_file +from convert_tests_to_old_format import convert_directory as convert_directory_roundtrip + +# Initialize the YAML handler with ruamel to ensure consistency in parsing and dumping +yaml = YAML() +yaml.default_flow_style = None + + +def normalize_data(data): + """Normalize the data by removing spaces around values and quotes around strings.""" + if isinstance(data, dict): + return {k: normalize_data(v) for k, v in data.items()} + elif isinstance(data, list): + return [normalize_data(item) for item in data] + elif isinstance(data, str): + # Remove extra spaces in specific cases, like decimal<38,0> vs decimal<38, 0> + data = re.sub(r"\s*,\s*", ",", data) # Remove spaces around commas + if data.lower() == "null": + return "null" + return data + else: + return data # Return non-string values as the + + +def compare_yaml_files(file1, file2): + data1 = normalize_data(load_test_file(file1)) + data2 = normalize_data(load_test_file(file2)) + + diff = DeepDiff( + data1, data2, ignore_order=True, ignore_numeric_type_changes=True, view="text" + ) + if diff: + print(f"\nDifferences found in '{file1}' vs '{file2}':") + print(diff) + return not diff + + +def compare_directories(original_dir, roundtrip_dir): + count = 0 + for root, _, files in os.walk(original_dir): + for file_name in files: + if file_name.endswith(".yaml"): + original_file = os.path.join(root, file_name) + relative_path = os.path.relpath(original_file, original_dir) + roundtrip_file = os.path.join(roundtrip_dir, relative_path).replace( + ".test", ".yaml" + ) + + if not os.path.exists(roundtrip_file): + print(f"File missing in roundtrip directory: {roundtrip_file}") + count += 1 + continue + + if not compare_yaml_files(original_file, roundtrip_file): + count += 1 + else: + print(f"YAML content matches: {original_file} and {roundtrip_file}") + return count + + +def main(): + # Directories + initial_cases_dir = "../../cases" + temp_dir = "./temp" + intermediate_dir = f"{temp_dir}/substrait_cases" + roundtrip_dir = f"{temp_dir}/cases" + uri_prefix = ( + "https://github.com/substrait-io/substrait/blob/main/extensions/substrait" + ) + + # Step 1: Convert from `../../cases` to `./temp/substrait_cases/` + convert_directory(initial_cases_dir, intermediate_dir, uri_prefix) + + # Step 2: Convert from `./temp/substrait_cases/` to `./temp/roundtrip_cases/` + convert_directory_roundtrip(intermediate_dir, roundtrip_dir) + + # Step 3: Compare YAML content in `./cases` and `./temp/roundtrip_cases/` + count = compare_directories(initial_cases_dir, roundtrip_dir) + if count == 0: + print("All YAML files match between original and roundtrip directories.") + else: + print( + f"Differences found in {count} YAML files between original and roundtrip directories." + ) + + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + main() diff --git a/tools/convert_tests/convert_tests_helper.py b/tools/convert_tests/convert_tests_helper.py new file mode 100644 index 00000000..7f7b9d93 --- /dev/null +++ b/tools/convert_tests/convert_tests_helper.py @@ -0,0 +1,409 @@ +from datetime import datetime +import pytz +import re + +# Define short names for various data types +short_name_map = { + "string": "str", + "boolean": "bool", + "varbinary": "vbin", + "timestamp": "ts", + "timestamp_tz": "tstz", + "interval_year": "iyear", + "interval_day": "iday", + "interval": "iday", + "decimal": "dec", + "precision_timestamp": "pts", + "precision_timestamp_tz": "ptstz", + "fixedchar": "fchar", + "varchar": "vchar", + "fixedbinary": "fbin", +} + +long_name_map = {v: k for k, v in short_name_map.items()} + +# Mapping of timezone abbreviations to pytz time zones +timezone_abbr_map = { + "PST": "America/Los_Angeles", + "EST": "America/New_York", + "CST": "America/Chicago", + "MST": "America/Denver", + "HST": "Pacific/Honolulu", + "UTC": "UTC", +} + +SQUOTE = "\u200B" +DQUOTE = "&" + +# Map timezone abbreviations to valid pytz timezone names +timediff_abbr_map = { + "-0800": "US/Pacific", # Pacific Standard Time + "-0500": "US/Eastern", # Eastern Standard Time +} + +timediff_zone_map = { + "-0800": "PST", # Pacific Standard Time + "-0500": "EST", # Eastern Standard Time +} + + +def convert_type(type_str, mapping): + """Helper function to convert a type string using the provided mapping.""" + if type_str.startswith("list<") and type_str.endswith(">"): + # Extract the inner type and convert it + inner_type = type_str[5:-1] # Extract what's inside "list<...>" + short_inner_type = mapping.get(inner_type.lower(), inner_type) + return f"list<{short_inner_type}>" + + base_type, parameters = (type_str.split("<", 1) + [""])[:2] + parameters = f"<{parameters}" if parameters else "" + return mapping.get(base_type.lower(), base_type) + parameters + + +def convert_to_long_type(type_str): + type_str = convert_type(type_str, long_name_map) + if type_str == "interval_year" or type_str == "interval_day": + type_str = "interval" + return type_str + + +def match_sql_duration(value, value_type): + """ + Check if a string is in the format of 'X days, HH:MM:SS'. + Returns a match object if successful, or None if it doesn't match. + """ + if value_type != "str": + return None + pattern = r"^(?:(\d+)\s+days?,\s*)?(\d+):(\d+):(\d+)$" + return re.match(pattern, value) + + +def format_timestamp(value): + """Format a timestamp value into ISO 8601 format, handling optional fractional seconds.""" + # Try parsing with fractional seconds first + try: + dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f") + return dt.isoformat() + except ValueError: + pass # Continue to the next format if parsing fails + + # Fallback to parsing without fractional seconds + try: + dt = datetime.strptime(value, "%Y-%m-%d %H:%M:%S") + return dt.isoformat() + except ValueError: + # Return the original value if it doesn't match expected formats + return value + + +def has_last_colon_after_last_dash(s): + """ + Check if the last ':' comes after the last '-' in the string. + + Args: + s (str): Input string, e.g., "-10:30". + + Returns: + bool: True if the last ':' comes after the last '-', False otherwise. + """ + last_dash_index = s.rfind("-") # Find the last occurrence of '-' + last_colon_index = s.rfind(":") # Find the last occurrence of ':' + + # Check if both '-' and ':' exist and if last ':' comes after last '-' + return ( + last_dash_index != -1 + and last_colon_index != -1 + and last_dash_index < last_colon_index + ) + + +def format_timestamp_tz(timestamp_with_tz): + """Convert a timestamp with timezone abbreviation to ISO 8601 with offset.""" + if "-" in timestamp_with_tz and not has_last_colon_after_last_dash( + timestamp_with_tz + ): + return timestamp_with_tz.replace(" ", "T") + ":00" + datetime_str, tz_abbr = timestamp_with_tz.rsplit(" ", 1) + dt = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S") + + if tz_abbr in timezone_abbr_map: + dt = pytz.timezone(timezone_abbr_map[tz_abbr]).localize(dt) + else: + return timestamp_with_tz + # raise ValueError(f"Invalid timezone abbreviation: {tz_abbr}") + + return dt.isoformat() + + +def convert_sql_interval_to_iso_duration(interval: str): + """Convert SQL-style interval to ISO 8601 format.""" + # Regular expression to match interval components (e.g., '1 DAY 10 HOUR' or '360 DAY, 1 HOUR') + match = re.match( + r"INTERVAL\s*'((\d+)\s*(YEAR|MONTH|DAY|HOUR|MINUTE|SECOND)(\s+\d+\s*(YEAR|MONTH|DAY|HOUR|MINUTE|SECOND))*)'", + interval.strip(), + ) + if not match: + raise ValueError(f"Invalid SQL interval format: {interval}") + + # Capture all parts of the interval, including multiple parts (space-separated) + parts = match.group(1).split() + + iso_parts = [] + unit_map = { + "YEAR": "P{}Y", + "MONTH": "P{}M", + "DAY": "P{}D", + "HOUR": "PT{}H", + "MINUTE": "PT{}M", + "SECOND": "PT{}S", + } + + i = 0 + iyear = False + while i < len(parts): + num = int(parts[i]) # The number part + unit = parts[i + 1] # The unit part + if unit in ["YEAR", "MONTH"]: + iyear = True + iso_parts.append(unit_map[unit].format(num)) + i += 2 # Skip to the next number-unit pair + + # Join the components together to form the final ISO 8601 duration string + iso_interval = "".join(iso_parts) + + # Determine whether it's a 'iyear' or 'iday' based on units + if iyear: + return iso_interval, "iyear" + else: + return iso_interval, "iday" + + +def convert_sql_duration_to_iso_duration(match): + """ + Convert a valid "X days, HH:MM:SS" string to ISO 8601 duration format. + If the input does not match, return the original string. + """ + + days, hours, minutes, seconds = match.groups(default="0") + + # Construct the ISO 8601 duration format + iso_duration = f"P{int(days)}D" # Add days + if int(hours) or int(minutes) or int(seconds): + iso_duration += f"T{int(hours)}H{int(minutes)}M{int(seconds)}S" + + return iso_duration + + +def iso_format_type(s): + # Define regex patterns for each ISO 8601 format + iso_patterns = { + "time": r"^\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})?$", + "iday": r"^P(\d+D)?(T(\d+H)?(\d+M)?(\d+S)?)?$", # Day-based duration, specifically for day/time-only formats + "iyear": r"^P(\d+Y)?(\d+M)?(\d+W)?(\d+D)?(T(\d+H)?(\d+M)?(\d+S)?)?$", + # Full ISO duration format, can include year/month and day/time components + } + + # Check against each pattern and return the type if it matches + for fmt, pattern in iso_patterns.items(): + if re.match(pattern, s): + return fmt # Return the matching format type + + # Return None if no pattern matches + return None + + +def format_null(value_type, value, level): + """Handle null values, with special handling for interval types.""" + if value_type == "interval": + # Special handling for null intervals + return "null::iday" if level == 0 else "null" + if value_type: + return f"null::{value_type}" if level == 0 else value + return "null" + + +def is_list_type(value_type): + return bool( + value_type and value_type.startswith("list<") and value_type.endswith(">") + ) + + +def needs_quotes(type_str): + quote_types = { + "str", + "string", + "fchar", + "vchar", + "date", + "time", + "list", + "iday", + "iyear", + "timestamp_tz", + "timestamp", + } + + """Check if the type requires quotes around its values.""" + # Extract base type, ignoring any <...> parameters, and lowercase for case-insensitive comparison + base_type = type_str.split("<", 1)[0].lower() + # Check against short and long versions of each type in quote_types + if base_type not in quote_types and type_str not in quote_types: + return False + return True + + +def iso_to_sql_interval(iso_duration): + # Match ISO 8601 duration format + match = re.match( + r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?(?:T(?:(\d+)H)?(?:(\d+)M)?(?:(\d+)S)?)?", + iso_duration, + ) + + if not match: + raise ValueError("Invalid ISO 8601 duration format") + + # Create interval parts by pairing matched values with corresponding units + units = ["YEAR", "MONTH", "DAY", "HOUR", "MINUTE", "SECOND"] + parts = [f"{value} {unit}" for value, unit in zip(match.groups(), units) if value] + interval = " ".join(parts) + + return "INTERVAL " + SQUOTE + interval + SQUOTE + + +def convert_iso_to_timezone_format(iso_string): + # Parse the ISO 8601 string into a datetime object with timezone info + dt = datetime.fromisoformat(iso_string) + + # Handle timezone-aware datetime + if dt.tzinfo is not None: + # Get the UTC offset + offset_str = dt.strftime("%z") + + # Look up the timezone abbreviation from the offset + timezone_name = timediff_abbr_map.get(offset_str, None) + + if timezone_name: + # Convert the datetime object to the local time based on the timezone info + tz = pytz.timezone(timezone_name) + dt_in_timezone = dt.astimezone(tz) + + # Format the datetime into the required format (without the UTC offset and with time zone abbreviation) + return dt_in_timezone.strftime( + f"%Y-%m-%d %H:%M:%S {timediff_zone_map[offset_str]}" + ) + else: + return dt.strftime("%Y-%m-%d %H:%M:%S UTC") + else: + return dt.strftime("%Y-%m-%d %H:%M:%S") + + +def iso_duration_to_timedelta(iso_duration): + # Match ISO 8601 duration pattern + match = re.match( + r"^P(?:(\d+)D)?(?:T(?:(\d+)H)?(?:(\d+)M)?(?:(\d+)S)?)?$", iso_duration + ) + + if not match: + return iso_duration # Return the original string if it's not a valid duration format + + days = int(match.group(1) or 0) + hours = int(match.group(2) or 0) + minutes = int(match.group(3) or 0) + seconds = int(match.group(4) or 0) + + # Use singular/plural for "day" and handle zero days + day_str = f"{days} day{'s' if days != 1 else ''}" if days > 0 else "" + time_str = f"{hours}:{minutes:02}:{seconds:02}" + + return f"{day_str}, {time_str}" if day_str else time_str + + +def convert_to_old_value(value, value_type, level=0): + if value_type is not None: + value_type = convert_to_long_type(value_type) + + if isinstance(value, list): + formatted_values = f"[{', '.join(convert_to_old_value(v, value_type, level + 1) for v in value)}]" + return ( + f"!decimallist {formatted_values}" + if value_type.startswith("decimal") and level == 0 + else formatted_values + ) + + if value_type is None and value.lower() == "null": + return DQUOTE + "NULL" + DQUOTE + if value_type is None and value.upper() == "IYEAR_360DAYS": + return "360" + if value_type is None and value.upper() == "IYEAR_365DAYS": + return "365" + if value in {None, "'Null'"}: + return str("Null") + value = str(value) + if value == "NaN": + return str("NAN") + if value_type is not None: + if value_type.startswith("decimal") and level == 0: + value = "!decimal " + value + if value_type == "interval": + value = iso_to_sql_interval(str(value)) + if value_type == "timestamp": + value = value.replace("T", " ") + if "." in value: + # Strip only the trailing zeros in the subsecond field + value = value.rstrip("0").rstrip(".") + if value_type == "timestamp_tz": + value = convert_iso_to_timezone_format(value) + if needs_quotes(value_type): + value = DQUOTE + value + DQUOTE + return value + return str(value) + + +def convert_to_new_value(value, value_type, level=0): + """Format a value based on its type, if specified.""" + if value_type: + value_type = convert_type(value_type, short_name_map) + + # Handle list values recursively + if isinstance(value, list): + left_delim, right_delim = ("[", "]") if is_list_type(value_type) else ("(", ")") + formatted_values = [ + convert_to_new_value(x, value_type, level + 1) for x in value + ] + return ( + f"{left_delim}" + + ", ".join(map(str, formatted_values)) + + f"{right_delim}::{value_type}" + ) + + if (value is None) or str(value).lower() == "null": + return format_null(value_type, value, level) + + if value_type is None: + if value == "360": + return "IYEAR_360DAYS" + if value == "365": + return "IYEAR_365DAYS" + return str(value) + + # convert duration + if value_type in ("iday", "iyear"): + value, value_type = convert_sql_interval_to_iso_duration(str(value)) + match = match_sql_duration(value, value_type) + if match: + value = convert_sql_duration_to_iso_duration(match) + value_type = iso_format_type(value) + + if value_type == "time": + value = value.split("+")[0] + if value_type == "ts": + value = f"'{format_timestamp(str(value))}'" + if value_type == "tstz": + value = f"'{format_timestamp_tz(str(value))}'" + if value_type == "bool": + value = str(value).lower() + + if needs_quotes(value_type): + value = f"'{value}'" + + return f"{value}::{value_type}" if level == 0 else value diff --git a/tools/convert_tests/convert_tests_to_new_format.py b/tools/convert_tests/convert_tests_to_new_format.py new file mode 100644 index 00000000..120b0fc1 --- /dev/null +++ b/tools/convert_tests/convert_tests_to_new_format.py @@ -0,0 +1,173 @@ +import yaml +import os +from collections import defaultdict +from itertools import count +from tools.convert_tests.convert_tests_helper import convert_to_new_value + + +# Define a custom YAML loader that interprets all values as strings +def string_loader(loader, node): + return str(loader.construct_scalar(node)) + + +def list_of_decimal_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode): + return [string_loader(loader, item) for item in node.value] + + +def load_test_file(file_path): + """Load a YAML file, interpreting all values as strings.""" + # Override default YAML constructors to load all types as strings + for tag in ("str", "int", "float", "bool", "null", "decimal"): + yaml.add_constructor(f"tag:yaml.org,2002:{tag}", string_loader) + + yaml.add_constructor("!decimal", string_loader) + yaml.add_constructor("!isostring", string_loader) + yaml.add_constructor("!decimallist", list_of_decimal_constructor) + + with open(file_path, "r") as file: + return yaml.load(file, Loader=yaml.FullLoader) + + +def format_return_value(case): + result = case.get("result", {}) + special = result.get("special") + + if special: + special = special.lower() + + # Handle special cases for ERROR and UNDEFINED + if special in {"error", "undefined"}: + return f"" + + if special == "nan": + return "nan::fp64" + + # Return formatted result with format_value + return convert_to_new_value(result.get("value"), result.get("type")) + + +def format_test_case_group(case, description_map): + """Extract group name and description for test case.""" + group = case.get("group", "basic") + group_name = group if isinstance(group, str) else group.get("id", "basic") + description = group.get("description", "") if isinstance(group, dict) else "" + + if group_name not in description_map: + description_map[group_name] = description + + return f"{group_name}: {description_map.get(group_name, '')}" + + +def generate_define_table(case, table_id): + """Generates the table definition only if there are arguments with 'is_not_a_func_arg'.""" + args = case.get("args", []) + + # If args is empty, return an empty string, as no table is needed + if not args: + return "" + + # Gather column types and names based on args + formatted_columns = ", ".join(str(arg["type"]) for arg in args) if args else "" + + # Transpose the arguments' values to construct rows + values = [ + [convert_to_new_value(value, arg["type"], 1) for value in arg.get("value", [])] + for arg in args + ] + rows = zip(*values) # zip will combine each nth element of each argument + + # Format rows as strings for the table definition + formatted_rows = ", ".join(f"({', '.join(map(str, row))})" for row in rows) + + # Define table format with column types + table_definition = ( + f"DEFINE t{table_id}({formatted_columns}) = ({formatted_rows}) \n" + ) + + return table_definition + + +def format_test_case(case, function, description_map, table_id_counter, is_aggregate): + """Format a single test case.""" + description = format_test_case_group(case, description_map) + options = case.get("options") + options = ( + f" [{', '.join(f'{k}:{convert_to_new_value(v, None)}' for k, v in options.items())}]" + if options + else "" + ) + results = format_return_value(case) + + args = [arg for arg in case.get("args", []) if not arg.get("is_not_a_func_arg")] + if is_aggregate and len(args) != 1: + table_id = next(table_id_counter) + args = ", ".join(f"t{table_id}.col{idx}" for idx in range(len(args))) + table_definition = generate_define_table(case, table_id) + return description, f"{table_definition}{function}({args}){options} = {results}" + + args = ", ".join( + convert_to_new_value(arg.get("value"), str(arg["type"])) + for arg in case.get("args", []) + ) + return description, f"{function}({args}){options} = {results}" + + +def convert_test_file_to_new_format(input_data, prefix, is_aggregate): + """Parse YAML test data to formatted cases.""" + function = input_data["function"] + base_uri = input_data["base_uri"][len(prefix) :] + description_map = {} + table_id_counter = count(0) + groups = defaultdict(lambda: {"tests": []}) + + for case in input_data["cases"]: + description, formatted_test = format_test_case( + case, function, description_map, table_id_counter, is_aggregate + ) + groups[description]["tests"].append(formatted_test) + + output_lines = [ + f"{'### SUBSTRAIT_AGGREGATE_TEST: v1.0' if is_aggregate else '### SUBSTRAIT_SCALAR_TEST: v1.0'}\n", + f"### SUBSTRAIT_INCLUDE: '{base_uri}'\n", + ] + + for description, details in groups.items(): + output_lines.append(f"\n# {description}\n") + output_lines.extend(f"{test}\n" for test in details["tests"]) + + return output_lines + + +def output_test_data(output_file, lines): + """Write formatted lines to a file.""" + os.makedirs(os.path.dirname(output_file), exist_ok=True) + with open(output_file, "w") as file: + file.writelines(lines) + + print(f"Converted '{output_file}' successfully.") + + +def convert_directory(input_dir, output_dir, prefix): + """Process all YAML files in a directory, convert and save them to output directory.""" + for root, _, files in os.walk(input_dir): + for filename in filter(lambda f: f.endswith(".yaml"), files): + input_file = os.path.join(root, filename) + output_file = os.path.join( + output_dir, os.path.relpath(input_file, input_dir) + ).replace(".yaml", ".test") + is_aggregate = "aggregate" in input_file + + yaml_data = load_test_file(input_file) + output_lines = convert_test_file_to_new_format( + yaml_data, prefix, is_aggregate + ) + output_test_data(output_file, output_lines) + + +if __name__ == "__main__": + input_directory = "./cases" + output_directory = "./substrait/tests/cases" + uri_prefix = ( + "https://github.com/substrait-io/substrait/blob/main/extensions/substrait" + ) + convert_directory(input_directory, output_directory, uri_prefix) diff --git a/tools/convert_tests/convert_tests_to_old_format.py b/tools/convert_tests/convert_tests_to_old_format.py new file mode 100644 index 00000000..9764bd7b --- /dev/null +++ b/tools/convert_tests/convert_tests_to_old_format.py @@ -0,0 +1,196 @@ +import os + +from ruamel.yaml import YAML +from tests.coverage.nodes import ( + TestFile, + AggregateArgument, +) +from tests.coverage.case_file_parser import load_all_testcases +from tools.convert_tests.convert_tests_helper import ( + convert_to_old_value, + convert_to_long_type, + SQUOTE, + DQUOTE, + iso_duration_to_timedelta, +) + +yaml = YAML() +yaml.indent(mapping=2, sequence=4, offset=2) # Adjust indentations as needed +yaml.width = 4096 # Extend line width to prevent line breaks + + +def convert_result(test_case): + """Convert the result section based on specific conditions.""" + if test_case.is_return_type_error(): + return {"special": str(test_case.result.error)} + elif str(test_case.result.value) == "nan": + return {"special": "nan"} + elif test_case.func_name == "add_intervals" and test_case.result.type == "iday": + return { + "value": convert_to_old_value( + iso_duration_to_timedelta(test_case.result.value), "str" + ), + "type": "string", + } + else: + return { + "value": convert_to_old_value( + test_case.result.value, test_case.result.type + ), + "type": convert_to_long_type(test_case.result.type), + } + + +def convert_table_definition(test_case): + column_types = None + + if test_case.column_types is not None: + column_types = [convert_to_long_type(type) for type in test_case.column_types] + elif test_case.args is not None: + column_types = [ + convert_to_long_type( + arg.scalar_value.type + if isinstance(arg, AggregateArgument) + else arg.type + ) + for arg in test_case.args + ] + + columns = list(map(list, zip(*test_case.rows))) + if not columns: + # Handle the case where columns is empty, but column_types is not + return [ + {"value": [], "type": col_type, "is_not_a_func_arg": "true"} + for col_type in column_types + ] + else: + # Handle the case where columns is not empty + return [ + { + "value": convert_to_old_value(column, col_type), + "type": col_type, + "is_not_a_func_arg": "true", + } + for column, col_type in zip(columns, column_types) + ] + + +def convert_group(test_case, groups): + id = str(test_case.group.name.split(": ")[0]) + desc = test_case.group.name.split(": ")[1] if ": " in test_case.group.name else "" + group = id if id in groups else {"id": id, "description": desc} + groups[id] = desc + return group + + +def convert_test_case_to_old_format(test_case, groups): + # Match group headers with descriptions + print(f"converting test '{test_case}'") + case = {} + case["group"] = convert_group(test_case, groups) + + if test_case.rows is not None: + case["args"] = convert_table_definition(test_case) + else: + if isinstance(test_case.args[0], AggregateArgument): + case["args"] = [ + { + "value": convert_to_old_value( + arg.scalar_value.value, arg.scalar_value.type + ), + "type": convert_to_long_type(arg.scalar_value.type), + } + for arg in test_case.args + ] + else: + case["args"] = [ + { + "value": convert_to_old_value(arg.value, arg.type), + "type": convert_to_long_type(arg.type), + } + for arg in test_case.args + ] + + if len(test_case.options) > 0: + case["options"] = { + key: convert_to_old_value(value, None) + for key, value in test_case.options.items() + } + + case["result"] = convert_result(test_case) + return case + + +def convert_test_file_to_yaml(testFile: TestFile): + # Get function name from the first expression + function = None + cases = [] + groups = {} + + for test_case in testFile.testcases: + function = test_case.func_name + cases.append(convert_test_case_to_old_format(test_case, groups)) + + # Construct the full YAML structure + return { + "base_uri": f"https://github.com/substrait-io/substrait/blob/main/extensions/substrait{testFile.include}", + "function": function, + "cases": cases, + } + + +def output_test_data(output_file, input_path, yaml_data): + with open(output_file, "w") as f: + yaml.dump(yaml_data, f) + + fix_quotes(output_file) + + print(f"Converted '{input_path}' to '{output_file}'.") + + +def fix_quotes(file_path): + with open(file_path, "r") as file: + content = file.read() + + # Remove all single quotes + content = ( + content.replace("'", "") + .replace('"', "") + .replace(SQUOTE, "'") + .replace(DQUOTE, '"') + ) + + with open(file_path, "w") as file: + file.write(content) + + +def convert_directory(input_dir, output_dir): + input_test_files = load_all_testcases(input_dir) + for input_test_file in input_test_files: + input_file = input_test_file.path + relative_path = os.path.relpath(input_file, input_dir) + output_file = os.path.join(output_dir, relative_path).replace(".test", ".yaml") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + yaml_data = convert_test_file_to_yaml(input_test_file) + output_test_data(output_file, input_test_file.path, yaml_data) + + +def main(): + input_dir = "../../substrait/tests/cases" + output_dir = "../../cases" # Specify the output directory + convert_directory(input_dir, output_dir) + + +if __name__ == "__main__": + main() + +from io import StringIO + + +def normalize_yaml(yaml_string): + """Normalize YAML by loading it into Python objects and then dumping it back to a string.""" + # If the input is a dictionary or list, convert it to a YAML string first + yaml_stream = StringIO(yaml_string) + + # Load the YAML from the string (as a stream) + return yaml.load(yaml_stream) diff --git a/tools/convert_tests/test_convert_tests_to_new_format.py b/tools/convert_tests/test_convert_tests_to_new_format.py new file mode 100644 index 00000000..c65c2f56 --- /dev/null +++ b/tools/convert_tests/test_convert_tests_to_new_format.py @@ -0,0 +1,503 @@ +import pytest + +from convert_tests_to_new_format import convert_test_file_to_new_format + + +@pytest.mark.parametrize( + "input_data, prefix, is_aggregate, expected_output", + [ + ( + { + "base_uri": "https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic.yaml", + "function": "add", + "cases": [ + { + "group": { + "id": "basic", + "description": "Basic examples without any special cases", + }, + "args": [ + {"value": 120, "type": "i8"}, + {"value": 5, "type": "i8"}, + ], + "result": {"value": 125, "type": "i8"}, + }, + { + "group": "basic", + "args": [ + {"value": 100, "type": "i16"}, + {"value": 100, "type": "i16"}, + ], + "result": {"value": 200, "type": "i16"}, + }, + { + "group": "basic", + "args": [ + {"value": 30000, "type": "i32"}, + {"value": 30000, "type": "i32"}, + ], + "result": {"value": 60000, "type": "i32"}, + }, + { + "group": "basic", + "args": [ + {"value": 2000000000, "type": "i64"}, + {"value": 2000000000, "type": "i64"}, + ], + "result": {"value": 4000000000, "type": "i64"}, + }, + { + "group": { + "id": "overflow", + "description": "Examples demonstrating overflow behavior", + }, + "args": [ + {"value": 120, "type": "i8"}, + {"value": 10, "type": "i8"}, + ], + "options": {"overflow": "ERROR"}, + "result": {"special": "error"}, + }, + { + "group": "overflow", + "args": [ + {"value": 30000, "type": "i16"}, + {"value": 30000, "type": "i16"}, + ], + "options": {"overflow": "ERROR"}, + "result": {"special": "error"}, + }, + { + "group": "overflow", + "args": [ + {"value": 2000000000, "type": "i32"}, + {"value": 2000000000, "type": "i32"}, + ], + "options": {"overflow": "ERROR"}, + "result": {"special": "error"}, + }, + { + "group": "overflow", + "args": [ + {"value": 9223372036854775807, "type": "i64"}, + {"value": 1, "type": "i64"}, + ], + "options": {"overflow": "ERROR"}, + "result": {"special": "error"}, + }, + ], + }, + "https://github.com/substrait-io/substrait/blob/main/extensions/substrait", + False, + [ + "### SUBSTRAIT_SCALAR_TEST: v1.0\n", + "### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic.yaml'\n", + "\n# basic: Basic examples without any special cases\n", + "add(120::i8, 5::i8) = 125::i8\n", + "add(100::i16, 100::i16) = 200::i16\n", + "add(30000::i32, 30000::i32) = 60000::i32\n", + "add(2000000000::i64, 2000000000::i64) = 4000000000::i64\n", + "\n# overflow: Examples demonstrating overflow behavior\n", + "add(120::i8, 10::i8) [overflow:ERROR] = \n", + "add(30000::i16, 30000::i16) [overflow:ERROR] = \n", + "add(2000000000::i32, 2000000000::i32) [overflow:ERROR] = \n", + "add(9223372036854775807::i64, 1::i64) [overflow:ERROR] = \n", + ], + ), + # Second test case for "max" function + ( + { + "base_uri": "https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic.yaml", + "function": "max", + "cases": [ + { + "group": { + "id": "basic", + "description": "Basic examples without any special cases", + }, + "args": [{"value": [20, -3, 1, -10, 0, 5], "type": "i8"}], + "result": {"value": 20, "type": "i8"}, + }, + { + "group": "basic", + "args": [ + {"value": [-32768, 32767, 20000, -30000], "type": "i16"} + ], + "result": {"value": 32767, "type": "i16"}, + }, + { + "group": "basic", + "args": [ + { + "value": [-214748648, 214748647, 21470048, 4000000], + "type": "i32", + } + ], + "result": {"value": 214748647, "type": "i32"}, + }, + { + "group": "basic", + "args": [ + { + "value": [ + 2000000000, + -3217908979, + 629000000, + -100000000, + 0, + 987654321, + ], + "type": "i64", + } + ], + "result": {"value": 2000000000, "type": "i64"}, + }, + { + "group": "basic", + "args": [{"value": [2.5, 0, 5.0, -2.5, -7.5], "type": "fp32"}], + "result": {"value": 5.0, "type": "fp32"}, + }, + { + "group": "basic", + "args": [ + { + "value": [ + "1.5e+308", + "1.5e+10", + "-1.5e+8", + "-1.5e+7", + "-1.5e+70", + ], + "type": "fp64", + } + ], + "result": {"value": "1.5e+308", "type": "fp64"}, + }, + { + "group": { + "id": "null_handling", + "description": "Examples with null as input or output", + }, + "args": [{"value": ["Null", "Null", "Null"], "type": "i16"}], + "result": {"value": "Null", "type": "i16"}, + }, + { + "group": "null_handling", + "args": [{"value": [], "type": "i16"}], + "result": {"value": "Null", "type": "i16"}, + }, + { + "group": "null_handling", + "args": [ + { + "value": [ + 2000000000, + "Null", + 629000000, + -100000000, + "Null", + 987654321, + ], + "type": "i64", + } + ], + "result": {"value": 2000000000, "type": "i64"}, + }, + { + "group": "null_handling", + "args": [{"value": ["Null", "inf"], "type": "fp64"}], + "result": {"value": "inf", "type": "fp64"}, + }, + { + "group": "null_handling", + "args": [ + { + "value": [ + "Null", + "-inf", + "-1.5e+8", + "-1.5e+7", + "-1.5e+70", + ], + "type": "fp64", + } + ], + "result": {"value": "-1.5e+7", "type": "fp64"}, + }, + { + "group": "null_handling", + "args": [ + { + "value": [ + "1.5e+308", + "1.5e+10", + "Null", + "-1.5e+7", + "Null", + ], + "type": "fp64", + } + ], + "result": {"value": "1.5e+308", "type": "fp64"}, + }, + ], + }, + "https://github.com/substrait-io/substrait/blob/main/extensions/substrait", + False, + [ + "### SUBSTRAIT_SCALAR_TEST: v1.0\n", + "### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic.yaml'\n", + "\n# basic: Basic examples without any special cases\n", + "max((20, -3, 1, -10, 0, 5)::i8) = 20::i8\n", + "max((-32768, 32767, 20000, -30000)::i16) = 32767::i16\n", + "max((-214748648, 214748647, 21470048, 4000000)::i32) = 214748647::i32\n", + "max((2000000000, -3217908979, 629000000, -100000000, 0, 987654321)::i64) = 2000000000::i64\n", + "max((2.5, 0, 5.0, -2.5, -7.5)::fp32) = 5.0::fp32\n", + "max((1.5e+308, 1.5e+10, -1.5e+8, -1.5e+7, -1.5e+70)::fp64) = 1.5e+308::fp64\n", + "\n# null_handling: Examples with null as input or output\n", + "max((Null, Null, Null)::i16) = null::i16\n", + "max(()::i16) = null::i16\n", + "max((2000000000, Null, 629000000, -100000000, Null, 987654321)::i64) = 2000000000::i64\n", + "max((Null, inf)::fp64) = inf::fp64\n", + "max((Null, -inf, -1.5e+8, -1.5e+7, -1.5e+70)::fp64) = -1.5e+7::fp64\n", + "max((1.5e+308, 1.5e+10, Null, -1.5e+7, Null)::fp64) = 1.5e+308::fp64\n", + ], + ), + # Test case for "lt" function + ( + { + "base_uri": "https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_datetime.yaml", + "function": "lt", + "cases": [ + { + "group": { + "id": "timestamps", + "description": "examples using the timestamp type", + }, + "args": [ + {"value": "2016-12-31 13:30:15", "type": "timestamp"}, + {"value": "2017-12-31 13:30:15", "type": "timestamp"}, + ], + "result": {"value": True, "type": "boolean"}, + }, + { + "group": "timestamps", + "args": [ + {"value": "2018-12-31 13:30:15", "type": "timestamp"}, + {"value": "2017-12-31 13:30:15", "type": "timestamp"}, + ], + "result": {"value": False, "type": "boolean"}, + }, + { + "group": { + "id": "timestamp_tz", + "description": "examples using the timestamp_tz type", + }, + "args": [ + { + "value": "1999-01-08 01:05:05 PST", + "type": "timestamp_tz", + }, + { + "value": "1999-01-08 04:05:06 EST", + "type": "timestamp_tz", + }, + ], + "result": {"value": True, "type": "boolean"}, + }, + { + "group": "timestamp_tz", + "args": [ + { + "value": "1999-01-08 01:05:06 PST", + "type": "timestamp_tz", + }, + { + "value": "1999-01-08 04:05:06 EST", + "type": "timestamp_tz", + }, + ], + "result": {"value": False, "type": "boolean"}, + }, + { + "group": { + "id": "date", + "description": "examples using the date type", + }, + "args": [ + {"value": "2020-12-30", "type": "date"}, + {"value": "2020-12-31", "type": "date"}, + ], + "result": {"value": True, "type": "boolean"}, + }, + { + "group": "date", + "args": [ + {"value": "2020-12-31", "type": "date"}, + {"value": "2020-12-30", "type": "date"}, + ], + "result": {"value": False, "type": "boolean"}, + }, + { + "group": { + "id": "interval", + "description": "examples using the interval type", + }, + "args": [ + {"value": "INTERVAL '5 DAY'", "type": "interval"}, + {"value": "INTERVAL '6 DAY'", "type": "interval"}, + ], + "result": {"value": True, "type": "boolean"}, + }, + { + "group": "interval", + "args": [ + {"value": "INTERVAL '7 DAY'", "type": "interval"}, + {"value": "INTERVAL '6 DAY'", "type": "interval"}, + ], + "result": {"value": False, "type": "boolean"}, + }, + { + "group": "interval", + "args": [ + {"value": "INTERVAL '5 YEAR'", "type": "interval"}, + {"value": "INTERVAL '6 YEAR'", "type": "interval"}, + ], + "result": {"value": True, "type": "boolean"}, + }, + { + "group": "interval", + "args": [ + {"value": "INTERVAL '7 YEAR'", "type": "interval"}, + {"value": "INTERVAL '6 YEAR'", "type": "interval"}, + ], + "result": {"value": False, "type": "boolean"}, + }, + { + "group": { + "id": "null_input", + "description": "examples with null args or return", + }, + "args": [ + {"value": None, "type": "interval"}, + {"value": "INTERVAL '5 DAY'", "type": "interval"}, + ], + "result": {"value": None, "type": "boolean"}, + }, + { + "group": "null_input", + "args": [ + {"value": None, "type": "date"}, + {"value": "2020-12-30", "type": "date"}, + ], + "result": {"value": None, "type": "boolean"}, + }, + { + "group": "null_input", + "args": [ + {"value": None, "type": "timestamp"}, + {"value": "2018-12-31 13:30:15", "type": "timestamp"}, + ], + "result": {"value": None, "type": "boolean"}, + }, + ], + }, + "https://github.com/substrait-io/substrait/blob/main/extensions/substrait", + False, + [ + "### SUBSTRAIT_SCALAR_TEST: v1.0\n", + "### SUBSTRAIT_INCLUDE: '/extensions/functions_datetime.yaml'\n", + "\n# timestamps: examples using the timestamp type\n", + "lt('2016-12-31T13:30:15'::ts, '2017-12-31T13:30:15'::ts) = true::bool\n", + "lt('2018-12-31T13:30:15'::ts, '2017-12-31T13:30:15'::ts) = false::bool\n", + "\n# timestamp_tz: examples using the timestamp_tz type\n", + "lt('1999-01-08T01:05:05-08:00'::tstz, '1999-01-08T04:05:06-05:00'::tstz) = true::bool\n", + "lt('1999-01-08T01:05:06-08:00'::tstz, '1999-01-08T04:05:06-05:00'::tstz) = false::bool\n", + "\n# date: examples using the date type\n", + "lt('2020-12-30'::date, '2020-12-31'::date) = true::bool\n", + "lt('2020-12-31'::date, '2020-12-30'::date) = false::bool\n", + "\n# interval: examples using the interval type\n", + "lt('P5D'::iday, 'P6D'::iday) = true::bool\n", + "lt('P7D'::iday, 'P6D'::iday) = false::bool\n", + "lt('P5Y'::iyear, 'P6Y'::iyear) = true::bool\n", + "lt('P7Y'::iyear, 'P6Y'::iyear) = false::bool\n", + "\n# null_input: examples with null args or return\n", + "lt(null::iday, 'P5D'::iday) = null::bool\n", + "lt(null::date, '2020-12-30'::date) = null::bool\n", + "lt(null::ts, '2018-12-31T13:30:15'::ts) = null::bool\n", + ], + ), + ( + { + "base_uri": "https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic.yaml", + "function": "power", + "cases": [ + { + "group": { + "id": "basic", + "description": "Basic examples without any special cases", + }, + "args": [ + {"value": 8, "type": "i64"}, + {"value": 2, "type": "i64"}, + ], + "result": {"value": 64, "type": "i64"}, + }, + { + "group": "basic", + "args": [ + {"value": 1.0, "type": "fp32"}, + {"value": -1.0, "type": "fp32"}, + ], + "result": {"value": 1.0, "type": "fp32"}, + }, + { + "group": "basic", + "args": [ + {"value": 2.0, "type": "fp64"}, + {"value": -2.0, "type": "fp64"}, + ], + "result": {"value": 0.25, "type": "fp64"}, + }, + { + "group": "basic", + "args": [ + {"value": 13, "type": "i64"}, + {"value": 10, "type": "i64"}, + ], + "result": {"value": 137858491849, "type": "i64"}, + }, + { + "group": { + "id": "floating_exception", + "description": "Examples demonstrating exceptional floating point cases", + }, + "args": [ + {"value": 1.5e100, "type": "fp64"}, + {"value": 1.5e208, "type": "fp64"}, + ], + "result": {"value": "inf", "type": "fp64"}, + }, + ], + }, + "https://github.com/substrait-io/substrait/blob/main/extensions/substrait", + False, + [ + "### SUBSTRAIT_SCALAR_TEST: v1.0\n", + "### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic.yaml'\n", + "\n# basic: Basic examples without any special cases\n", + "power(8::i64, 2::i64) = 64::i64\n", + "power(1.0::fp32, -1.0::fp32) = 1.0::fp32\n", + "power(2.0::fp64, -2.0::fp64) = 0.25::fp64\n", + "power(13::i64, 10::i64) = 137858491849::i64\n", + "\n# floating_exception: Examples demonstrating exceptional floating point cases\n", + "power(1.5e+100::fp64, 1.5e+208::fp64) = inf::fp64\n", + ], + ), + ], +) +def test_convert_test_file_to_new_format( + input_data, prefix, is_aggregate, expected_output +): + result = convert_test_file_to_new_format(input_data, prefix, is_aggregate) + assert result == expected_output diff --git a/tools/convert_tests/test_convert_tests_to_old_format.py b/tools/convert_tests/test_convert_tests_to_old_format.py new file mode 100644 index 00000000..0fe7ffd4 --- /dev/null +++ b/tools/convert_tests/test_convert_tests_to_old_format.py @@ -0,0 +1,238 @@ +import pytest + +from convert_tests_to_old_format import convert_test_file_to_yaml +from tests.coverage.nodes import ( + TestFile, + TestCase, + CaseLiteral, + CaseGroup, +) + + +@pytest.mark.parametrize( + "test_file, expected_yaml", + [ + ( + TestFile( + path="test_path", + version="v1.0", + include="/extensions/functions_arithmetic.yaml", + testcases=[ + TestCase( + func_name="power", + base_uri="https://github.com/substrait-io/substrait", + group=CaseGroup( + name="basic: Basic examples without any special cases", + description="", + ), + options={}, + rows=None, + args=[ + CaseLiteral(value=8, type="i64"), + CaseLiteral(value=2, type="i64"), + ], + result=CaseLiteral(value=64, type="i64"), + comment="", + ), + TestCase( + func_name="power", + base_uri="https://github.com/substrait-io/substrait", + group=CaseGroup( + name="basic: Basic examples without any special cases", + description="", + ), + options={}, + rows=None, + args=[ + CaseLiteral(value=1.0, type="fp32"), + CaseLiteral(value=-1.0, type="fp32"), + ], + result=CaseLiteral(value=1.0, type="fp32"), + comment="", + ), + TestCase( + func_name="power", + base_uri="https://github.com/substrait-io/substrait", + group=CaseGroup( + name="basic: Basic examples without any special cases", + description="", + ), + options={}, + rows=None, + args=[ + CaseLiteral(value=2.0, type="fp64"), + CaseLiteral(value=-2.0, type="fp64"), + ], + result=CaseLiteral(value=0.25, type="fp64"), + comment="", + ), + TestCase( + func_name="power", + base_uri="https://github.com/substrait-io/substrait", + group=CaseGroup( + name="basic: Basic examples without any special cases", + description="", + ), + options={}, + rows=None, + args=[ + CaseLiteral(value=13, type="i64"), + CaseLiteral(value=10, type="i64"), + ], + result=CaseLiteral(value=137858491849, type="i64"), + comment="", + ), + TestCase( + func_name="power", + base_uri="https://github.com/substrait-io/substrait", + group=CaseGroup( + name="floating_exception: Examples demonstrating exceptional floating point cases", + description="", + ), + options={}, + rows=None, + args=[ + CaseLiteral(value=1.5e100, type="fp64"), + CaseLiteral(value=1.5e208, type="fp64"), + ], + result=CaseLiteral(value="inf", type="fp64"), + comment="", + ), + ], + ), + { + "base_uri": "https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic.yaml", + "function": "power", + "cases": [ + { + "group": { + "id": "basic", + "description": "Basic examples without any special cases", + }, + "args": [ + {"value": "8", "type": "i64"}, + {"value": "2", "type": "i64"}, + ], + "result": {"value": "64", "type": "i64"}, + }, + { + "group": "basic", + "args": [ + {"value": "1.0", "type": "fp32"}, + {"value": "-1.0", "type": "fp32"}, + ], + "result": {"value": "1.0", "type": "fp32"}, + }, + { + "group": "basic", + "args": [ + {"value": "2.0", "type": "fp64"}, + {"value": "-2.0", "type": "fp64"}, + ], + "result": {"value": "0.25", "type": "fp64"}, + }, + { + "group": "basic", + "args": [ + {"value": "13", "type": "i64"}, + {"value": "10", "type": "i64"}, + ], + "result": {"value": "137858491849", "type": "i64"}, + }, + { + "group": { + "id": "floating_exception", + "description": "Examples demonstrating exceptional floating point cases", + }, + "args": [ + {"value": "1.5e+100", "type": "fp64"}, + {"value": "1.5e+208", "type": "fp64"}, + ], + "result": {"value": "inf", "type": "fp64"}, + }, + ], + }, + ), + ( + TestFile( + path="test_path", + version="v1.0", + include="/extensions/functions_arithmetic.yaml", + testcases=[ + TestCase( + func_name="max", + base_uri="https://github.com/substrait-io/substrait", + group=CaseGroup( + name="basic: Basic examples without any special cases", + description="", + ), + options={}, + rows=None, + args=[ + CaseLiteral(value=20, type="i8"), + CaseLiteral(value=-3, type="i8"), + CaseLiteral(value=1, type="i8"), + CaseLiteral(value=-10, type="i8"), + CaseLiteral(value=0, type="i8"), + CaseLiteral(value=5, type="i8"), + ], + result=CaseLiteral(value=20, type="i8"), + comment="", + ), + TestCase( + func_name="max", + base_uri="https://github.com/substrait-io/substrait", + group=CaseGroup( + name="basic: Basic examples without any special cases", + description="", + ), + options={}, + rows=None, + args=[ + CaseLiteral(value=-32768, type="i16"), + CaseLiteral(value=32767, type="i16"), + CaseLiteral(value=20000, type="i16"), + CaseLiteral(value=-30000, type="i16"), + ], + result=CaseLiteral(value=32767, type="i16"), + comment="", + ), + ], + ), + { + "base_uri": "https://github.com/substrait-io/substrait/blob/main/extensions/substrait/extensions/functions_arithmetic.yaml", + "function": "max", + "cases": [ + { + "group": { + "id": "basic", + "description": "Basic examples without any special cases", + }, + "args": [ + {"value": "20", "type": "i8"}, + {"value": "-3", "type": "i8"}, + {"value": "1", "type": "i8"}, + {"value": "-10", "type": "i8"}, + {"value": "0", "type": "i8"}, + {"value": "5", "type": "i8"}, + ], + "result": {"value": "20", "type": "i8"}, + }, + { + "group": "basic", + "args": [ + {"value": "-32768", "type": "i16"}, + {"value": "32767", "type": "i16"}, + {"value": "20000", "type": "i16"}, + {"value": "-30000", "type": "i16"}, + ], + "result": {"value": "32767", "type": "i16"}, + }, + ], + }, + ), + ], +) +def test_convert_test_file_to_yaml(test_file, expected_yaml): + result = convert_test_file_to_yaml(test_file) + assert result == expected_yaml