Skip to content

Commit

Permalink
Closes Bears-R-Us#2778: Add convert_categoricals flag to `dataframe…
Browse files Browse the repository at this point in the history
….to_parquet` (Bears-R-Us#2780)

* Closes Bears-R-Us#2778: Add convert_categoricals flag to dataframe.to_parquet

This PR (closes Bears-R-Us#2778) adds a `convert_categoricals` flag to `df.to_parquet`

* flake8 fix

---------

Co-authored-by: Pierce Hayes <[email protected]>
  • Loading branch information
stress-tess and Pierce Hayes authored Sep 20, 2023
1 parent 6651b27 commit b7bbe67
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
40 changes: 31 additions & 9 deletions arkouda/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,10 @@ def transfer(self, hostname, port):
msg_list = []
for col in self._columns:
if isinstance(self[col], Categorical):
msg_list.append(f"Categorical+{col}+{self[col].codes.name} \
+{self[col].categories.name}+{self[col]._akNAcode.name}")
msg_list.append(
f"Categorical+{col}+{self[col].codes.name} \
+{self[col].categories.name}+{self[col]._akNAcode.name}"
)
elif isinstance(self[col], SegArray):
msg_list.append(f"SegArray+{col}+{self[col].segments.name}+{self[col].values.name}")
elif isinstance(self[col], Strings):
Expand All @@ -654,11 +656,11 @@ def transfer(self, hostname, port):
generic_msg(
cmd="sendDataframe",
args={
"size" : len(msg_list),
"idx_name" : idx.name,
"columns" : msg_list,
"hostname" : hostname,
"port" : port
"size": len(msg_list),
"idx_name": idx.name,
"columns": msg_list,
"hostname": hostname,
"port": port,
},
),
)
Expand Down Expand Up @@ -1639,7 +1641,14 @@ def update_hdf(self, prefix_path: str, index=False, columns=None, repack: bool =
data = self._prep_data(index=index, columns=columns)
update_hdf(data, prefix_path=prefix_path, repack=repack)

def to_parquet(self, path, index=False, columns=None, compression: Optional[str] = None):
def to_parquet(
self,
path,
index=False,
columns=None,
compression: Optional[str] = None,
convert_categoricals: bool = False,
):
"""
Save DataFrame to disk as parquet, preserving column names.
Expand All @@ -1655,6 +1664,11 @@ def to_parquet(self, path, index=False, columns=None, compression: Optional[str]
Default None
Provide the compression type to use when writing the file.
Supported values: snappy, gzip, brotli, zstd, lz4
convert_categoricals: bool
Defaults to False
Parquet requires all columns to be the same size and Categoricals
don't satisfy that requirement.
if set, write the equivalent Strings in place of any Categorical columns.
Returns
-------
None
Expand All @@ -1674,7 +1688,15 @@ def to_parquet(self, path, index=False, columns=None, compression: Optional[str]
from arkouda.io import to_parquet

data = self._prep_data(index=index, columns=columns)
to_parquet(data, prefix_path=path, compression=compression)
if not convert_categoricals and any(isinstance(val, Categorical) for val in data.values()):
raise ValueError(
"to_parquet doesn't support Categorical columns. To write the equivalent "
"Strings in place of any Categorical columns, rerun with convert_categoricals "
"set to True."
)
to_parquet(
data, prefix_path=path, compression=compression, convert_categoricals=convert_categoricals
)

@typechecked
def to_csv(
Expand Down
28 changes: 18 additions & 10 deletions arkouda/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def import_data(read_path: str, write_file: str = None, return_obj: bool = True,
from arkouda.dataframe import DataFrame

# verify file path
is_glob = False if os.path.isfile(read_path) else True
is_glob = not os.path.isfile(read_path)
file_list = glob.glob(read_path)
if len(file_list) == 0:
raise FileNotFoundError(f"Invalid read_path, {read_path}. No files found.")
Expand Down Expand Up @@ -1080,6 +1080,7 @@ def _bulk_write_prep(
List[Union[pdarray, Strings, SegArray, ArrayView]],
],
names: List[str] = None,
convert_categoricals: bool = False,
):
datasetNames = []
if names is not None:
Expand All @@ -1101,6 +1102,11 @@ def _bulk_write_prep(
if len(data) == 0:
raise RuntimeError("No data was found.")

if convert_categoricals:
for i, val in enumerate(data):
if isinstance(val, Categorical):
data[i] = val.categories[val.codes]

col_objtypes = [c.objType for c in data]

return datasetNames, data, col_objtypes
Expand All @@ -1115,6 +1121,7 @@ def to_parquet(
names: List[str] = None,
mode: str = "truncate",
compression: Optional[str] = None,
convert_categoricals: bool = False,
) -> None:
"""
Save multiple named pdarrays to Parquet files.
Expand All @@ -1135,7 +1142,11 @@ def to_parquet(
Default None
Provide the compression type to use when writing the file.
Supported values: snappy, gzip, brotli, zstd, lz4
convert_categoricals: bool
Defaults to False
Parquet requires all columns to be the same size and Categoricals
don't satisfy that requirement.
if set, write the equivalent Strings in place of any Categorical columns.
Returns
-------
Expand Down Expand Up @@ -1176,15 +1187,14 @@ def to_parquet(
"""
if mode.lower() not in ["append", "truncate"]:
raise ValueError("Allowed modes are 'truncate' and 'append'")

if mode.lower() == "append":
warn(
"Append has been deprecated when writing Parquet files. "
"Please write all columns to the file at once.",
DeprecationWarning,
)

datasetNames, data, col_objtypes = _bulk_write_prep(columns, names)
datasetNames, data, col_objtypes = _bulk_write_prep(columns, names, convert_categoricals)
# append or single column use the old logic
if mode.lower() == "append" or len(data) == 1:
for arr, name in zip(data, cast(List[str], datasetNames)):
Expand Down Expand Up @@ -2009,7 +2019,7 @@ def restore(filename):
return read_hdf(sorted(restore_files))


def receive(hostname : str, port):
def receive(hostname: str, port):
"""
Receive a pdarray sent by `pdarray.transfer()`.
Expand Down Expand Up @@ -2043,13 +2053,12 @@ def receive(hostname : str, port):
Raised if other is not a pdarray or the pdarray.dtype is not
a supported dtype
"""
rep_msg = generic_msg(cmd="receiveArray", args={"hostname": hostname,
"port" : port})
rep_msg = generic_msg(cmd="receiveArray", args={"hostname": hostname, "port": port})
rep = json.loads(rep_msg)
return _build_objects(rep)


def receive_dataframe(hostname : str, port):
def receive_dataframe(hostname: str, port):
"""
Receive a pdarray sent by `dataframe.transfer()`.
Expand Down Expand Up @@ -2083,7 +2092,6 @@ def receive_dataframe(hostname : str, port):
Raised if other is not a pdarray or the pdarray.dtype is not
a supported dtype
"""
rep_msg = generic_msg(cmd="receiveDataframe", args={"hostname": hostname,
"port" : port})
rep_msg = generic_msg(cmd="receiveDataframe", args={"hostname": hostname, "port": port})
rep = json.loads(rep_msg)
return DataFrame(_build_objects(rep))

0 comments on commit b7bbe67

Please sign in to comment.