Skip to content

Commit

Permalink
Add shutdown notifier to task scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Jan 30, 2025
1 parent 48d011b commit 9c62110
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
2 changes: 1 addition & 1 deletion casper-server/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl AppContextInner {

// Start task scheduler
let max_background_tasks = self.config.main.max_background_tasks;
lua::tasks::start_task_scheduler(&lua, max_background_tasks);
lua::tasks::start_task_scheduler(lua, max_background_tasks);

// Enable sandboxing before loading user code
lua.sandbox(true)?;
Expand Down
37 changes: 29 additions & 8 deletions casper-server/src/lua/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::rc::Rc;
use std::result::Result as StdResult;
use std::sync::atomic::{AtomicU64, Ordering};

use mlua::{
AnyUserData, ExternalError, ExternalResult, Function, Lua, Result, Table, UserData, Value,
};
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::sync::{oneshot, watch};
use tokio::task::JoinHandle;
use tokio::time::{Duration, Instant};
use tracing::warn;
Expand All @@ -30,8 +31,12 @@ struct TaskHandle {
join_handle_rx: Option<oneshot::Receiver<TaskJoinHandle>>,
}

#[derive(Clone, Copy)]
struct MaxBackgroundTasks(Option<u64>);

#[derive(Clone)]
struct ShutdownNotifier(watch::Sender<bool>);

// Global task identifier
static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);

Expand Down Expand Up @@ -71,7 +76,7 @@ impl UserData for TaskHandle {
}

fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
let max_background_tasks = lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
let max_background_tasks = *lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
let current_tasks = tasks_counter_get!();

if let Some(max_tasks) = max_background_tasks.0 {
Expand Down Expand Up @@ -128,23 +133,36 @@ fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
}

pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {
let lua = lua.clone();
let lua = Rc::new(lua.clone());
let mut task_rx = lua
.remove_app_data::<UnboundedReceiver<Task>>()
.expect("Failed to get task receiver");

lua.set_app_data(MaxBackgroundTasks(max_background_tasks));

let (shutdown_tx, shutdown_rx) = watch::channel(false);
lua.set_app_data(ShutdownNotifier(shutdown_tx));

tokio::task::spawn_local(async move {
while let Some(task) = task_rx.recv().await {
let lua = lua.clone();
let mut shutdown = shutdown_rx.clone();
let join_handle = tokio::task::spawn_local(async move {
let start = Instant::now();
let _task_count_guard = tasks_counter_inc!();
// Keep Lua instance alive while task is running
let _lua_guard = lua;
let task_future = task.handler.call_async::<Value>(());

let result = match task.timeout {
Some(timeout) => ntex::time::timeout(timeout, task_future).await,
None => Ok(task_future.await),
Some(timeout) => tokio::select! {
_ = shutdown.wait_for(|&x| x) => Ok(Err("task scheduler shutdown".to_string().into_lua_err())),
result = ntex::time::timeout(timeout, task_future) => result,
},
None => tokio::select! {
_ = shutdown.wait_for(|&x| x) => Ok(Err("task scheduler shutdown".to_string().into_lua_err())),
result = task_future => Ok(result),
},
};
// Outer Result errors will always be timeouts
let result = result
Expand Down Expand Up @@ -178,7 +196,9 @@ pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {

pub fn stop_task_scheduler(lua: &Lua) {
lua.remove_app_data::<UnboundedSender<Task>>();
lua.remove_app_data::<UnboundedReceiver<Task>>();

// Notify all tasks to stop
_ = lua.app_data_ref::<ShutdownNotifier>().unwrap().0.send(true);
}

pub fn create_module(lua: &Lua) -> Result<Table> {
Expand All @@ -192,14 +212,13 @@ pub fn create_module(lua: &Lua) -> Result<Table> {

#[cfg(test)]
mod tests {
use std::rc::Rc;
use std::time::Duration;

use mlua::{chunk, Lua, Result};

#[ntex::test]
async fn test_tasks() -> Result<()> {
let lua = Rc::new(Lua::new());
let lua = Lua::new();

lua.globals().set("tasks", super::create_module(&lua)?)?;
lua.globals().set(
Expand Down Expand Up @@ -331,6 +350,8 @@ mod tests {
.await
.unwrap();

super::stop_task_scheduler(&lua);

Ok(())
}
}

0 comments on commit 9c62110

Please sign in to comment.