Skip to content

Commit

Permalink
Closes Bears-R-Us#2708 - special_objType support for HDF5 (Bears-R-…
Browse files Browse the repository at this point in the history
…Us#2709)

* Adding support for special_dtype cases of pdarray.

* Overwrite and testing.

* Correcting docstring

* Fixing flake8
  • Loading branch information
Ethan-DeBandi99 authored Aug 28, 2023
1 parent fc07ea3 commit 9ec222c
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 33 deletions.
41 changes: 41 additions & 0 deletions PROTO_tests/tests/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,47 @@ def test_index_save_and_load(self):
ret_idx = ak.read_hdf(f"{file_name}*")
assert (idx == ret_idx).all()

def test_special_objtype(self):
"""
This test is simply to ensure that the dtype is persisted through the io
operation. It ultimately uses the process of pdarray, but need to ensure
correct Arkouda Object Type is returned
"""
ip = ak.IPv4(ak.arange(10))
dt = ak.Datetime(ak.arange(10))
td = ak.Timedelta(ak.arange(10))
df = ak.DataFrame({
"ip": ip,
"datetime": dt,
"timedelta": td
})

with tempfile.TemporaryDirectory(dir=TestHDF5.hdf_test_base_tmp) as tmp_dirname:
ip.to_hdf(f"{tmp_dirname}/ip_test")
rd_ip = ak.read_hdf(f"{tmp_dirname}/ip_test*")
assert isinstance(rd_ip, ak.IPv4)
assert ip.to_list() == rd_ip.to_list()

dt.to_hdf(f"{tmp_dirname}/dt_test")
rd_dt = ak.read_hdf(f"{tmp_dirname}/dt_test*")
assert isinstance(rd_dt, ak.Datetime)
assert dt.to_list() == rd_dt.to_list()

td.to_hdf(f"{tmp_dirname}/td_test")
rd_td = ak.read_hdf(f"{tmp_dirname}/td_test*")
assert isinstance(rd_td, ak.Timedelta)
assert td.to_list() == rd_td.to_list()

df.to_hdf(f"{tmp_dirname}/df_test")
rd_df = ak.read_hdf(f"{tmp_dirname}/df_test*")

assert isinstance(rd_df["ip"], ak.IPv4)
assert isinstance(rd_df["datetime"], ak.Datetime)
assert isinstance(rd_df["timedelta"], ak.Timedelta)
assert df["ip"].to_list() == rd_df["ip"].to_list()
assert df["datetime"].to_list() == rd_df["datetime"].to_list()
assert df["timedelta"].to_list() == rd_df["timedelta"].to_list()


class TestCSV:
csv_test_base_tmp = f"{os.getcwd()}/csv_io_test"
Expand Down
63 changes: 63 additions & 0 deletions arkouda/client_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,69 @@ def register(self, user_defined_name):
self.registered_name = user_defined_name
return self

def to_hdf(
self,
prefix_path: str,
dataset: str = "array",
mode: str = "truncate",
file_type: str = "distribute",
):
"""
Override of the pdarray to_hdf to store the special object type
"""
from typing import cast as typecast

from arkouda.client import generic_msg
from arkouda.io import _file_type_to_int, _mode_str_to_int

return typecast(
str,
generic_msg(
cmd="tohdf",
args={
"values": self,
"dset": dataset,
"write_mode": _mode_str_to_int(mode),
"filename": prefix_path,
"dtype": self.dtype,
"objType": self.special_objType,
"file_format": _file_type_to_int(file_type),
},
),
)

def update_hdf(self, prefix_path: str, dataset: str = "array", repack: bool = True):
"""
Override the pdarray implementation so that the special object type will be used.
"""
from arkouda.client import generic_msg
from arkouda.io import (
_file_type_to_int,
_get_hdf_filetype,
_mode_str_to_int,
_repack_hdf,
)

# determine the format (single/distribute) that the file was saved in
file_type = _get_hdf_filetype(prefix_path + "*")

generic_msg(
cmd="tohdf",
args={
"values": self,
"dset": dataset,
"write_mode": _mode_str_to_int("append"),
"filename": prefix_path,
"dtype": self.dtype,
"objType": self.special_objType,
"file_format": _file_type_to_int(file_type),
"overwrite": True,
},
)

