Skip to content

Commit

Permalink
Ensure correct dtypes in dependency table (#416)
Browse files Browse the repository at this point in the history
* Add failing test

* Fix dtypes

* Revert __eq__()

* Refine comment

* Improve code readability
  • Loading branch information
hagenw authored May 29, 2024
1 parent 44df511 commit 210441f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 22 deletions.
50 changes: 30 additions & 20 deletions audb/core/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,8 @@ class Dependencies:
""" # noqa: E501

def __init__(self):
data = {}
for name, dtype in zip(
define.DEPEND_FIELD_NAMES.values(),
define.DEPEND_FIELD_DTYPES.values(),
):
data[name] = pd.Series(dtype=dtype)
self._df = pd.DataFrame(data)
self._df.index = self._df.index.astype(define.DEPEND_INDEX_DTYPE)
self._df = pd.DataFrame(columns=define.DEPEND_FIELD_NAMES.values())
self._df = self._set_dtypes(self._df)
# pyarrow schema
# used for reading and writing files
self._schema = pa.schema(
Expand Down Expand Up @@ -328,12 +322,10 @@ def load(self, path: str):
)
if extension == "pkl":
self._df = pd.read_pickle(path)
# Correct dtype of index
# Correct dtypes
# to make backward compatiple
# with old pickle files in cache
# that might use `string` as dtype
if self._df.index.dtype != define.DEPEND_INDEX_DTYPE:
self._df.index = self._df.index.astype(define.DEPEND_INDEX_DTYPE)
self._df = self._set_dtypes(self._df)

elif extension == "csv":
table = csv.read_csv(
Expand Down Expand Up @@ -483,8 +475,7 @@ def _add_media(
values,
columns=["file"] + list(define.DEPEND_FIELD_NAMES.values()),
).set_index("file")
df.index = df.index.astype(define.DEPEND_INDEX_DTYPE)

df = self._set_dtypes(df)
self._df = pd.concat([self._df, df])

def _add_meta(
Expand Down Expand Up @@ -583,6 +574,30 @@ def _remove(self, file: str):
"""
self._df.at[file, "removed"] = 1

@staticmethod
def _set_dtypes(df: pd.DataFrame) -> pd.DataFrame:
r"""Set dependency table dtypes.
Args:
df: dataframe representing dependency table
Returns:
dataframe representing dependency table
with correct dtypes
"""
# Check the dtype of index,
# to decide if we need to update dtypes,
# as dtype of index changed to `object`
# in version 1.7.0 of audb.
if df.index.dtype != define.DEPEND_INDEX_DTYPE:
df.index = df.index.astype(define.DEPEND_INDEX_DTYPE, copy=False)
columns = define.DEPEND_FIELD_NAMES.values()
dtypes = define.DEPEND_FIELD_DTYPES.values()
mapping = {column: dtype for column, dtype in zip(columns, dtypes)}
df = df.astype(mapping, copy=False)
return df

def _table_to_dataframe(self, table: pa.Table) -> pd.DataFrame:
r"""Convert pyarrow table to pandas dataframe.
Expand Down Expand Up @@ -639,12 +654,7 @@ def _update_media(
values,
columns=["file"] + list(define.DEPEND_FIELD_NAMES.values()),
).set_index("file")
df.index = df.index.astype(define.DEPEND_INDEX_DTYPE)
for name, dtype in zip(
define.DEPEND_FIELD_NAMES.values(),
define.DEPEND_FIELD_DTYPES.values(),
):
df[name] = df[name].astype(dtype)
df = self._set_dtypes(df)
self._df.loc[df.index] = df

def _update_media_version(
Expand Down
55 changes: 53 additions & 2 deletions tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,66 @@ def test_load_save_backward_compatibility(tmpdir, deps):
we need to make sure this is corrected
when loading old cache files.
Old behaviour (audb<1.7):
archive string[python]
bit_depth int32
channels int32
checksum string[python]
duration float64
format string[python]
removed int32
sampling_rate int32
type int32
version string[python]
New behaviour (audb>=1.7):
archive string[pyarrow]
bit_depth int32[pyarrow]
channels int32[pyarrow]
checksum string[pyarrow]
duration double[pyarrow]
format string[pyarrow]
removed int32[pyarrow]
sampling_rate int32[pyarrow]
type int32[pyarrow]
version string[pyarrow]
"""
deps_file = audeer.path(tmpdir, "deps.pkl")

deps_old = audb.Dependencies()
deps_old._df = deps._df.copy()

# Change dtype of index from object to string
# to mimic previous behavior
deps._df.index = deps._df.index.astype("string")
deps.save(deps_file)
deps_old._df.index = deps_old._df.index.astype("string")
# Change dtype of columns
# to mimic previous behavior
deps_old._df = deps_old._df.astype(
{
"archive": "string",
"bit_depth": "int32",
"channels": "int32",
"checksum": "string",
"duration": "float64",
"format": "string",
"removed": "int32",
"sampling_rate": "int32",
"type": "int32",
"version": "string",
}
)
deps_old.save(deps_file)

# Check that we get the correct dtypes,
# when loading from cache
deps2 = audb.Dependencies()
deps2.load(deps_file)
assert deps2._df.index.dtype == audb.core.define.DEPEND_INDEX_DTYPE
pd.testing.assert_frame_equal(deps._df, deps2._df)
assert deps == deps2


def test_load_save_errors(deps):
Expand Down

0 comments on commit 210441f

Please sign in to comment.