diff --git a/casper-server/src/context.rs b/casper-server/src/context.rs index 1c40fb6..3905825 100644 --- a/casper-server/src/context.rs +++ b/casper-server/src/context.rs @@ -141,7 +141,7 @@ impl AppContextInner { // Start task scheduler let max_background_tasks = self.config.main.max_background_tasks; - lua::tasks::start_task_scheduler(&lua, max_background_tasks); + lua::tasks::start_task_scheduler(lua, max_background_tasks); // Enable sandboxing before loading user code lua.sandbox(true)?; diff --git a/casper-server/src/lua/tasks.rs b/casper-server/src/lua/tasks.rs index 13ed111..af62d4c 100644 --- a/casper-server/src/lua/tasks.rs +++ b/casper-server/src/lua/tasks.rs @@ -1,3 +1,4 @@ +use std::rc::Rc; use std::result::Result as StdResult; use std::sync::atomic::{AtomicU64, Ordering}; @@ -5,7 +6,7 @@ use mlua::{ AnyUserData, ExternalError, ExternalResult, Function, Lua, Result, Table, UserData, Value, }; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, watch}; use tokio::task::JoinHandle; use tokio::time::{Duration, Instant}; use tracing::warn; @@ -30,8 +31,12 @@ struct TaskHandle { join_handle_rx: Option>, } +#[derive(Clone, Copy)] struct MaxBackgroundTasks(Option); +#[derive(Clone)] +struct ShutdownNotifier(watch::Sender); + // Global task identifier static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1); @@ -71,7 +76,7 @@ impl UserData for TaskHandle { } fn spawn_task(lua: &Lua, arg: Value) -> Result> { - let max_background_tasks = lua.app_data_ref::().unwrap(); + let max_background_tasks = *lua.app_data_ref::().unwrap(); let current_tasks = tasks_counter_get!(); if let Some(max_tasks) = max_background_tasks.0 { @@ -128,23 +133,36 @@ fn spawn_task(lua: &Lua, arg: Value) -> Result> { } pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option) { - let lua = lua.clone(); + let lua = Rc::new(lua.clone()); let mut task_rx = lua .remove_app_data::>() .expect("Failed to get task receiver"); lua.set_app_data(MaxBackgroundTasks(max_background_tasks)); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + lua.set_app_data(ShutdownNotifier(shutdown_tx)); + tokio::task::spawn_local(async move { while let Some(task) = task_rx.recv().await { + let lua = lua.clone(); + let mut shutdown = shutdown_rx.clone(); let join_handle = tokio::task::spawn_local(async move { let start = Instant::now(); let _task_count_guard = tasks_counter_inc!(); + // Keep Lua instance alive while task is running + let _lua_guard = lua; let task_future = task.handler.call_async::(()); let result = match task.timeout { - Some(timeout) => ntex::time::timeout(timeout, task_future).await, - None => Ok(task_future.await), + Some(timeout) => tokio::select! { + _ = shutdown.wait_for(|&x| x) => Ok(Err("task scheduler shutdown".to_string().into_lua_err())), + result = ntex::time::timeout(timeout, task_future) => result, + }, + None => tokio::select! { + _ = shutdown.wait_for(|&x| x) => Ok(Err("task scheduler shutdown".to_string().into_lua_err())), + result = task_future => Ok(result), + }, }; // Outer Result errors will always be timeouts let result = result @@ -178,7 +196,9 @@ pub fn start_task_scheduler(lua: &Lua, max_background_tasks: Option) { pub fn stop_task_scheduler(lua: &Lua) { lua.remove_app_data::>(); - lua.remove_app_data::>(); + + // Notify all tasks to stop + _ = lua.app_data_ref::().unwrap().0.send(true); } pub fn create_module(lua: &Lua) -> Result { @@ -192,14 +212,13 @@ pub fn create_module(lua: &Lua) -> Result
{ #[cfg(test)] mod tests { - use std::rc::Rc; use std::time::Duration; use mlua::{chunk, Lua, Result}; #[ntex::test] async fn test_tasks() -> Result<()> { - let lua = Rc::new(Lua::new()); + let lua = Lua::new(); lua.globals().set("tasks", super::create_module(&lua)?)?; lua.globals().set( @@ -331,6 +350,8 @@ mod tests { .await .unwrap(); + super::stop_task_scheduler(&lua); + Ok(()) } }