Skip to content

Commit

Permalink
cast to np if contains numeric list
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Jan 16, 2025
1 parent 35adce9 commit 43ea5cc
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 92 deletions.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ class PyDataType:
def is_image(self) -> builtins.bool: ...
def is_fixed_shape_image(self) -> builtins.bool: ...
def is_list(self) -> builtins.bool: ...
def contains_numeric_list(self) -> builtins.bool: ...
def is_tensor(self) -> builtins.bool: ...
def is_fixed_shape_tensor(self) -> builtins.bool: ...
def is_sparse_tensor(self) -> builtins.bool: ...
Expand Down
42 changes: 10 additions & 32 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,12 @@ def __iter__(self) -> Iterator[Dict[str, Any]]:
def iter_rows(
self,
results_buffer_size: Union[Optional[int], Literal["num_cpus"]] = "num_cpus",
column_format: Literal["python", "arrow"] = "python",
) -> Iterator[Dict[str, Any]]:
"""Return an iterator of rows for this dataframe.
Each row will be a Python dictionary of the form { "key" : value, ... }. If you are instead looking to iterate over
entire partitions of data, see: :meth:`df.iter_partitions() <daft.DataFrame.iter_partitions>`.
By default, Daft will convert the columns to Python lists for easy consumption. However, for nested data such as List or Struct arrays, this can be expensive.
You may wish to set `column_format` to "arrow" such that the nested data is returned as an Arrow array.
.. NOTE::
A quick note on configuring asynchronous/parallel execution using `results_buffer_size`.
Expand All @@ -296,36 +292,20 @@ def iter_rows(
Args:
results_buffer_size: how many partitions to allow in the results buffer (defaults to the total number of CPUs
available on the machine).
column_format: the format of the columns to iterate over. One of "python", "arrow", or "numpy". Defaults to "python".
.. seealso::
:meth:`df.iter_partitions() <daft.DataFrame.iter_partitions>`: iterator over entire partitions instead of single rows
"""
if results_buffer_size == "num_cpus":
results_buffer_size = multiprocessing.cpu_count()

def arrow_iter_rows(table: "pyarrow.Table") -> Iterator[Dict[str, Any]]:
columns = table.columns
for i in range(len(table)):
row = {col._name: col[i] for col in columns}
yield row

def python_iter_rows(pydict: Dict[str, List[Any]], num_rows: int) -> Iterator[Dict[str, Any]]:
for i in range(num_rows):
row = {key: value[i] for (key, value) in pydict.items()}
yield row

