From b4d3a94f148754f1ec215ce977779a9d1b86e872 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Thu, 31 Oct 2024 20:52:52 +0000 Subject: [PATCH] add test of reentrancy, fix deadlock --- src/err/err_state.rs | 118 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 19 deletions(-) diff --git a/src/err/err_state.rs b/src/err/err_state.rs index 5d930794014..be5918389c6 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -89,25 +89,29 @@ impl PyErrState { ); } - self.normalized.call_once(|| { - self.normalizing_thread - .lock() - .unwrap() - .replace(std::thread::current().id()); - - // Safety: no other thread can access the inner value while we are normalizing it. - let state = unsafe { - (*self.inner.get()) - .take() - .expect("Cannot normalize a PyErr while already normalizing it.") - }; - - let normalized_state = PyErrStateInner::Normalized(state.normalize(py)); - - // Safety: no other thread can access the inner value while we are normalizing it. - unsafe { - *self.inner.get() = Some(normalized_state); - } + // avoid deadlock of `.call_once` with the GIL + py.allow_threads(|| { + self.normalized.call_once(|| { + self.normalizing_thread + .lock() + .unwrap() + .replace(std::thread::current().id()); + + // Safety: no other thread can access the inner value while we are normalizing it. + let state = unsafe { + (*self.inner.get()) + .take() + .expect("Cannot normalize a PyErr while already normalizing it.") + }; + + let normalized_state = + Python::with_gil(|py| PyErrStateInner::Normalized(state.normalize(py))); + + // Safety: no other thread can access the inner value while we are normalizing it. + unsafe { + *self.inner.get() = Some(normalized_state); + } + }) }); match unsafe { @@ -349,3 +353,79 @@ fn raise_lazy(py: Python<'_>, lazy: Box) { } } } + +#[cfg(test)] +mod tests { + use std::sync::OnceLock; + + use crate::{exceptions::PyValueError, PyErr, PyErrArguments, PyObject, Python}; + + #[test] + #[should_panic(expected = "Re-entrant normalization of PyErrState detected")] + fn test_reentrant_normalization() { + static ERR: OnceLock = OnceLock::new(); + + struct RecursiveArgs; + + impl PyErrArguments for RecursiveArgs { + fn arguments(self, py: Python<'_>) -> PyObject { + // .value(py) triggers normalization + ERR.get() + .expect("is set just below") + .value(py) + .clone() + .into() + } + } + + ERR.set(PyValueError::new_err(RecursiveArgs)).unwrap(); + + Python::with_gil(|py| { + ERR.get().expect("is set just above").value(py); + }) + } + + #[test] + fn test_no_deadlock_thread_switch() { + static ERR: OnceLock = OnceLock::new(); + + struct GILSwitchArgs; + + impl PyErrArguments for GILSwitchArgs { + fn arguments(self, py: Python<'_>) -> PyObject { + // releasing the GIL potentially allows for other threads to deadlock + // with the normalization going on here + py.allow_threads(|| { + std::thread::sleep(std::time::Duration::from_millis(10)); + }); + py.None() + } + } + + ERR.set(PyValueError::new_err(GILSwitchArgs)).unwrap(); + + // Let many threads attempt to read the normalized value at the same time + let handles = (0..10) + .map(|_| { + std::thread::spawn(|| { + Python::with_gil(|py| { + ERR.get().expect("is set just above").value(py); + }); + }) + }) + .collect::>(); + + for handle in handles { + handle.join().unwrap(); + } + + // We should never have deadlocked, and should be able to run + // this assertion + Python::with_gil(|py| { + assert!(ERR + .get() + .expect("is set above") + .is_instance_of::(py)) + }); + } +}