diff --git a/jax_galsim/position.py b/jax_galsim/position.py index 3b84e3fe..4b5f0ad7 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -134,8 +134,8 @@ def shear(self, shear): shear_pos = jnp.dot(shear_mat, self._array) return PositionD(shear_pos[0], shear_pos[1]) - @implements(_galsim.Position.round) def round(self): + """Return the rounded-off PositionI version of this position.""" return PositionI(jnp.round(self.x), jnp.round(self.y)) def tree_flatten(self):