diff --git a/crates/krun-server/src/bin/krun-server.rs b/crates/krun-server/src/bin/krun-server.rs index e552d0c..a2b0bd4 100644 --- a/crates/krun-server/src/bin/krun-server.rs +++ b/crates/krun-server/src/bin/krun-server.rs @@ -2,7 +2,7 @@ use std::os::unix::process::ExitStatusExt as _; use anyhow::Result; use krun_server::cli_options::options; -use krun_server::server::{start_server, State}; +use krun_server::server::{Server, State}; use log::error; use tokio::net::TcpListener; use tokio::process::Command; @@ -17,13 +17,12 @@ async fn main() -> Result<()> { let options = options().run(); let listener = TcpListener::bind(format!("0.0.0.0:{}", options.server_port)).await?; - let (state_tx, state_rx) = watch::channel(State { - connection_idle: true, - child_processes: 0, - }); + let (state_tx, state_rx) = watch::channel(State::new()); - let server_handle = tokio::spawn(start_server(listener, state_tx)); - tokio::pin!(server_handle); + let mut server_handle = tokio::spawn(async move { + let mut server = Server::new(listener, state_tx); + server.run().await; + }); let command_status = Command::new(&options.command) .args(options.command_args) .status(); @@ -77,7 +76,7 @@ async fn main() -> Result<()> { command_exited = true; }, Some(state) = state_rx.next(), if command_exited => { - if state.connection_idle && state.child_processes == 0 { + if state.connection_idle() && state.child_processes() == 0 { // Server is idle (not currently handling an accepted // incoming connection) and no more child processes. // We're done. @@ -85,7 +84,7 @@ async fn main() -> Result<()> { } println!( "Waiting for {} other commands launched through this krun server to exit...", - state.child_processes + state.child_processes() ); println!("Press Ctrl+C to force quit"); }, diff --git a/crates/krun-server/src/server.rs b/crates/krun-server/src/server.rs index 4ead4e7..5299f99 100644 --- a/crates/krun-server/src/server.rs +++ b/crates/krun-server/src/server.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; -use std::env; use std::os::unix::process::ExitStatusExt as _; use std::path::PathBuf; -use std::process::Stdio; +use std::process::{ExitStatus, Stdio}; +use std::{env, io}; use anyhow::{anyhow, Context, Result}; use log::{debug, error}; @@ -10,116 +10,138 @@ use tokio::io::{AsyncBufReadExt as _, AsyncWriteExt as _, BufStream}; use tokio::net::{TcpListener, TcpStream}; use tokio::process::{Child, Command}; use tokio::sync::watch; -use tokio::task::JoinSet; +use tokio::task::{JoinError, JoinSet}; use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::StreamExt as _; use utils::launch::Launch; use utils::stdio::make_stdout_stderr; +#[derive(Debug)] +pub struct Server { + listener_stream: TcpListenerStream, + state_tx: watch::Sender, + child_set: JoinSet<(PathBuf, ChildResult)>, +} + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub struct State { - pub connection_idle: bool, - pub child_processes: usize, + connection_idle: bool, + child_processes: usize, } -pub async fn start_server(listener: TcpListener, state_tx: watch::Sender) { - let mut listener_stream = TcpListenerStream::new(listener); - let mut child_set = JoinSet::new(); +type ChildResult = Result; - loop { - tokio::select! { - Some(stream) = listener_stream.next() => { - state_tx.send_if_modified(|state| { - let connection_idle = false; - if state.connection_idle == connection_idle { - return false; - } - state.connection_idle = connection_idle; - true - }); - let stream = match stream { - Ok(stream) => stream, - Err(err) => { - eprintln!("Failed to accept incoming connection: {err}"); - state_tx.send_if_modified(|state| { - let connection_idle = true; - if state.connection_idle == connection_idle { - return false; - } - state.connection_idle = connection_idle; - true - }); - continue; - }, - }; - let stream = BufStream::new(stream); - - match handle_connection(stream).await { - Ok((command, mut child)) => { - child_set.spawn(async move { (command, child.wait().await) }); - state_tx.send_if_modified(|state| { - let child_processes = child_set.len(); - if state.child_processes == child_processes { - return false; - } - state.child_processes = child_processes; - true - }); - }, - Err(err) => { - eprintln!("Failed to process client request: {err:?}"); - }, - } - state_tx.send_if_modified(|state| { - let connection_idle = true; - if state.connection_idle == connection_idle { - return false; - } - state.connection_idle = connection_idle; - true - }); - }, - Some(res) = child_set.join_next() => { - match res { - Ok((command, res)) => match res { - Ok(status) => { - debug!(command:?; "child process exited"); - if !status.success() { - if let Some(code) = status.code() { - eprintln!( - "{command:?} process exited with status code: {code}" - ); - } else { - eprintln!( - "{command:?} process terminated by signal: {}", - status - .signal() - .expect( - "either one of status code or signal should be set" - ) - ); - } - } +impl Server { + pub fn new(listener: TcpListener, state_tx: watch::Sender) -> Self { + Server { + listener_stream: TcpListenerStream::new(listener), + state_tx, + child_set: JoinSet::new(), + } + } + + pub async fn run(&mut self) { + loop { + tokio::select! { + Some(stream) = self.listener_stream.next() => { + self.set_connection_idle(false); + let stream = match stream { + Ok(stream) => stream, + Err(err) => { + eprintln!("Failed to accept incoming connection: {err}"); + self.set_connection_idle(true); + continue; + }, + }; + let stream = BufStream::new(stream); + + match handle_connection(stream).await { + Ok((command, mut child)) => { + self.child_set.spawn(async move { (command, child.wait().await) }); + self.set_child_processes(self.child_set.len()); }, Err(err) => { - eprintln!("Failed to wait for {command:?} process to exit: {err}"); + eprintln!("Failed to process client request: {err:?}"); }, - }, - Err(err) => { - error!(err:% = err; "child task failed"); - }, - } - state_tx.send_if_modified(|state| { - let child_processes = child_set.len(); - if state.child_processes == child_processes { - return false; } - state.child_processes = child_processes; - true - }); + self.set_connection_idle(true); + }, + Some(res) = self.child_set.join_next() => self.handle_child_join(res), + } + } + } + + fn handle_child_join(&self, res: Result<(PathBuf, ChildResult), JoinError>) { + match res { + Ok((command, res)) => match res { + Ok(status) => { + debug!(command:?; "child process exited"); + if !status.success() { + if let Some(code) = status.code() { + eprintln!("{command:?} process exited with status code: {code}"); + } else { + eprintln!( + "{command:?} process terminated by signal: {}", + status + .signal() + .expect("either one of status code or signal should be set") + ); + } + } + }, + Err(err) => { + eprintln!("Failed to wait for {command:?} process to exit: {err}"); + }, }, + Err(err) => { + error!(err:% = err; "child task failed"); + }, + } + self.set_child_processes(self.child_set.len()); + } + + fn set_connection_idle(&self, connection_idle: bool) { + self.state_tx.send_if_modified(|state| { + if state.connection_idle == connection_idle { + return false; + } + state.connection_idle = connection_idle; + true + }); + } + + fn set_child_processes(&self, child_processes: usize) { + self.state_tx.send_if_modified(|state| { + if state.child_processes == child_processes { + return false; + } + state.child_processes = child_processes; + true + }); + } +} + +impl State { + pub fn new() -> Self { + Self { + connection_idle: true, + child_processes: 0, } } + + pub fn connection_idle(&self) -> bool { + self.connection_idle + } + + pub fn child_processes(&self) -> usize { + self.child_processes + } +} + +impl Default for State { + fn default() -> Self { + Self::new() + } } async fn read_request(stream: &mut BufStream) -> Result {