if self._result is not None:
# If the dataframe has already finished executing,
# use the precomputed results.
if column_format == "python":
yield from python_iter_rows(self.to_pydict(), len(self))
elif column_format == "arrow":
yield from arrow_iter_rows(self.to_arrow())
else:
raise ValueError(
f"Unsupported column_format: {column_format}, supported formats are 'python' and 'arrow'"
)
pydict = self.to_pydict()
for i in range(len(self)):
row = {key: value[i] for (key, value) in pydict.items()}
yield row
else:
# Execute the dataframe in a streaming fashion.
context = get_context()
Expand All @@ -335,14 +315,12 @@ def python_iter_rows(pydict: Dict[str, List[Any]], num_rows: int) -> Iterator[Di

# Iterate through partitions.
for partition in partitions_iter:
if column_format == "python":
yield from python_iter_rows(partition.to_pydict(), len(partition))
elif column_format == "arrow":
yield from arrow_iter_rows(partition.to_arrow())
else:
raise ValueError(
f"Unsupported column_format: {column_format}, supported formats are 'python' and 'arrow'"
)
pydict = partition.to_pydict()

# Yield invidiual rows from the partition.
for i in range(len(partition)):
row = {key: value[i] for (key, value) in pydict.items()}
yield row

@DataframePublicAPI
def to_arrow_iter(
Expand Down
3 changes: 3 additions & 0 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,9 @@ def _is_integer(self) -> builtins.bool:
def _is_list(self) -> builtins.bool:
return self._dtype.is_list()

def _contains_numeric_list(self) -> builtins.bool:
return self._dtype.contains_numeric_list()

def _is_boolean(self) -> builtins.bool:
return self._dtype.is_boolean()

Expand Down
6 changes: 6 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def to_pylist(self) -> list:
return self._series.to_pylist()
elif self.datatype()._should_cast_to_python():
return self._series.cast(DataType.python()._dtype).to_pylist()
elif self.datatype()._contains_numeric_list():
try:
return self._series.to_arrow().to_numpy(False)
except Exception as e:
warnings.warn(f"Error converting series containing numeric list to numpy: {e}, falling back to pylist")
return self._series.to_arrow().to_pylist()
else:
return self._series.to_arrow().to_pylist()

Expand Down
11 changes: 11 additions & 0 deletions src/daft-schema/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,17 @@ impl DataType {
matches!(self, Self::List(..))
}

#[inline]
pub fn contains_numeric_list(&self) -> bool {
match self {
Self::List(dtype) => dtype.is_numeric(),
Self::FixedSizeList(dtype, _) => dtype.is_numeric(),
Self::Extension(_, inner, _) => inner.contains_numeric_list(),
Self::Struct(fields) => fields.iter().any(|f| f.dtype.contains_numeric_list()),
_ => false,
}
}

#[inline]
pub fn is_string(&self) -> bool {
matches!(self, Self::Utf8)
Expand Down
4 changes: 4 additions & 0 deletions src/daft-schema/src/python/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ impl PyDataType {
Ok(self.dtype.is_list())
}

pub fn contains_numeric_list(&self) -> PyResult<bool> {
Ok(self.dtype.contains_numeric_list())
}

pub fn is_boolean(&self) -> PyResult<bool> {
Ok(self.dtype.is_boolean())
}
Expand Down
105 changes: 45 additions & 60 deletions tests/dataframe/test_iter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import numpy as np
import pyarrow as pa
import pytest

import daft
from daft.datatype import DataType
from tests.conftest import get_tests_daft_runner_name


Expand All @@ -26,67 +26,66 @@ def test_iter_rows(make_df, materialized):


@pytest.mark.parametrize(
"format, data, expected",
"dtype, data, expected",
[
### Ints
pytest.param("python", [1, 2, 3], [1, 2, 3], id="python_ints"),
### Lists
pytest.param(
"arrow",
[1, 2, 3],
[pa.scalar(1), pa.scalar(2), pa.scalar(3)],
id="arrow_ints",
DataType.list(DataType.int64()),
[[1, 2], [3, 4, 5]],
[np.array([1, 2]), np.array([3, 4, 5])],
id="int_list_to_numpy",
),
### Strings
pytest.param("python", ["a", "b", "c"], ["a", "b", "c"], id="python_strs"),
pytest.param(
"arrow",
["a", "b", "c"],
[
pa.scalar("a", pa.large_string()),
pa.scalar("b", pa.large_string()),
pa.scalar("c", pa.large_string()),
],
id="arrow_strs",
DataType.list(DataType.int64()),
[[1.0, 2.0], [3.0, 4.0, 5]],
[np.array([1.0, 2.0]), np.array([3.0, 4.0, 5.0])],
id="float_list_to_numpy",
),
### Lists
pytest.param("python", [[1, 2], [3, 4]], [[1, 2], [3, 4]], id="python_lists"),
pytest.param(
"arrow",
DataType.list(DataType.string()),
[["a", "b"], ["c", "d", "e"]],
[["a", "b"], ["c", "d", "e"]],
id="string_list_to_python",
),
### Fixed size lists
pytest.param(
DataType.fixed_size_list(DataType.int64(), 2),
[[1, 2], [3, 4]],
[
pa.scalar([1, 2], pa.large_list(pa.int64())),
pa.scalar([3, 4], pa.large_list(pa.int64())),
],
id="arrow_lists",
[np.array([1, 2]), np.array([3, 4])],
id="int_fixed_size_list_to_numpy",
),
pytest.param(
DataType.fixed_size_list(DataType.int64(), 2),
[[1.0, 2.0], [3.0, 4.0]],
[np.array([1.0, 2.0]), np.array([3.0, 4.0])],
id="float_fixed_size_list_to_numpy",
),
pytest.param(
DataType.fixed_size_list(DataType.string(), 2),
[["a", "b"], ["c", "d"]],
[["a", "b"], ["c", "d"]],
id="string_fixed_size_list_to_python",
),
### Structs
pytest.param(
"python",
[{"a": 1, "b": 2}, {"a": 3, "b": 4}],
[{"a": 1, "b": 2}, {"a": 3, "b": 4}],
id="python_structs",
DataType.struct({"a": DataType.int64(), "b": DataType.string()}),
[{"a": 1, "b": "a"}, {"a": 2, "b": "b"}],
[{"a": 1, "b": "a"}, {"a": 2, "b": "b"}],
id="struct_to_python",
),
### Structs with lists
pytest.param(
"arrow",
[{"a": 1, "b": 2}, {"a": 3, "b": 4}],
[
pa.scalar(
{"a": 1, "b": 2},
pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.int64())]),
),
pa.scalar(
{"a": 3, "b": 4},
pa.struct([pa.field("a", pa.int64()), pa.field("b", pa.int64())]),
),
],
id="arrow_structs",
DataType.struct({"a": DataType.int64(), "list": DataType.list(DataType.int64())}),
[{"a": 1, "list": [1, 2]}, {"a": 2, "list": [3, 4]}],
[{"a": 1, "list": np.array([1, 2])}, {"a": 2, "list": np.array([3, 4])}],
id="struct_with_list_to_numpy",
),
],
)
def test_iter_rows_column_formats(make_df, format, data, expected):
df = make_df({"a": data})
def test_iter_rows_nested_dtypes(make_df, dtype, data, expected):
df = make_df({"a": data}).select(daft.col("a").cast(dtype))

rows = list(df.iter_rows(column_format=format))
rows = list(df.iter_rows())

def compare_values(v1, v2):
if isinstance(v1, np.ndarray) and isinstance(v2, np.ndarray):
Expand All @@ -101,20 +100,6 @@ def compare_values(v1, v2):
assert compare_values(actual_row, expected_row)


@pytest.mark.parametrize(
"format",
[
"arrow",
"numpy",
],
)
def test_iter_rows_column_format_not_compatible(format):
df = daft.from_pydict({"a": [object()]}) # Object type is not supported by arrow or numpy

with pytest.raises(ValueError):
list(df.iter_rows(column_format=format))


@pytest.mark.parametrize("materialized", [False, True])
def test_iter_partitions(make_df, materialized):
# Test that df.iter_partitions() produces partitions in the correct order.
Expand Down

0 comments on commit 43ea5cc

Please sign in to comment.