Skip to content

Commit

Permalink
Add shutdown notifier to task scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Jan 29, 2025
1 parent 48d011b commit 22dd3a7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion casper-server/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl AppContextInner {

// Start task scheduler
let max_background_tasks = self.config.main.max_background_tasks;
lua::tasks::start_task_scheduler(&lua, max_background_tasks);
lua::tasks::start_task_scheduler(lua, max_background_tasks);

// Enable sandboxing before loading user code
lua.sandbox(true)?;
Expand Down
33 changes: 26 additions & 7 deletions casper-server/src/lua/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::rc::Rc;
use std::result::Result as StdResult;
use std::sync::atomic::{AtomicU64, Ordering};

use mlua::{
AnyUserData, ExternalError, ExternalResult, Function, Lua, Result, Table, UserData, Value,
};
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::sync::{oneshot, Notify};
use tokio::task::JoinHandle;
use tokio::time::{Duration, Instant};
use tracing::warn;
Expand All @@ -30,8 +31,12 @@ struct TaskHandle {
join_handle_rx: Option<oneshot::Receiver<TaskJoinHandle>>,
}

#[derive(Clone, Copy)]
struct MaxBackgroundTasks(Option<u64>);

#[derive(Clone)]
struct ShutdownNotifier(Rc<Notify>);

// Global task identifier
static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);

Expand Down Expand Up @@ -71,7 +76,7 @@ impl UserData for TaskHandle {
}

fn spawn_task(lua: &Lua, arg: Value) -> Result<StdResult<TaskHandle, String>> {
let max_background_tasks = lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
let max_background_tasks = *lua.app_data_ref::<MaxBackgroundTasks>().unwrap();
let current_tasks = tasks_counter_get!();

if let Some(max_tasks) = max_background_tasks.0 {
Expand Down Expand Up @@ -135,16 +140,26 @@ pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {

lua.set_app_data(MaxBackgroundTasks(max_background_tasks));

let shutdown_notifier = Rc::new(Notify::new());
lua.set_app_data(ShutdownNotifier(shutdown_notifier.clone()));

tokio::task::spawn_local(async move {
while let Some(task) = task_rx.recv().await {
let shutdown = shutdown_notifier.clone();
let join_handle = tokio::task::spawn_local(async move {
let start = Instant::now();
let _task_count_guard = tasks_counter_inc!();
let task_future = task.handler.call_async::<Value>(());

let result = match task.timeout {
Some(timeout) => ntex::time::timeout(timeout, task_future).await,
None => Ok(task_future.await),
Some(timeout) => tokio::select! {
_ = shutdown.notified() => Ok(Err("task scheduler shutdown".to_string().into_lua_err())),
result = ntex::time::timeout(timeout, task_future) => result,
},
None => tokio::select! {
_ = shutdown.notified() => Ok(Err("task scheduler shutdown".to_string().into_lua_err())),
result = task_future => Ok(result),
},
};
// Outer Result errors will always be timeouts
let result = result
Expand Down Expand Up @@ -178,7 +193,12 @@ pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option<u64>) {

pub fn stop_task_scheduler(lua: &Lua) {
lua.remove_app_data::<UnboundedSender<Task>>();
lua.remove_app_data::<UnboundedReceiver<Task>>();

// Notify all tasks to stop
lua.app_data_ref::<ShutdownNotifier>()
.unwrap()
.0
.notify_waiters();
}

pub fn create_module(lua: &Lua) -> Result<Table> {
Expand All @@ -192,14 +212,13 @@ pub fn create_module(lua: &Lua) -> Result<Table> {

#[cfg(test)]
mod tests {
use std::rc::Rc;
use std::time::Duration;

use mlua::{chunk, Lua, Result};

#[ntex::test]
async fn test_tasks() -> Result<()> {
let lua = Rc::new(Lua::new());
let lua = Lua::new();

lua.globals().set("tasks", super::create_module(&lua)?)?;
lua.globals().set(
Expand Down

0 comments on commit 22dd3a7

Please sign in to comment.