From 23e1619a7edc6c75f2ba7835c712109835660796 Mon Sep 17 00:00:00 2001 From: Igor Tavares Date: Thu, 19 Sep 2019 16:13:11 -0300 Subject: [PATCH] Add compression for Pandas.to_parquet --- awswrangler/exceptions.py | 4 ++ awswrangler/glue.py | 48 ++++++++++++++------- awswrangler/pandas.py | 56 +++++++++++++++++++++++-- testing/test_awswrangler/test_pandas.py | 23 ++++++++++ 4 files changed, 113 insertions(+), 18 deletions(-) diff --git a/awswrangler/exceptions.py b/awswrangler/exceptions.py index b87cc4bd3..e82e3c11a 100644 --- a/awswrangler/exceptions.py +++ b/awswrangler/exceptions.py @@ -72,3 +72,7 @@ class InvalidSerDe(Exception): class ApiError(Exception): pass + + +class InvalidCompression(Exception): + pass diff --git a/awswrangler/glue.py b/awswrangler/glue.py index d1597df20..c224a5ce0 100644 --- a/awswrangler/glue.py +++ b/awswrangler/glue.py @@ -133,12 +133,14 @@ def metadata_to_glue(self, partition_cols=None, preserve_index=True, mode="append", + compression=None, cast_columns=None, extra_args=None): schema, partition_cols_schema = Glue._build_schema( dataframe=dataframe, partition_cols=partition_cols, - preserve_index=preserve_index) + preserve_index=preserve_index, + cast_columns=cast_columns) table = table if table else Glue._parse_table_name(path) table = table.lower().replace(".", "_") if mode == "overwrite": @@ -151,6 +153,7 @@ def metadata_to_glue(self, partition_cols_schema=partition_cols_schema, path=path, file_format=file_format, + compression=compression, extra_args=extra_args) if partition_cols: partitions_tuples = Glue._parse_partitions_tuples( @@ -159,6 +162,7 @@ def metadata_to_glue(self, table=table, partition_paths=partitions_tuples, file_format=file_format, + compression=compression, extra_args=extra_args) def delete_table_if_exists(self, database, table): @@ -180,16 +184,18 @@ def create_table(self, schema, path, file_format, + compression, partition_cols_schema=None, extra_args=None): if file_format == "parquet": table_input = Glue.parquet_table_definition( - table, partition_cols_schema, schema, path) + table, partition_cols_schema, schema, path, compression) elif file_format == "csv": table_input = Glue.csv_table_definition(table, partition_cols_schema, schema, path, + compression, extra_args=extra_args) else: raise UnsupportedFileFormat(file_format) @@ -227,15 +233,21 @@ def get_connection_details(self, name): Name=name, HidePassword=False)["Connection"] @staticmethod - def _extract_pyarrow_schema(dataframe, preserve_index): + def _extract_pyarrow_schema(dataframe, preserve_index, cast_columns=None): cols = [] cols_dtypes = {} schema = [] + casted = [] + if cast_columns is not None: + casted = cast_columns.keys() + for name, dtype in dataframe.dtypes.to_dict().items(): dtype = str(dtype) - if str(dtype) == "Int64": + if dtype == "Int64": cols_dtypes[name] = "int64" + elif name in casted: + cols_dtypes[name] = cast_columns[name] else: cols.append(name) @@ -252,13 +264,18 @@ def _extract_pyarrow_schema(dataframe, preserve_index): return schema @staticmethod - def _build_schema(dataframe, partition_cols, preserve_index): + def _build_schema(dataframe, + partition_cols, + preserve_index, + cast_columns={}): logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}") if not partition_cols: partition_cols = [] pyarrow_schema = Glue._extract_pyarrow_schema( - dataframe=dataframe, preserve_index=preserve_index) + dataframe=dataframe, + preserve_index=preserve_index, + cast_columns=cast_columns) schema_built = [] partition_cols_types = {} @@ -285,9 +302,10 @@ def _parse_table_name(path): @staticmethod def csv_table_definition(table, partition_cols_schema, schema, path, - extra_args): + compression, extra_args): if not partition_cols_schema: partition_cols_schema = [] + compressed = False if compression is None else True sep = extra_args["sep"] if "sep" in extra_args else "," serde = extra_args.get("serde") if serde == "OpenCSVSerDe": @@ -322,7 +340,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path, "EXTERNAL_TABLE", "Parameters": { "classification": "csv", - "compressionType": "none", + "compressionType": str(compression).lower(), "typeOfData": "file", "delimiter": sep, "columnsOrdered": "true", @@ -337,7 +355,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path, "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", - "Compressed": False, + "Compressed": True, "NumberOfBuckets": -1, "SerdeInfo": { "Parameters": param, @@ -347,7 +365,7 @@ def csv_table_definition(table, partition_cols_schema, schema, path, "SortColumns": [], "Parameters": { "classification": "csv", - "compressionType": "none", + "compressionType": str(compression).lower(), "typeOfData": "file", "delimiter": sep, "columnsOrdered": "true", @@ -386,9 +404,11 @@ def csv_partition_definition(partition, extra_args): } @staticmethod - def parquet_table_definition(table, partition_cols_schema, schema, path): + def parquet_table_definition(table, partition_cols_schema, schema, path, + compression): if not partition_cols_schema: partition_cols_schema = [] + compressed = False if compression is None else True return { "Name": table, @@ -400,7 +420,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path): "EXTERNAL_TABLE", "Parameters": { "classification": "parquet", - "compressionType": "none", + "compressionType": str(compression).lower(), "typeOfData": "file", }, "StorageDescriptor": { @@ -413,7 +433,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path): "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", "OutputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "Compressed": False, + "Compressed": compressed, "NumberOfBuckets": -1, "SerdeInfo": { "SerializationLibrary": @@ -427,7 +447,7 @@ def parquet_table_definition(table, partition_cols_schema, schema, path): "Parameters": { "CrawlerSchemaDeserializerVersion": "1.0", "classification": "parquet", - "compressionType": "none", + "compressionType": str(compression).lower(), "typeOfData": "file", }, }, diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 50846c92a..152a10996 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -11,7 +11,8 @@ from pyarrow import parquet from awswrangler.exceptions import UnsupportedWriteMode, UnsupportedFileFormat,\ - AthenaQueryError, EmptyS3Object, LineTerminatorNotFound, EmptyDataframe, InvalidSerDe + AthenaQueryError, EmptyS3Object, LineTerminatorNotFound, EmptyDataframe, \ + InvalidSerDe, InvalidCompression from awswrangler.utils import calculate_bounders from awswrangler import s3 @@ -28,6 +29,8 @@ def _get_bounders(dataframe, num_partitions): class Pandas: VALID_CSV_SERDES = ["OpenCSVSerDe", "LazySimpleSerDe"] + VALID_CSV_COMPRESSIONS = [None] + VALID_PARQUET_COMPRESSIONS = [None, "snappy", "gzip"] def __init__(self, session): self._session = session @@ -479,6 +482,7 @@ def to_csv( partition_cols=partition_cols, preserve_index=preserve_index, mode=mode, + compression=None, procs_cpu_bound=procs_cpu_bound, procs_io_bound=procs_io_bound, extra_args=extra_args) @@ -491,6 +495,7 @@ def to_parquet(self, partition_cols=None, preserve_index=True, mode="append", + compression="snappy", procs_cpu_bound=None, procs_io_bound=None, cast_columns=None): @@ -505,6 +510,7 @@ def to_parquet(self, :param partition_cols: List of columns names that will be partitions on S3 :param preserve_index: Should preserve index on S3? :param mode: "append", "overwrite", "overwrite_partitions" + :param compression: None, snappy, gzip, lzo :param procs_cpu_bound: Number of cores used for CPU bound tasks :param procs_io_bound: Number of cores used for I/O bound tasks :param cast_columns: Dictionary of columns names and Arrow types to be casted. (E.g. {"col name": "int64", "col2 name": "int32"}) @@ -518,6 +524,7 @@ def to_parquet(self, partition_cols=partition_cols, preserve_index=preserve_index, mode=mode, + compression=compression, procs_cpu_bound=procs_cpu_bound, procs_io_bound=procs_io_bound, cast_columns=cast_columns) @@ -531,6 +538,7 @@ def to_s3(self, partition_cols=None, preserve_index=True, mode="append", + compression=None, procs_cpu_bound=None, procs_io_bound=None, cast_columns=None, @@ -547,12 +555,25 @@ def to_s3(self, :param partition_cols: List of columns names that will be partitions on S3 :param preserve_index: Should preserve index on S3? :param mode: "append", "overwrite", "overwrite_partitions" + :param compression: None, gzip, snappy, etc :param procs_cpu_bound: Number of cores used for CPU bound tasks :param procs_io_bound: Number of cores used for I/O bound tasks :param cast_columns: Dictionary of columns indexes and Arrow types to be casted. (E.g. {2: "int64", 5: "int32"}) (Only for "parquet" file_format) :param extra_args: Extra arguments specific for each file formats (E.g. "sep" for CSV) :return: List of objects written on S3 """ + if compression is not None: + compression = compression.lower() + if file_format == "csv": + if compression not in Pandas.VALID_CSV_COMPRESSIONS: + raise InvalidCompression( + f"{compression} isn't a valid CSV compression. Try: {Pandas.VALID_CSV_COMPRESSIONS}" + ) + if file_format == "parquet": + if compression not in Pandas.VALID_PARQUET_COMPRESSIONS: + raise InvalidCompression( + f"{compression} isn't a valid PARQUET compression. Try: {Pandas.VALID_PARQUET_COMPRESSIONS}" + ) if dataframe.empty: raise EmptyDataframe() if not partition_cols: @@ -568,6 +589,7 @@ def to_s3(self, preserve_index=preserve_index, file_format=file_format, mode=mode, + compression=compression, procs_cpu_bound=procs_cpu_bound, procs_io_bound=procs_io_bound, cast_columns=cast_columns, @@ -582,6 +604,7 @@ def to_s3(self, preserve_index=preserve_index, file_format=file_format, mode=mode, + compression=compression, cast_columns=cast_columns, extra_args=extra_args) return objects_paths @@ -593,6 +616,7 @@ def data_to_s3(self, partition_cols=None, preserve_index=True, mode="append", + compression=None, procs_cpu_bound=None, procs_io_bound=None, cast_columns=None, @@ -619,7 +643,7 @@ def data_to_s3(self, proc = mp.Process( target=self._data_to_s3_dataset_writer_remote, args=(send_pipe, dataframe.iloc[bounder[0]:bounder[1], :], - path, partition_cols, preserve_index, + path, partition_cols, preserve_index, compression, self._session.primitives, file_format, cast_columns, extra_args), ) @@ -637,6 +661,7 @@ def data_to_s3(self, path=path, partition_cols=partition_cols, preserve_index=preserve_index, + compression=compression, session_primitives=self._session.primitives, file_format=file_format, cast_columns=cast_columns, @@ -658,6 +683,7 @@ def _data_to_s3_dataset_writer(dataframe, path, partition_cols, preserve_index, + compression, session_primitives, file_format, cast_columns=None, @@ -668,6 +694,7 @@ def _data_to_s3_dataset_writer(dataframe, dataframe=dataframe, path=path, preserve_index=preserve_index, + compression=compression, session_primitives=session_primitives, file_format=file_format, cast_columns=cast_columns, @@ -686,6 +713,7 @@ def _data_to_s3_dataset_writer(dataframe, dataframe=subgroup, path=prefix, preserve_index=preserve_index, + compression=compression, session_primitives=session_primitives, file_format=file_format, cast_columns=cast_columns, @@ -699,6 +727,7 @@ def _data_to_s3_dataset_writer_remote(send_pipe, path, partition_cols, preserve_index, + compression, session_primitives, file_format, cast_columns=None, @@ -709,6 +738,7 @@ def _data_to_s3_dataset_writer_remote(send_pipe, path=path, partition_cols=partition_cols, preserve_index=preserve_index, + compression=compression, session_primitives=session_primitives, file_format=file_format, cast_columns=cast_columns, @@ -719,6 +749,7 @@ def _data_to_s3_dataset_writer_remote(send_pipe, def _data_to_s3_object_writer(dataframe, path, preserve_index, + compression, session_primitives, file_format, cast_columns=None, @@ -726,10 +757,21 @@ def _data_to_s3_object_writer(dataframe, fs = s3.get_fs(session_primitives=session_primitives) fs = pyarrow.filesystem._ensure_filesystem(fs) s3.mkdir_if_not_exists(fs, path) + + if compression is None: + compression_end = "" + elif compression == "snappy": + compression_end = ".snappy" + elif compression == "gzip": + compression_end = ".gz" + else: + raise InvalidCompression(compression) + + guid = pyarrow.compat.guid() if file_format == "parquet": - outfile = pyarrow.compat.guid() + ".parquet" + outfile = f"{guid}.parquet{compression_end}" elif file_format == "csv": - outfile = pyarrow.compat.guid() + ".csv" + outfile = f"{guid}.csv{compression_end}" else: raise UnsupportedFileFormat(file_format) object_path = "/".join([path, outfile]) @@ -737,6 +779,7 @@ def _data_to_s3_object_writer(dataframe, Pandas.write_parquet_dataframe(dataframe=dataframe, path=object_path, preserve_index=preserve_index, + compression=compression, fs=fs, cast_columns=cast_columns, extra_args=extra_args) @@ -744,6 +787,7 @@ def _data_to_s3_object_writer(dataframe, Pandas.write_csv_dataframe(dataframe=dataframe, path=object_path, preserve_index=preserve_index, + compression=compression, fs=fs, extra_args=extra_args) return object_path @@ -752,6 +796,7 @@ def _data_to_s3_object_writer(dataframe, def write_csv_dataframe(dataframe, path, preserve_index, + compression, fs, extra_args=None): csv_extra_args = {} @@ -770,6 +815,7 @@ def write_csv_dataframe(dataframe, dataframe.to_csv(None, header=False, index=preserve_index, + compression=compression, **csv_extra_args), "utf-8") with fs.open(path, "wb") as f: f.write(csv_buffer) @@ -778,6 +824,7 @@ def write_csv_dataframe(dataframe, def write_parquet_dataframe(dataframe, path, preserve_index, + compression, fs, cast_columns, extra_args=None): @@ -804,6 +851,7 @@ def write_parquet_dataframe(dataframe, with fs.open(path, "wb") as f: parquet.write_table(table, f, + compression=compression, coerce_timestamps="ms", flavor="spark") for col in casted_in_pandas: diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index ad28b01cb..f1bd57dde 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -611,3 +611,26 @@ def test_to_csv_serde_exception( preserve_index=False, mode="overwrite", serde="foo") + + +@pytest.mark.parametrize("compression", [None, "snappy", "gzip"]) +def test_to_parquet_compressed(session, bucket, database, compression): + dataframe = pandas.read_csv("data_samples/small.csv") + session.pandas.to_parquet(dataframe=dataframe, + database=database, + path=f"s3://{bucket}/test/", + preserve_index=False, + mode="overwrite", + compression=compression, + procs_cpu_bound=1) + dataframe2 = None + for counter in range(10): + dataframe2 = session.pandas.read_sql_athena(sql="select * from test", + database=database) + if len(dataframe.index) == len(dataframe2.index): + break + sleep(2) + assert len(dataframe.index) == len(dataframe2.index) + assert len(list(dataframe.columns)) == len(list(dataframe2.columns)) + assert dataframe[dataframe["id"] == 1].iloc[0]["name"] == dataframe2[ + dataframe2["id"] == 1].iloc[0]["name"]