From 9510ed21299fc68bb405a9e6b04b09497a78f617 Mon Sep 17 00:00:00 2001 From: junderw Date: Sun, 24 Sep 2023 11:00:22 -0700 Subject: [PATCH] Fix: electrum server graceful shutdown doesn't work --- src/bin/electrs.rs | 17 +++++ src/electrum/server.rs | 150 ++++++++++++++++++++++++++++----------- src/elements/registry.rs | 2 +- src/new_index/fetch.rs | 5 +- src/rest.rs | 2 +- src/signal.rs | 3 +- src/util/mod.rs | 67 ++++++++++++++--- 7 files changed, 187 insertions(+), 59 deletions(-) diff --git a/src/bin/electrs.rs b/src/bin/electrs.rs index 5511593c8..b7d5f96d3 100644 --- a/src/bin/electrs.rs +++ b/src/bin/electrs.rs @@ -105,6 +105,20 @@ fn run_server(config: Arc) -> Result<()> { loop { if let Err(err) = signal.wait(Duration::from_secs(5), true) { info!("stopping server: {}", err); + + electrs::util::spawn_thread("shutdown-thread-checker", || { + let mut counter = 40; + let interval_ms = 500; + + while counter > 0 { + electrs::util::with_spawned_threads(|threads| { + debug!("Threads during shutdown: {:?}", threads); + }); + std::thread::sleep(std::time::Duration::from_millis(interval_ms)); + counter -= 1; + } + }); + rest_server.stop(); // the electrum server is stopped when dropped break; @@ -133,4 +147,7 @@ fn main() { error!("server failed: {}", e.display_chain()); process::exit(1); } + electrs::util::with_spawned_threads(|threads| { + debug!("Threads before closing: {:?}", threads); + }); } diff --git a/src/electrum/server.rs b/src/electrum/server.rs index 30092bd28..4b5418f50 100644 --- a/src/electrum/server.rs +++ b/src/electrum/server.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; use std::io::{BufRead, BufReader, Write}; use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream}; -use std::sync::mpsc::{Sender, SyncSender, TrySendError}; +use std::sync::atomic::AtomicBool; +use std::sync::mpsc::{Receiver, Sender}; use std::sync::{Arc, Mutex}; use std::thread; @@ -100,6 +101,7 @@ struct Connection { chan: SyncChannel, stats: Arc, txs_limit: usize, + die_please: Option>, #[cfg(feature = "electrum-discovery")] discovery: Option>, } @@ -111,6 +113,7 @@ impl Connection { addr: SocketAddr, stats: Arc, txs_limit: usize, + die_please: Receiver<()>, #[cfg(feature = "electrum-discovery")] discovery: Option>, ) -> Connection { Connection { @@ -122,6 +125,7 @@ impl Connection { chan: SyncChannel::new(10), stats, txs_limit, + die_please: Some(die_please), #[cfg(feature = "electrum-discovery")] discovery, } @@ -501,40 +505,46 @@ impl Connection { Ok(()) } - fn handle_replies(&mut self) -> Result<()> { + fn handle_replies(&mut self, shutdown: crossbeam_channel::Receiver<()>) -> Result<()> { let empty_params = json!([]); loop { - let msg = self.chan.receiver().recv().chain_err(|| "channel closed")?; - trace!("RPC {:?}", msg); - match msg { - Message::Request(line) => { - let cmd: Value = from_str(&line).chain_err(|| "invalid JSON format")?; - let reply = match ( - cmd.get("method"), - cmd.get("params").unwrap_or_else(|| &empty_params), - cmd.get("id"), - ) { - ( - Some(&Value::String(ref method)), - &Value::Array(ref params), - Some(ref id), - ) => self.handle_command(method, params, id)?, - _ => bail!("invalid command: {}", cmd), - }; - self.send_values(&[reply])? - } - Message::PeriodicUpdate => { - let values = self - .update_subscriptions() - .chain_err(|| "failed to update subscriptions")?; - self.send_values(&values)? + crossbeam_channel::select! { + recv(self.chan.receiver()) -> msg => { + let msg = msg.chain_err(|| "channel closed")?; + trace!("RPC {:?}", msg); + match msg { + Message::Request(line) => { + let cmd: Value = from_str(&line).chain_err(|| "invalid JSON format")?; + let reply = match ( + cmd.get("method"), + cmd.get("params").unwrap_or(&empty_params), + cmd.get("id"), + ) { + (Some(Value::String(method)), Value::Array(params), Some(id)) => { + self.handle_command(method, params, id)? + } + _ => bail!("invalid command: {}", cmd), + }; + self.send_values(&[reply])? + } + Message::PeriodicUpdate => { + let values = self + .update_subscriptions() + .chain_err(|| "failed to update subscriptions")?; + self.send_values(&values)? + } + Message::Done => return Ok(()), + } } - Message::Done => return Ok(()), + recv(shutdown) -> _ => return Ok(()), } } } - fn handle_requests(mut reader: BufReader, tx: SyncSender) -> Result<()> { + fn handle_requests( + mut reader: BufReader, + tx: crossbeam_channel::Sender, + ) -> Result<()> { loop { let mut line = Vec::::new(); reader @@ -566,8 +576,18 @@ impl Connection { self.stats.clients.inc(); let reader = BufReader::new(self.stream.try_clone().expect("failed to clone TcpStream")); let tx = self.chan.sender(); + + let stream = self.stream.try_clone().expect("failed to clone TcpStream"); + let die_please = self.die_please.take().unwrap(); + let (reply_killer, reply_receiver) = crossbeam_channel::unbounded(); + spawn_thread("properly-die", move || { + let _ = die_please.recv(); + let _ = stream.shutdown(Shutdown::Both); + let _ = reply_killer.send(()); + }); + let child = spawn_thread("reader", || Connection::handle_requests(reader, tx)); - if let Err(e) = self.handle_replies() { + if let Err(e) = self.handle_replies(reply_receiver) { error!( "[{}] connection handling failed: {}", self.addr, @@ -633,8 +653,9 @@ struct Stats { impl RPC { fn start_notifier( notification: Channel, - senders: Arc>>>, + senders: Arc>>>, acceptor: Sender>, + acceptor_shutdown: Sender<()>, ) { spawn_thread("notification", move || { for msg in notification.receiver().iter() { @@ -642,7 +663,7 @@ impl RPC { match msg { Notification::Periodic => { for sender in senders.split_off(0) { - if let Err(TrySendError::Disconnected(_)) = + if let Err(crossbeam_channel::TrySendError::Disconnected(_)) = sender.try_send(Message::PeriodicUpdate) { continue; @@ -650,13 +671,20 @@ impl RPC { senders.push(sender); } } - Notification::Exit => acceptor.send(None).unwrap(), // mark acceptor as done + Notification::Exit => { + acceptor_shutdown.send(()).unwrap(); // Stop the acceptor itself + acceptor.send(None).unwrap(); // mark acceptor as done + break; + } } } }); } - fn start_acceptor(addr: SocketAddr) -> Channel> { + fn start_acceptor( + addr: SocketAddr, + shutdown_channel: Channel<()>, + ) -> Channel> { let chan = Channel::unbounded(); let acceptor = chan.sender(); spawn_thread("acceptor", move || { @@ -666,10 +694,29 @@ impl RPC { .set_nonblocking(false) .expect("cannot set nonblocking to false"); let listener = TcpListener::from(socket); + let local_addr = listener.local_addr().unwrap(); + let shutdown_bool = Arc::new(AtomicBool::new(false)); + + { + let shutdown_bool = Arc::clone(&shutdown_bool); + crate::util::spawn_thread("shutdown-acceptor", move || { + // Block until shutdown is sent. + let _ = shutdown_channel.receiver().recv(); + // Store the bool so after the next accept it will break the loop + shutdown_bool.store(true, std::sync::atomic::Ordering::Release); + // Connect to the socket to cause it to unblock + let _ = TcpStream::connect(local_addr); + }); + } info!("Electrum RPC server running on {}", addr); loop { let (stream, addr) = listener.accept().expect("accept failed"); + + if shutdown_bool.load(std::sync::atomic::Ordering::Acquire) { + break; + } + stream .set_nonblocking(false) .expect("failed to set connection as blocking"); @@ -726,10 +773,18 @@ impl RPC { RPC { notification: notification.sender(), server: Some(spawn_thread("rpc", move || { - let senders = Arc::new(Mutex::new(Vec::>::new())); - - let acceptor = RPC::start_acceptor(rpc_addr); - RPC::start_notifier(notification, senders.clone(), acceptor.sender()); + let senders = + Arc::new(Mutex::new(Vec::>::new())); + + let acceptor_shutdown = Channel::unbounded(); + let acceptor_shutdown_sender = acceptor_shutdown.sender(); + let acceptor = RPC::start_acceptor(rpc_addr, acceptor_shutdown); + RPC::start_notifier( + notification, + senders.clone(), + acceptor.sender(), + acceptor_shutdown_sender, + ); let mut threads = HashMap::new(); let (garbage_sender, garbage_receiver) = crossbeam_channel::unbounded(); @@ -740,6 +795,10 @@ impl RPC { let senders = Arc::clone(&senders); let stats = Arc::clone(&stats); let garbage_sender = garbage_sender.clone(); + + // Kill the peers properly + let (killer, peace_receiver) = std::sync::mpsc::channel(); + #[cfg(feature = "electrum-discovery")] let discovery = discovery.clone(); @@ -751,6 +810,7 @@ impl RPC { addr, stats, txs_limit, + peace_receiver, #[cfg(feature = "electrum-discovery")] discovery, ); @@ -761,24 +821,29 @@ impl RPC { }); trace!("[{}] spawned {:?}", addr, spawned.thread().id()); - threads.insert(spawned.thread().id(), spawned); + threads.insert(spawned.thread().id(), (spawned, killer)); while let Ok(id) = garbage_receiver.try_recv() { - if let Some(thread) = threads.remove(&id) { + if let Some((thread, killer)) = threads.remove(&id) { trace!("[{}] joining {:?}", addr, id); + let _ = killer.send(()); if let Err(error) = thread.join() { error!("failed to join {:?}: {:?}", id, error); } } } } + // Drop these + drop(acceptor); + drop(garbage_receiver); trace!("closing {} RPC connections", senders.lock().unwrap().len()); for sender in senders.lock().unwrap().iter() { - let _ = sender.send(Message::Done); + let _ = sender.try_send(Message::Done); } - for (id, thread) in threads { + for (id, (thread, killer)) in threads { trace!("joining {:?}", id); + let _ = killer.send(()); if let Err(error) = thread.join() { error!("failed to join {:?}: {:?}", id, error); } @@ -802,5 +867,8 @@ impl Drop for RPC { handle.join().unwrap(); } trace!("RPC server is stopped"); + crate::util::with_spawned_threads(|threads| { + trace!("Threads after dropping RPC: {:?}", threads); + }); } } diff --git a/src/elements/registry.rs b/src/elements/registry.rs index e0728320d..dbc9c06c9 100644 --- a/src/elements/registry.rs +++ b/src/elements/registry.rs @@ -102,7 +102,7 @@ impl AssetRegistry { } pub fn spawn_sync(asset_db: Arc>) -> thread::JoinHandle<()> { - thread::spawn(move || loop { + crate::util::spawn_thread("asset-registry", move || loop { if let Err(e) = asset_db.write().unwrap().fs_sync() { error!("registry fs_sync failed: {:?}", e); } diff --git a/src/new_index/fetch.rs b/src/new_index/fetch.rs index 20d0bfde0..89ed77af0 100644 --- a/src/new_index/fetch.rs +++ b/src/new_index/fetch.rs @@ -9,7 +9,6 @@ use std::collections::HashMap; use std::fs; use std::io::Cursor; use std::path::PathBuf; -use std::sync::mpsc::Receiver; use std::thread; use crate::chain::{Block, BlockHash}; @@ -44,12 +43,12 @@ pub struct BlockEntry { type SizedBlock = (Block, u32); pub struct Fetcher { - receiver: Receiver, + receiver: crossbeam_channel::Receiver, thread: thread::JoinHandle<()>, } impl Fetcher { - fn from(receiver: Receiver, thread: thread::JoinHandle<()>) -> Self { + fn from(receiver: crossbeam_channel::Receiver, thread: thread::JoinHandle<()>) -> Self { Fetcher { receiver, thread } } diff --git a/src/rest.rs b/src/rest.rs index 108215634..7b62bf833 100644 --- a/src/rest.rs +++ b/src/rest.rs @@ -594,7 +594,7 @@ pub fn start(config: Arc, query: Arc) -> Handle { Handle { tx, - thread: thread::spawn(move || { + thread: crate::util::spawn_thread("rest-server", move || { run_server(config, query, rx); }), } diff --git a/src/signal.rs b/src/signal.rs index 9bc30d9e3..c4ebc8e3c 100644 --- a/src/signal.rs +++ b/src/signal.rs @@ -1,6 +1,5 @@ use crossbeam_channel as channel; use crossbeam_channel::RecvTimeoutError; -use std::thread; use std::time::{Duration, Instant}; use signal_hook::consts::{SIGINT, SIGTERM, SIGUSR1}; @@ -16,7 +15,7 @@ fn notify(signals: &[i32]) -> channel::Receiver { let (s, r) = channel::bounded(1); let mut signals = signal_hook::iterator::Signals::new(signals).expect("failed to register signal hook"); - thread::spawn(move || { + crate::util::spawn_thread("signal-notifier", move || { for signal in signals.forever() { s.send(signal) .unwrap_or_else(|_| panic!("failed to send signal {}", signal)); diff --git a/src/util/mod.rs b/src/util/mod.rs index 233e3efea..03e9780b7 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -14,8 +14,10 @@ pub use self::transaction::{ }; use std::collections::HashMap; -use std::sync::mpsc::{channel, sync_channel, Receiver, Sender, SyncSender}; -use std::thread; +use std::sync::atomic::AtomicUsize; +use std::sync::mpsc::{channel, Receiver, Sender}; +use std::sync::Mutex; +use std::thread::{self, ThreadId}; use crate::chain::BlockHeader; use bitcoin::hashes::sha256d::Hash as Sha256dHash; @@ -35,25 +37,25 @@ pub fn full_hash(hash: &[u8]) -> FullHash { } pub struct SyncChannel { - tx: SyncSender, - rx: Receiver, + tx: crossbeam_channel::Sender, + rx: crossbeam_channel::Receiver, } impl SyncChannel { pub fn new(size: usize) -> SyncChannel { - let (tx, rx) = sync_channel(size); + let (tx, rx) = crossbeam_channel::bounded(size); SyncChannel { tx, rx } } - pub fn sender(&self) -> SyncSender { + pub fn sender(&self) -> crossbeam_channel::Sender { self.tx.clone() } - pub fn receiver(&self) -> &Receiver { + pub fn receiver(&self) -> &crossbeam_channel::Receiver { &self.rx } - pub fn into_receiver(self) -> Receiver { + pub fn into_receiver(self) -> crossbeam_channel::Receiver { self.rx } } @@ -82,15 +84,58 @@ impl Channel { } } -pub fn spawn_thread(name: &str, f: F) -> thread::JoinHandle +/// This static HashMap contains all the threads spawned with [`spawn_thread`] with their name +#[inline] +pub fn with_spawned_threads(f: F) +where + F: FnOnce(&mut HashMap), +{ + lazy_static! { + static ref SPAWNED_THREADS: Mutex> = Mutex::new(HashMap::new()); + } + let mut lock = match SPAWNED_THREADS.lock() { + Ok(threads) => threads, + // There's no possible broken state + Err(threads) => { + warn!("SPAWNED_THREADS is in a poisoned state! Be wary of incorrect logs!"); + threads.into_inner() + } + }; + f(&mut lock) +} + +pub fn spawn_thread(prefix: &str, do_work: F) -> thread::JoinHandle where F: FnOnce() -> T, F: Send + 'static, T: Send + 'static, { + static THREAD_COUNTER: AtomicUsize = AtomicUsize::new(0); + let counter = THREAD_COUNTER.fetch_add(1, std::sync::atomic::Ordering::AcqRel); thread::Builder::new() - .name(name.to_owned()) - .spawn(f) + .name(format!("{}-{}", prefix, counter)) + .spawn(move || { + let thread = std::thread::current(); + let name = thread.name().unwrap(); + let id = thread.id(); + + trace!("[THREAD] GETHASHMAP INSERT | {name} {id:?}"); + with_spawned_threads(|threads| { + threads.insert(id, name.to_owned()); + }); + trace!("[THREAD] START WORK | {name} {id:?}"); + + let result = do_work(); + + trace!("[THREAD] FINISHED WORK | {name} {id:?}"); + trace!("[THREAD] GETHASHMAP REMOVE | {name} {id:?}"); + with_spawned_threads(|threads| { + threads.remove(&id); + }); + trace!("[THREAD] HASHMAP REMOVED | {name} {id:?}"); + + result + }) .unwrap() }