diff --git a/newsfragments/4671.fixed.md b/newsfragments/4671.fixed.md
new file mode 100644
index 00000000000..9b0cd9d8f0c
--- /dev/null
+++ b/newsfragments/4671.fixed.md
@@ -0,0 +1 @@
+Make `PyErr` internals thread-safe.
diff --git a/src/err/err_state.rs b/src/err/err_state.rs
index 2ba153b6ef8..70d449c8b52 100644
--- a/src/err/err_state.rs
+++ b/src/err/err_state.rs
@@ -1,4 +1,8 @@
-use std::cell::UnsafeCell;
+use std::{
+ cell::UnsafeCell,
+ sync::{Mutex, Once},
+ thread::ThreadId,
+};
use crate::{
exceptions::{PyBaseException, PyTypeError},
@@ -11,13 +15,14 @@ use crate::{
pub(crate) struct PyErrState {
// Safety: can only hand out references when in the "normalized" state. Will never change
// after normalization.
- //
- // The state is temporarily removed from the PyErr during normalization, to avoid
- // concurrent modifications.
+ normalized: Once,
+ // Guard against re-entrancy when normalizing the exception state.
+ normalizing_thread: Mutex>,
inner: UnsafeCell >,
}
-// The inner value is only accessed through ways that require the gil is held.
+// Safety: The inner value is protected by locking to ensure that only the normalized state is
+// handed out as a reference.
unsafe impl Send for PyErrState {}
unsafe impl Sync for PyErrState {}
@@ -48,17 +53,22 @@ impl PyErrState {
fn from_inner(inner: PyErrStateInner) -> Self {
Self {
+ normalized: Once::new(),
+ normalizing_thread: Mutex::new(None),
inner: UnsafeCell::new(Some(inner)),
}
}
#[inline]
pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
- if let Some(PyErrStateInner::Normalized(n)) = unsafe {
- // Safety: self.inner will never be written again once normalized.
- &*self.inner.get()
- } {
- return n;
+ if self.normalized.is_completed() {
+ match unsafe {
+ // Safety: self.inner will never be written again once normalized.
+ &*self.inner.get()
+ } {
+ Some(PyErrStateInner::Normalized(n)) => return n,
+ _ => unreachable!(),
+ }
}
self.make_normalized(py)
@@ -69,25 +79,47 @@ impl PyErrState {
// This process is safe because:
// - Access is guaranteed not to be concurrent thanks to `Python` GIL token
// - Write happens only once, and then never will change again.
- // - State is set to None during the normalization process, so that a second
- // concurrent normalization attempt will panic before changing anything.
- // FIXME: this needs to be rewritten to deal with free-threaded Python
- // see https://github.com/PyO3/pyo3/issues/4584
+ // Guard against re-entrant normalization, because `Once` does not provide
+ // re-entrancy guarantees.
+ if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
+ assert!(
+ !(*thread == std::thread::current().id()),
+ "Re-entrant normalization of PyErrState detected"
+ );
+ }
- let state = unsafe {
- (*self.inner.get())
- .take()
- .expect("Cannot normalize a PyErr while already normalizing it.")
- };
+ // avoid deadlock of `.call_once` with the GIL
+ py.allow_threads(|| {
+ self.normalized.call_once(|| {
+ self.normalizing_thread
+ .lock()
+ .unwrap()
+ .replace(std::thread::current().id());
+
+ // Safety: no other thread can access the inner value while we are normalizing it.
+ let state = unsafe {
+ (*self.inner.get())
+ .take()
+ .expect("Cannot normalize a PyErr while already normalizing it.")
+ };
+
+ let normalized_state =
+ Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py)));
+
+ // Safety: no other thread can access the inner value while we are normalizing it.
+ unsafe {
+ *self.inner.get() = Some(normalized_state);
+ }
+ })
+ });
- unsafe {
- let self_state = &mut *self.inner.get();
- *self_state = Some(PyErrStateInner::Normalized(state.normalize(py)));
- match self_state {
- Some(PyErrStateInner::Normalized(n)) => n,
- _ => unreachable!(),
- }
+ match unsafe {
+ // Safety: self.inner will never be written again once normalized.
+ &*self.inner.get()
+ } {
+ Some(PyErrStateInner::Normalized(n)) => n,
+ _ => unreachable!(),
}
}
}
@@ -321,3 +353,79 @@ fn raise_lazy(py: Python<'_>, lazy: Box) {
}
}
}
+
+#[cfg(test)]
+mod tests {
+
+ use crate::{
+ exceptions::PyValueError, sync::GILOnceCell, PyErr, PyErrArguments, PyObject, Python,
+ };
+
+ #[test]
+ #[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
+ fn test_reentrant_normalization() {
+ static ERR: GILOnceCell = GILOnceCell::new();
+
+ struct RecursiveArgs;
+
+ impl PyErrArguments for RecursiveArgs {
+ fn arguments(self, py: Python<'_>) -> PyObject {
+ // .value(py) triggers normalization
+ ERR.get(py)
+ .expect("is set just below")
+ .value(py)
+ .clone()
+ .into()
+ }
+ }
+
+ Python::with_gil(|py| {
+ ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
+ ERR.get(py).expect("is set just above").value(py);
+ })
+ }
+
+ #[test]
+ fn test_no_deadlock_thread_switch() {
+ static ERR: GILOnceCell = GILOnceCell::new();
+
+ struct GILSwitchArgs;
+
+ impl PyErrArguments for GILSwitchArgs {
+ fn arguments(self, py: Python<'_>) -> PyObject {
+ // releasing the GIL potentially allows for other threads to deadlock
+ // with the normalization going on here
+ py.allow_threads(|| {
+ std::thread::sleep(std::time::Duration::from_millis(10));
+ });
+ py.None()
+ }
+ }
+
+ Python::with_gil(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
+
+ // Let many threads attempt to read the normalized value at the same time
+ let handles = (0..10)
+ .map(|_| {
+ std::thread::spawn(|| {
+ Python::with_gil(|py| {
+ ERR.get(py).expect("is set just above").value(py);
+ });
+ })
+ })
+ .collect::>();
+
+ for handle in handles {
+ handle.join().unwrap();
+ }
+
+ // We should never have deadlocked, and should be able to run
+ // this assertion
+ Python::with_gil(|py| {
+ assert!(ERR
+ .get(py)
+ .expect("is set above")
+ .is_instance_of::(py))
+ });
+ }
+}