Skip to content

Commit

Permalink
fix: abort signal
Browse files Browse the repository at this point in the history
  • Loading branch information
atanmarko committed Nov 1, 2024
1 parent 6740678 commit 545620e
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions paladin-core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use dashmap::{mapref::entry::Entry, DashMap};
use futures::{stream::BoxStream, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::{select, task::JoinHandle, try_join};
use tracing::{debug_span, error, instrument, trace, warn, Instrument};
use tracing::{debug, debug_span, error, instrument, trace, warn, Instrument};

use self::dynamic_channel::{DynamicChannel, DynamicChannelFactory};
use crate::{
Expand Down Expand Up @@ -420,7 +420,6 @@ pub struct ExecutionErr<E> {
/// Command and error inter-process messages between leader and workers.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CommandIpc {
ExecutionError { routing_key: String },
Abort { routing_key: String },
}

Expand Down Expand Up @@ -568,16 +567,18 @@ impl WorkerRuntime {
match strategy {
FatalStrategy::Ignore => Ok(()),
FatalStrategy::Terminate => {
// Notify other workers of the error.
// Notify leader of the error. Send abort command to all workers.
let (ipc, sender) = try_join!(
self.get_command_ipc_sender(),
self.get_result_sender(routing_key.clone())
)?;

let ipc_msg = CommandIpc::ExecutionError { routing_key };
let abort_ipc_msg = CommandIpc::Abort {
routing_key: COMMAND_IPC_ABORT_ALL_KEY.to_string(),
};
let sender_msg = AnyTaskResult::Err(err.to_string());

try_join!(ipc.publish(&ipc_msg), sender.publish(&sender_msg))?;
try_join!(ipc.publish(&abort_ipc_msg), sender.publish(&sender_msg))?;
try_join!(ipc.close(), sender.close())?;

Ok(())
Expand Down Expand Up @@ -682,6 +683,7 @@ impl WorkerRuntime {

// Create a watch channel for signaling IPC changes while processing a task.
let (ipc_sig_term_tx, ipc_sig_term_rx) = tokio::sync::watch::channel::<String>(identifier);
let abort_worker_execution = Arc::new(std::sync::atomic::AtomicBool::new(false));

// Spawn a task that will listen for IPC termination signals and mark jobs as
// terminated.
Expand All @@ -692,12 +694,10 @@ impl WorkerRuntime {
async move {
while let Some(ipc) = command_receiver.next().await {
match ipc {
CommandIpc::ExecutionError { routing_key }
| CommandIpc::Abort { routing_key } => {
CommandIpc::Abort { routing_key } => {
// Mark the job as terminated if it hasn't been already.
if mark_terminated(&terminated_jobs, routing_key.clone()) {
warn!(routing_key = %routing_key, "received IPC termination signal");
// Notify any currently executing tasks of the error.
// Notify any currently executing task about the termination.
ipc_sig_term_tx.send_replace(routing_key.clone());
}
}
Expand All @@ -723,7 +723,12 @@ impl WorkerRuntime {
}

while let Some((payload, acker)) = task_stream.next().await {
let abort = Arc::new(std::sync::atomic::AtomicBool::new(false));
// If abort condition was previously set for some reason, stop processing tasks.
if abort_worker_execution.load(Ordering::SeqCst) {

Check failure on line 727 in paladin-core/src/runtime/mod.rs

View workflow job for this annotation

GitHub Actions / clippy

borrow of moved value: `abort_worker_execution`
warn!("stopping worker execution due to abort flag");
break;
}

// Skip tasks associated with terminated jobs.
if terminated_jobs.contains_key(&payload.clone().routing_key) {
trace!(routing_key = %payload.clone().routing_key, "skipping terminated job");
Expand All @@ -736,7 +741,9 @@ impl WorkerRuntime {
let routing_key_clone = routing_key.clone();

let span = debug_span!("remote_execute", routing_key = %routing_key_clone);
let execution_task = payload.remote_execute(Some(abort.clone())).instrument(span);
let execution_task = payload
.remote_execute(Some(abort_worker_execution.clone()))
.instrument(span);

// Create a future that will wait for an IPC termination signal.
let ipc_sig_term = {
Expand All @@ -745,11 +752,11 @@ impl WorkerRuntime {
loop {
ipc_sig_term_rx.changed().await.expect("IPC channel closed");
let received_key = ipc_sig_term_rx.borrow().clone();
if received_key == routing_key_clone
|| received_key == COMMAND_IPC_ABORT_ALL_KEY
{
abort.store(true, Ordering::SeqCst);
tokio::time::sleep(ABORT_SIGNAL_SHUTDOWN_INTERVAL).await;
debug!(routing_key = %routing_key, "received IPC termination signal with key: {received_key}");

if received_key == COMMAND_IPC_ABORT_ALL_KEY {
warn!(routing_key = %routing_key, "worker abort signal received");
abort_worker_execution.store(true, Ordering::SeqCst);
return true;
}
}
Expand Down Expand Up @@ -785,8 +792,11 @@ impl WorkerRuntime {
}
}
_ = ipc_sig_term => {
warn!(routing_key = %routing_key, "task cancelled via IPC sigterm");
_ = acker.nack().await;
// Give time to the execution_task to finish gracefully on abort signal before shutting down.
tokio::time::sleep(ABORT_SIGNAL_SHUTDOWN_INTERVAL).await;
warn!(routing_key = %routing_key, "worker execution stopped via IPC sigterm");
break;
}
}
}
Expand Down

0 comments on commit 545620e

Please sign in to comment.