Skip to content

Commit

Permalink
Add more failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Dec 17, 2024
1 parent 1803b0c commit 60eed9d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
40 changes: 38 additions & 2 deletions tests/memory/test_udf_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
from tests.memory.utils import run_wrapper_build_partitions


def format_bytes(bytes_value):
"""Format bytes into human readable string with appropriate unit."""
for unit in ["B", "KB", "MB", "GB"]:
if bytes_value < 1024:
return f"{bytes_value:.2f} {unit}"
bytes_value /= 1024
return f"{bytes_value:.2f} GB"


@daft.udf(return_dtype=str)
def to_arrow_identity(s):
data = s.to_arrow()
Expand Down Expand Up @@ -49,10 +58,37 @@ def to_pylist_identity_batched_arrow_return(s):
to_pylist_identity_batched_arrow_return,
],
)
def test_string_identity_projection(udf):
def test_short_string_identity_projection(udf):
instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))]
inputs = [{"a": [str(uuid.uuid4()) for _ in range(62500)]}]
_, memray_file = run_wrapper_build_partitions(inputs, instructions)
stats = compute_statistics(memray_file)

expected_peak_bytes = 100
assert stats.peak_memory_allocated < expected_peak_bytes, (
f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) "
f"exceeded threshold ({format_bytes(expected_peak_bytes)})"
)


@pytest.mark.parametrize(
"udf",
[
to_arrow_identity,
to_pylist_identity,
to_arrow_identity_batched,
to_pylist_identity_batched,
to_pylist_identity_batched_arrow_return,
],
)
def test_long_string_identity_projection(udf):
instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))]
inputs = [{"a": [str(uuid.uuid4()) for _ in range(625000)]}]
_, memray_file = run_wrapper_build_partitions(inputs, instructions)
stats = compute_statistics(memray_file)

assert stats.peak_memory_allocated < 100
expected_peak_bytes = 100
assert stats.peak_memory_allocated < expected_peak_bytes, (
f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) "
f"exceeded threshold ({format_bytes(expected_peak_bytes)})"
)
6 changes: 6 additions & 0 deletions tests/memory/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import tempfile
import uuid
Expand All @@ -9,11 +10,16 @@
from daft.runners.ray_runner import build_partitions
from daft.table import MicroPartition

logger = logging.getLogger(__name__)


def run_wrapper_build_partitions(
input_partitions: list[dict], instructions: list[Instruction]
) -> tuple[list[MicroPartition], str]:
inputs = [MicroPartition.from_pydict(p) for p in input_partitions]

logger.info("Input total size: %s", sum(i.size_bytes() for i in inputs))

tmpdir = tempfile.gettempdir()
memray_path = os.path.join(tmpdir, f"memray-{uuid.uuid4()}.bin")
with memray.Tracker(memray_path, native_traces=True, follow_fork=True):
Expand Down

0 comments on commit 60eed9d

Please sign in to comment.