Skip to content

Commit

Permalink
feat: Periodically check python signals ('CTRL-C' handling) (#20826)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 21, 2025
1 parent 30a7e34 commit f826c32
Show file tree
Hide file tree
Showing 15 changed files with 57 additions and 4 deletions.
3 changes: 3 additions & 0 deletions crates/polars-core/src/frame/group_by/into_groups.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups};
use polars_error::check_signals;
use polars_utils::total_ord::{ToTotalOrd, TotalHash};

use super::*;
Expand Down Expand Up @@ -234,6 +235,7 @@ where
num_groups_proxy(ca, multithreaded, sorted)
},
};
check_signals()?;
Ok(out)
}
}
Expand Down Expand Up @@ -285,6 +287,7 @@ impl IntoGroupsType for BinaryChunked {
} else {
group_by(bh[0].iter(), sorted)
};
check_signals()?;
Ok(out)
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use std::fmt::{self, Display, Formatter, Write};
use std::ops::Deref;
use std::sync::{Arc, LazyLock};
use std::{env, io};
mod signals;

pub use signals::{check_signals, set_signals_function};
pub use warning::*;

enum ErrorStrategy {
Expand Down
23 changes: 23 additions & 0 deletions crates/polars-error/src/signals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use crate::PolarsResult;

type SignalsFunction = fn() -> PolarsResult<()>;
static mut SIGNALS_FUNCTION: Option<SignalsFunction> = None;

/// Set the function that will be called check_signals.
/// This can be set on startup to enable stopping a query when user input like `ctrl-c` is called.
///
/// # Safety
/// The caller must ensure there is no other thread accessing this function
/// or calling `check_signals`.
pub unsafe fn set_signals_function(function: SignalsFunction) {
SIGNALS_FUNCTION = Some(function)
}

fn default() -> PolarsResult<()> {
Ok(())
}

pub fn check_signals() -> PolarsResult<()> {
let f = unsafe { SIGNALS_FUNCTION.unwrap_or(default) };
f()
}
2 changes: 2 additions & 0 deletions crates/polars-expr/src/state/execution_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::sync::{Mutex, RwLock};
use bitflags::bitflags;
use once_cell::sync::OnceCell;
use polars_core::config::verbose;
use polars_core::error::check_signals;
use polars_core::prelude::*;
use polars_ops::prelude::ChunkJoinOptIds;

Expand Down Expand Up @@ -149,6 +150,7 @@ impl ExecutionState {

// This is wrong when the U64 overflows which will never happen.
pub fn should_stop(&self) -> PolarsResult<()> {
check_signals()?;
polars_ensure!(!self.stop.load(Ordering::Relaxed), ComputeError: "query interrupted");
Ok(())
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-ops/src/frame/join/asof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::cmp::Ordering;
use default::*;
pub use groups::AsofJoinBy;
use polars_core::prelude::*;
use polars_error::check_signals;
use polars_utils::pl_str::PlSmallStr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -312,6 +313,7 @@ pub trait AsofJoin: IntoDf {
join_asof_numeric(ca, &right_key, strategy, tolerance, allow_eq)
},
}?;
check_signals()?;

// Drop right join column.
let other = if coalesce && left_key.name() == right_key.name() {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/frame/join/cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ fn cross_join_dfs(
}
};
let (l_df, r_df) = if parallel {
check_signals()?;
POOL.install(|| rayon::join(create_left_df, create_right_df))
} else {
(create_left_df(), create_right_df())
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/frame/join/dispatch_left_right.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ pub fn materialize_left_join_from_series(
} else {
right.drop(s_right.name()).unwrap()
};
check_signals()?;

#[cfg(feature = "chunked_ids")]
match (left_idx, right_idx) {
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/frame/join/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ pub trait JoinDispatch: IntoDf {
let (mut join_idx_l, mut join_idx_r) =
s_left.hash_join_outer(s_right, args.validation, args.join_nulls)?;

check_signals()?;
if let Some((offset, len)) = args.slice {
let (offset, len) = slice_offsets(offset, len, join_idx_l.len());
join_idx_l.slice(offset, len);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ where
} else {
build_tables(build, join_nulls)
};
check_signals()?;

let n_tables = hash_tbls.len();
let offsets = probe_to_offsets(&probe);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ where
} else {
build_tables(build, join_nulls)
};
check_signals()?;
let n_tables = hash_tbls.len();

// we determine the offset so that we later know which index to store in the join tuples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ where
let (probe_hashes, _) = create_hash_and_keys_threaded_vectorized(probe, Some(random_state));

let n_tables = hash_tbls.len();
check_signals()?;

// probe the hash table.
// Note: indexes from b that are not matched will be None, Some(idx_b)
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-ops/src/frame/join/iejoin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::{_set_partition_size, split};
use polars_core::{with_match_physical_numeric_polars_type, POOL};
use polars_error::{polars_err, PolarsResult};
use polars_error::{check_signals, polars_err, PolarsResult};
use polars_utils::binary_search::ExponentialSearch;
use polars_utils::itertools::Itertools;
use polars_utils::total_ord::{TotalEq, TotalOrd};
Expand Down Expand Up @@ -362,6 +362,7 @@ unsafe fn materialize_join(
right_row_idx: &IdxCa,
suffix: Option<PlSmallStr>,
) -> PolarsResult<DataFrame> {
check_signals()?;
let (join_left, join_right) = {
POOL.join(
|| left.take_unchecked(left_row_idx),
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-ops/src/frame/join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ use polars_core::utils::slice_offsets;
#[allow(unused_imports)]
use polars_core::utils::slice_slice;
use polars_core::POOL;
use polars_error::check_signals;
use polars_utils::hashing::BytesHash;
use rayon::prelude::*;

Expand Down Expand Up @@ -562,6 +563,7 @@ trait DataFrameJoinOpsPrivate: IntoDf {
args.maintain_order,
MaintainOrderJoin::Left | MaintainOrderJoin::LeftRight
);
check_signals()?;
let (df_left, df_right) =
if args.maintain_order != MaintainOrderJoin::None && !already_left_sorted {
let mut df =
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-python/src/functions/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ pub fn register_plugin_function(
#[pyfunction]
pub fn __register_startup_deps() {
#[cfg(feature = "object")]
crate::on_startup::register_startup_deps()
crate::on_startup::register_startup_deps(true)
}
16 changes: 14 additions & 2 deletions crates/polars-python/src/on_startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use polars_core::chunked_array::object::builder::ObjectChunkedBuilder;
use polars_core::chunked_array::object::registry::AnonymousObjectBuilder;
use polars_core::chunked_array::object::{registry, set_polars_allow_extension};
use polars_core::error::PolarsError::ComputeError;
use polars_error::PolarsWarning;
use polars_error::{set_signals_function, PolarsWarning};
use pyo3::prelude::*;
use pyo3::{intern, IntoPyObjectExt};

Expand Down Expand Up @@ -68,7 +68,7 @@ fn warning_function(msg: &str, warning: PolarsWarning) {
});
}

pub fn register_startup_deps() {
pub fn register_startup_deps(check_python_signals: bool) {
set_polars_allow_extension(true);
if !registry::is_object_builder_registered() {
// Stack frames can get really large in debug mode.
Expand Down Expand Up @@ -100,6 +100,18 @@ pub fn register_startup_deps() {
unsafe { python_udf::CALL_DF_UDF_PYTHON = Some(python_function_caller_df) }
// register warning function for `polars_warn!`
unsafe { polars_error::set_warning_function(warning_function) };

if check_python_signals {
fn signals_function() -> PolarsResult<()> {
Python::with_gil(|py| {
py.check_signals()
.map_err(|err| polars_err!(ComputeError: "{err}"))
})
}

unsafe { set_signals_function(signals_function) };
}

Python::with_gil(|py| {
// init AnyValue LUT
crate::conversion::any_value::LUT
Expand Down

0 comments on commit f826c32

Please sign in to comment.