Skip to content

Commit

Permalink
avoid particle_index type cast
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Sep 23, 2024
1 parent e648ef7 commit 9b5c823
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
3 changes: 2 additions & 1 deletion yt/data_objects/selection_objects/data_selection_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def get_data(self, fields=None):

for f, v in read_particles.items():
self.field_data[f] = self.ds.arr(v, units=finfos[f].units)
self.field_data[f].convert_to_units(finfos[f].output_units)
if finfos[f].units != finfos[f].output_units:
self.field_data[f].convert_to_units(finfos[f].output_units)

fields_to_generate += gen_fluids + gen_particles
self._generate_fields(fields_to_generate)
Expand Down
65 changes: 65 additions & 0 deletions yt/frontends/stream/tests/test_stream_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
from numpy.testing import assert_equal

import yt
import yt.utilities.initial_conditions as ic
from yt.loaders import load_amr_grids, load_particles, load_uniform_grid
from yt.testing import fake_particle_ds, fake_sph_orientation_ds
Expand Down Expand Up @@ -404,3 +405,67 @@ def test_stream_non_cartesian_particles_amr():
assert_equal(dd["all", "particle_position_r"].v, particle_position_r)
assert_equal(dd["all", "particle_position_phi"].v, particle_position_phi)
assert_equal(dd["all", "particle_position_theta"].v, particle_position_theta)


@pytest.fixture
def sph_dataset_with_integer_index():
num_particles = 100

data = {
("gas", "particle_position_x"): np.linspace(0.0, 1.0, num_particles),
("gas", "particle_position_y"): np.linspace(0.0, 1.0, num_particles),
("gas", "particle_position_z"): np.linspace(0.0, 1.0, num_particles),
("gas", "particle_mass"): np.ones(num_particles),
("gas", "density"): np.ones(num_particles),
("gas", "smoothing_length"): np.ones(num_particles) * 0.1,
("gas", "particle_index"): np.arange(0, num_particles),
}

ds = load_particles(data)
return ds


def test_particle_dtypes_selection(sph_dataset_with_integer_index):
# these operations will preserve data type
ds = sph_dataset_with_integer_index
ad = ds.all_data()
assert ad["gas", "particle_index"].dtype == np.int64

min_max = ad.quantities.extrema(("gas", "particle_index"))
assert min_max.dtype == np.int64

# check that subselections preserve type
le = ds.domain_center - ds.domain_width / 10.0
re = ds.domain_center + ds.domain_width / 10.0
reg = ds.region(ds.domain_center, le, re)
assert reg["gas", "particle_index"].dtype == np.int64

vals = ds.slice(0, ds.domain_center[0])["gas", "particle_index"]
assert vals.max() > 0
assert vals.dtype == np.int64


def test_particle_dtypes_operations(sph_dataset_with_integer_index):
# these operations will not preserve dtype (will be cast to float64).
# note that the numerical outputs of these operations are not
# physical (projecting the particle index does not make any physical
# sense), but they do make sure the methods run in case any frontends
# start setting physical fields with different data types.

ds = sph_dataset_with_integer_index

field = ("gas", "particle_index")
frb = ds.proj(field, 2).to_frb(ds.domain_width[0], (64, 64))
image = frb["gas", "particle_index"]
assert image.max() > 0

off_axis_prj = yt.off_axis_projection(
ds,
ds.domain_center,
[0.5, 0.5, 0.5],
ds.domain_width,
(64, 64),
("gas", "particle_index"),
weight=None,
)
assert off_axis_prj.max() > 0
21 changes: 19 additions & 2 deletions yt/utilities/io_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ def push(self, grid, field, data):
raise ValueError
self.queue[grid][field] = data

@property
def _particle_dtypes(self) -> defaultdict[FieldKey, type]:
# returns a defaultdict of data types for particle fields.
# defaults for all fields are float64, except for the particle_index
# field (or its alias if it exists), which is set as int64. Important
# to note that the data type will only be preserved for direct reads
# and any operations will either implicity (via unyt) or explicitly
# convert to float64.
dtypes: defaultdict[FieldKey, type] = defaultdict(lambda: np.float64)
for ptype in self.ds.particle_types:
p_index = (ptype, "particle_index")
if p_index in self.ds.field_info.field_aliases:
p_index = self.ds.field_info.field_aliases[p_index]
dtypes[p_index] = np.int64
return dtypes

def _field_in_backup(self, grid, backup_file, field_name):
if os.path.exists(backup_file):
fhandle = h5py.File(backup_file, mode="r")
Expand Down Expand Up @@ -173,6 +189,7 @@ def _read_particle_selection(
# field_maps stores fields, accounting for field unions
ptf: defaultdict[str, list[str]] = defaultdict(list)
field_maps: defaultdict[FieldKey, list[FieldKey]] = defaultdict(list)
p_dtypes = self._particle_dtypes

# We first need a set of masks for each particle type
chunks = list(chunks)
Expand Down Expand Up @@ -206,14 +223,14 @@ def _read_particle_selection(
vals = data.pop(field_f)
# note: numpy.concatenate has a dtype argument that would avoid
# a copy using .astype(...), available in numpy>=1.20
rv[field_f] = np.concatenate(vals, axis=0).astype("float64")
rv[field_f] = np.concatenate(vals, axis=0).astype(p_dtypes[field_f])
else:
shape = [0]
if field_f[1] in self._vector_fields:
shape.append(self._vector_fields[field_f[1]])
elif field_f[1] in self._array_fields:
shape.append(self._array_fields[field_f[1]])
rv[field_f] = np.empty(shape, dtype="float64")
rv[field_f] = np.empty(shape, dtype=p_dtypes[field_f])
return rv

def _read_particle_fields(self, chunks, ptf, selector):
Expand Down

0 comments on commit 9b5c823

Please sign in to comment.