Skip to content

Commit

Permalink
Fix pandas to redshift cast feature
Browse files Browse the repository at this point in the history
  • Loading branch information
igorborgest committed Oct 20, 2019
1 parent 98a5e0d commit fccfdf7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 29 deletions.
55 changes: 29 additions & 26 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, List, Tuple, Optional, Any
from io import BytesIO, StringIO
import multiprocessing as mp
import logging
Expand Down Expand Up @@ -854,20 +855,20 @@ def write_parquet_dataframe(dataframe, path, preserve_index, compression, fs, ca

def to_redshift(
self,
dataframe,
path,
connection,
schema,
table,
iam_role,
diststyle="AUTO",
distkey=None,
sortstyle="COMPOUND",
sortkey=None,
preserve_index=False,
mode="append",
cast_columns=None,
):
dataframe: pd.DataFrame,
path: str,
connection: Any,
schema: str,
table: str,
iam_role: str,
diststyle: str = "AUTO",
distkey: Optional[str] = None,
sortstyle: str = "COMPOUND",
sortkey: Optional[str] = None,
preserve_index: bool = False,
mode: str = "append",
cast_columns: Optional[Dict[str, str]] = None,
) -> None:
"""
Load Pandas Dataframe as a Table on Amazon Redshift
Expand All @@ -888,28 +889,30 @@ def to_redshift(
"""
if cast_columns is None:
cast_columns = {}
cast_columns_parquet = {}
cast_columns_parquet: Dict = {}
else:
cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena, schema=cast_columns)
cast_columns_tuples: List[Tuple[str, str]] = [(k, v) for k, v in cast_columns.items()]
cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena,
schema=cast_columns_tuples)
if path[-1] != "/":
path += "/"
self._session.s3.delete_objects(path=path)
num_rows = len(dataframe.index)
num_rows: int = len(dataframe.index)
logger.debug(f"Number of rows: {num_rows}")
if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE:
num_partitions = 1
num_partitions: int = 1
else:
num_slices = self._session.redshift.get_number_of_slices(redshift_conn=connection)
num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection)
logger.debug(f"Number of slices on Redshift: {num_slices}")
num_partitions = num_slices
logger.debug(f"Number of partitions calculated: {num_partitions}")
objects_paths = self.to_parquet(dataframe=dataframe,
path=path,
preserve_index=preserve_index,
mode="append",
procs_cpu_bound=num_partitions,
cast_columns=cast_columns_parquet)
manifest_path = f"{path}manifest.json"
objects_paths: List[str] = self.to_parquet(dataframe=dataframe,
path=path,
preserve_index=preserve_index,
mode="append",
procs_cpu_bound=num_partitions,
cast_columns=cast_columns_parquet)
manifest_path: str = f"{path}manifest.json"
self._session.redshift.write_load_manifest(manifest_path=manifest_path, objects_paths=objects_paths)
self._session.redshift.load_table(
dataframe=dataframe,
Expand Down
47 changes: 44 additions & 3 deletions testing/test_awswrangler/test_redshift.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
from datetime import date, datetime

import pytest
import boto3
import pandas
import pandas as pd
from pyspark.sql import SparkSession
import pg8000

Expand Down Expand Up @@ -80,7 +81,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
dates = ["date"]
if sample_name == "nano":
dates = ["date", "time"]
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
dataframe = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True)
dataframe["date"] = dataframe["date"].dt.date
con = Redshift.generate_connection(
database="test",
Expand Down Expand Up @@ -113,6 +114,46 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
assert len(list(dataframe.columns)) + 1 == len(list(rows[0]))


def test_to_redshift_pandas_cast(session, bucket, redshift_parameters):
df = pd.DataFrame({
"id": [1, 2, 3],
"name": ["name1", "name2", "name3"],
"foo": [None, None, None],
"boo": [date(2020, 1, 1), None, None],
"bar": [datetime(2021, 1, 1), None, None]})
schema = {
"id": "BIGINT",
"name": "VARCHAR",
"foo": "REAL",
"boo": "DATE",
"bar": "TIMESTAMP"}
con = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
port=redshift_parameters.get("RedshiftPort"),
user="test",
password=redshift_parameters.get("RedshiftPassword"),
)
path = f"s3://{bucket}/redshift-load/"
session.pandas.to_redshift(dataframe=df,
path=path,
schema="public",
table="test",
connection=con,
iam_role=redshift_parameters.get("RedshiftRole"),
mode="overwrite",
preserve_index=False,
cast_columns=schema)
cursor = con.cursor()
cursor.execute("SELECT * from public.test")
rows = cursor.fetchall()
cursor.close()
con.close()
print(rows)
assert len(df.index) == len(rows)
assert len(list(df.columns)) == len(list(rows[0]))


@pytest.mark.parametrize(
"sample_name,mode,factor,diststyle,distkey,exc,sortstyle,sortkey",
[
Expand All @@ -125,7 +166,7 @@ def test_to_redshift_pandas(session, bucket, redshift_parameters, sample_name, m
)
def test_to_redshift_pandas_exceptions(session, bucket, redshift_parameters, sample_name, mode, factor, diststyle,
distkey, sortstyle, sortkey, exc):
dataframe = pandas.read_csv(f"data_samples/{sample_name}.csv")
dataframe = pd.read_csv(f"data_samples/{sample_name}.csv")
con = Redshift.generate_connection(
database="test",
host=redshift_parameters.get("RedshiftAddress"),
Expand Down

0 comments on commit fccfdf7

Please sign in to comment.