Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodetect rockstar format #4997

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 118 additions & 2 deletions yt/frontends/rockstar/data_structures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import glob
import os
from functools import cached_property
from typing import Any, Optional

import numpy as np

Expand All @@ -9,22 +11,58 @@
from yt.geometry.particle_geometry_handler import ParticleIndex
from yt.utilities import fortran_utils as fpu
from yt.utilities.cosmology import Cosmology
from yt.utilities.exceptions import YTFieldNotFound

from .definitions import header_dt
from .fields import RockstarFieldInfo


class RockstarBinaryFile(HaloCatalogFile):
header: dict
_position_offset: int
_member_offset: int
_Npart: "np.ndarray[Any, np.dtype[np.int64]]"
_ids_halos: list[int]
_file_size: int

def __init__(self, ds, io, filename, file_id, range):
with open(filename, "rb") as f:
self.header = fpu.read_cattrs(f, header_dt, "=")
self._position_offset = f.tell()
pcount = self.header["num_halos"]

halos = np.fromfile(f, dtype=io._halo_dt, count=pcount)
self._member_offset = f.tell()
self._ids_halos = list(halos["particle_identifier"])
self._Npart = halos["num_p"]

f.seek(0, os.SEEK_END)
self._file_size = f.tell()

expected_end = self._member_offset + 8 * self._Npart.sum()
if expected_end != self._file_size:
raise RuntimeError(
f"File size {self._file_size} does not match expected size {expected_end}."
)

super().__init__(ds, io, filename, file_id, range)

def _read_particle_positions(self, ptype, f=None):
def _read_member(
self, ihalo: int
) -> Optional["np.ndarray[Any, np.dtype[np.int64]]"]:
if ihalo not in self._ids_halos:
return None

ind_halo = self._ids_halos.index(ihalo)

ipos = self._member_offset + 8 * self._Npart[:ind_halo].sum()

with open(self.filename, "rb") as f:
f.seek(ipos, os.SEEK_SET)
ids = np.fromfile(f, dtype=np.int64, count=self._Npart[ind_halo])
return ids

def _read_particle_positions(self, ptype: str, f=None):
"""
Read all particle positions in this file.
"""
Expand All @@ -48,8 +86,18 @@ def _read_particle_positions(self, ptype, f=None):
return pos


class RockstarIndex(ParticleIndex):
def get_member(self, ihalo: int):
for df in self.data_files:
members = df._read_member(ihalo)
if members is not None:
return members

raise RuntimeError(f"Could not find halo {ihalo} in any data file.")


class RockstarDataset(ParticleDataset):
_index_class = ParticleIndex
_index_class = RockstarIndex
_file_class = RockstarBinaryFile
_field_info_class = RockstarFieldInfo
_suffix = ".bin"
Expand Down Expand Up @@ -122,3 +170,71 @@ def _is_valid(cls, filename: str, *args, **kwargs) -> bool:
return False
else:
return header["magic"] == 18077126535843729616

def halo(self, ptype, particle_identifier):
return RockstarHaloContainer(
ptype,
particle_identifier,
parent_ds=None,
halo_ds=self,
)


class RockstarHaloContainer:
def __init__(self, ptype, particle_identifier, *, parent_ds, halo_ds):
if ptype not in halo_ds.particle_types_raw:
raise RuntimeError(
f'Possible halo types are {halo_ds.particle_types_raw}, supplied "{ptype}".'
)

self.ds = parent_ds
self.halo_ds = halo_ds
self.ptype = ptype
self.particle_identifier = particle_identifier

def __repr__(self):
return "%s_%s_%09d" % (self.halo_ds, self.ptype, self.particle_identifier)

def __getitem__(self, key):
if isinstance(key, tuple):
ptype, field = key
else:
ptype = self.ptype
field = key

data = {
"mass": self.mass,
"position": self.position,
"velocity": self.velocity,
"member_ids": self.member_ids,
}
if ptype == "halos" and field in data:
return data[field]

raise YTFieldNotFound((ptype, field), dataset=self.ds)

@cached_property
def ihalo(self):
halo_id = self.particle_identifier
halo_ids = list(self.halo_ds.r["halos", "particle_identifier"].astype("i8"))
ihalo = halo_ids.index(halo_id)

assert halo_ids[ihalo] == halo_id

return ihalo

@property
def mass(self):
return self.halo_ds.r["halos", "particle_mass"][self.ihalo]

@property
def position(self):
return self.halo_ds.r["halos", "particle_position"][self.ihalo]

@property
def velocity(self):
return self.halo_ds.r["halos", "particle_velocity"][self.ihalo]

@property
def member_ids(self):
return self.halo_ds.index.get_member(self.particle_identifier)
42 changes: 40 additions & 2 deletions yt/frontends/rockstar/io.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,56 @@
import os
from collections.abc import Sequence

import numpy as np

from yt.utilities import fortran_utils as fpu
from yt.utilities.io_handler import BaseParticleIOHandler

from .definitions import halo_dts
from .definitions import halo_dts, header_dt


def _can_load_with_format(
filename: str, header_fmt: Sequence[tuple[str, int, str]], halo_format: np.dtype
) -> bool:
with open(filename, "rb") as f:
header = fpu.read_cattrs(f, header_fmt, "=")
Nhalos = header["num_halos"]
Nparttot = header["num_particles"]
halos = np.fromfile(f, dtype=halo_format, count=Nhalos)

# Make sure all masses are > 0
if np.any(halos["particle_mass"] <= 0):
return False
# Make sure number of particles sums to expected value
if halos["num_p"].sum() != Nparttot:
return False

return True


class IOHandlerRockstarBinary(BaseParticleIOHandler):
_dataset_type = "rockstar_binary"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._halo_dt = halo_dts[self.ds.parameters["format_revision"]]
self._halo_dt = self.detect_rockstar_format(
self.ds.filename,
self.ds.parameters["format_revision"],
)

@staticmethod
def detect_rockstar_format(
filename: str,
guess: int,
) -> np.dtype:
revisions: list[int] = list(halo_dts.keys())
if guess in revisions:
revisions.pop(revisions.index(guess))
revisions = [guess] + revisions
for revision in revisions:
if _can_load_with_format(filename, header_dt, halo_dts[revision]):
return halo_dts[revision]
raise RuntimeError(f"Could not detect Rockstar format for file {filename}")

def _read_fluid_selection(self, chunks, selector, fields, size):
raise NotImplementedError
Expand Down
20 changes: 20 additions & 0 deletions yt/frontends/rockstar/tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,23 @@ def test_particle_selection():
ds = data_dir_load(r1)
psc = ParticleSelectionComparison(ds)
psc.run_defaults()


@requires_file(r1)
def test_halo_loading():
ds = data_dir_load(r1)

for halo_id, Npart in zip(
ds.r["halos", "particle_identifier"],
ds.r["halos", "num_p"],
):
halo = ds.halo("halos", halo_id)
assert halo is not None

# Try accessing properties
halo.position
halo.velocity
halo.mass

# Make sure we can access the member particles
assert_equal(len(halo.member_ids), Npart)
Loading