Skip to content

Commit

Permalink
handling of errors
Browse files Browse the repository at this point in the history
  • Loading branch information
lfarv committed Jan 12, 2025
1 parent edceae6 commit 6c3034b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
20 changes: 17 additions & 3 deletions pyat/at/latticetools/observablelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"ObservableList",
]

from collections.abc import Iterable
from collections.abc import Iterable, Iterator
from functools import reduce
from typing import Callable

Expand All @@ -44,6 +44,17 @@ def _flatten(vals, order="F"):
return np.concatenate([np.reshape(v, -1, order=order) for v in vals])


class _ObsResIter(Iterator):
def __init__(self, obsiter):
self.base = obsiter

def __next__(self):
val = next(self.base)
if isinstance(val, Exception):
raise val
return val


class _ObsResults(tuple):
def __getitem__(self, item):
if isinstance(item, slice):
Expand All @@ -54,6 +65,9 @@ def __getitem__(self, item):
raise AtError(f"Evaluation failed: {val.args[0]}") from val
return val

def __iter__(self):
return _ObsResIter(super().__iter__())


class ObservableList(list):
"""Handles a list of Observables to be evaluated together.
Expand Down Expand Up @@ -265,7 +279,7 @@ def obseval(ring, obs):
"""Evaluate a single observable."""

def check_error(data, refpts):
return data if isinstance(data, AtError) else data[refpts]
return data if isinstance(data, Exception) else data[refpts]

obsneeds = obs.needs
obsrefs = getattr(obs, "_boolrefs", None)
Expand Down Expand Up @@ -385,7 +399,7 @@ def ringeval(
trajs, orbits, rgdata, eldata, emdata, mxdata, geodata = ringeval(
ring, dp=dp, dct=dct, df=df
)
return [obseval(ring, ob) for ob in self]
return _ObsResults(obseval(ring, ob) for ob in self)

def check(self) -> bool:
"""Check the evaluation
Expand Down
9 changes: 6 additions & 3 deletions pyat/at/latticetools/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,12 @@ def evaluate(self, *data, initial: bool = False):
"""
for d in data:
if isinstance(d, Exception):
self._value = d
errtype = type(d)
err = errtype(f"Evaluation of {self.name} failed: {d.args[0]}")
err.__cause__ = d
self._value = err
self._shape = None
return d
return err

val = np.asarray(self.fun(*data, *self.args, **self.kwargs))
if initial:
Expand All @@ -357,7 +360,7 @@ def value(self):
"""Value of the observable."""
val = self._value
if isinstance(val, Exception):
raise AtError(f"Evaluation of {self.name} failed: {val.args[0]}") from val
raise val
return val

@property
Expand Down

0 comments on commit 6c3034b

Please sign in to comment.