diff --git a/newsfragments/4671.fixed.md b/newsfragments/4671.fixed.md new file mode 100644 index 00000000000..9b0cd9d8f0c --- /dev/null +++ b/newsfragments/4671.fixed.md @@ -0,0 +1 @@ +Make `PyErr` internals thread-safe. diff --git a/src/err/err_state.rs b/src/err/err_state.rs index 2ba153b6ef8..70d449c8b52 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -1,4 +1,8 @@ -use std::cell::UnsafeCell; +use std::{ + cell::UnsafeCell, + sync::{Mutex, Once}, + thread::ThreadId, +}; use crate::{ exceptions::{PyBaseException, PyTypeError}, @@ -11,13 +15,14 @@ use crate::{ pub(crate) struct PyErrState { // Safety: can only hand out references when in the "normalized" state. Will never change // after normalization. - // - // The state is temporarily removed from the PyErr during normalization, to avoid - // concurrent modifications. + normalized: Once, + // Guard against re-entrancy when normalizing the exception state. + normalizing_thread: Mutex>, inner: UnsafeCell>, } -// The inner value is only accessed through ways that require the gil is held. +// Safety: The inner value is protected by locking to ensure that only the normalized state is +// handed out as a reference. unsafe impl Send for PyErrState {} unsafe impl Sync for PyErrState {} @@ -48,17 +53,22 @@ impl PyErrState { fn from_inner(inner: PyErrStateInner) -> Self { Self { + normalized: Once::new(), + normalizing_thread: Mutex::new(None), inner: UnsafeCell::new(Some(inner)), } } #[inline] pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized { - if let Some(PyErrStateInner::Normalized(n)) = unsafe { - // Safety: self.inner will never be written again once normalized. - &*self.inner.get() - } { - return n; + if self.normalized.is_completed() { + match unsafe { + // Safety: self.inner will never be written again once normalized. + &*self.inner.get() + } { + Some(PyErrStateInner::Normalized(n)) => return n, + _ => unreachable!(), + } } self.make_normalized(py) @@ -69,25 +79,47 @@ impl PyErrState { // This process is safe because: // - Access is guaranteed not to be concurrent thanks to `Python` GIL token // - Write happens only once, and then never will change again. - // - State is set to None during the normalization process, so that a second - // concurrent normalization attempt will panic before changing anything. - // FIXME: this needs to be rewritten to deal with free-threaded Python - // see https://github.com/PyO3/pyo3/issues/4584 + // Guard against re-entrant normalization, because `Once` does not provide + // re-entrancy guarantees. + if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() { + assert!( + !(*thread == std::thread::current().id()), + "Re-entrant normalization of PyErrState detected" + ); + } - let state = unsafe { - (*self.inner.get()) - .take() - .expect("Cannot normalize a PyErr while already normalizing it.") - }; + // avoid deadlock of `.call_once` with the GIL + py.allow_threads(|| { + self.normalized.call_once(|| { + self.normalizing_thread + .lock() + .unwrap() + .replace(std::thread::current().id()); + + // Safety: no other thread can access the inner value while we are normalizing it. + let state = unsafe { + (*self.inner.get()) + .take() + .expect("Cannot normalize a PyErr while already normalizing it.") + }; + + let normalized_state = + Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py))); + + // Safety: no other thread can access the inner value while we are normalizing it. + unsafe { + *self.inner.get() = Some(normalized_state); + } + }) + }); - unsafe { - let self_state = &mut *self.inner.get(); - *self_state = Some(PyErrStateInner::Normalized(state.normalize(py))); - match self_state { - Some(PyErrStateInner::Normalized(n)) => n, - _ => unreachable!(), - } + match unsafe { + // Safety: self.inner will never be written again once normalized. + &*self.inner.get() + } { + Some(PyErrStateInner::Normalized(n)) => n, + _ => unreachable!(), } } } @@ -321,3 +353,79 @@ fn raise_lazy(py: Python<'_>, lazy: Box) { } } } + +#[cfg(test)] +mod tests { + + use crate::{ + exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python, + }; + + #[test] + #[should_panic(expected = "Re-entrant normalization of PyErrState detected")] + fn test_reentrant_normalization() { + static ERR: GILOnceCell = GILOnceCell::new(); + + struct RecursiveArgs; + + impl PyErrArguments for RecursiveArgs { + fn arguments(self, py: Python<'_>) -> PyObject { + // .value(py) triggers normalization + ERR.get(py) + .expect("is set just below") + .value(py) + .clone() + .into() + } + } + + Python::with_gil(|py| { + ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap(); + ERR.get(py).expect("is set just above").value(py); + }) + } + + #[test] + fn test_no_deadlock_thread_switch() { + static ERR: GILOnceCell = GILOnceCell::new(); + + struct GILSwitchArgs; + + impl PyErrArguments for GILSwitchArgs { + fn arguments(self, py: Python<'_>) -> PyObject { + // releasing the GIL potentially allows for other threads to deadlock + // with the normalization going on here + py.allow_threads(|| { + std::thread::sleep(std::time::Duration::from_millis(10)); + }); + py.None() + } + } + + Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap()); + + // Let many threads attempt to read the normalized value at the same time + let handles = (0..10) + .map(|_| { + std::thread::spawn(|| { + Python::with_gil(|py| { + ERR.get(py).expect("is set just above").value(py); + }); + }) + }) + .collect::>(); + + for handle in handles { + handle.join().unwrap(); + } + + // We should never have deadlocked, and should be able to run + // this assertion + Python::with_gil(|py| { + assert!(ERR + .get(py) + .expect("is set above") + .is_instance_of::(py)) + }); + } +}