From a6d845ee8df178c4dcb18847bac1679ff5937352 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Tue, 21 Jan 2025 23:45:15 +0100 Subject: [PATCH] test: more polars enabled tests --- python/tests/test_table_read.py | 192 ++++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 5fe0e413b3..fa07b0992d 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -57,12 +57,47 @@ def test_read_table_with_edge_timestamps(): assert len(list(dataset.get_fragments(predicate))) == 1 +@pytest.mark.polars +def test_read_table_with_edge_timestamps_polars(): + os.environ["POLARS_NEW_MULTIFILE"] = "1" + import polars as pl + + table_path = "../crates/test/tests/data/table_with_edge_timestamps" + dt = DeltaTable(table_path) + dataset = pl.scan_delta(dt).collect().to_arrow() + assert dataset.to_pydict() == { + "BIG_DATE": [ + datetime(9999, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(9999, 12, 30, 0, 0, 0, tzinfo=timezone.utc), + ], + "NORMAL_DATE": [ + datetime(2022, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2022, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + ], + "SOME_VALUE": [1, 2], + } + # Can push down filters to these timestamps. + predicate = ds.field("BIG_DATE") == datetime( + 9999, 12, 31, 0, 0, 0, tzinfo=timezone.utc + ) + assert len(list(dataset.get_fragments(predicate))) == 1 + + def test_read_simple_table_to_dict(): table_path = "../crates/test/tests/data/simple_table" dt = DeltaTable(table_path) assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]} +@pytest.mark.polars +def test_read_simple_table_to_dict_polars(): + import polars as pl + + table_path = "../crates/test/tests/data/simple_table" + dt = DeltaTable(table_path) + assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == {"id": [5, 7, 9]} + + class _SerializableException(BaseException): pass @@ -85,6 +120,24 @@ def _recursively_read_simple_table(executor_class: Type[Executor], depth): future.result() +def _recursively_read_simple_table_polars(executor_class: Type[Executor], depth): + try: + test_read_simple_table_to_dict_polars() + except BaseException as e: # Ideally this would catch `pyo3_runtime.PanicException` but its seems that is not possible. + # Re-raise as something that can be serialized and therefore sent back to parent processes. + raise _SerializableException(f"Seraializatble exception: {e}") from e + + if depth == 0: + return + # We use concurrent.futures.Executors instead of `threading.Thread` or `multiprocessing.Process` to that errors + # are re-rasied in the parent process/thread when we call `future.result()`. + with executor_class(max_workers=1) as executor: + future = executor.submit( + _recursively_read_simple_table_polars, executor_class, depth - 1 + ) + future.result() + + @pytest.mark.parametrize( "executor_class,multiprocessing_start_method,expect_panic", [ @@ -109,6 +162,42 @@ def test_read_simple_in_threads_and_processes( _recursively_read_simple_table(executor_class=executor_class, depth=5) +@pytest.mark.polars +@pytest.mark.parametrize( + "executor_class,multiprocessing_start_method,expect_panic", + [ + (ThreadPoolExecutor, None, False), + (ProcessPoolExecutor, "forkserver", False), + (ProcessPoolExecutor, "spawn", False), + (ProcessPoolExecutor, "fork", True), + ], +) +def test_read_simple_in_threads_and_processes_polars( + executor_class, multiprocessing_start_method, expect_panic +): + if multiprocessing_start_method is not None: + multiprocessing.set_start_method(multiprocessing_start_method, force=True) + if expect_panic: + with pytest.raises( + _SerializableException, + match="The tokio runtime does not support forked processes", + ): + _recursively_read_simple_table_polars( + executor_class=executor_class, depth=5 + ) + else: + _recursively_read_simple_table_polars(executor_class=executor_class, depth=5) + + +@pytest.mark.polars +def test_read_simple_table_by_version_to_dict_polars(): + import polars as pl + + table_path = "../crates/test/tests/data/delta-0.2.0" + dt = DeltaTable(table_path, version=2) + assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == {"value": [1, 2, 3]} + + def test_read_simple_table_by_version_to_dict(): table_path = "../crates/test/tests/data/delta-0.2.0" dt = DeltaTable(table_path, version=2) @@ -218,6 +307,18 @@ def test_read_simple_table_update_incremental(): assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]} +@pytest.mark.polars +def test_read_simple_table_update_incremental_polars(): + import polars as pl + + table_path = "../crates/test/tests/data/simple_table" + dt = DeltaTable(table_path, version=0) + data = pl.scan_delta(dt).collect().to_arrow() + assert data.to_pydict() == {"id": [0, 1, 2, 3, 4]} + dt.update_incremental() + assert data.to_pydict() == {"id": [5, 7, 9]} + + def test_read_simple_table_file_sizes_failure(mocker): table_path = "../crates/test/tests/data/simple_table" dt = DeltaTable(table_path) @@ -235,6 +336,22 @@ def test_read_simple_table_file_sizes_failure(mocker): dt.to_pyarrow_dataset().to_table().to_pydict() +@pytest.mark.polars +def test_read_partitioned_table_to_dict_polars(): + os.environ["POLARS_NEW_MULTIFILE"] = "1" + import polars as pl + + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["1", "2", "3", "6", "7", "5", "4"], + "year": ["2020", "2020", "2020", "2021", "2021", "2021", "2021"], + "month": ["1", "2", "2", "12", "12", "12", "4"], + "day": ["1", "3", "5", "20", "20", "4", "5"], + } + assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == expected + + def test_read_partitioned_table_to_dict(): table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" dt = DeltaTable(table_path) @@ -261,6 +378,27 @@ def test_read_partitioned_table_with_partitions_filters_to_dict(): assert dt.to_pyarrow_dataset(partitions).to_table().to_pydict() == expected +@pytest.mark.polars +def test_read_partitioned_table_with_filters_to_dict_polars(): + os.environ["POLARS_NEW_MULTIFILE"] = "1" + import polars as pl + + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + partitions = pl.col("year") == "2021" + expected = { + "value": ["6", "7", "5", "4"], + "year": ["2021", "2021", "2021", "2021"], + "month": ["12", "12", "12", "4"], + "day": ["20", "20", "4", "5"], + } + + assert ( + pl.scan_delta(dt).filter(partitions).collect().to_arrow().to_pydict() + == expected + ) + + def test_read_empty_delta_table_after_delete(): table_path = "../crates/test/tests/data/delta-0.8-empty" dt = DeltaTable(table_path) @@ -269,6 +407,17 @@ def test_read_empty_delta_table_after_delete(): assert dt.to_pyarrow_dataset().to_table().to_pydict() == expected +@pytest.mark.polars +def test_read_empty_delta_table_after_delete_polars(): + import polars as pl + + table_path = "../crates/test/tests/data/delta-0.8-empty" + dt = DeltaTable(table_path) + expected = {"column": []} + + assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == expected + + def test_read_table_with_column_subset(): table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" dt = DeltaTable(table_path) @@ -282,6 +431,22 @@ def test_read_table_with_column_subset(): ) +@pytest.mark.polars +def test_read_table_with_column_subset_polars(): + import polars as pl + + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["1", "2", "3", "6", "7", "5", "4"], + "day": ["1", "3", "5", "20", "20", "4", "5"], + } + assert ( + pl.scan_delta(dt).select(["value", "day"]).collect().to_arrow().to_pydict() + == expected + ) + + def test_read_table_as_category(): table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" dt = DeltaTable(table_path) @@ -359,6 +524,33 @@ def test_read_special_partition(): assert set(table["x"].to_pylist()) == {"A/A", "B B"} +@pytest.mark.polars +def test_read_special_partition_polars(): + os.environ["POLARS_NEW_MULTIFILE"] = "1" + import polars as pl + + table_path = "../crates/test/tests/data/delta-0.8.0-special-partition" + + dt = DeltaTable(table_path) + + file1 = ( + r"x=A%2FA/part-00007-b350e235-2832-45df-9918-6cab4f7578f7.c000.snappy.parquet" + ) + file2 = ( + r"x=B%20B/part-00015-e9abbc6f-85e9-457b-be8e-e9f5b8a22890.c000.snappy.parquet" + ) + + assert set(dt.files()) == {file1, file2} + + assert dt.files([("x", "=", "A/A")]) == [file1] + assert dt.files([("x", "=", "B B")]) == [file2] + assert dt.files([("x", "=", "c")]) == [] + + table = pl.scan_delta(dt).collect().to_arrow() + + assert set(table["x"].to_pylist()) == {"A/A", "B B"} + + def test_read_partitioned_table_metadata(): table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" dt = DeltaTable(table_path)