diff --git a/README.md b/README.md index 528071e07..f6f5b82ef 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,12 @@ * Pandas -> Redshift (Parallel) * CSV (S3) -> Pandas (One shot or Batching) * Athena -> Pandas (One shot or Batching) -* CloudWatch Logs Insights -> Pandas (NEW :star:) -* Encrypt Pandas Dataframes on S3 with KMS keys (NEW :star:) +* CloudWatch Logs Insights -> Pandas +* Encrypt Pandas Dataframes on S3 with KMS keys ### PySpark -* PySpark -> Redshift (Parallel) (NEW :star:) +* PySpark -> Redshift (Parallel) +* Register Glue table from Dataframe stored on S3 (NEW :star:) ### General * List S3 objects (Parallel) @@ -35,7 +36,8 @@ * Delete NOT listed S3 objects (Parallel) * Copy listed S3 objects (Parallel) * Get the size of S3 objects (Parallel) -* Get CloudWatch Logs Insights query results (NEW :star:) +* Get CloudWatch Logs Insights query results +* Load partitions on Athena/Glue table (repair table) (NEW :star:) ## Installation @@ -166,6 +168,23 @@ session.spark.to_redshift( ) ``` +#### Register Glue table from Dataframe stored on S3 + +```py3 +dataframe.write \ + .mode("overwrite") \ + .format("parquet") \ + .partitionBy(["year", "month"]) \ + .save(compression="gzip", path="s3://...") +session = awswrangler.Session(spark_session=spark) +session.spark.create_glue_table(dataframe=dataframe, + file_format="parquet", + partition_by=["year", "month"], + path="s3://...", + compression="gzip", + database="my_database") +``` + ### General #### Deleting a bunch of S3 objects (parallel :rocket:) @@ -185,6 +204,13 @@ results = session.cloudwatchlogs.query( ) ``` +#### Load partitions on Athena/Glue table (repair table) + +```py3 +session = awswrangler.Session() +session.athena.repair_table(database="db_name", table="tbl_name") +``` + ## Diving Deep ### Pandas to Redshift Flow @@ -217,13 +243,13 @@ results = session.cloudwatchlogs.query( * Fork the AWS Data Wrangler repository and clone that into your development environment -* Go to the project's directory create a Python's virtual environment for the project (**python -m venv venv && source source venv/bin/activate**) +* Go to the project's directory create a Python's virtual environment for the project (**python -m venv venv && source venv/bin/activate**) * Run **./install-dev.sh** * Go to the *testing* directory -* Configure the parameters.json file with your AWS environment infos (Make sure that your Redshift will not be open for the World!) +* Configure the parameters.json file with your AWS environment infos (Make sure that your Redshift will not be open for the World! Configure your security group to only give access for your IP.) * Deploy the Cloudformation stack **./deploy-cloudformation.sh** diff --git a/awswrangler/athena.py b/awswrangler/athena.py index a7595b6e9..7196880cb 100644 --- a/awswrangler/athena.py +++ b/awswrangler/athena.py @@ -118,3 +118,25 @@ def wait_query(self, query_execution_id): raise QueryCancelled( response["QueryExecution"]["Status"].get("StateChangeReason")) return response + + def repair_table(self, database, table, s3_output=None): + """ + Hive's metastore consistency check + "MSCK REPAIR TABLE table;" + Recovers partitions and data associated with partitions. + Use this statement when you add partitions to the catalog. + It is possible it will take some time to add all partitions. + If this operation times out, it will be in an incomplete state + where only a few partitions are added to the catalog. + + :param database: Glue database name + :param table: Glue table name + :param s3_output: AWS S3 path + :return: Query execution ID + """ + query = f"MSCK REPAIR TABLE {table};" + query_id = self.run_query(query=query, + database=database, + s3_output=s3_output) + self.wait_query(query_execution_id=query_id) + return query_id diff --git a/awswrangler/glue.py b/awswrangler/glue.py index c9c45ae67..f3d2790d5 100644 --- a/awswrangler/glue.py +++ b/awswrangler/glue.py @@ -141,7 +141,7 @@ def metadata_to_glue(self, partition_cols=partition_cols, preserve_index=preserve_index, cast_columns=cast_columns) - table = table if table else Glue._parse_table_name(path) + table = table if table else Glue.parse_table_name(path) table = table.lower().replace(".", "_") if mode == "overwrite": self.delete_table_if_exists(database=database, table=table) @@ -301,7 +301,7 @@ def _build_schema(dataframe, return schema_built, partition_cols_schema_built @staticmethod - def _parse_table_name(path): + def parse_table_name(path): if path[-1] == "/": path = path[:-1] return path.rpartition("/")[2] diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 661f10cf3..6a56095cb 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -576,12 +576,15 @@ def to_s3(self, """ if compression is not None: compression = compression.lower() + file_format = file_format.lower() + if file_format not in ["parquet", "csv"]: + raise UnsupportedFileFormat(file_format) if file_format == "csv": if compression not in Pandas.VALID_CSV_COMPRESSIONS: raise InvalidCompression( f"{compression} isn't a valid CSV compression. Try: {Pandas.VALID_CSV_COMPRESSIONS}" ) - if file_format == "parquet": + elif file_format == "parquet": if compression not in Pandas.VALID_PARQUET_COMPRESSIONS: raise InvalidCompression( f"{compression} isn't a valid PARQUET compression. Try: {Pandas.VALID_PARQUET_COMPRESSIONS}" @@ -641,9 +644,6 @@ def data_to_s3(self, logger.debug(f"procs_io_bound: {procs_io_bound}") if path[-1] == "/": path = path[:-1] - file_format = file_format.lower() - if file_format not in ["parquet", "csv"]: - raise UnsupportedFileFormat(file_format) objects_paths = [] if procs_cpu_bound > 1: bounders = _get_bounders(dataframe=dataframe, diff --git a/awswrangler/spark.py b/awswrangler/spark.py index bfd819981..33d03b70c 100644 --- a/awswrangler/spark.py +++ b/awswrangler/spark.py @@ -6,7 +6,7 @@ from pyspark.sql.functions import floor, rand from pyspark.sql.types import TimestampType -from awswrangler.exceptions import MissingBatchDetected +from awswrangler.exceptions import MissingBatchDetected, UnsupportedFileFormat logger = logging.getLogger(__name__) @@ -142,3 +142,70 @@ def write(pandas_dataframe): ) dataframe.unpersist() self._session.s3.delete_objects(path=path) + + def create_glue_table(self, + database, + path, + dataframe, + file_format, + compression, + table=None, + serde=None, + sep=",", + partition_by=None, + load_partitions=True, + replace_if_exists=True): + """ + Create a Glue metadata table pointing for some dataset stored on AWS S3. + + :param dataframe: PySpark Dataframe + :param file_format: File format (E.g. "parquet", "csv") + :param partition_by: Columns used for partitioning + :param path: AWS S3 path + :param compression: Compression (e.g. gzip, snappy, lzo, etc) + :param sep: Separator token for CSV formats (e.g. ",", ";", "|") + :param serde: Serializer/Deserializer (e.g. "OpenCSVSerDe", "LazySimpleSerDe") + :param database: Glue database name + :param table: Glue table name. If not passed, extracted from the path + :param load_partitions: Load partitions after the table creation + :param replace_if_exists: Drop table and recreates that if already exists + :return: None + """ + file_format = file_format.lower() + if file_format not in ["parquet", "csv"]: + raise UnsupportedFileFormat(file_format) + table = table if table else self._session.glue.parse_table_name(path) + table = table.lower().replace(".", "_") + logger.debug(f"table: {table}") + full_schema = dataframe.dtypes + if partition_by is None: + partition_by = [] + schema = [x for x in full_schema if x[0] not in partition_by] + partitions_schema_tmp = { + x[0]: x[1] + for x in full_schema if x[0] in partition_by + } + partitions_schema = [(x, partitions_schema_tmp[x]) + for x in partition_by] + logger.debug(f"schema: {schema}") + logger.debug(f"partitions_schema: {partitions_schema}") + if replace_if_exists is not None: + self._session.glue.delete_table_if_exists(database=database, + table=table) + extra_args = {} + if file_format == "csv": + extra_args["sep"] = sep + if serde is None: + serde = "OpenCSVSerDe" + extra_args["serde"] = serde + self._session.glue.create_table( + database=database, + table=table, + schema=schema, + partition_cols_schema=partitions_schema, + path=path, + file_format=file_format, + compression=compression, + extra_args=extra_args) + if load_partitions: + self._session.athena.repair_table(database=database, table=table) diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index d5ca976e2..53bdacf5d 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -24,13 +24,13 @@ Step-by-step * Fork the AWS Data Wrangler repository and clone that into your development environment -* Go to the project's directory create a Python's virtual environment for the project (**python -m venv venv && source source venv/bin/activate**) +* Go to the project's directory create a Python's virtual environment for the project (**python -m venv venv && source venv/bin/activate**) * Run **./install-dev.sh** * Go to the *testing* directory -* Configure the parameters.json file with your AWS environment infos (Make sure that your Redshift will not be open for the World!) +* Configure the parameters.json file with your AWS environment infos (Make sure that your Redshift will not be open for the World! Configure your security group to only give access for your IP.) * Deploy the Cloudformation stack **./deploy-cloudformation.sh** diff --git a/docs/source/examples.rst b/docs/source/examples.rst index e36b366af..bc28d854f 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -139,6 +139,24 @@ Loading Pyspark Dataframe to Redshift mode="append", ) +Register Glue table from Dataframe stored on S3 +``````````````````````````````````````````````` + +.. code-block:: python + + dataframe.write \ + .mode("overwrite") \ + .format("parquet") \ + .partitionBy(["year", "month"]) \ + .save(compression="gzip", path="s3://...") + session = awswrangler.Session(spark_session=spark) + session.spark.create_glue_table(dataframe=dataframe, + file_format="parquet", + partition_by=["year", "month"], + path="s3://...", + compression="gzip", + database="my_database") + General ------- @@ -160,3 +178,11 @@ Get CloudWatch Logs Insights query results log_group_names=[LOG_GROUP_NAME], query="fields @timestamp, @message | sort @timestamp desc | limit 5", ) + +Load partitions on Athena/Glue table (repair table) +``````````````````````````````````````````````````` + +.. code-block:: python + + session = awswrangler.Session() + session.athena.repair_table(database="db_name", table="tbl_name") diff --git a/docs/source/index.rst b/docs/source/index.rst index 8298b42a6..63556d7b2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,12 +20,13 @@ Pandas * Pandas -> Redshift (Parallel) * CSV (S3) -> Pandas (One shot or Batching) * Athena -> Pandas (One shot or Batching) -* CloudWatch Logs Insights -> Pandas (NEW) -* Encrypt Pandas Dataframes on S3 with KMS keys (NEW) +* CloudWatch Logs Insights -> Pandas +* Encrypt Pandas Dataframes on S3 with KMS keys PySpark ``````` -* PySpark -> Redshift (Parallel) (NEW) +* PySpark -> Redshift (Parallel) +* Register Glue table from Dataframe stored on S3 (NEW) General ``````` @@ -35,7 +36,8 @@ General * Delete NOT listed S3 objects (Parallel) * Copy listed S3 objects (Parallel) * Get the size of S3 objects (Parallel) -* Get CloudWatch Logs Insights query results (NEW) +* Get CloudWatch Logs Insights query results +* Load partitions on Athena/Glue table (repair table) (NEW) Table Of Contents diff --git a/testing/template.yaml b/testing/template.yaml index 6bb5d0893..de2707e13 100644 --- a/testing/template.yaml +++ b/testing/template.yaml @@ -131,6 +131,7 @@ Resources: Properties: CatalogId: !Ref AWS::AccountId DatabaseInput: + Name: awswrangler_test Description: AWS Data Wrangler Test Arena - Glue Database LogGroup: diff --git a/testing/test_awswrangler/test_spark.py b/testing/test_awswrangler/test_spark.py index 1c13db86c..6abfb8cf2 100644 --- a/testing/test_awswrangler/test_spark.py +++ b/testing/test_awswrangler/test_spark.py @@ -3,6 +3,7 @@ import pytest import boto3 from pyspark.sql import SparkSession +from pyspark.sql.functions import lit, array, create_map, struct from awswrangler import Session @@ -39,6 +40,16 @@ def bucket(session, cloudformation_outputs): session.s3.delete_objects(path=f"s3://{bucket}/") +@pytest.fixture(scope="module") +def database(cloudformation_outputs): + if "GlueDatabaseName" in cloudformation_outputs: + database = cloudformation_outputs["GlueDatabaseName"] + else: + raise Exception( + "You must deploy the test infrastructure using Cloudformation!") + yield database + + @pytest.mark.parametrize( "sample_name", ["nano", "micro", "small"], @@ -52,7 +63,7 @@ def test_read_csv(session, bucket, sample_name): schema = "id BIGINT, name STRING, date DATE" timestamp_format = "dd-MM-yy" elif sample_name == "nano": - schema = "id INTEGER, name STRING, value DOUBLE, date TIMESTAMP, time TIMESTAMP" + schema = "id INTEGER, name STRING, value DOUBLE, date DATE, time TIMESTAMP" timestamp_format = "yyyy-MM-dd" dataframe = session.spark.read_csv(path=path, schema=schema, @@ -69,3 +80,87 @@ def test_read_csv(session, bucket, sample_name): header=True) assert dataframe.count() == dataframe2.count() assert len(list(dataframe.columns)) == len(list(dataframe2.columns)) + + +@pytest.mark.parametrize( + "compression, partition_by", + [("snappy", []), ("gzip", ["date", "value"]), ("none", ["time"])], +) +def test_create_glue_table_parquet(session, bucket, database, compression, + partition_by): + path = "data_samples/nano.csv" + schema = "id INTEGER, name STRING, value DOUBLE, date DATE, time TIMESTAMP" + timestamp_format = "yyyy-MM-dd" + dataframe = session.spark.read_csv(path=path, + schema=schema, + timestampFormat=timestamp_format, + dateFormat=timestamp_format, + header=True) + dataframe = dataframe \ + .withColumn("my_array", array(lit(0), lit(1))) \ + .withColumn("my_struct", struct(lit("text").alias("a"), lit(1).alias("b"))) \ + .withColumn("my_map", create_map(lit("k0"), lit(1.0), lit("k1"), lit(2.0))) + s3_path = f"s3://{bucket}/test" + dataframe.write \ + .mode("overwrite") \ + .format("parquet") \ + .partitionBy(partition_by) \ + .save(compression=compression, path=s3_path) + session.spark.create_glue_table(dataframe=dataframe, + file_format="parquet", + partition_by=partition_by, + path=s3_path, + compression=compression, + database=database, + table="test", + replace_if_exists=True) + query = "select count(*) as counter from test" + pandas_df = session.pandas.read_sql_athena(sql=query, database=database) + assert pandas_df.iloc[0]["counter"] == 5 + query = "select my_array[1] as foo, my_struct.a as boo, my_map['k0'] as bar from test limit 1" + pandas_df = session.pandas.read_sql_athena(sql=query, database=database) + assert pandas_df.iloc[0]["foo"] == 0 + assert pandas_df.iloc[0]["boo"] == "text" + assert pandas_df.iloc[0]["bar"] == 1.0 + + +@pytest.mark.parametrize( + "compression, partition_by, serde", + [("gzip", [], None), ("gzip", ["date", "value"], None), + ("none", ["time"], "OpenCSVSerDe"), ("gzip", [], "LazySimpleSerDe"), + ("gzip", ["date", "value"], "LazySimpleSerDe"), + ("none", ["time"], "LazySimpleSerDe")], +) +def test_create_glue_table_csv(session, bucket, database, compression, + partition_by, serde): + path = "data_samples/nano.csv" + schema = "id INTEGER, name STRING, value DOUBLE, date DATE, time TIMESTAMP" + timestamp_format = "yyyy-MM-dd" + dataframe = session.spark.read_csv(path=path, + schema=schema, + timestampFormat=timestamp_format, + dateFormat=timestamp_format, + header=True) + s3_path = f"s3://{bucket}/test" + dataframe.write \ + .mode("overwrite") \ + .format("csv") \ + .partitionBy(partition_by) \ + .save(compression=compression, path=s3_path) + session.spark.create_glue_table(dataframe=dataframe, + file_format="csv", + partition_by=partition_by, + path=s3_path, + compression=compression, + database=database, + table="test", + serde=serde, + replace_if_exists=True) + query = "select count(*) as counter from test" + pandas_df = session.pandas.read_sql_athena(sql=query, database=database) + assert pandas_df.iloc[0]["counter"] == 5 + query = "select id, name, value from test where cast(id as varchar) = '4' limit 1" + pandas_df = session.pandas.read_sql_athena(sql=query, database=database) + assert int(pandas_df.iloc[0]["id"]) == 4 + assert pandas_df.iloc[0]["name"] == "four" + assert float(pandas_df.iloc[0]["value"]) == 4.0