From 08beaa53f976be7dade7bc4386c87f2671092517 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Wed, 30 Oct 2024 19:53:28 +0000 Subject: [PATCH 1/5] make `PyErrState` thread-safe --- src/err/err_state.rs | 77 ++++++++++++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 25 deletions(-) diff --git a/src/err/err_state.rs b/src/err/err_state.rs index 2ba153b6ef8..f8ef457d95c 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -1,4 +1,8 @@ -use std::cell::UnsafeCell; +use std::{ + cell::UnsafeCell, + sync::{Mutex, Once}, + thread::ThreadId, +}; use crate::{ exceptions::{PyBaseException, PyTypeError}, @@ -11,13 +15,14 @@ use crate::{ pub(crate) struct PyErrState { // Safety: can only hand out references when in the "normalized" state. Will never change // after normalization. - // - // The state is temporarily removed from the PyErr during normalization, to avoid - // concurrent modifications. + normalized: Once, + // Guard against re-entrancy when normalizing the exception state. + normalizing_thread: Mutex>, inner: UnsafeCell>, } -// The inner value is only accessed through ways that require the gil is held. +// Safety: The inner value is protected by locking to ensure that only the normalized state is +// handed out as a reference. unsafe impl Send for PyErrState {} unsafe impl Sync for PyErrState {} @@ -48,17 +53,22 @@ impl PyErrState { fn from_inner(inner: PyErrStateInner) -> Self { Self { + normalized: Once::new(), + normalizing_thread: Mutex::new(None), inner: UnsafeCell::new(Some(inner)), } } #[inline] pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized { - if let Some(PyErrStateInner::Normalized(n)) = unsafe { - // Safety: self.inner will never be written again once normalized. - &*self.inner.get() - } { - return n; + if self.normalized.is_completed() { + match unsafe { + // Safety: self.inner will never be written again once normalized. + &*self.inner.get() + } { + Some(PyErrStateInner::Normalized(n)) => return n, + _ => unreachable!(), + } } self.make_normalized(py) @@ -69,25 +79,42 @@ impl PyErrState { // This process is safe because: // - Access is guaranteed not to be concurrent thanks to `Python` GIL token // - Write happens only once, and then never will change again. - // - State is set to None during the normalization process, so that a second - // concurrent normalization attempt will panic before changing anything. - // FIXME: this needs to be rewritten to deal with free-threaded Python - // see https://github.com/PyO3/pyo3/issues/4584 + // Guard against re-entrant normalization, because `Once` does not provide + // re-entrancy guarantees. + if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() { + if *thread == std::thread::current().id() { + panic!("Re-entrant normalization of PyErrState detected"); + } + } - let state = unsafe { - (*self.inner.get()) - .take() - .expect("Cannot normalize a PyErr while already normalizing it.") - }; + self.normalized.call_once(|| { + self.normalizing_thread + .lock() + .unwrap() + .replace(std::thread::current().id()); + + // Safety: no other thread can access the inner value while we are normalizing it. + let state = unsafe { + (*self.inner.get()) + .take() + .expect("Cannot normalize a PyErr while already normalizing it.") + }; - unsafe { - let self_state = &mut *self.inner.get(); - *self_state = Some(PyErrStateInner::Normalized(state.normalize(py))); - match self_state { - Some(PyErrStateInner::Normalized(n)) => n, - _ => unreachable!(), + let normalized_state = PyErrStateInner::Normalized(state.normalize(py)); + + // Safety: no other thread can access the inner value while we are normalizing it. + unsafe { + *self.inner.get() = Some(normalized_state); } + }); + + match unsafe { + // Safety: self.inner will never be written again once normalized. + &*self.inner.get() + } { + Some(PyErrStateInner::Normalized(n)) => return n, + _ => unreachable!(), } } } From 8421034726d65ce58f8e6587f7b9a0fb6c113f62 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Thu, 31 Oct 2024 09:04:19 -0600 Subject: [PATCH 2/5] fix clippy --- src/err/err_state.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/err/err_state.rs b/src/err/err_state.rs index f8ef457d95c..5d930794014 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -83,9 +83,10 @@ impl PyErrState { // Guard against re-entrant normalization, because `Once` does not provide // re-entrancy guarantees. if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() { - if *thread == std::thread::current().id() { - panic!("Re-entrant normalization of PyErrState detected"); - } + assert!( + !(*thread == std::thread::current().id()), + "Re-entrant normalization of PyErrState detected" + ); } self.normalized.call_once(|| { @@ -113,7 +114,7 @@ impl PyErrState { // Safety: self.inner will never be written again once normalized. &*self.inner.get() } { - Some(PyErrStateInner::Normalized(n)) => return n, + Some(PyErrStateInner::Normalized(n)) => n, _ => unreachable!(), } } From b4d3a94f148754f1ec215ce977779a9d1b86e872 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 31 Oct 2024 20:52:52 +0000 Subject: [PATCH 3/5] add test of reentrancy, fix deadlock --- src/err/err_state.rs | 118 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 19 deletions(-) diff --git a/src/err/err_state.rs b/src/err/err_state.rs index 5d930794014..be5918389c6 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -89,25 +89,29 @@ impl PyErrState { ); } - self.normalized.call_once(|| { - self.normalizing_thread - .lock() - .unwrap() - .replace(std::thread::current().id()); - - // Safety: no other thread can access the inner value while we are normalizing it. - let state = unsafe { - (*self.inner.get()) - .take() - .expect("Cannot normalize a PyErr while already normalizing it.") - }; - - let normalized_state = PyErrStateInner::Normalized(state.normalize(py)); - - // Safety: no other thread can access the inner value while we are normalizing it. - unsafe { - *self.inner.get() = Some(normalized_state); - } + // avoid deadlock of `.call_once` with the GIL + py.allow_threads(|| { + self.normalized.call_once(|| { + self.normalizing_thread + .lock() + .unwrap() + .replace(std::thread::current().id()); + + // Safety: no other thread can access the inner value while we are normalizing it. + let state = unsafe { + (*self.inner.get()) + .take() + .expect("Cannot normalize a PyErr while already normalizing it.") + }; + + let normalized_state = + Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py))); + + // Safety: no other thread can access the inner value while we are normalizing it. + unsafe { + *self.inner.get() = Some(normalized_state); + } + }) }); match unsafe { @@ -349,3 +353,79 @@ fn raise_lazy(py: Python<'_>, lazy: Box) { } } } + +#[cfg(test)] +mod tests { + use std::sync::OnceLock; + + use crate::{exceptions::PyValueError, PyErr, PyErrArguments, PyObject, Python}; + + #[test] + #[should_panic(expected = "Re-entrant normalization of PyErrState detected")] + fn test_reentrant_normalization() { + static ERR: OnceLock = OnceLock::new(); + + struct RecursiveArgs; + + impl PyErrArguments for RecursiveArgs { + fn arguments(self, py: Python<'_>) -> PyObject { + // .value(py) triggers normalization + ERR.get() + .expect("is set just below") + .value(py) + .clone() + .into() + } + } + + ERR.set(PyValueError::new_err(RecursiveArgs)).unwrap(); + + Python::with_gil(|py| { + ERR.get().expect("is set just above").value(py); + }) + } + + #[test] + fn test_no_deadlock_thread_switch() { + static ERR: OnceLock = OnceLock::new(); + + struct GILSwitchArgs; + + impl PyErrArguments for GILSwitchArgs { + fn arguments(self, py: Python<'_>) -> PyObject { + // releasing the GIL potentially allows for other threads to deadlock + // with the normalization going on here + py.allow_threads(|| { + std::thread::sleep(std::time::Duration::from_millis(10)); + }); + py.None() + } + } + + ERR.set(PyValueError::new_err(GILSwitchArgs)).unwrap(); + + // Let many threads attempt to read the normalized value at the same time + let handles = (0..10) + .map(|_| { + std::thread::spawn(|| { + Python::with_gil(|py| { + ERR.get().expect("is set just above").value(py); + }); + }) + }) + .collect::>(); + + for handle in handles { + handle.join().unwrap(); + } + + // We should never have deadlocked, and should be able to run + // this assertion + Python::with_gil(|py| { + assert!(ERR + .get() + .expect("is set above") + .is_instance_of::(py)) + }); + } +} From 4a30ddec828745d9e52d3b02457ca6e184450183 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 31 Oct 2024 20:54:45 +0000 Subject: [PATCH 4/5] newsfragment --- newsfragments/4671.fixed.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/4671.fixed.md diff --git a/newsfragments/4671.fixed.md b/newsfragments/4671.fixed.md new file mode 100644 index 00000000000..9b0cd9d8f0c --- /dev/null +++ b/newsfragments/4671.fixed.md @@ -0,0 +1 @@ +Make `PyErr` internals thread-safe. From f5fa4522534321d26806033da4f08007f8bc70d2 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 31 Oct 2024 21:37:42 +0000 Subject: [PATCH 5/5] fix MSRV --- src/err/err_state.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/err/err_state.rs b/src/err/err_state.rs index be5918389c6..70d449c8b52 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -356,21 +356,22 @@ fn raise_lazy(py: Python<'_>, lazy: Box) { #[cfg(test)] mod tests { - use std::sync::OnceLock; - use crate::{exceptions::PyValueError, PyErr, PyErrArguments, PyObject, Python}; + use crate::{ + exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python, + }; #[test] #[should_panic(expected = "Re-entrant normalization of PyErrState detected")] fn test_reentrant_normalization() { - static ERR: OnceLock = OnceLock::new(); + static ERR: GILOnceCell = GILOnceCell::new(); struct RecursiveArgs; impl PyErrArguments for RecursiveArgs { fn arguments(self, py: Python<'_>) -> PyObject { // .value(py) triggers normalization - ERR.get() + ERR.get(py) .expect("is set just below") .value(py) .clone() @@ -378,16 +379,15 @@ mod tests { } } - ERR.set(PyValueError::new_err(RecursiveArgs)).unwrap(); - Python::with_gil(|py| { - ERR.get().expect("is set just above").value(py); + ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap(); + ERR.get(py).expect("is set just above").value(py); }) } #[test] fn test_no_deadlock_thread_switch() { - static ERR: OnceLock = OnceLock::new(); + static ERR: GILOnceCell = GILOnceCell::new(); struct GILSwitchArgs; @@ -402,14 +402,14 @@ mod tests { } } - ERR.set(PyValueError::new_err(GILSwitchArgs)).unwrap(); + Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap()); // Let many threads attempt to read the normalized value at the same time let handles = (0..10) .map(|_| { std::thread::spawn(|| { Python::with_gil(|py| { - ERR.get().expect("is set just above").value(py); + ERR.get(py).expect("is set just above").value(py); }); }) }) @@ -423,7 +423,7 @@ mod tests { // this assertion Python::with_gil(|py| { assert!(ERR - .get() + .get(py) .expect("is set above") .is_instance_of::(py)) });