diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index c119b512c..8c72252c4 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -1,3 +1,4 @@ +from typing import Dict, List, Tuple, Optional, Any from io import BytesIO, StringIO import multiprocessing as mp import logging @@ -854,20 +855,20 @@ def write_parquet_dataframe(dataframe, path, preserve_index, compression, fs, ca def to_redshift( self, - dataframe, - path, - connection, - schema, - table, - iam_role, - diststyle="AUTO", - distkey=None, - sortstyle="COMPOUND", - sortkey=None, - preserve_index=False, - mode="append", - cast_columns=None, - ): + dataframe: pd.DataFrame, + path: str, + connection: Any, + schema: str, + table: str, + iam_role: str, + diststyle: str = "AUTO", + distkey: Optional[str] = None, + sortstyle: str = "COMPOUND", + sortkey: Optional[str] = None, + preserve_index: bool = False, + mode: str = "append", + cast_columns: Optional[Dict[str, str]] = None, + ) -> None: """ Load Pandas Dataframe as a Table on Amazon Redshift @@ -888,28 +889,30 @@ def to_redshift( """ if cast_columns is None: cast_columns = {} - cast_columns_parquet = {} + cast_columns_parquet: Dict = {} else: - cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena, schema=cast_columns) + cast_columns_tuples: List[Tuple[str, str]] = [(k, v) for k, v in cast_columns.items()] + cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena, + schema=cast_columns_tuples) if path[-1] != "/": path += "/" self._session.s3.delete_objects(path=path) - num_rows = len(dataframe.index) + num_rows: int = len(dataframe.index) logger.debug(f"Number of rows: {num_rows}") if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE: - num_partitions = 1 + num_partitions: int = 1 else: - num_slices = self._session.redshift.get_number_of_slices(redshift_conn=connection) + num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection) logger.debug(f"Number of slices on Redshift: {num_slices}") num_partitions = num_slices logger.debug(f"Number of partitions calculated: {num_partitions}") - objects_paths = self.to_parquet(dataframe=dataframe, - path=path, - preserve_index=preserve_index, - mode="append", - procs_cpu_bound=num_partitions, - cast_columns=cast_columns_parquet) - manifest_path = f"{path}manifest.json" + objects_paths: List[str] = self.to_parquet(dataframe=dataframe, + path=path, + preserve_index=preserve_index, + mode="append", + procs_cpu_bound=num_partitions, + cast_columns=cast_columns_parquet) + manifest_path: str = f"{path}manifest.json" self._session.redshift.write_load_manifest(manifest_path=manifest_path, objects_paths=objects_paths) self._session.redshift.load_table( dataframe=dataframe, diff --git a/testing/test_awswrangler/test_redshift.py b/testing/test_awswrangler/test_redshift.py index 0ffdfdebf..020f7e130 100644 --- a/testing/test_awswrangler/test_redshift.py +++ b/testing/test_awswrangler/test_redshift.py @@ -1,9 +1,10 @@ import json import logging +from datetime import date, datetime import pytest import boto3 -import pandas +import pandas as pd from pyspark.sql import SparkSession import pg8000 @@ -80,7 +81,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m dates = ["date"] if sample_name == "nano": dates = ["date", "time"] - dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True) + dataframe = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True) dataframe["date"] = dataframe["date"].dt.date con = Redshift.generate_connection( database="test", @@ -113,6 +114,46 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m assert len(list(dataframe.columns)) + 1 == len(list(rows[0])) +def test_to_redshift_pandas_cast(session, bucket, redshift_parameters): + df = pd.DataFrame({ + "id": [1, 2, 3], + "name": ["name1", "name2", "name3"], + "foo": [None, None, None], + "boo": [date(2020, 1, 1), None, None], + "bar": [datetime(2021, 1, 1), None, None]}) + schema = { + "id": "BIGINT", + "name": "VARCHAR", + "foo": "REAL", + "boo": "DATE", + "bar": "TIMESTAMP"} + con = Redshift.generate_connection( + database="test", + host=redshift_parameters.get("RedshiftAddress"), + port=redshift_parameters.get("RedshiftPort"), + user="test", + password=redshift_parameters.get("RedshiftPassword"), + ) + path = f"s3://{bucket}/redshift-load/" + session.pandas.to_redshift(dataframe=df, + path=path, + schema="public", + table="test", + connection=con, + iam_role=redshift_parameters.get("RedshiftRole"), + mode="overwrite", + preserve_index=False, + cast_columns=schema) + cursor = con.cursor() + cursor.execute("SELECT * from public.test") + rows = cursor.fetchall() + cursor.close() + con.close() + print(rows) + assert len(df.index) == len(rows) + assert len(list(df.columns)) == len(list(rows[0])) + + @pytest.mark.parametrize( "sample_name,mode,factor,diststyle,distkey,exc,sortstyle,sortkey", [ @@ -125,7 +166,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m ) def test_to_redshift_pandas_exceptions(session, bucket, redshift_parameters, sample_name, mode, factor, diststyle, distkey, sortstyle, sortkey, exc): - dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv") + dataframe = pd.read_csv(f"data_samples/{sample_name}.csv") con = Redshift.generate_connection( database="test", host=redshift_parameters.get("RedshiftAddress"),