diff --git a/corsikaio/file.py b/corsikaio/file.py index 1a63c29..42e5e4e 100644 --- a/corsikaio/file.py +++ b/corsikaio/file.py @@ -109,7 +109,7 @@ def __next__(self): raise IOError("File seems to be truncated") if self.parse_blocks: - event_end = parse_event_end(block)[0] + event_end = parse_event_end(block,self.version)[0] data = self.parse_data_blocks(data_bytes) longitudinal = parse_longitudinal(long_bytes) else: diff --git a/corsikaio/subblocks/__init__.py b/corsikaio/subblocks/__init__.py index 503b4eb..4be0f5c 100644 --- a/corsikaio/subblocks/__init__.py +++ b/corsikaio/subblocks/__init__.py @@ -5,7 +5,7 @@ from .run_end import run_end_dtype from .event_header import event_header_types -from .event_end import event_end_dtype +from .event_end import event_end_types from .data import cherenkov_photons_dtype, particle_data_dtype from .longitudinal import longitudinal_data_dtype @@ -38,8 +38,8 @@ def parse_event_header(event_header_bytes): return np.frombuffer(event_header_bytes, dtype=event_header_types[version]) -def parse_event_end(event_end_bytes): - return np.frombuffer(event_end_bytes, dtype=event_end_dtype) +def parse_event_end(event_end_bytes,version): + return np.frombuffer(event_end_bytes, dtype=event_end_types[float(str(version)[:3])]) def get_version(header_bytes, version_pos): diff --git a/corsikaio/subblocks/event_end.py b/corsikaio/subblocks/event_end.py index a5f4305..85da66d 100644 --- a/corsikaio/subblocks/event_end.py +++ b/corsikaio/subblocks/event_end.py @@ -1,7 +1,9 @@ -from .dtypes import build_dtype, Field +import warnings +from collections import defaultdict +from .dtypes import build_dtype, Field -event_end_fields = [ +event_end_fields_65 = [ Field(1, "event_end", dtype="S4"), Field(2, "event_number"), Field(3, "n_photons_weighted"), @@ -17,4 +19,44 @@ Field(267, "n_em_particles_preshower"), ] -event_end_dtype = build_dtype(event_end_fields) +event_end_fields_7x = [ + Field(1, "event_end", dtype="S4"), + Field(2, "event_number"), + Field(3, "n_photons_weighted"), + Field(4, "n_electrons_weighted"), + Field(5, "n_hadrons_weighted"), + Field(6, "n_muons_weighted"), + Field(7, "n_particles_written"), + Field(256, "longitudinal_fit_parameters", shape=6), + Field(262, "chi_square_longitudinal"), + Field(263, "n_photons_written"), + Field(264, "n_electrons_written"), + Field(265, "n_hadrons_written"), + Field(266, "n_muons_written"), + Field(267, "n_em_particles_preshower"), + ] + +event_end_dtype_65 = build_dtype(event_end_fields_65) +event_end_dtype_7x = build_dtype(event_end_fields_7x) + +def warn_dtype(): + warnings.warn("Version unknown, using default event end definition dtype of version 7.x") + return event_end_dtype_7x + +def warn_fields(): + warnings.warn("Version unknown, using default event end fields definition of version 7.x") + return event_end_fields_7x + +event_end_fields = defaultdict(warn_fields) +event_end_fields[6.5] = event_end_fields_65 +event_end_fields[7.4] = event_end_fields_7x +event_end_fields[7.5] = event_end_fields_7x +event_end_fields[7.6] = event_end_fields_7x +event_end_fields[7.7] = event_end_fields_7x + +event_end_types = defaultdict(warn_dtype) +event_end_types[6.5] = event_end_dtype_65 +event_end_types[7.4] = event_end_dtype_7x +event_end_types[7.5] = event_end_dtype_7x +event_end_types[7.6] = event_end_dtype_7x +event_end_types[7.7] = event_end_dtype_7x diff --git a/tests/resources/corsika_77500_particle b/tests/resources/corsika_77500_particle new file mode 100644 index 0000000..01b21ff Binary files /dev/null and b/tests/resources/corsika_77500_particle differ diff --git a/tests/test_file.py b/tests/test_file.py index c695b14..2a9d0e4 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -118,3 +118,19 @@ def test_truncated(tmp_path, size): with CorsikaParticleFile(path) as f: for _ in f: pass + + +def test_longitudinal_parameters(): + '''Test event end blocks contain longitudinal parameters''' + from corsikaio import CorsikaParticleFile + + path = "tests/resources/corsika_77500_particle" + + with CorsikaParticleFile(path) as f: + n_events = 0 + for event in f: + n_events += 1 + assert "longitudinal_fit_parameters" in event.end.dtype.names + parameters = event.end["longitudinal_fit_parameters"] + np.testing.assert_array_equal(parameters != 0, True) + assert n_events == 5 diff --git a/tests/test_units.py b/tests/test_units.py index d520e0e..61b691d 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -25,8 +25,10 @@ def test_new_field(): def test_event_end_units(): - - assert all([f.unit is None for f in event_end_fields]) + + for version in event_end_fields: + + assert all([f.unit is None for f in event_end_fields[version]]) def test_run_end_units():