diff --git a/README.md b/README.md index 4ccc13e33..bf1c66381 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ > Utility belt to handle data on AWS. -[![Release](https://img.shields.io/badge/release-0.0.23-brightgreen.svg)](https://pypi.org/project/awswrangler/) +[![Release](https://img.shields.io/badge/release-0.0.24-brightgreen.svg)](https://pypi.org/project/awswrangler/) [![Downloads](https://img.shields.io/pypi/dm/awswrangler.svg)](https://pypi.org/project/awswrangler/) [![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7-brightgreen.svg)](https://pypi.org/project/awswrangler/) [![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/en/latest/?badge=latest) diff --git a/awswrangler/__version__.py b/awswrangler/__version__.py index 09ab247c4..2dd7e260d 100644 --- a/awswrangler/__version__.py +++ b/awswrangler/__version__.py @@ -1,4 +1,4 @@ __title__ = "awswrangler" __description__ = "Utility belt to handle data on AWS." -__version__ = "0.0.23" +__version__ = "0.0.24" __license__ = "Apache License 2.0" diff --git a/awswrangler/athena.py b/awswrangler/athena.py index b55e59c3f..c9c9686a3 100644 --- a/awswrangler/athena.py +++ b/awswrangler/athena.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Optional, Any, Iterator, Union +from typing import Dict, List, Tuple, Optional, Any, Iterator from time import sleep import logging import re @@ -25,7 +25,6 @@ def get_query_columns_metadata(self, query_execution_id: str) -> Dict[str, str]: """ response: Dict = self._client_athena.get_query_results(QueryExecutionId=query_execution_id, MaxResults=1) col_info: List[Dict[str, str]] = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] - logger.debug(f"col_info: {col_info}") return {x["Name"]: x["Type"] for x in col_info} def create_athena_bucket(self): @@ -42,7 +41,13 @@ def create_athena_bucket(self): s3_resource.Bucket(s3_output) return s3_output - def run_query(self, query: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None) -> str: + def run_query(self, + query: str, + database: Optional[str] = None, + s3_output: Optional[str] = None, + workgroup: Optional[str] = None, + encryption: Optional[str] = None, + kms_key: Optional[str] = None) -> str: """ Run a SQL Query against AWS Athena P.S All default values will be inherited from the Session() @@ -55,7 +60,7 @@ def run_query(self, query: str, database: Optional[str] = None, s3_output: Optio :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. :return: Query execution ID """ - args: Dict[str, Union[str, Dict[str, Union[str, Dict[str, str]]]]] = {"QueryString": query} + args: Dict[str, Any] = {"QueryString": query} # s3_output if s3_output is None: @@ -71,7 +76,9 @@ def run_query(self, query: str, database: Optional[str] = None, s3_output: Optio if kms_key is not None: args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key elif self._session.athena_encryption is not None: - args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": self._session.athena_encryption} + args["ResultConfiguration"]["EncryptionConfiguration"] = { + "EncryptionOption": self._session.athena_encryption + } if self._session.athena_kms_key is not None: args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = self._session.athena_kms_key @@ -113,7 +120,13 @@ def wait_query(self, query_execution_id): raise QueryCancelled(response["QueryExecution"]["Status"].get("StateChangeReason")) return response - def repair_table(self, table: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None): + def repair_table(self, + table: str, + database: Optional[str] = None, + s3_output: Optional[str] = None, + workgroup: Optional[str] = None, + encryption: Optional[str] = None, + kms_key: Optional[str] = None): """ Hive's metastore consistency check "MSCK REPAIR TABLE table;" @@ -133,7 +146,12 @@ def repair_table(self, table: str, database: Optional[str] = None, s3_output: Op :return: Query execution ID """ query = f"MSCK REPAIR TABLE {table};" - query_id = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key) + query_id = self.run_query(query=query, + database=database, + s3_output=s3_output, + workgroup=workgroup, + encryption=encryption, + kms_key=kms_key) self.wait_query(query_execution_id=query_id) return query_id @@ -174,7 +192,13 @@ def get_results(self, query_execution_id: str) -> Iterator[Dict[str, Any]]: yield row next_token = res.get("NextToken") - def query(self, query: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None) -> Iterator[Dict[str, Any]]: + def query(self, + query: str, + database: Optional[str] = None, + s3_output: Optional[str] = None, + workgroup: Optional[str] = None, + encryption: Optional[str] = None, + kms_key: Optional[str] = None) -> Iterator[Dict[str, Any]]: """ Run a SQL Query against AWS Athena and return the result as a Iterator of lists P.S All default values will be inherited from the Session() @@ -187,7 +211,12 @@ def query(self, query: str, database: Optional[str] = None, s3_output: Optional[ :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. :return: Query execution ID """ - query_id: str = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key) + query_id: str = self.run_query(query=query, + database=database, + s3_output=s3_output, + workgroup=workgroup, + encryption=encryption, + kms_key=kms_key) self.wait_query(query_execution_id=query_id) return self.get_results(query_execution_id=query_id) diff --git a/awswrangler/data_types.py b/awswrangler/data_types.py index 24945b053..bb6851b23 100644 --- a/awswrangler/data_types.py +++ b/awswrangler/data_types.py @@ -304,7 +304,8 @@ def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, s return {name: func(dtype) for name, dtype in schema} -def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame, preserve_index: bool, +def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame, + preserve_index: bool, indexes_position: str = "right") -> List[Tuple[str, str]]: """ Extract the related Pyarrow schema from any Pandas DataFrame diff --git a/awswrangler/emr.py b/awswrangler/emr.py index 589991a92..d6cb516a7 100644 --- a/awswrangler/emr.py +++ b/awswrangler/emr.py @@ -480,7 +480,10 @@ def submit_step(self, logger.info(f"response: \n{json.dumps(response, default=str, indent=4)}") return response["StepIds"][0] - def build_step(self, name: str, command: str, action_on_failure: str = "CONTINUE", + def build_step(self, + name: str, + command: str, + action_on_failure: str = "CONTINUE", script: bool = False) -> Dict[str, Collection[str]]: """ Build the Step dictionary diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 17987fa93..6320354d3 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -7,6 +7,7 @@ import csv from datetime import datetime from decimal import Decimal +from ast import literal_eval from botocore.exceptions import ClientError, HTTPClientError # type: ignore import pandas as pd # type: ignore @@ -46,24 +47,24 @@ def _parse_path(path): return parts[0], parts[2] def read_csv( - self, - path, - max_result_size=None, - header="infer", - names=None, - usecols=None, - dtype=None, - sep=",", - thousands=None, - decimal=".", - lineterminator="\n", - quotechar='"', - quoting=csv.QUOTE_MINIMAL, - escapechar=None, - parse_dates: Union[bool, Dict, List] = False, - infer_datetime_format=False, - encoding="utf-8", - converters=None, + self, + path, + max_result_size=None, + header="infer", + names=None, + usecols=None, + dtype=None, + sep=",", + thousands=None, + decimal=".", + lineterminator="\n", + quotechar='"', + quoting=csv.QUOTE_MINIMAL, + escapechar=None, + parse_dates: Union[bool, Dict, List] = False, + infer_datetime_format=False, + encoding="utf-8", + converters=None, ): """ Read CSV file from AWS S3 using optimized strategies. @@ -137,25 +138,25 @@ def read_csv( @staticmethod def _read_csv_iterator( - client_s3, - bucket_name, - key_path, - max_result_size=200_000_000, # 200 MB - header="infer", - names=None, - usecols=None, - dtype=None, - sep=",", - thousands=None, - decimal=".", - lineterminator="\n", - quotechar='"', - quoting=csv.QUOTE_MINIMAL, - escapechar=None, - parse_dates: Union[bool, Dict, List] = False, - infer_datetime_format=False, - encoding="utf-8", - converters=None, + client_s3, + bucket_name, + key_path, + max_result_size=200_000_000, # 200 MB + header="infer", + names=None, + usecols=None, + dtype=None, + sep=",", + thousands=None, + decimal=".", + lineterminator="\n", + quotechar='"', + quoting=csv.QUOTE_MINIMAL, + escapechar=None, + parse_dates: Union[bool, Dict, List] = False, + infer_datetime_format=False, + encoding="utf-8", + converters=None, ): """ Read CSV file from AWS S3 using optimized strategies. @@ -350,24 +351,24 @@ def _find_terminator(body, sep, quoting, quotechar, lineterminator): @staticmethod def _read_csv_once( - client_s3, - bucket_name, - key_path, - header="infer", - names=None, - usecols=None, - dtype=None, - sep=",", - thousands=None, - decimal=".", - lineterminator="\n", - quotechar='"', - quoting=0, - escapechar=None, - parse_dates: Union[bool, Dict, List] = False, - infer_datetime_format=False, - encoding=None, - converters=None, + client_s3, + bucket_name, + key_path, + header="infer", + names=None, + usecols=None, + dtype=None, + sep=",", + thousands=None, + decimal=".", + lineterminator="\n", + quotechar='"', + quoting=0, + escapechar=None, + parse_dates: Union[bool, Dict, List] = False, + infer_datetime_format=False, + encoding=None, + converters=None, ): """ Read CSV file from AWS S3 using optimized strategies. @@ -420,9 +421,17 @@ def _read_csv_once( @staticmethod def _list_parser(value: str) -> List[Union[int, float, str, None]]: + # try resolve with a simple literal_eval + try: + return literal_eval(value) + except ValueError: + pass # keep trying + + # sanity check if len(value) <= 1: return [] - items: List[None, str] = [None if x == "null" else x for x in value[1:-1].split(", ")] + + items: List[Union[None, str]] = [None if x == "null" else x for x in value[1:-1].split(", ")] array_type: Optional[type] = None # check if all values are integers @@ -481,8 +490,14 @@ def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], Lis logger.debug(f"converters: {converters}") return dtype, parse_timestamps, parse_dates, converters - def read_sql_athena(self, sql, database=None, s3_output=None, max_result_size=None, workgroup=None, - encryption=None, kms_key=None): + def read_sql_athena(self, + sql, + database=None, + s3_output=None, + max_result_size=None, + workgroup=None, + encryption=None, + kms_key=None): """ Executes any SQL query on AWS Athena and return a Dataframe of the result. P.S. If max_result_size is passed, then a iterator of Dataframes is returned. @@ -499,7 +514,12 @@ def read_sql_athena(self, sql, database=None, s3_output=None, max_result_size=No """ if not s3_output: s3_output = self._session.athena.create_athena_bucket() - query_execution_id = self._session.athena.run_query(query=sql, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key) + query_execution_id = self._session.athena.run_query(query=sql, + database=database, + s3_output=s3_output, + workgroup=workgroup, + encryption=encryption, + kms_key=kms_key) query_response = self._session.athena.wait_query(query_execution_id=query_execution_id) if query_response["QueryExecution"]["Status"]["State"] in ["FAILED", "CANCELLED"]: reason = query_response["QueryExecution"]["Status"]["StateChangeReason"] @@ -532,19 +552,19 @@ def _apply_dates_to_generator(generator, parse_dates): yield df def to_csv( - self, - dataframe, - path, - sep=",", - serde="OpenCSVSerDe", - database: Optional[str] = None, - table=None, - partition_cols=None, - preserve_index=True, - mode="append", - procs_cpu_bound=None, - procs_io_bound=None, - inplace=True, + self, + dataframe, + path, + sep=",", + serde="OpenCSVSerDe", + database: Optional[str] = None, + table=None, + partition_cols=None, + preserve_index=True, + mode="append", + procs_cpu_bound=None, + procs_io_bound=None, + inplace=True, ): """ Write a Pandas Dataframe as CSV files on S3 @@ -806,7 +826,7 @@ def _data_to_s3_dataset_writer(dataframe, for keys, subgroup in dataframe.groupby(partition_cols): subgroup = subgroup.drop(partition_cols, axis="columns") if not isinstance(keys, tuple): - keys = (keys,) + keys = (keys, ) subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)]) prefix = "/".join([path, subdir]) object_path = Pandas._data_to_s3_object_writer(dataframe=subgroup, diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index c9de16321..319af0c51 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -117,8 +117,11 @@ def get_connection(self, glue_connection): conn = self.generate_connection(database=database, host=host, port=int(port), user=user, password=password) return conn - def write_load_manifest(self, manifest_path: str, objects_paths: List[str], procs_io_bound: Optional[int] = None - ) -> Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]]: + def write_load_manifest( + self, + manifest_path: str, + objects_paths: List[str], + procs_io_bound: Optional[int] = None) -> Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]]: objects_sizes: Dict[str, int] = self._session.s3.get_objects_sizes(objects_paths=objects_paths, procs_io_bound=procs_io_bound) manifest: Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]] = {"entries": []} diff --git a/awswrangler/session.py b/awswrangler/session.py index da136d3a3..f8051efe4 100644 --- a/awswrangler/session.py +++ b/awswrangler/session.py @@ -31,26 +31,24 @@ class Session: PROCS_IO_BOUND_FACTOR = 2 - def __init__( - self, - boto3_session=None, - profile_name: Optional[str] = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - region_name: Optional[str] = None, - botocore_max_retries: int = 40, - s3_additional_kwargs=None, - spark_context=None, - spark_session=None, - procs_cpu_bound: int = os.cpu_count(), - procs_io_bound: int = os.cpu_count() * PROCS_IO_BOUND_FACTOR, - athena_workgroup: str = "primary", - athena_s3_output: Optional[str] = None, - athena_encryption: Optional[str] = "SSE_S3", - athena_kms_key: Optional[str] = None, - athena_database: str = "default" - ): + def __init__(self, + boto3_session=None, + profile_name: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + region_name: Optional[str] = None, + botocore_max_retries: int = 40, + s3_additional_kwargs=None, + spark_context=None, + spark_session=None, + procs_cpu_bound: Optional[int] = None, + procs_io_bound: Optional[int] = None, + athena_workgroup: str = "primary", + athena_s3_output: Optional[str] = None, + athena_encryption: Optional[str] = "SSE_S3", + athena_kms_key: Optional[str] = None, + athena_database: str = "default"): """ Most parameters inherit from Boto3 or Pyspark. https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html @@ -75,8 +73,10 @@ def __init__( :param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID. """ self._profile_name: Optional[str] = (boto3_session.profile_name if boto3_session else profile_name) - self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key if boto3_session else aws_access_key_id) - self._aws_secret_access_key: Optional[str] = (boto3_session.get_credentials().secret_key if boto3_session else aws_secret_access_key) + self._aws_access_key_id: Optional[str] = (boto3_session.get_credentials().access_key + if boto3_session else aws_access_key_id) + self._aws_secret_access_key: Optional[str] = (boto3_session.get_credentials().secret_key + if boto3_session else aws_secret_access_key) self._botocore_max_retries: int = botocore_max_retries self._botocore_config = Config(retries={"max_attempts": self._botocore_max_retries}) self._aws_session_token: Optional[str] = aws_session_token @@ -84,8 +84,9 @@ def __init__( self._s3_additional_kwargs = s3_additional_kwargs self._spark_context = spark_context self._spark_session = spark_session - self._procs_cpu_bound: int = procs_cpu_bound - self._procs_io_bound: int = procs_io_bound + cpus = os.cpu_count() + self._procs_cpu_bound: int = 1 if cpus is None else cpus if procs_cpu_bound is None else procs_cpu_bound + self._procs_io_bound: int = 1 if cpus is None else cpus * Session.PROCS_IO_BOUND_FACTOR if procs_io_bound is None else procs_io_bound self._athena_workgroup: str = athena_workgroup self._athena_s3_output: Optional[str] = athena_s3_output self._athena_encryption: Optional[str] = athena_encryption @@ -281,24 +282,22 @@ class SessionPrimitives: It is required to "share" the session attributes to other processes. That must be "pickable"! """ - def __init__( - self, - profile_name=None, - aws_access_key_id=None, - aws_secret_access_key=None, - aws_session_token=None, - region_name=None, - botocore_max_retries=None, - s3_additional_kwargs=None, - botocore_config=None, - procs_cpu_bound=None, - procs_io_bound=None, - athena_workgroup: Optional[str] = None, - athena_s3_output: Optional[str] = None, - athena_encryption: Optional[str] = None, - athena_kms_key: Optional[str] = None, - athena_database: Optional[str] = None - ): + def __init__(self, + profile_name=None, + aws_access_key_id=None, + aws_secret_access_key=None, + aws_session_token=None, + region_name=None, + botocore_max_retries=None, + s3_additional_kwargs=None, + botocore_config=None, + procs_cpu_bound=None, + procs_io_bound=None, + athena_workgroup: Optional[str] = None, + athena_s3_output: Optional[str] = None, + athena_encryption: Optional[str] = None, + athena_kms_key: Optional[str] = None, + athena_database: Optional[str] = None): """ Most parameters inherit from Boto3. https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html @@ -376,7 +375,7 @@ def procs_io_bound(self): return self._procs_io_bound @property - def athena_workgroup(self) -> str: + def athena_workgroup(self) -> Optional[str]: return self._athena_workgroup @property @@ -392,7 +391,7 @@ def athena_kms_key(self) -> Optional[str]: return self._athena_kms_key @property - def athena_database(self) -> str: + def athena_database(self) -> Optional[str]: return self._athena_database @property diff --git a/awswrangler/spark.py b/awswrangler/spark.py index 58be7768f..4636634da 100644 --- a/awswrangler/spark.py +++ b/awswrangler/spark.py @@ -272,7 +272,8 @@ def _flatten_struct_column(path: str, dtype: str) -> List[Tuple[str, str]]: return cols @staticmethod - def _flatten_struct_dataframe(df: DataFrame, explode_outer: bool = True, + def _flatten_struct_dataframe(df: DataFrame, + explode_outer: bool = True, explode_pos: bool = True) -> List[Tuple[str, str, str]]: explode: str = "EXPLODE_OUTER" if explode_outer is True else "EXPLODE" explode = f"POS{explode}" if explode_pos is True else explode @@ -311,7 +312,9 @@ def _build_name(name: str, expr: str) -> str: return f"{name}_{suffix}".replace(".", "_") @staticmethod - def flatten(dataframe: DataFrame, explode_outer: bool = True, explode_pos: bool = True, + def flatten(dataframe: DataFrame, + explode_outer: bool = True, + explode_pos: bool = True, name: str = "root") -> Dict[str, DataFrame]: """ Convert a complex nested DataFrame in one (or many) flat DataFrames diff --git a/pytest.ini b/pytest.ini index d233cbf74..8e7a47ef1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] addopts = --verbose - --capture=no + --capture=fd filterwarnings = ignore::DeprecationWarning ignore::UserWarning \ No newline at end of file diff --git a/testing/test_awswrangler/test_athena.py b/testing/test_awswrangler/test_athena.py index 31147cc5d..b9ea9c74f 100644 --- a/testing/test_awswrangler/test_athena.py +++ b/testing/test_awswrangler/test_athena.py @@ -1,5 +1,4 @@ import logging -from pprint import pprint import pytest import boto3 @@ -52,21 +51,20 @@ def workgroup_secondary(bucket): wkgs = client.list_work_groups() wkgs = [x["Name"] for x in wkgs["WorkGroups"]] if wkg_name not in wkgs: - response = client.create_work_group(Name=wkg_name, - Configuration={ - "ResultConfiguration": { - "OutputLocation": f"s3://{bucket}/athena_workgroup_secondary/", - "EncryptionConfiguration": { - "EncryptionOption": "SSE_S3", - } - }, - "EnforceWorkGroupConfiguration": True, - "PublishCloudWatchMetricsEnabled": True, - "BytesScannedCutoffPerQuery": 100_000_000, - "RequesterPaysEnabled": False - }, - Description="AWS Data Wrangler Test WorkGroup") - pprint(response) + client.create_work_group(Name=wkg_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": f"s3://{bucket}/athena_workgroup_secondary/", + "EncryptionConfiguration": { + "EncryptionOption": "SSE_S3", + } + }, + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "BytesScannedCutoffPerQuery": 100_000_000, + "RequesterPaysEnabled": False + }, + Description="AWS Data Wrangler Test WorkGroup") yield wkg_name diff --git a/testing/test_awswrangler/test_emr.py b/testing/test_awswrangler/test_emr.py index 0e7f8078e..8a1d1e93a 100644 --- a/testing/test_awswrangler/test_emr.py +++ b/testing/test_awswrangler/test_emr.py @@ -79,7 +79,6 @@ def test_cluster(session, bucket, cloudformation_outputs): steps=steps) sleep(10) cluster_state = session.emr.get_cluster_state(cluster_id=cluster_id) - print(f"cluster_state: {cluster_state}") assert cluster_state == "STARTING" step_id = session.emr.submit_step(cluster_id=cluster_id, name="step_test", @@ -87,7 +86,6 @@ def test_cluster(session, bucket, cloudformation_outputs): script=True) sleep(10) step_state = session.emr.get_step_state(cluster_id=cluster_id, step_id=step_id) - print(f"step_state: {step_state}") assert step_state == "PENDING" session.emr.terminate_cluster(cluster_id=cluster_id) @@ -141,7 +139,6 @@ def test_cluster_single_node(session, bucket, cloudformation_outputs): }) sleep(10) cluster_state = session.emr.get_cluster_state(cluster_id=cluster_id) - print(f"cluster_state: {cluster_state}") assert cluster_state == "STARTING" steps = [] for cmd in ['echo "Hello"', "ls -la"]: diff --git a/testing/test_awswrangler/test_glue.py b/testing/test_awswrangler/test_glue.py index 17b0c7627..58c52e68a 100644 --- a/testing/test_awswrangler/test_glue.py +++ b/testing/test_awswrangler/test_glue.py @@ -46,9 +46,9 @@ def database(cloudformation_outputs): @pytest.fixture(scope="module") def table( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.read_csv("data_samples/micro.csv") path = f"s3://{bucket}/test/" diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index bb93f9ae8..cb3a21e86 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -223,15 +223,15 @@ def test_read_csv_thousands_and_decimal(session, bucket): ], ) def test_to_s3( - session, - bucket, - database, - mode, - file_format, - preserve_index, - partition_cols, - procs_cpu_bound, - factor, + session, + bucket, + database, + mode, + file_format, + preserve_index, + partition_cols, + procs_cpu_bound, + factor, ): dataframe = pd.read_csv("data_samples/micro.csv") func = session.pandas.to_csv if file_format == "csv" else session.pandas.to_parquet @@ -261,9 +261,9 @@ def test_to_s3( def test_to_parquet_with_cast_int( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.read_csv("data_samples/nano.csv", dtype={"id": "Int64"}, parse_dates=["date", "time"]) path = f"s3://{bucket}/test/" @@ -318,7 +318,6 @@ def test_read_sql_athena_iterator(session, bucket, database, sample, row_num, ma for dataframe in dataframe_iter: total_count += len(dataframe.index) assert len(list(dataframe.columns)) == len(list(dataframe_sample.columns)) - print(dataframe) if total_count == row_num: break session.s3.delete_objects(path=path) @@ -422,9 +421,9 @@ def test_etl_complex(session, bucket, database, max_result_size): def test_to_parquet_with_kms( - bucket, - database, - kms_key, + bucket, + database, + kms_key, ): extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key} session_inner = Session(s3_additional_kwargs=extra_args) @@ -562,9 +561,9 @@ def test_to_s3_types(session, bucket, database, file_format, serde, index, parti def test_to_csv_with_sep( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.read_csv("data_samples/nano.csv") session.pandas.to_csv(dataframe=dataframe, @@ -584,9 +583,9 @@ def test_to_csv_with_sep( def test_to_csv_serde_exception( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.read_csv("data_samples/nano.csv") with pytest.raises(InvalidSerDe): @@ -651,42 +650,10 @@ def test_to_parquet_lists(session, bucket, database): assert val == val2 -def test_to_parquet_cast(session, bucket, database): - dataframe = pd.DataFrame({ - "id": [0, 1], - "col_int": [[1, 2], [3, 4, 5]], - "col_float": [[1.0, 2.0, 3.0], [4.0, 5.0]], - "col_string": [["foo"], ["boo", "bar"]], - "col_timestamp": [[datetime(2019, 1, 1), datetime(2019, 1, 2)], [datetime(2019, 1, 3)]], - "col_date": [[date(2019, 1, 1), date(2019, 1, 2)], [date(2019, 1, 3)]], - "col_list_int": [[[1]], [[2, 3], [4, 5, 6]]], - "col_list_list_string": [[[["foo"]]], [[["boo", "bar"]]]], - }) - paths = session.pandas.to_parquet(dataframe=dataframe, - database=database, - path=f"s3://{bucket}/test/", - preserve_index=False, - mode="overwrite", - procs_cpu_bound=1) - assert len(paths) == 1 - dataframe2 = None - for counter in range(10): - sleep(1) - dataframe2 = session.pandas.read_sql_athena(sql="select id, col_int, col_float, col_list_int from test", - database=database) - if len(dataframe.index) == len(dataframe2.index): - break - assert len(dataframe.index) == len(dataframe2.index) - assert 4 == len(list(dataframe2.columns)) - val = dataframe[dataframe["id"] == 0].iloc[0]["col_list_int"] - val2 = dataframe2[dataframe2["id"] == 0].iloc[0]["col_list_int"] - assert val == val2 - - def test_to_parquet_with_cast_null( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.DataFrame({ "id": [0, 1], @@ -757,9 +724,9 @@ def test_normalize_columns_names_athena(): def test_to_parquet_with_normalize( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.DataFrame({ "CamelCase": [1, 2, 3], @@ -792,9 +759,9 @@ def test_to_parquet_with_normalize( def test_to_parquet_with_normalize_and_cast( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.DataFrame({ "CamelCase": [1, 2, 3], @@ -847,9 +814,9 @@ def test_drop_duplicated_columns(): def test_to_parquet_duplicated_columns( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.DataFrame({ "a": [1, 2, 3], @@ -871,9 +838,9 @@ def test_to_parquet_duplicated_columns( def test_to_parquet_with_pyarrow_null_type( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.DataFrame({ "a": [1, 2, 3], @@ -889,9 +856,9 @@ def test_to_parquet_with_pyarrow_null_type( def test_to_parquet_casting_to_string( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.DataFrame({ "a": [1, 2, 3], @@ -911,13 +878,12 @@ def test_to_parquet_casting_to_string( break assert len(dataframe.index) == len(dataframe2.index) assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns)) - print(dataframe2) def test_to_parquet_casting_with_null_object( - session, - bucket, - database, + session, + bucket, + database, ): dataframe = pd.DataFrame({ "a": [1, 2, 3], @@ -951,8 +917,6 @@ def test_read_sql_athena_with_nulls(session, bucket, database): if len(df.index) == len(df2.index): break assert len(df.index) == len(df2.index) - print(df2) - print(df2.dtypes) assert df2.dtypes[0] == "Int64" assert df2.dtypes[1] == "bool" assert df2.dtypes[2] == "bool" @@ -967,8 +931,6 @@ def test_partition_date(session, bucket, database): }) df["datecol"] = pd.to_datetime(df.datecol).dt.date df["partcol"] = pd.to_datetime(df.partcol).dt.date - print(df) - print(df.dtypes) path = f"s3://{bucket}/test/" session.pandas.to_parquet(dataframe=df, database=database, @@ -984,8 +946,6 @@ def test_partition_date(session, bucket, database): if len(df.index) == len(df2.index): break assert len(df.index) == len(df2.index) - print(df2) - print(df2.dtypes) assert df2.dtypes[0] == "object" assert df2.dtypes[1] == "object" assert df2.dtypes[2] == "object" @@ -998,8 +958,6 @@ def test_partition_cast_date(session, bucket, database): "datecol": ["2019-11-09", "2019-11-08"], "partcol": ["2019-11-09", "2019-11-08"] }) - print(df) - print(df.dtypes) path = f"s3://{bucket}/test/" schema = { "col1": "string", @@ -1021,8 +979,6 @@ def test_partition_cast_date(session, bucket, database): if len(df.index) == len(df2.index): break assert len(df.index) == len(df2.index) - print(df2) - print(df2.dtypes) assert df2.dtypes[0] == "object" assert df2.dtypes[1] == "object" assert df2.dtypes[2] == "object" @@ -1035,8 +991,6 @@ def test_partition_cast_timestamp(session, bucket, database): "datecol": ["2019-11-09", "2019-11-08"], "partcol": ["2019-11-09", "2019-11-08"] }) - print(df) - print(df.dtypes) path = f"s3://{bucket}/test/" schema = { "col1": "string", @@ -1058,8 +1012,6 @@ def test_partition_cast_timestamp(session, bucket, database): if len(df.index) == len(df2.index): break assert len(df.index) == len(df2.index) - print(df2) - print(df2.dtypes) assert str(df2.dtypes[0]) == "object" assert str(df2.dtypes[1]).startswith("datetime64") assert str(df2.dtypes[2]).startswith("datetime64") @@ -1074,8 +1026,6 @@ def test_partition_cast(session, bucket, database): "col_double": ["1.0", "1.1"], "col_bool": ["True", "False"], }) - print(df) - print(df.dtypes) path = f"s3://{bucket}/test/" schema = { "col1": "string", @@ -1099,8 +1049,6 @@ def test_partition_cast(session, bucket, database): if len(df.index) == len(df2.index): break assert len(df.index) == len(df2.index) - print(df2) - print(df2.dtypes) assert df2.dtypes[0] == "object" assert str(df2.dtypes[1]).startswith("datetime") assert str(df2.dtypes[2]).startswith("float") @@ -1151,7 +1099,6 @@ def test_partition_single_row(session, bucket, database, procs): assert len(list(df.columns)) == len(list(df2.columns)) if len(df.index) == len(df2.index): break - print(df2.dtypes) assert len(df.index) == len(df2.index) assert df2.dtypes[0] == "Int64" assert df2.dtypes[1] == "object" @@ -1163,7 +1110,6 @@ def test_partition_single_row(session, bucket, database, procs): def test_nan_cast(session, bucket, database, partition_cols): dtypes = {"col1": "object", "col2": "object", "col3": "object", "col4": "object", "pt": "object"} df = pd.read_csv("data_samples/nan.csv", dtype=dtypes) - print(df) schema = { "col1": "string", "col2": "string", @@ -1185,7 +1131,6 @@ def test_nan_cast(session, bucket, database, partition_cols): assert len(list(df.columns)) == len(list(df2.columns)) - 1 if len(df.index) == len(df2.index): break - print(df2.dtypes) assert len(df.index) == len(df2.index) assert df2.dtypes[0] == "object" assert df2.dtypes[1] == "object" @@ -1283,7 +1228,7 @@ def test_to_parquet_date_null_at_first(session, bucket, database): assert df[df.col1 == "val0"].iloc[0].datecol == df2[df2.col1 == "val0"].iloc[0].datecol is None -def test_to_parquet_array(session, bucket, database): +def test_to_parquet_lists2(session, bucket, database): df = pd.DataFrame({ "A": [1, 2, 3], "B": [[], [4.0, None, 6.0], []], @@ -1318,10 +1263,9 @@ def test_to_parquet_decimal(session, bucket, database): df = pd.DataFrame({ "id": [1, 2, 3], "decimal_2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], - "decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, Decimal((0, (1, 9, 0, 0, 0, 0), -5))], + "decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, + Decimal((0, (1, 9, 0, 0, 0, 0), -5))], }) - print(df) - print(df.dtypes) path = f"s3://{bucket}/test/" session.pandas.to_parquet(dataframe=df, database=database, diff --git a/testing/test_awswrangler/test_redshift.py b/testing/test_awswrangler/test_redshift.py index 0a62f56d2..2bc772567 100644 --- a/testing/test_awswrangler/test_redshift.py +++ b/testing/test_awswrangler/test_redshift.py @@ -76,6 +76,7 @@ def redshift_parameters(cloudformation_outputs): ) def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, mode, factor, diststyle, distkey, sortstyle, sortkey): + if sample_name == "micro": dates = ["date"] if sample_name == "small": @@ -146,7 +147,6 @@ def test_to_redshift_pandas_cast(session, bucket, redshift_parameters): rows = cursor.fetchall() cursor.close() con.close() - print(rows) assert len(df.index) == len(rows) assert len(list(df.columns)) == len(list(rows[0])) @@ -283,8 +283,6 @@ def test_to_redshift_spark_big(session, bucket, redshift_parameters): def test_to_redshift_spark_bool(session, bucket, redshift_parameters): dataframe = session.spark_session.createDataFrame(pd.DataFrame({"A": [1, 2, 3], "B": [True, False, True]})) - print(dataframe) - print(dataframe.dtypes) con = Redshift.generate_connection( database="test", host=redshift_parameters.get("RedshiftAddress"), @@ -428,7 +426,8 @@ def test_to_redshift_pandas_decimal(session, bucket, redshift_parameters): df = pd.DataFrame({ "id": [1, 2, 3], "decimal_2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], - "decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, Decimal((0, (1, 9, 0, 0, 0, 0), -5))], + "decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, + Decimal((0, (1, 9, 0, 0, 0, 0), -5))], }) con = Redshift.generate_connection( database="test", @@ -455,7 +454,6 @@ def test_to_redshift_pandas_decimal(session, bucket, redshift_parameters): con.close() assert len(df.index) == len(rows) assert len(list(df.columns)) == len(list(rows[0])) - print(rows) for row in rows: if row[0] == 1: assert row[1] == Decimal((0, (1, 9, 9), -2)) @@ -472,8 +470,10 @@ def test_to_redshift_spark_decimal(session, bucket, redshift_parameters): df = session.spark_session.createDataFrame(pd.DataFrame({ "id": [1, 2, 3], "decimal_2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], - "decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, Decimal((0, (1, 9, 0, 0, 0, 0), -5))]}), - schema="id INTEGER, decimal_2 DECIMAL(3,2), decimal_5 DECIMAL(6,5)") + "decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, + Decimal((0, (1, 9, 0, 0, 0, 0), -5))] + }), + schema="id INTEGER, decimal_2 DECIMAL(3,2), decimal_5 DECIMAL(6,5)") con = Redshift.generate_connection( database="test", host=redshift_parameters.get("RedshiftAddress"), @@ -498,7 +498,6 @@ def test_to_redshift_spark_decimal(session, bucket, redshift_parameters): con.close() assert df.count() == len(rows) assert len(list(df.columns)) == len(list(rows[0])) - print(rows) for row in rows: if row[0] == 1: assert row[1] == Decimal((0, (1, 9, 9), -2)) diff --git a/testing/test_awswrangler/test_spark.py b/testing/test_awswrangler/test_spark.py index fa0f58df4..9836e850b 100644 --- a/testing/test_awswrangler/test_spark.py +++ b/testing/test_awswrangler/test_spark.py @@ -165,7 +165,6 @@ def test_create_glue_table_csv(session, bucket, database, compression, partition def test_flatten_simple_struct(session): - print() pdf = pd.DataFrame({ "a": [1, 2], "b": [ @@ -199,7 +198,6 @@ def test_flatten_simple_struct(session): def test_flatten_complex_struct(session): - print() pdf = pd.DataFrame({ "a": [1, 2], "b": [ @@ -272,7 +270,6 @@ def test_flatten_complex_struct(session): def test_flatten_simple_map(session): - print() pdf = pd.DataFrame({ "a": [1, 2], "b": [ @@ -315,7 +312,6 @@ def test_flatten_simple_map(session): def test_flatten_simple_array(session): - print() pdf = pd.DataFrame({ "a": [1, 2], "b": [