if repack:
_repack_hdf(prefix_path)


@typechecked
def is_ipv4(ip: Union[pdarray, IPv4], ip2: Optional[pdarray] = None) -> pdarray:
Expand Down
14 changes: 8 additions & 6 deletions arkouda/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,10 @@ def _to_hdf_snapshot(self, path, dataset="DataFrame", mode="truncate", file_type
str(obj.categories.dtype) if isinstance(obj, Categorical_) else str(obj.dtype)
for obj in self.values()
]
col_objTypes = [
obj.special_objType if hasattr(obj, "special_objType") else obj.objType
for obj in self.values()
]
return cast(
str,
generic_msg(
Expand All @@ -1557,7 +1561,7 @@ def _to_hdf_snapshot(self, path, dataset="DataFrame", mode="truncate", file_type
"objType": self.objType,
"num_cols": len(self.columns),
"column_names": self.columns,
"column_objTypes": [obj.objType for key, obj in self.items()],
"column_objTypes": col_objTypes,
"column_dtypes": dtypes,
"columns": column_data,
"index": self.index.values.name,
Expand Down Expand Up @@ -2239,11 +2243,9 @@ def register(self, user_defined_name: str) -> DataFrame:
if isinstance(obj, Categorical_)
else json.dumps({"segments": obj.segments.name, "values": obj.values.name})
if isinstance(obj, SegArray)
else json.dumps({ # BitVector Case
"name": obj.name,
"width": obj.width,
"reverse": obj.reverse
})
else json.dumps(
{"name": obj.name, "width": obj.width, "reverse": obj.reverse} # BitVector Case
)
for obj in self.values()
]

Expand Down
24 changes: 20 additions & 4 deletions arkouda/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from arkouda.pdarraycreation import arange, array
from arkouda.segarray import SegArray
from arkouda.strings import Strings
from arkouda.client_dtypes import IPv4
from arkouda.timeclass import Datetime, Timedelta

__all__ = [
"get_filetype",
Expand Down Expand Up @@ -438,7 +440,7 @@ def _parse_errors(rep_msg, allow_errors: bool = False):

def _parse_obj(
obj: Dict,
) -> Union[Strings, pdarray, ArrayView, SegArray, Categorical, DataFrame]:
) -> Union[Strings, pdarray, ArrayView, SegArray, Categorical, DataFrame, IPv4, Datetime, Timedelta]:
"""
Helper function to create an Arkouda object from read response
Expand All @@ -462,6 +464,12 @@ def _parse_obj(
return SegArray.from_return_msg(obj["created"])
elif pdarray.objType.upper() == obj["arkouda_type"]:
return create_pdarray(obj["created"])
elif IPv4.special_objType.upper() == obj["arkouda_type"]:
return IPv4(create_pdarray(obj["created"]))
elif Datetime.special_objType.upper() == obj["arkouda_type"]:
return Datetime(create_pdarray(obj["created"]))
elif Timedelta.special_objType.upper() == obj["arkouda_type"]:
return Timedelta(create_pdarray(obj["created"]))
elif ArrayView.objType.upper() == obj["arkouda_type"]:
components = obj["created"].split("+")
flat = create_pdarray(components[0])
Expand Down Expand Up @@ -525,7 +533,13 @@ def _build_objects(
ArrayView,
Categorical,
DataFrame,
Mapping[str, Union[Strings, pdarray, SegArray, ArrayView, Categorical, DataFrame]],
IPv4,
Datetime,
Timedelta,
Mapping[
str,
Union[Strings, pdarray, SegArray, ArrayView, Categorical, DataFrame, IPv4, Datetime, Timedelta],
],
]:
"""
Helper function to create the Arkouda objects from a read operation
Expand Down Expand Up @@ -1591,8 +1605,10 @@ def load(
the extension is not required to be a specific format.
"""
if "*" in path_prefix:
raise ValueError("Glob expressions not supported by ak.load(). "
"To read files using a glob expression, please use ak.read()")
raise ValueError(
"Glob expressions not supported by ak.load(). "
"To read files using a glob expression, please use ak.read()"
)
prefix, extension = os.path.splitext(path_prefix)
globstr = f"{prefix}_LOCALE*{extension}"
try:
Expand Down
2 changes: 1 addition & 1 deletion arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,7 +1518,7 @@ def to_hdf(
"write_mode": _mode_str_to_int(mode),
"filename": prefix_path,
"dtype": self.dtype,
"objType": "pdarray",
"objType": self.objType,
"file_format": _file_type_to_int(file_type),
},
),
Expand Down
63 changes: 63 additions & 0 deletions arkouda/timeclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class _AbstractBaseTime(pdarray):
so that all resulting operations are transparent.
"""

special_objType = "Time"

def __init__(self, pda, unit: str = _BASE_UNIT): # type: ignore
if isinstance(pda, Datetime) or isinstance(pda, Timedelta):
self.unit: str = pda.unit
Expand Down Expand Up @@ -210,6 +212,67 @@ def to_list(self):
__doc__ = super().to_list().__doc__ # noqa
return self.to_ndarray().tolist()

def to_hdf(
self,
prefix_path: str,
dataset: str = "array",
mode: str = "truncate",
file_type: str = "distribute",
):
"""
Override of the pdarray to_hdf to store the special dtype
"""
from typing import cast as typecast

from arkouda.io import _file_type_to_int, _mode_str_to_int

return typecast(
str,
generic_msg(
cmd="tohdf",
args={
"values": self,
"dset": dataset,
"write_mode": _mode_str_to_int(mode),
"filename": prefix_path,
"dtype": self.dtype,
"objType": self.special_objType,
"file_format": _file_type_to_int(file_type),
},
),
)

def update_hdf(self, prefix_path: str, dataset: str = "array", repack: bool = True):
"""
Override the pdarray implementation so that the special object type will be used.
"""
from arkouda.io import (
_file_type_to_int,
_get_hdf_filetype,
_mode_str_to_int,
_repack_hdf,
)

# determine the format (single/distribute) that the file was saved in
file_type = _get_hdf_filetype(prefix_path + "*")

generic_msg(
cmd="tohdf",
args={
"values": self,
"dset": dataset,
"write_mode": _mode_str_to_int("append"),
"filename": prefix_path,
"dtype": self.dtype,
"objType": self.special_objType,
"file_format": _file_type_to_int(file_type),
"overwrite": True,
},
)

if repack:
_repack_hdf(prefix_path)

def __str__(self):
from arkouda.client import pdarrayIterThresh

Expand Down
33 changes: 17 additions & 16 deletions src/GenSymIO.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -202,23 +202,24 @@ module GenSymIO {
item.add("dataset_name", dsetName.replace(Q, ESCAPED_QUOTES, -1));
item.add("arkouda_type", akType: string);
var create_str: string;
if akType == ObjType.ARRAYVIEW {
var (valName, segName) = id.splitMsgToTuple("+", 2);
select akType {
when ObjType.ARRAYVIEW {
var (valName, segName) = id.splitMsgToTuple("+", 2);
create_str = "created " + st.attrib(valName) + "+created " + st.attrib(segName);
}
else if akType == ObjType.PDARRAY {
create_str = "created " + st.attrib(id);
}
else if akType == ObjType.STRINGS {
var (segName, nBytes) = id.splitMsgToTuple("+", 2);
create_str = "created " + st.attrib(segName) + "+created bytes.size " + nBytes;
}
else if (akType == ObjType.SEGARRAY || akType == ObjType.CATEGORICAL ||
akType == ObjType.GROUPBY || akType == ObjType.DATAFRAME) {
create_str = id;
}
else {
continue;
}
when ObjType.PDARRAY, ObjType.IPV4, ObjType.DATETIME, ObjType.TIMEDELTA {
create_str = "created " + st.attrib(id);
}
when ObjType.STRINGS {
var (segName, nBytes) = id.splitMsgToTuple("+", 2);
create_str = "created " + st.attrib(segName) + "+created bytes.size " + nBytes;
}
when ObjType.SEGARRAY, ObjType.CATEGORICAL, ObjType.GROUPBY, ObjType.DATAFRAME {
create_str = id;
}
otherwise {
continue;
}
}
item.add("created", create_str);
items.pushBack(item);
Expand Down
Loading

0 comments on commit 9ec222c

Please sign in to comment.