From 8aee9e291eca2ce8dfa58c3c05b20e1207a8efb8 Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Mon, 4 Mar 2024 14:54:20 +0100 Subject: [PATCH 01/10] Refactor forward_messages to use select macro --- test/test-rpc/src/transport.rs | 57 +++++++++++++++++----------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/test/test-rpc/src/transport.rs b/test/test-rpc/src/transport.rs index b8086b41456b..f5f461702688 100644 --- a/test/test-rpc/src/transport.rs +++ b/test/test-rpc/src/transport.rs @@ -1,5 +1,5 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; -use futures::{channel::mpsc, SinkExt, StreamExt}; +use futures::{channel::mpsc, FutureExt, SinkExt, StreamExt}; use serde::{de::DeserializeOwned, Serialize}; use std::{ fmt::Write, @@ -256,13 +256,12 @@ async fn forward_messages< let mut mullvad_daemon_forwarder = LengthDelimitedCodec::new().framed(mullvad_daemon_forwarder); loop { - match futures::future::select( - futures::future::select(serial_stream.next(), handshaker.1.next()), - futures::future::select(runner_forwarder.next(), mullvad_daemon_forwarder.next()), - ) - .await - { - futures::future::Either::Left((futures::future::Either::Left((Some(frame), _)), _)) => { + futures::select! { + frame = serial_stream.next().fuse() => { + let Some(frame) = frame else { + break Ok(()); + }; + let frame = frame.map_err(ForwardError::SerialConnection)?; // @@ -294,7 +293,12 @@ async fn forward_messages< } } } - futures::future::Either::Left((futures::future::Either::Right((Some(()), _)), _)) => { + + handshake = handshaker.1.next().fuse() => { + if handshake.is_none() { + break Ok(()); + } + log::trace!("shake: send"); // Ping the other end @@ -303,10 +307,12 @@ async fn forward_messages< .await .map_err(ForwardError::HandshakeError)?; } - futures::future::Either::Right(( - futures::future::Either::Left((Some(message), _)), - _, - )) => { + + message = runner_forwarder.next().fuse() => { + let Some(message) = message else { + break Ok(()); + }; + let message = message.map_err(ForwardError::TestRunnerChannel)?; // @@ -321,10 +327,16 @@ async fn forward_messages< .await .map_err(ForwardError::SerialConnection)?; } - futures::future::Either::Right(( - futures::future::Either::Right((Some(data), _)), - _, - )) => { + + data = mullvad_daemon_forwarder.next().fuse() => { + let Some(data) = data else { + // + // Force management interface socket to close + // + let _ = serial_stream.send(Frame::DaemonRpc(Bytes::new())).await; + break Ok(()); + }; + let data = data.map_err(ForwardError::DaemonChannel)?; // @@ -336,17 +348,6 @@ async fn forward_messages< .await .map_err(ForwardError::SerialConnection)?; } - futures::future::Either::Right((futures::future::Either::Right((None, _)), _)) => { - // - // Force management interface socket to close - // - let _ = serial_stream.send(Frame::DaemonRpc(Bytes::new())).await; - - break Ok(()); - } - _ => { - break Ok(()); - } } } } From aca51ed386c29ed5b02614139794439e3847321a Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Mon, 4 Mar 2024 16:07:03 +0100 Subject: [PATCH 02/10] Add basic split-tunnel test --- test/test-manager/src/tests/mod.rs | 3 +- test/test-manager/src/tests/split_tunnel.rs | 58 +++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 test/test-manager/src/tests/split_tunnel.rs diff --git a/test/test-manager/src/tests/mod.rs b/test/test-manager/src/tests/mod.rs index 0cf135769687..f14d5238d52c 100644 --- a/test/test-manager/src/tests/mod.rs +++ b/test/test-manager/src/tests/mod.rs @@ -6,6 +6,7 @@ mod helpers; mod install; mod settings; mod software; +mod split_tunnel; mod test_metadata; mod tunnel; mod tunnel_state; @@ -40,7 +41,7 @@ pub enum Error { Rpc(#[from] test_rpc::Error), #[error("geoip lookup failed")] - GeoipLookup(test_rpc::Error), + GeoipLookup(#[source] test_rpc::Error), #[error("Found running daemon unexpectedly")] DaemonRunning, diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs new file mode 100644 index 000000000000..26ccd7794f5e --- /dev/null +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -0,0 +1,58 @@ +use mullvad_management_interface::MullvadProxyClient; +use std::str; +use test_macro::test_function; +use test_rpc::{ExecResult, ServiceClient}; + +use super::{helpers, Error, TestContext}; + +#[test_function] +#[cfg(any(target_os = "linux", target_os = "windows"))] +pub async fn test_split_tunnel( + _: TestContext, + rpc: ServiceClient, + mut mullvad_client: MullvadProxyClient, +) -> Result<(), Error> { + let mut errored = false; + let parse_am_i_mullvad = |result: ExecResult| -> bool { + let stdout = str::from_utf8(&result.stdout).expect("am-i-mullvad output is UTF-8"); + + if stdout.contains("You are connected") { + true + } else if stdout.contains("You are not connected") { + false + } else { + panic!("Unexpected output from `am-i-mullvad`: {stdout}") + } + }; + + helpers::connect_and_wait(&mut mullvad_client).await?; + + let i_am_mullvad = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?); + if !i_am_mullvad { + log::error!("We should be connected, but `am-i-mullvad` reported that it was not connected to Mullvad."); + errored = true; + } + + let i_am_mullvad_while_split = + parse_am_i_mullvad(rpc.exec("mullvad-exclude", ["am-i-mullvad"]).await?); + if i_am_mullvad_while_split { + log::error!("`mullvad-exclude am-i-mullvad` reported that it was connected to Mullvad."); + log::error!("`am-i-mullvad` does not appear to have been split correctly."); + errored = true; + } + + helpers::disconnect_and_wait(&mut mullvad_client).await?; + + let i_am_mullvad_while_disconnected = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?); + if i_am_mullvad_while_disconnected { + log::error!("We should be disconnected, but `am-i-mullvad` reported that it was connected to Mullvad."); + log::error!("Host machine is probably connected to Mullvad. This may affect test results."); + errored = true; + } + + if errored { + panic!("test_split_tunnel failed, see logs for details."); + } + + Ok(()) +} From 1badc46007ab3a97b81ae2ef84df546c6435060c Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Mon, 4 Mar 2024 16:46:52 +0100 Subject: [PATCH 03/10] Make e2e tests accept anyhow errors --- test/test-manager/src/logging.rs | 2 +- test/test-manager/src/run_tests.rs | 6 ++---- test/test-manager/src/tests/mod.rs | 2 +- test/test-manager/src/tests/split_tunnel.rs | 20 ++++++++++---------- test/test-manager/test_macro/src/lib.rs | 8 ++++---- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/test/test-manager/src/logging.rs b/test/test-manager/src/logging.rs index cd0bd4af2840..e85920b1cd9c 100644 --- a/test/test-manager/src/logging.rs +++ b/test/test-manager/src/logging.rs @@ -1,4 +1,4 @@ -use crate::tests::Error; +use anyhow::Error; use colored::Colorize; use std::sync::{Arc, Mutex}; use test_rpc::logging::{LogOutput, Output}; diff --git a/test/test-manager/src/run_tests.rs b/test/test-manager/src/run_tests.rs index 6af153656277..6b3da3713808 100644 --- a/test/test-manager/src/run_tests.rs +++ b/test/test-manager/src/run_tests.rs @@ -2,9 +2,7 @@ use crate::summary::{self, maybe_log_test_result}; use crate::tests::{config::TEST_CONFIG, TestContext}; use crate::{ logging::{panic_as_string, TestOutput}, - mullvad_daemon, tests, - tests::Error, - vm, + mullvad_daemon, tests, vm, }; use anyhow::{Context, Result}; use futures::FutureExt; @@ -187,7 +185,7 @@ pub async fn run_test( ) -> TestOutput where F: Fn(super::tests::TestContext, ServiceClient, MullvadClient) -> R, - R: Future>, + R: Future>, { let _flushed = runner_rpc.try_poll_output().await; diff --git a/test/test-manager/src/tests/mod.rs b/test/test-manager/src/tests/mod.rs index f14d5238d52c..48d75b9e3f96 100644 --- a/test/test-manager/src/tests/mod.rs +++ b/test/test-manager/src/tests/mod.rs @@ -33,7 +33,7 @@ pub type TestWrapperFunction = fn( TestContext, ServiceClient, Box, -) -> BoxFuture<'static, Result<(), Error>>; +) -> BoxFuture<'static, anyhow::Result<()>>; #[derive(thiserror::Error, Debug)] pub enum Error { diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs index 26ccd7794f5e..4cd48cdce334 100644 --- a/test/test-manager/src/tests/split_tunnel.rs +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -3,7 +3,7 @@ use std::str; use test_macro::test_function; use test_rpc::{ExecResult, ServiceClient}; -use super::{helpers, Error, TestContext}; +use super::{helpers, TestContext}; #[test_function] #[cfg(any(target_os = "linux", target_os = "windows"))] @@ -11,30 +11,30 @@ pub async fn test_split_tunnel( _: TestContext, rpc: ServiceClient, mut mullvad_client: MullvadProxyClient, -) -> Result<(), Error> { +) -> anyhow::Result<()> { let mut errored = false; - let parse_am_i_mullvad = |result: ExecResult| -> bool { + let parse_am_i_mullvad = |result: ExecResult| { let stdout = str::from_utf8(&result.stdout).expect("am-i-mullvad output is UTF-8"); - if stdout.contains("You are connected") { + Ok(if stdout.contains("You are connected") { true } else if stdout.contains("You are not connected") { false } else { - panic!("Unexpected output from `am-i-mullvad`: {stdout}") - } + anyhow::bail!("Unexpected output from `am-i-mullvad`: {stdout}") + }) }; helpers::connect_and_wait(&mut mullvad_client).await?; - let i_am_mullvad = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?); + let i_am_mullvad = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?)?; if !i_am_mullvad { log::error!("We should be connected, but `am-i-mullvad` reported that it was not connected to Mullvad."); errored = true; } let i_am_mullvad_while_split = - parse_am_i_mullvad(rpc.exec("mullvad-exclude", ["am-i-mullvad"]).await?); + parse_am_i_mullvad(rpc.exec("mullvad-exclude", ["am-i-mullvad"]).await?)?; if i_am_mullvad_while_split { log::error!("`mullvad-exclude am-i-mullvad` reported that it was connected to Mullvad."); log::error!("`am-i-mullvad` does not appear to have been split correctly."); @@ -43,7 +43,7 @@ pub async fn test_split_tunnel( helpers::disconnect_and_wait(&mut mullvad_client).await?; - let i_am_mullvad_while_disconnected = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?); + let i_am_mullvad_while_disconnected = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?)?; if i_am_mullvad_while_disconnected { log::error!("We should be disconnected, but `am-i-mullvad` reported that it was connected to Mullvad."); log::error!("Host machine is probably connected to Mullvad. This may affect test results."); @@ -51,7 +51,7 @@ pub async fn test_split_tunnel( } if errored { - panic!("test_split_tunnel failed, see logs for details."); + anyhow::bail!("test_split_tunnel failed, see log output for details."); } Ok(()) diff --git a/test/test-manager/test_macro/src/lib.rs b/test/test-manager/test_macro/src/lib.rs index d95c3f883211..387676d80a77 100644 --- a/test/test-manager/test_macro/src/lib.rs +++ b/test/test-manager/test_macro/src/lib.rs @@ -52,7 +52,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta}; /// pub async fn test_function( /// rpc: ServiceClient, /// mut mullvad_client: mullvad_management_interface::MullvadProxyClient, -/// ) -> Result<(), Error> { +/// ) -> anyhow::Result<()> { /// Ok(()) /// } /// ``` @@ -67,7 +67,7 @@ use syn::{AttributeArgs, Lit, Meta, NestedMeta}; /// pub async fn test_function( /// rpc: ServiceClient, /// mut mullvad_client: mullvad_management_interface::MullvadProxyClient, -/// ) -> Result<(), Error> { +/// ) -> anyhow::Result<()> { /// Ok(()) /// } /// ``` @@ -193,7 +193,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { use std::any::Any; let mullvad_client = mullvad_client.downcast::<#mullvad_client_type>().expect("invalid mullvad client"); Box::pin(async move { - #func_name(test_context, rpc, *mullvad_client).await + Ok(#func_name(test_context, rpc, *mullvad_client).await?) }) } } @@ -204,7 +204,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { rpc: test_rpc::ServiceClient, mullvad_client: Box| { Box::pin(async move { - #func_name(test_context, rpc).await + Ok(#func_name(test_context, rpc).await?) }) } } From bea5570b3e89acbe8c12d16d7ada54cb5a4fd5f0 Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Tue, 5 Mar 2024 15:20:28 +0100 Subject: [PATCH 04/10] Refactor test_macro error handling without panics --- test/Cargo.lock | 1 + test/test-manager/test_macro/Cargo.toml | 1 + test/test-manager/test_macro/src/lib.rs | 193 +++++++++++++----------- 3 files changed, 107 insertions(+), 88 deletions(-) diff --git a/test/Cargo.lock b/test/Cargo.lock index b974ef9217ce..60ab48547b9e 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -3229,6 +3229,7 @@ dependencies = [ "proc-macro2", "quote", "syn 1.0.109", + "test-rpc", ] [[package]] diff --git a/test/test-manager/test_macro/Cargo.toml b/test/test-manager/test_macro/Cargo.toml index a064b6d200f1..19a405d08fe9 100644 --- a/test/test-manager/test_macro/Cargo.toml +++ b/test/test-manager/test_macro/Cargo.toml @@ -14,3 +14,4 @@ proc-macro = true syn = "1.0" quote = "1.0" proc-macro2 = "1.0" +test-rpc = { path = "../../test-rpc" } diff --git a/test/test-manager/test_macro/src/lib.rs b/test/test-manager/test_macro/src/lib.rs index 387676d80a77..fdf7e5539cc9 100644 --- a/test/test-manager/test_macro/src/lib.rs +++ b/test/test-manager/test_macro/src/lib.rs @@ -1,6 +1,7 @@ use proc_macro::TokenStream; use quote::{quote, ToTokens}; -use syn::{AttributeArgs, Lit, Meta, NestedMeta}; +use syn::{AttributeArgs, Lit, Meta, NestedMeta, Result}; +use test_rpc::meta::Os; /// Register an `async` function to be run by `test-manager`. /// @@ -76,7 +77,10 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream let function: syn::ItemFn = syn::parse(code).unwrap(); let attributes = syn::parse_macro_input!(attributes as AttributeArgs); - let test_function = parse_marked_test_function(&attributes, &function); + let test_function = match parse_marked_test_function(&attributes, &function) { + Ok(tf) => tf, + Err(e) => return e.into_compile_error().into(), + }; let register_test = create_test(test_function); @@ -88,19 +92,31 @@ pub fn test_function(attributes: TokenStream, code: TokenStream) -> TokenStream .into() } -fn parse_marked_test_function(attributes: &AttributeArgs, function: &syn::ItemFn) -> TestFunction { - let macro_parameters = get_test_macro_parameters(attributes); +/// Shorthand for `return syn::Error::new(...)`. +macro_rules! bail { + ($span:expr, $($tt:tt)*) => {{ + return ::core::result::Result::Err(::syn::Error::new( + ::syn::spanned::Spanned::span(&$span), + ::core::format_args!($($tt)*), + )) + }}; +} - let function_parameters = get_test_function_parameters(&function.sig.inputs); +fn parse_marked_test_function( + attributes: &AttributeArgs, + function: &syn::ItemFn, +) -> Result { + let macro_parameters = get_test_macro_parameters(attributes)?; + let function_parameters = get_test_function_parameters(&function.sig.inputs)?; - TestFunction { + Ok(TestFunction { name: function.sig.ident.clone(), function_parameters, macro_parameters, - } + }) } -fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> MacroParameters { +fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> Result { let mut priority = None; let mut cleanup = true; let mut always_run = false; @@ -108,53 +124,57 @@ fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> MacroParameters let mut target_os = None; for attribute in attributes { - if let NestedMeta::Meta(Meta::NameValue(nv)) = attribute { - if nv.path.is_ident("priority") { - match &nv.lit { - Lit::Int(lit_int) => { - priority = Some(lit_int.base10_parse().unwrap()); - } - _ => panic!("'priority' should have an integer value"), - } - } else if nv.path.is_ident("always_run") { - match &nv.lit { - Lit::Bool(lit_bool) => { - always_run = lit_bool.value(); - } - _ => panic!("'always_run' should have a bool value"), - } - } else if nv.path.is_ident("must_succeed") { - match &nv.lit { - Lit::Bool(lit_bool) => { - must_succeed = lit_bool.value(); - } - _ => panic!("'must_succeed' should have a bool value"), - } - } else if nv.path.is_ident("cleanup") { - match &nv.lit { - Lit::Bool(lit_bool) => { - cleanup = lit_bool.value(); - } - _ => panic!("'cleanup' should have a bool value"), - } - } else if nv.path.is_ident("target_os") { - match &nv.lit { - Lit::Str(lit_str) => { - target_os = Some(lit_str.value()); - } - _ => panic!("'target_os' should have a string value"), - } + // we only use name-value attributes + let NestedMeta::Meta(Meta::NameValue(nv)) = attribute else { + bail!(attribute, "unknown attribute"); + }; + let lit = &nv.lit; + + if nv.path.is_ident("priority") { + match lit { + Lit::Int(lit_int) => priority = Some(lit_int.base10_parse().unwrap()), + _ => bail!(nv, "'priority' should have an integer value"), } + } else if nv.path.is_ident("always_run") { + match lit { + Lit::Bool(lit_bool) => always_run = lit_bool.value(), + _ => bail!(nv, "'always_run' should have a bool value"), + } + } else if nv.path.is_ident("must_succeed") { + match lit { + Lit::Bool(lit_bool) => must_succeed = lit_bool.value(), + _ => bail!(nv, "'must_succeed' should have a bool value"), + } + } else if nv.path.is_ident("cleanup") { + match lit { + Lit::Bool(lit_bool) => cleanup = lit_bool.value(), + _ => bail!(nv, "'cleanup' should have a bool value"), + } + } else if nv.path.is_ident("target_os") { + let Lit::Str(lit_str) = lit else { + bail!(nv, "'target_os' should have a string value"); + }; + + if target_os.is_some() { + bail!(nv, "can't specify multiple targets"); + } + + target_os = match lit_str.value().parse() { + Ok(os) => Some(os), + Err(e) => bail!(lit_str, "{e}"), + } + } else { + bail!(nv, "unknown attribute"); } } - MacroParameters { + Ok(MacroParameters { priority, cleanup, always_run, must_succeed, target_os, - } + }) } fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { @@ -162,15 +182,10 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { Some(priority) => quote! { Some(#priority) }, None => quote! { None }, }; - let target_os = match test_function.macro_parameters.target_os.as_deref() { - Some("linux") => quote! { Some(::test_rpc::meta::Os::Linux) }, - Some("macos") => quote! { Some(::test_rpc::meta::Os::Macos) }, - Some("windows") => quote! { Some(::test_rpc::meta::Os::Windows) }, - Some(target_os) => { - return quote! { - compile_error!("invalid target_os: {:?}", #target_os); - }; - } + let target_os = match test_function.macro_parameters.target_os { + Some(Os::Linux) => quote! { Some(::test_rpc::meta::Os::Linux) }, + Some(Os::Macos) => quote! { Some(::test_rpc::meta::Os::Macos) }, + Some(Os::Windows) => quote! { Some(::test_rpc::meta::Os::Windows) }, None => quote! { None }, }; let should_cleanup = test_function.macro_parameters.cleanup; @@ -193,7 +208,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { use std::any::Any; let mullvad_client = mullvad_client.downcast::<#mullvad_client_type>().expect("invalid mullvad client"); Box::pin(async move { - Ok(#func_name(test_context, rpc, *mullvad_client).await?) + #func_name(test_context, rpc, *mullvad_client).await.map_err(Into::into) }) } } @@ -202,9 +217,9 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { quote! { |test_context: crate::tests::TestContext, rpc: test_rpc::ServiceClient, - mullvad_client: Box| { + _mullvad_client: Box| { Box::pin(async move { - Ok(#func_name(test_context, rpc).await?) + #func_name(test_context, rpc).await.map_err(Into::into) }) } } @@ -237,7 +252,7 @@ struct MacroParameters { cleanup: bool, always_run: bool, must_succeed: bool, - target_os: Option, + target_os: Option, } enum MullvadClient { @@ -269,36 +284,38 @@ struct FunctionParameters { } fn get_test_function_parameters( - inputs: &syn::punctuated::Punctuated, -) -> FunctionParameters { - if inputs.len() > 2 { - match inputs[2].clone() { - syn::FnArg::Typed(pat_type) => { - let mullvad_client = match &*pat_type.ty { - syn::Type::Path(syn::TypePath { path, .. }) => { - match path.segments[0].ident.to_string().as_str() { - "mullvad_management_interface" | "MullvadProxyClient" => { - let mullvad_client_version = - quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New }; - MullvadClient::New { - mullvad_client_type: pat_type.ty, - mullvad_client_version, - } - } - _ => panic!("cannot infer mullvad client type"), - } - } - _ => panic!("unexpected 'mullvad_client' type"), - }; - FunctionParameters { mullvad_client } - } - syn::FnArg::Receiver(_) => panic!("unexpected 'mullvad_client' arg"), - } - } else { - FunctionParameters { + args: &syn::punctuated::Punctuated, +) -> Result { + if args.len() <= 2 { + return Ok(FunctionParameters { mullvad_client: MullvadClient::None { - mullvad_client_version: quote! { test_rpc::mullvad_daemon::MullvadClientVersion::None }, + mullvad_client_version: quote! { + test_rpc::mullvad_daemon::MullvadClientVersion::None + }, }, - } + }); } + + let arg = args[2].clone(); + let syn::FnArg::Typed(pat_type) = arg else { + bail!(arg, "unexpected 'mullvad_client' arg"); + }; + + let syn::Type::Path(syn::TypePath { path, .. }) = &*pat_type.ty else { + bail!(pat_type, "unexpected 'mullvad_client' type"); + }; + + let mullvad_client = match path.segments[0].ident.to_string().as_str() { + "mullvad_management_interface" | "MullvadProxyClient" => { + let mullvad_client_version = + quote! { test_rpc::mullvad_daemon::MullvadClientVersion::New }; + MullvadClient::New { + mullvad_client_type: pat_type.ty, + mullvad_client_version, + } + } + _ => bail!(pat_type, "cannot infer mullvad client type"), + }; + + Ok(FunctionParameters { mullvad_client }) } From 611efc096c1aaf4105b840df0423eb7f0db5eb9c Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Tue, 5 Mar 2024 15:24:04 +0100 Subject: [PATCH 05/10] Use curl for split tunnel test --- test/test-manager/src/tests/split_tunnel.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs index 4cd48cdce334..522af854863b 100644 --- a/test/test-manager/src/tests/split_tunnel.rs +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -5,6 +5,8 @@ use test_rpc::{ExecResult, ServiceClient}; use super::{helpers, TestContext}; +const AM_I_MULLVAD: &str = "https://am.i.mullvad.net/connected"; + #[test_function] #[cfg(any(target_os = "linux", target_os = "windows"))] pub async fn test_split_tunnel( @@ -14,38 +16,41 @@ pub async fn test_split_tunnel( ) -> anyhow::Result<()> { let mut errored = false; let parse_am_i_mullvad = |result: ExecResult| { - let stdout = str::from_utf8(&result.stdout).expect("am-i-mullvad output is UTF-8"); + let stdout = str::from_utf8(&result.stdout).expect("curl output is UTF-8"); Ok(if stdout.contains("You are connected") { true } else if stdout.contains("You are not connected") { false } else { - anyhow::bail!("Unexpected output from `am-i-mullvad`: {stdout}") + anyhow::bail!("Unexpected output from `curl {AM_I_MULLVAD}`: {stdout}") }) }; helpers::connect_and_wait(&mut mullvad_client).await?; - let i_am_mullvad = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?)?; + let i_am_mullvad = parse_am_i_mullvad(rpc.exec("curl", [AM_I_MULLVAD]).await?)?; if !i_am_mullvad { - log::error!("We should be connected, but `am-i-mullvad` reported that it was not connected to Mullvad."); + log::error!("We should be connected, but `am.i.mullvad` reported that it was not connected to Mullvad."); errored = true; } let i_am_mullvad_while_split = - parse_am_i_mullvad(rpc.exec("mullvad-exclude", ["am-i-mullvad"]).await?)?; + parse_am_i_mullvad(rpc.exec("mullvad-exclude", ["curl", AM_I_MULLVAD]).await?)?; if i_am_mullvad_while_split { - log::error!("`mullvad-exclude am-i-mullvad` reported that it was connected to Mullvad."); + log::error!( + "`mullvad-exclude curl {AM_I_MULLVAD}` reported that it was connected to Mullvad." + ); log::error!("`am-i-mullvad` does not appear to have been split correctly."); errored = true; } helpers::disconnect_and_wait(&mut mullvad_client).await?; - let i_am_mullvad_while_disconnected = parse_am_i_mullvad(rpc.exec("am-i-mullvad", []).await?)?; + let i_am_mullvad_while_disconnected = + parse_am_i_mullvad(rpc.exec("curl", [AM_I_MULLVAD]).await?)?; if i_am_mullvad_while_disconnected { - log::error!("We should be disconnected, but `am-i-mullvad` reported that it was connected to Mullvad."); + log::error!("We should be disconnected, but `curl {AM_I_MULLVAD}` reported that it was connected to Mullvad."); log::error!("Host machine is probably connected to Mullvad. This may affect test results."); errored = true; } From 69d2c506d34cf944812ef65d695bb7c882e764ab Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Wed, 6 Mar 2024 14:00:36 +0100 Subject: [PATCH 06/10] Make OVMF paths configurable per VM --- test/test-manager/src/config.rs | 10 ++++++++++ test/test-manager/src/vm/qemu.rs | 34 ++++++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/test/test-manager/src/config.rs b/test/test-manager/src/config.rs index 1605661d5307..6921c0b33f3b 100644 --- a/test/test-manager/src/config.rs +++ b/test/test-manager/src/config.rs @@ -139,6 +139,16 @@ pub struct VmConfig { #[serde(default)] #[arg(long)] pub tpm: bool, + + /// Override the path to `OVMF_VARS.secboot.fd`. Requires `tpm`. + #[serde(default)] + #[arg(long, requires("tpm"))] + pub ovmf_vars_path: Option, + + /// Override the path to `OVMF_CODE.secboot.fd`. Requires `tpm`. + #[serde(default)] + #[arg(long, requires("tpm"))] + pub ovmf_code_path: Option, } impl VmConfig { diff --git a/test/test-manager/src/vm/qemu.rs b/test/test-manager/src/vm/qemu.rs index 5688f47101c9..62613d5e1daa 100644 --- a/test/test-manager/src/vm/qemu.rs +++ b/test/test-manager/src/vm/qemu.rs @@ -134,7 +134,7 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result // Configure OVMF. Currently, this is enabled implicitly if using a TPM let ovmf_handle = if vm_config.tpm { - let handle = OvmfHandle::new().await?; + let handle = OvmfHandle::new(vm_config).await?; handle.append_qemu_args(&mut qemu_cmd); Some(handle) } else { @@ -202,32 +202,50 @@ pub async fn run(config: &Config, vm_config: &VmConfig) -> Result /// Used to set up UEFI and append options to the QEMU command struct OvmfHandle { temp_vars: TempFile, + ovmf_code_path: String, } impl OvmfHandle { - pub async fn new() -> Result { - const OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd"; + pub async fn new(config: &VmConfig) -> Result { + const DEFAULT_OVMF_VARS_PATH: &str = "/usr/share/OVMF/OVMF_VARS.secboot.fd"; + const DEFAULT_OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd"; + + let ovmf_code_path = config + .ovmf_code_path + .as_deref() + .unwrap_or(DEFAULT_OVMF_CODE_PATH) + .to_owned(); + + let ovmf_vars_path = config + .ovmf_vars_path + .as_deref() + .unwrap_or(DEFAULT_OVMF_VARS_PATH); // Create a local copy of OVMF_VARS let temp_vars_path = random_tempfile_name(); - fs::copy(OVMF_VARS_PATH, &temp_vars_path) + fs::copy(ovmf_vars_path, &temp_vars_path) .await .map_err(Error::CopyOvmfVars)?; let temp_vars = TempFile::from_existing(temp_vars_path, async_tempfile::Ownership::Owned) .await .map_err(|_| Error::WrapOvmfVars)?; - Ok(OvmfHandle { temp_vars }) + + Ok(OvmfHandle { + temp_vars, + ovmf_code_path, + }) } pub fn append_qemu_args(&self, qemu_cmd: &mut Command) { - const OVMF_CODE_PATH: &str = "/usr/share/OVMF/OVMF_CODE.secboot.fd"; - qemu_cmd.args([ "-global", "driver=cfi.pflash01,property=secure,value=on", "-drive", - &format!("if=pflash,format=raw,unit=0,file={OVMF_CODE_PATH},readonly=on"), + &format!( + "if=pflash,format=raw,unit=0,file={},readonly=on", + self.ovmf_code_path + ), "-drive", &format!( "if=pflash,format=raw,unit=1,file={}", From 17814ac1ecd589aaaff66be816bf375c8f707b09 Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Wed, 6 Mar 2024 14:02:55 +0100 Subject: [PATCH 07/10] Add am-i-mullvad cli for testing split tunneling --- test/Cargo.lock | 168 ++++++++++++++++++++++++++++- test/Cargo.toml | 8 +- test/am-i-mullvad/Cargo.toml | 17 +++ test/am-i-mullvad/src/main.rs | 33 ++++++ test/build.sh | 5 +- test/scripts/build-runner-image.sh | 1 + 6 files changed, 225 insertions(+), 7 deletions(-) create mode 100644 test/am-i-mullvad/Cargo.toml create mode 100644 test/am-i-mullvad/src/main.rs diff --git a/test/Cargo.lock b/test/Cargo.lock index 60ab48547b9e..09ec93e782c4 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -61,6 +61,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "am-i-mullvad" +version = "0.0.0" +dependencies = [ + "color-eyre", + "eyre", + "reqwest", + "serde", +] + [[package]] name = "android-tzdata" version = "0.1.1" @@ -463,6 +473,33 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" +[[package]] +name = "color-eyre" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a667583cca8c4f8436db8de46ea8233c42a7d9ae424a82d338f2e4675229204" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "colorchoice" version = "1.0.0" @@ -745,6 +782,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encoding_rs" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1" +dependencies = [ + "cfg-if", +] + [[package]] name = "enum-as-inner" version = "0.6.0" @@ -821,6 +867,16 @@ dependencies = [ "libc", ] +[[package]] +name = "eyre" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fast-socks5" version = "0.9.5" @@ -1197,7 +1253,7 @@ dependencies = [ "rustls-native-certs", "tokio", "tokio-rustls", - "webpki-roots", + "webpki-roots 0.23.1", ] [[package]] @@ -1245,6 +1301,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -1943,6 +2005,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "p256" version = "0.11.1" @@ -2464,6 +2532,47 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +[[package]] +name = "reqwest" +version = "0.11.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" +dependencies = [ + "base64 0.21.4", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-rustls", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls", + "rustls-pemfile 1.0.3", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "system-configuration", + "tokio", + "tokio-rustls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots 0.25.4", + "winreg", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -2710,18 +2819,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.188" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", @@ -3006,6 +3115,27 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "talpid-dbus" version = "0.0.0" @@ -3553,6 +3683,16 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-error" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d686ec1c0f384b1277f097b2f279a2ecc11afe8c133c1aabf036a27cb4cd206e" +dependencies = [ + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-opentelemetry" version = "0.17.4" @@ -3793,6 +3933,18 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.87" @@ -3841,6 +3993,12 @@ dependencies = [ "rustls-webpki 0.100.3", ] +[[package]] +name = "webpki-roots" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" + [[package]] name = "which" version = "4.4.2" diff --git a/test/Cargo.toml b/test/Cargo.toml index 977f9082d82b..4c23a55eb047 100644 --- a/test/Cargo.toml +++ b/test/Cargo.toml @@ -7,7 +7,13 @@ rust-version = "1.75.0" [workspace] resolver = "2" -members = ["test-manager", "test-runner", "test-rpc", "socks-server"] +members = [ + "test-manager", + "test-runner", + "test-rpc", + "socks-server", + "am-i-mullvad", +] [workspace.lints.rust] rust_2018_idioms = "deny" diff --git a/test/am-i-mullvad/Cargo.toml b/test/am-i-mullvad/Cargo.toml new file mode 100644 index 000000000000..c3bda1b1cbc7 --- /dev/null +++ b/test/am-i-mullvad/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "am-i-mullvad" +description = "Simple cli for testing Mullvad VPN connections" +authors.workspace = true +repository.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true + +[lints] +workspace = true + +[dependencies] +color-eyre = "0.6.2" +eyre = "0.6.12" +reqwest = { version = "0.11.24", default-features = false, features = ["blocking", "rustls-tls", "json"] } +serde = { version = "1.0.197", features = ["derive"] } diff --git a/test/am-i-mullvad/src/main.rs b/test/am-i-mullvad/src/main.rs new file mode 100644 index 000000000000..c6cc272d308a --- /dev/null +++ b/test/am-i-mullvad/src/main.rs @@ -0,0 +1,33 @@ +use eyre::{eyre, Context}; +use reqwest::blocking::get; +use serde::Deserialize; +use std::process; + +#[derive(Debug, Deserialize)] +struct Response { + ip: String, + mullvad_exit_ip_hostname: Option, +} + +fn main() -> eyre::Result<()> { + color_eyre::install()?; + + let url = "https://am.i.mullvad.net/json"; + let response: Response = get(url) + .and_then(|r| r.json()) + .wrap_err_with(|| eyre!("Failed to GET {url}"))?; + + if let Some(server) = &response.mullvad_exit_ip_hostname { + println!( + "You are connected to Mullvad (server {}). Your IP address is {}", + server, response.ip + ); + Ok(()) + } else { + println!( + "You are not connected to Mullvad. Your IP address is {}", + response.ip + ); + process::exit(1) + } +} diff --git a/test/build.sh b/test/build.sh index 2a8f7c7063fa..1f0099ccf47c 100755 --- a/test/build.sh +++ b/test/build.sh @@ -19,7 +19,10 @@ if [[ $TARGET == x86_64-unknown-linux-gnu ]]; then mullvadvpn-app-tests \ /bin/bash -c "cd /src/test/; cargo build --bin test-runner --release --target ${TARGET}" else - cargo build --bin test-runner --release --target "${TARGET}" + cargo build \ + --bin test-runner \ + --bin am-i-mullvad \ + --release --target "${TARGET}" fi # Only build runner image for Windows diff --git a/test/scripts/build-runner-image.sh b/test/scripts/build-runner-image.sh index fe8077b33777..be0d6373234a 100755 --- a/test/scripts/build-runner-image.sh +++ b/test/scripts/build-runner-image.sh @@ -33,6 +33,7 @@ case $TARGET in mcopy \ -i "${TEST_RUNNER_IMAGE_PATH}" \ "${SCRIPT_DIR}/../target/$TARGET/release/test-runner.exe" \ + "${SCRIPT_DIR}/../target/$TARGET/release/am-i-mullvad.exe" \ "${PACKAGES_DIR}/"*.exe \ "${SCRIPT_DIR}/../openvpn.ca.crt" \ "::" From 8354142fd8c15680b62fec185eb9492d6dad20f8 Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Wed, 6 Mar 2024 14:05:11 +0100 Subject: [PATCH 08/10] Add split tunnel test for windows --- mullvad-management-interface/src/client.rs | 9 +- test/test-manager/src/tests/split_tunnel.rs | 121 ++++++++++++++++---- 2 files changed, 100 insertions(+), 30 deletions(-) diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index f30b61317146..0020ca696989 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -16,7 +16,6 @@ use mullvad_types::{ version::AppVersionInfo, wireguard::{PublicKey, QuantumResistantState, RotationInterval}, }; -#[cfg(target_os = "windows")] use std::path::Path; use std::str::FromStr; #[cfg(target_os = "windows")] @@ -632,7 +631,7 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] + //#[cfg(target_os = "windows")] pub async fn add_split_tunnel_app>(&mut self, path: P) -> Result<()> { let path = path.as_ref().to_str().ok_or(Error::PathMustBeUtf8)?; self.0 @@ -642,7 +641,7 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] + //#[cfg(target_os = "windows")] pub async fn remove_split_tunnel_app>(&mut self, path: P) -> Result<()> { let path = path.as_ref().to_str().ok_or(Error::PathMustBeUtf8)?; self.0 @@ -652,7 +651,7 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] + //#[cfg(target_os = "windows")] pub async fn clear_split_tunnel_apps(&mut self) -> Result<()> { self.0 .clear_split_tunnel_apps(()) @@ -661,7 +660,7 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "windows")] + //#[cfg(target_os = "windows")] pub async fn set_split_tunnel_state(&mut self, state: bool) -> Result<()> { self.0 .set_split_tunnel_state(state) diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs index 522af854863b..cdb96a5dd2c1 100644 --- a/test/test-manager/src/tests/split_tunnel.rs +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -5,52 +5,110 @@ use test_rpc::{ExecResult, ServiceClient}; use super::{helpers, TestContext}; -const AM_I_MULLVAD: &str = "https://am.i.mullvad.net/connected"; - -#[test_function] -#[cfg(any(target_os = "linux", target_os = "windows"))] -pub async fn test_split_tunnel( +#[test_function(target_os = "windows")] +pub async fn test_split_tunnel_windows( _: TestContext, rpc: ServiceClient, mut mullvad_client: MullvadProxyClient, ) -> anyhow::Result<()> { + const AM_I_MULLVAD_EXE: &str = "E:\\am-i-mullvad.exe"; + + async fn am_i_mullvad(rpc: &ServiceClient) -> anyhow::Result { + parse_am_i_mullvad(rpc.exec(AM_I_MULLVAD_EXE, []).await?) + } + let mut errored = false; - let parse_am_i_mullvad = |result: ExecResult| { - let stdout = str::from_utf8(&result.stdout).expect("curl output is UTF-8"); - Ok(if stdout.contains("You are connected") { - true - } else if stdout.contains("You are not connected") { - false + helpers::disconnect_and_wait(&mut mullvad_client).await?; + + if am_i_mullvad(&rpc).await? { + log::error!("We should be disconnected, but `{AM_I_MULLVAD_EXE}` reported that it was connected to Mullvad."); + log::error!("Host machine is probably connected to Mullvad, this will throw off results"); + errored = true + } + + helpers::connect_and_wait(&mut mullvad_client).await?; + + if !am_i_mullvad(&rpc).await? { + log::error!( + "We should be connected, but `{AM_I_MULLVAD_EXE}` reported no connection to Mullvad." + ); + errored = true + } + + mullvad_client + .add_split_tunnel_app(AM_I_MULLVAD_EXE) + .await?; + mullvad_client.set_split_tunnel_state(true).await?; + + if am_i_mullvad(&rpc).await? { + log::error!( + "`{AM_I_MULLVAD_EXE}` should have been split, but it reported a connection to Mullvad" + ); + errored = true + } + + helpers::disconnect_and_wait(&mut mullvad_client).await?; + + if am_i_mullvad(&rpc).await? { + log::error!( + "`{AM_I_MULLVAD_EXE}` reported a connection to Mullvad while split and disconnected" + ); + errored = true + } + + mullvad_client.set_split_tunnel_state(false).await?; + mullvad_client + .remove_split_tunnel_app(AM_I_MULLVAD_EXE) + .await?; + + if errored { + anyhow::bail!("test_split_tunnel failed, see log output for details."); + } + + Ok(()) +} + +#[test_function(target_os = "linux")] +pub async fn test_split_tunnel_linux( + _: TestContext, + rpc: ServiceClient, + mut mullvad_client: MullvadProxyClient, +) -> anyhow::Result<()> { + const AM_I_MULLVAD_URL: &str = "https://am.i.mullvad.net/connected"; + + async fn am_i_mullvad(rpc: &ServiceClient, split_tunnel: bool) -> anyhow::Result { + let result = if split_tunnel { + rpc.exec("mullvad-exclude", ["curl", AM_I_MULLVAD_URL]) + .await? } else { - anyhow::bail!("Unexpected output from `curl {AM_I_MULLVAD}`: {stdout}") - }) - }; + rpc.exec("curl", [AM_I_MULLVAD_URL]).await? + }; + + parse_am_i_mullvad(result) + } + + let mut errored = false; helpers::connect_and_wait(&mut mullvad_client).await?; - let i_am_mullvad = parse_am_i_mullvad(rpc.exec("curl", [AM_I_MULLVAD]).await?)?; - if !i_am_mullvad { + if !am_i_mullvad(&rpc, false).await? { log::error!("We should be connected, but `am.i.mullvad` reported that it was not connected to Mullvad."); errored = true; } - let i_am_mullvad_while_split = - parse_am_i_mullvad(rpc.exec("mullvad-exclude", ["curl", AM_I_MULLVAD]).await?)?; - if i_am_mullvad_while_split { + if am_i_mullvad(&rpc, true).await? { log::error!( - "`mullvad-exclude curl {AM_I_MULLVAD}` reported that it was connected to Mullvad." + "`mullvad-exclude curl {AM_I_MULLVAD_URL}` reported that it was connected to Mullvad." ); - log::error!("`am-i-mullvad` does not appear to have been split correctly."); + log::error!("`curl` does not appear to have been split correctly."); errored = true; } helpers::disconnect_and_wait(&mut mullvad_client).await?; - let i_am_mullvad_while_disconnected = - parse_am_i_mullvad(rpc.exec("curl", [AM_I_MULLVAD]).await?)?; - if i_am_mullvad_while_disconnected { - log::error!("We should be disconnected, but `curl {AM_I_MULLVAD}` reported that it was connected to Mullvad."); + if am_i_mullvad(&rpc, false).await? { + log::error!("We should be disconnected, but `curl {AM_I_MULLVAD_URL}` reported that it was connected to Mullvad."); log::error!("Host machine is probably connected to Mullvad. This may affect test results."); errored = true; } @@ -61,3 +119,16 @@ pub async fn test_split_tunnel( Ok(()) } + +/// Parse output from am-i-mullvad. Returns true if connected to Mullvad. +fn parse_am_i_mullvad(result: ExecResult) -> anyhow::Result { + let stdout = str::from_utf8(&result.stdout).expect("curl output is UTF-8"); + + Ok(if stdout.contains("You are connected") { + true + } else if stdout.contains("You are not connected") { + false + } else { + anyhow::bail!("Unexpected output from am-i-mullvad: {stdout:?}") + }) +} From e5487658fb6065f5f41acdc7970235b64641b6bd Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Wed, 6 Mar 2024 15:48:19 +0100 Subject: [PATCH 09/10] Group platform split tunnel tests under one test --- test/test-manager/src/tests/split_tunnel.rs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs index cdb96a5dd2c1..9902dec231dd 100644 --- a/test/test-manager/src/tests/split_tunnel.rs +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -1,11 +1,23 @@ use mullvad_management_interface::MullvadProxyClient; use std::str; use test_macro::test_function; -use test_rpc::{ExecResult, ServiceClient}; +use test_rpc::{meta::Os, ExecResult, ServiceClient}; -use super::{helpers, TestContext}; +use super::{config::TEST_CONFIG, helpers, TestContext}; + +#[test_function] +pub async fn test_split_tunnel( + ctx: TestContext, + rpc: ServiceClient, + mullvad_client: MullvadProxyClient, +) -> anyhow::Result<()> { + match TEST_CONFIG.os { + Os::Linux => test_split_tunnel_linux(ctx, rpc, mullvad_client).await, + Os::Windows => test_split_tunnel_windows(ctx, rpc, mullvad_client).await, + Os::Macos => todo!("MacOS"), + } +} -#[test_function(target_os = "windows")] pub async fn test_split_tunnel_windows( _: TestContext, rpc: ServiceClient, @@ -69,7 +81,6 @@ pub async fn test_split_tunnel_windows( Ok(()) } -#[test_function(target_os = "linux")] pub async fn test_split_tunnel_linux( _: TestContext, rpc: ServiceClient, From 6594c6de52763ab313d99edecf9231596a003e1f Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Tue, 12 Mar 2024 11:21:19 +0100 Subject: [PATCH 10/10] Test leaking TCP/UDP/ICMP packets in split tunnel --- mullvad-management-interface/src/client.rs | 8 - test/Cargo.lock | 35 +- test/Cargo.toml | 2 +- test/am-i-mullvad/src/main.rs | 33 -- test/build.sh | 4 +- .../Cargo.toml | 5 +- test/connection-checker/src/cli.rs | 36 ++ test/connection-checker/src/lib.rs | 2 + test/connection-checker/src/main.rs | 73 ++++ test/connection-checker/src/net.rs | 78 ++++ test/scripts/build-runner-image.sh | 2 +- test/scripts/ssh-setup.sh | 2 +- test/test-manager/src/tests/split_tunnel.rs | 402 +++++++++++++----- test/test-manager/src/tests/test_metadata.rs | 6 +- test/test-manager/src/vm/provision.rs | 5 + test/test-manager/test_macro/src/lib.rs | 36 +- test/test-rpc/src/client.rs | 22 + test/test-rpc/src/lib.rs | 47 ++ test/test-runner/Cargo.toml | 2 +- test/test-runner/src/main.rs | 222 +++++++++- test/test-runner/src/util.rs | 23 + 21 files changed, 864 insertions(+), 181 deletions(-) delete mode 100644 test/am-i-mullvad/src/main.rs rename test/{am-i-mullvad => connection-checker}/Cargo.toml (75%) create mode 100644 test/connection-checker/src/cli.rs create mode 100644 test/connection-checker/src/lib.rs create mode 100644 test/connection-checker/src/main.rs create mode 100644 test/connection-checker/src/net.rs create mode 100644 test/test-runner/src/util.rs diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index 0020ca696989..847150c47731 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -591,7 +591,6 @@ impl MullvadProxyClient { .map(drop) } - #[cfg(target_os = "linux")] pub async fn get_split_tunnel_processes(&mut self) -> Result> { use futures::TryStreamExt; @@ -604,7 +603,6 @@ impl MullvadProxyClient { procs.try_collect().await.map_err(Error::Rpc) } - #[cfg(target_os = "linux")] pub async fn add_split_tunnel_process(&mut self, pid: i32) -> Result<()> { self.0 .add_split_tunnel_process(pid) @@ -613,7 +611,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "linux")] pub async fn remove_split_tunnel_process(&mut self, pid: i32) -> Result<()> { self.0 .remove_split_tunnel_process(pid) @@ -622,7 +619,6 @@ impl MullvadProxyClient { Ok(()) } - #[cfg(target_os = "linux")] pub async fn clear_split_tunnel_processes(&mut self) -> Result<()> { self.0 .clear_split_tunnel_processes(()) @@ -631,7 +627,6 @@ impl MullvadProxyClient { Ok(()) } - //#[cfg(target_os = "windows")] pub async fn add_split_tunnel_app>(&mut self, path: P) -> Result<()> { let path = path.as_ref().to_str().ok_or(Error::PathMustBeUtf8)?; self.0 @@ -641,7 +636,6 @@ impl MullvadProxyClient { Ok(()) } - //#[cfg(target_os = "windows")] pub async fn remove_split_tunnel_app>(&mut self, path: P) -> Result<()> { let path = path.as_ref().to_str().ok_or(Error::PathMustBeUtf8)?; self.0 @@ -651,7 +645,6 @@ impl MullvadProxyClient { Ok(()) } - //#[cfg(target_os = "windows")] pub async fn clear_split_tunnel_apps(&mut self) -> Result<()> { self.0 .clear_split_tunnel_apps(()) @@ -660,7 +653,6 @@ impl MullvadProxyClient { Ok(()) } - //#[cfg(target_os = "windows")] pub async fn set_split_tunnel_state(&mut self, state: bool) -> Result<()> { self.0 .set_split_tunnel_state(state) diff --git a/test/Cargo.lock b/test/Cargo.lock index 09ec93e782c4..5a7771fb4701 100644 --- a/test/Cargo.lock +++ b/test/Cargo.lock @@ -61,16 +61,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "am-i-mullvad" -version = "0.0.0" -dependencies = [ - "color-eyre", - "eyre", - "reqwest", - "serde", -] - [[package]] name = "android-tzdata" version = "0.1.1" @@ -527,6 +517,19 @@ dependencies = [ "memchr", ] +[[package]] +name = "connection-checker" +version = "0.0.0" +dependencies = [ + "clap", + "color-eyre", + "eyre", + "ping", + "reqwest", + "serde", + "socket2 0.5.4", +] + [[package]] name = "const-oid" version = "0.9.5" @@ -2157,6 +2160,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ping" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122ee1f5a6843bec84fcbd5c6ba3622115337a6b8965b93a61aad347648f4e8d" +dependencies = [ + "rand 0.8.5", + "socket2 0.4.9", + "thiserror", +] + [[package]] name = "pkcs8" version = "0.9.0" @@ -3172,6 +3186,7 @@ dependencies = [ "base64 0.13.1", "ipnetwork 0.16.0", "jnix", + "log", "serde", "thiserror", "x25519-dalek", diff --git a/test/Cargo.toml b/test/Cargo.toml index 4c23a55eb047..0fd7b4a2e97e 100644 --- a/test/Cargo.toml +++ b/test/Cargo.toml @@ -12,7 +12,7 @@ members = [ "test-runner", "test-rpc", "socks-server", - "am-i-mullvad", + "connection-checker", ] [workspace.lints.rust] diff --git a/test/am-i-mullvad/src/main.rs b/test/am-i-mullvad/src/main.rs deleted file mode 100644 index c6cc272d308a..000000000000 --- a/test/am-i-mullvad/src/main.rs +++ /dev/null @@ -1,33 +0,0 @@ -use eyre::{eyre, Context}; -use reqwest::blocking::get; -use serde::Deserialize; -use std::process; - -#[derive(Debug, Deserialize)] -struct Response { - ip: String, - mullvad_exit_ip_hostname: Option, -} - -fn main() -> eyre::Result<()> { - color_eyre::install()?; - - let url = "https://am.i.mullvad.net/json"; - let response: Response = get(url) - .and_then(|r| r.json()) - .wrap_err_with(|| eyre!("Failed to GET {url}"))?; - - if let Some(server) = &response.mullvad_exit_ip_hostname { - println!( - "You are connected to Mullvad (server {}). Your IP address is {}", - server, response.ip - ); - Ok(()) - } else { - println!( - "You are not connected to Mullvad. Your IP address is {}", - response.ip - ); - process::exit(1) - } -} diff --git a/test/build.sh b/test/build.sh index 1f0099ccf47c..d3a3c174704e 100755 --- a/test/build.sh +++ b/test/build.sh @@ -17,11 +17,11 @@ if [[ $TARGET == x86_64-unknown-linux-gnu ]]; then -e CARGO_HOME=/root/.cargo/registry \ -e CARGO_TARGET_DIR=/src/test/target \ mullvadvpn-app-tests \ - /bin/bash -c "cd /src/test/; cargo build --bin test-runner --release --target ${TARGET}" + /bin/bash -c "cd /src/test/; cargo build --bin test-runner --bin connection-checker --release --target ${TARGET}" else cargo build \ --bin test-runner \ - --bin am-i-mullvad \ + --bin connection-checker \ --release --target "${TARGET}" fi diff --git a/test/am-i-mullvad/Cargo.toml b/test/connection-checker/Cargo.toml similarity index 75% rename from test/am-i-mullvad/Cargo.toml rename to test/connection-checker/Cargo.toml index c3bda1b1cbc7..d579510bd1e8 100644 --- a/test/am-i-mullvad/Cargo.toml +++ b/test/connection-checker/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "am-i-mullvad" +name = "connection-checker" description = "Simple cli for testing Mullvad VPN connections" authors.workspace = true repository.workspace = true @@ -11,7 +11,10 @@ rust-version.workspace = true workspace = true [dependencies] +clap = { workspace = true, features = ["derive"] } color-eyre = "0.6.2" eyre = "0.6.12" +ping = "0.5.2" reqwest = { version = "0.11.24", default-features = false, features = ["blocking", "rustls-tls", "json"] } serde = { version = "1.0.197", features = ["derive"] } +socket2 = { version = "0.5.4", features = ["all"] } diff --git a/test/connection-checker/src/cli.rs b/test/connection-checker/src/cli.rs new file mode 100644 index 000000000000..dddb348b255c --- /dev/null +++ b/test/connection-checker/src/cli.rs @@ -0,0 +1,36 @@ +use std::net::SocketAddr; + +use clap::Parser; + +/// CLI tool that queries to check if the machine is connected to +/// Mullvad VPN. +#[derive(Parser)] +pub struct Opt { + /// Interactive mode, press enter to check if you are Mullvad. + #[clap(short, long)] + pub interactive: bool, + + /// Timeout for network connection to am.i.mullvad (in millis). + #[clap(short, long, default_value = "3000")] + pub timeout: u64, + + /// Try to send some junk data over TCP to . + #[clap(long, requires = "leak")] + pub leak_tcp: bool, + + /// Try to send some junk data over UDP to . + #[clap(long, requires = "leak")] + pub leak_udp: bool, + + /// Try to send ICMP request to . + #[clap(long, requires = "leak")] + pub leak_icmp: bool, + + /// Target of , or . + #[clap(long)] + pub leak: Option, + + /// Timeout for leak check network connections (in millis). + #[clap(long, default_value = "1000")] + pub leak_timeout: u64, +} diff --git a/test/connection-checker/src/lib.rs b/test/connection-checker/src/lib.rs new file mode 100644 index 000000000000..cb36c236b0be --- /dev/null +++ b/test/connection-checker/src/lib.rs @@ -0,0 +1,2 @@ +pub mod cli; +pub mod net; diff --git a/test/connection-checker/src/main.rs b/test/connection-checker/src/main.rs new file mode 100644 index 000000000000..ed48999970ce --- /dev/null +++ b/test/connection-checker/src/main.rs @@ -0,0 +1,73 @@ +use clap::Parser; +use eyre::{eyre, Context}; +use reqwest::blocking::Client; +use serde::Deserialize; +use std::{io::stdin, time::Duration}; + +use connection_checker::cli::Opt; +use connection_checker::net::{send_ping, send_tcp, send_udp}; + +fn main() -> eyre::Result<()> { + let opt = Opt::parse(); + color_eyre::install()?; + + if opt.interactive { + let stdin = stdin(); + for line in stdin.lines() { + let _ = line.wrap_err("Failed to read from stdin")?; + test_connection(&opt)?; + } + } else { + test_connection(&opt)?; + } + + Ok(()) +} + +fn test_connection(opt: &Opt) -> eyre::Result { + if let Some(destination) = opt.leak { + if opt.leak_tcp { + let _ = send_tcp(opt, destination); + } + if opt.leak_udp { + let _ = send_udp(opt, destination); + } + if opt.leak_icmp { + let _ = send_ping(opt, destination.ip()); + } + } + am_i_mullvad(opt) +} + +/// Check if connected to Mullvad and print the result to stdout +fn am_i_mullvad(opt: &Opt) -> eyre::Result { + #[derive(Debug, Deserialize)] + struct Response { + ip: String, + mullvad_exit_ip_hostname: Option, + } + + let url = "https://am.i.mullvad.net/json"; + + let client = Client::new(); + let response: Response = client + .get(url) + .timeout(Duration::from_millis(opt.timeout)) + .send() + .and_then(|r| r.json()) + .wrap_err_with(|| eyre!("Failed to GET {url}"))?; + + if let Some(server) = &response.mullvad_exit_ip_hostname { + println!( + "You are connected to Mullvad (server {}). Your IP address is {}", + server, response.ip + ); + Ok(true) + } else { + println!( + "You are not connected to Mullvad. Your IP address is {}", + response.ip + ); + Ok(false) + } +} diff --git a/test/connection-checker/src/net.rs b/test/connection-checker/src/net.rs new file mode 100644 index 000000000000..6634be41b0c8 --- /dev/null +++ b/test/connection-checker/src/net.rs @@ -0,0 +1,78 @@ +use eyre::{eyre, Context}; +use std::{ + io::Write, + net::{IpAddr, Ipv4Addr, SocketAddr}, + time::Duration, +}; + +use crate::cli::Opt; + +pub fn send_tcp(opt: &Opt, destination: SocketAddr) -> eyre::Result<()> { + let bind_addr: SocketAddr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); + + let family = match &destination { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }; + let sock = socket2::Socket::new(family, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) + .wrap_err(eyre!("Failed to create TCP socket"))?; + + eprintln!("Leaking TCP packets to {destination}"); + + sock.bind(&socket2::SockAddr::from(bind_addr)) + .wrap_err(eyre!("Failed to bind TCP socket to {bind_addr}"))?; + + let timeout = Duration::from_millis(opt.leak_timeout); + sock.set_write_timeout(Some(timeout))?; + sock.set_read_timeout(Some(timeout))?; + + sock.connect_timeout(&socket2::SockAddr::from(destination), timeout) + .wrap_err(eyre!("Failed to connect to {destination}"))?; + + let mut stream = std::net::TcpStream::from(sock); + stream + .write_all(b"hello there") + .wrap_err(eyre!("Failed to send message to {destination}"))?; + + Ok(()) +} + +pub fn send_udp(_opt: &Opt, destination: SocketAddr) -> Result<(), eyre::Error> { + let bind_addr: SocketAddr = SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0); + + eprintln!("Leaking UDP packets to {destination}"); + + let family = match &destination { + SocketAddr::V4(_) => socket2::Domain::IPV4, + SocketAddr::V6(_) => socket2::Domain::IPV6, + }; + let sock = socket2::Socket::new(family, socket2::Type::DGRAM, Some(socket2::Protocol::UDP)) + .wrap_err("Failed to create UDP socket")?; + + sock.bind(&socket2::SockAddr::from(bind_addr)) + .wrap_err(eyre!("Failed to bind UDP socket to {bind_addr}"))?; + + //log::debug!("Send message from {bind_addr} to {destination}/UDP"); + + let std_socket = std::net::UdpSocket::from(sock); + std_socket + .send_to(b"Hello there!", destination) + .wrap_err(eyre!("Failed to send message to {destination}"))?; + + Ok(()) +} + +pub fn send_ping(opt: &Opt, destination: IpAddr) -> eyre::Result<()> { + eprintln!("Leaking IMCP packets to {destination}"); + + ping::ping( + destination, + Some(Duration::from_millis(opt.leak_timeout)), + None, + None, + None, + None, + )?; + + Ok(()) +} diff --git a/test/scripts/build-runner-image.sh b/test/scripts/build-runner-image.sh index be0d6373234a..30252d844510 100755 --- a/test/scripts/build-runner-image.sh +++ b/test/scripts/build-runner-image.sh @@ -33,7 +33,7 @@ case $TARGET in mcopy \ -i "${TEST_RUNNER_IMAGE_PATH}" \ "${SCRIPT_DIR}/../target/$TARGET/release/test-runner.exe" \ - "${SCRIPT_DIR}/../target/$TARGET/release/am-i-mullvad.exe" \ + "${SCRIPT_DIR}/../target/$TARGET/release/connection-checker.exe" \ "${PACKAGES_DIR}/"*.exe \ "${SCRIPT_DIR}/../openvpn.ca.crt" \ "::" diff --git a/test/scripts/ssh-setup.sh b/test/scripts/ssh-setup.sh index a3809e023036..b3d358f5a013 100644 --- a/test/scripts/ssh-setup.sh +++ b/test/scripts/ssh-setup.sh @@ -16,7 +16,7 @@ echo "Copying test-runner to $RUNNER_DIR" mkdir -p "$RUNNER_DIR" -for file in test-runner $CURRENT_APP $PREVIOUS_APP $UI_RUNNER openvpn.ca.crt; do +for file in test-runner connection-checker $CURRENT_APP $PREVIOUS_APP $UI_RUNNER openvpn.ca.crt; do echo "Moving $file to $RUNNER_DIR" cp -f "$SCRIPT_DIR/$file" "$RUNNER_DIR" done diff --git a/test/test-manager/src/tests/split_tunnel.rs b/test/test-manager/src/tests/split_tunnel.rs index 9902dec231dd..336ee5b5ab33 100644 --- a/test/test-manager/src/tests/split_tunnel.rs +++ b/test/test-manager/src/tests/split_tunnel.rs @@ -1,145 +1,357 @@ +use anyhow::{anyhow, bail, ensure, Context}; use mullvad_management_interface::MullvadProxyClient; -use std::str; +use pcap::Direction; +use pnet_packet::ip::IpNextHeaderProtocols; +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + str, + time::Duration, +}; use test_macro::test_function; -use test_rpc::{meta::Os, ExecResult, ServiceClient}; +use test_rpc::{meta::Os, ServiceClient, SpawnOpts}; +use tokio::time::{sleep, timeout}; + +use crate::network_monitor::{start_packet_monitor, MonitorOptions}; use super::{config::TEST_CONFIG, helpers, TestContext}; -#[test_function] +const CHECKER_FILENAME_WINDOWS: &str = "connection-checker.exe"; +const CHECKER_FILENAME_UNIX: &str = "connection-checker"; +const LEAK_DESTINATION: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 1337); + +/// Test that split tunneling works by asserting the following: +/// - Splitting a process shouldn't do anything if tunnel is not connected. +/// - A split process should never push traffic through the tunnel. +/// - Splitting/unsplitting should work regardless if process is running. +#[test_function(target_os = "linux", target_os = "windows")] pub async fn test_split_tunnel( - ctx: TestContext, + _ctx: TestContext, rpc: ServiceClient, - mullvad_client: MullvadProxyClient, + mut mullvad_client: MullvadProxyClient, ) -> anyhow::Result<()> { - match TEST_CONFIG.os { - Os::Linux => test_split_tunnel_linux(ctx, rpc, mullvad_client).await, - Os::Windows => test_split_tunnel_windows(ctx, rpc, mullvad_client).await, - Os::Macos => todo!("MacOS"), - } + let mut checker = ConnChecker::new(rpc.clone(), mullvad_client.clone()); + + // Test that program is behaving when we are disconnected + (checker.spawn().await?.assert_insecure().await) + .with_context(|| "Test disconnected and unsplit")?; + checker.split().await?; + (checker.spawn().await?.assert_insecure().await) + .with_context(|| "Test disconnected and split")?; + checker.unsplit().await?; + + // Test that program is behaving being split/unsplit while running and we are disconnected + let mut handle = checker.spawn().await?; + handle.split().await?; + (handle.assert_insecure().await) + .with_context(|| "Test disconnected and being split while running")?; + handle.unsplit().await?; + (handle.assert_insecure().await) + .with_context(|| "Test disconnected and being unsplit while running")?; + drop(handle); + + helpers::connect_and_wait(&mut mullvad_client).await?; + + // Test running an unsplit program + checker + .spawn() + .await? + .assert_secure() + .await + .with_context(|| "Test connected and unsplit")?; + + // Test running a split program + checker.split().await?; + checker + .spawn() + .await? + .assert_insecure() + .await + .with_context(|| "Test connected and split")?; + + checker.unsplit().await?; + + // Test splitting and unsplitting a program while it's running + let mut handle = checker.spawn().await?; + (handle.assert_secure().await).with_context(|| "Test connected and unsplit (again)")?; + handle.split().await?; + (handle.assert_insecure().await) + .with_context(|| "Test connected and being split while running")?; + handle.unsplit().await?; + (handle.assert_secure().await) + .with_context(|| "Test connected and being unsplit while running")?; + + Ok(()) } -pub async fn test_split_tunnel_windows( - _: TestContext, +/// This helper spawns a seperate process which checks if we are connected to Mullvad, and tries to +/// leak traffic outside the tunnel by sending TCP, UDP, and ICMP packets to [LEAK_DESTINATION]. +struct ConnChecker { rpc: ServiceClient, - mut mullvad_client: MullvadProxyClient, -) -> anyhow::Result<()> { - const AM_I_MULLVAD_EXE: &str = "E:\\am-i-mullvad.exe"; + mullvad_client: MullvadProxyClient, - async fn am_i_mullvad(rpc: &ServiceClient) -> anyhow::Result { - parse_am_i_mullvad(rpc.exec(AM_I_MULLVAD_EXE, []).await?) - } + /// Path to the process binary. + executable_path: String, - let mut errored = false; + /// Whether the process should be split when spawned. Needed on Linux. + split: bool, +} - helpers::disconnect_and_wait(&mut mullvad_client).await?; +struct ConnCheckerHandle<'a> { + checker: &'a mut ConnChecker, - if am_i_mullvad(&rpc).await? { - log::error!("We should be disconnected, but `{AM_I_MULLVAD_EXE}` reported that it was connected to Mullvad."); - log::error!("Host machine is probably connected to Mullvad, this will throw off results"); - errored = true - } + /// ID of the spawned process. + pid: u32, +} - helpers::connect_and_wait(&mut mullvad_client).await?; +struct ConnectionStatus { + /// True if reported we are connected. + am_i_mullvad: bool, - if !am_i_mullvad(&rpc).await? { - log::error!( - "We should be connected, but `{AM_I_MULLVAD_EXE}` reported no connection to Mullvad." - ); - errored = true - } + /// True if we sniffed TCP packets going outside the tunnel. + leaked_tcp: bool, - mullvad_client - .add_split_tunnel_app(AM_I_MULLVAD_EXE) - .await?; - mullvad_client.set_split_tunnel_state(true).await?; + /// True if we sniffed UDP packets going outside the tunnel. + leaked_udp: bool, - if am_i_mullvad(&rpc).await? { - log::error!( - "`{AM_I_MULLVAD_EXE}` should have been split, but it reported a connection to Mullvad" - ); - errored = true + /// True if we sniffed ICMP packets going outside the tunnel. + leaked_icmp: bool, +} + +impl ConnChecker { + pub fn new(rpc: ServiceClient, mullvad_client: MullvadProxyClient) -> Self { + let artifacts_dir = &TEST_CONFIG.artifacts_dir; + let executable_path = match TEST_CONFIG.os { + Os::Linux | Os::Macos => format!("{artifacts_dir}/{CHECKER_FILENAME_UNIX}"), + Os::Windows => format!("{artifacts_dir}\\{CHECKER_FILENAME_WINDOWS}"), + }; + + Self { + rpc, + mullvad_client, + split: false, + executable_path, + } } - helpers::disconnect_and_wait(&mut mullvad_client).await?; + /// Spawn the connecton checker process and return a handle to it. + /// + /// Dropping the handle will stop the process. + /// **NOTE**: The handle must be dropped from a tokio runtime context. + pub async fn spawn(&mut self) -> anyhow::Result> { + log::debug!("spawning connection checker"); + + let opts = SpawnOpts { + attach_stdin: true, + attach_stdout: true, + args: [ + "--interactive", + "--timeout", + "10000", + // try to leak traffic to LEAK_DESTINATION + "--leak", + &LEAK_DESTINATION.to_string(), + "--leak-timeout", + "500", + "--leak-tcp", + "--leak-udp", + "--leak-icmp", + ] + .map(String::from) + .to_vec(), + ..SpawnOpts::new(&self.executable_path) + }; + + let pid = self.rpc.spawn(opts).await?; - if am_i_mullvad(&rpc).await? { - log::error!( - "`{AM_I_MULLVAD_EXE}` reported a connection to Mullvad while split and disconnected" - ); - errored = true + if self.split && TEST_CONFIG.os == Os::Linux { + self.mullvad_client + .add_split_tunnel_process(pid as i32) + .await?; + } + + Ok(ConnCheckerHandle { pid, checker: self }) } - mullvad_client.set_split_tunnel_state(false).await?; - mullvad_client - .remove_split_tunnel_app(AM_I_MULLVAD_EXE) - .await?; + /// Enable split tunneling for the connection checker. + pub async fn split(&mut self) -> anyhow::Result<()> { + log::debug!("enable split tunnel"); + self.split = true; + + match TEST_CONFIG.os { + Os::Linux => { /* linux programs can't be split until they are spawned */ } + Os::Windows => { + self.mullvad_client + .add_split_tunnel_app(&self.executable_path) + .await?; + self.mullvad_client.set_split_tunnel_state(true).await?; + } + Os::Macos => unimplemented!("MacOS"), + } - if errored { - anyhow::bail!("test_split_tunnel failed, see log output for details."); + Ok(()) } - Ok(()) -} + /// Disable split tunneling for the connection checker. + pub async fn unsplit(&mut self) -> anyhow::Result<()> { + log::debug!("disable split tunnel"); + self.split = false; -pub async fn test_split_tunnel_linux( - _: TestContext, - rpc: ServiceClient, - mut mullvad_client: MullvadProxyClient, -) -> anyhow::Result<()> { - const AM_I_MULLVAD_URL: &str = "https://am.i.mullvad.net/connected"; - - async fn am_i_mullvad(rpc: &ServiceClient, split_tunnel: bool) -> anyhow::Result { - let result = if split_tunnel { - rpc.exec("mullvad-exclude", ["curl", AM_I_MULLVAD_URL]) - .await? - } else { - rpc.exec("curl", [AM_I_MULLVAD_URL]).await? - }; + match TEST_CONFIG.os { + Os::Linux => {} + Os::Windows => { + self.mullvad_client.set_split_tunnel_state(false).await?; + self.mullvad_client + .remove_split_tunnel_app(&self.executable_path) + .await?; + } + Os::Macos => unimplemented!("MacOS"), + } - parse_am_i_mullvad(result) + Ok(()) } +} - let mut errored = false; +impl ConnCheckerHandle<'_> { + pub async fn split(&mut self) -> anyhow::Result<()> { + if TEST_CONFIG.os == Os::Linux { + self.checker + .mullvad_client + .add_split_tunnel_process(self.pid as i32) + .await?; + } - helpers::connect_and_wait(&mut mullvad_client).await?; + self.checker.split().await + } + + pub async fn unsplit(&mut self) -> anyhow::Result<()> { + if TEST_CONFIG.os == Os::Linux { + self.checker + .mullvad_client + .remove_split_tunnel_process(self.pid as i32) + .await?; + } - if !am_i_mullvad(&rpc, false).await? { - log::error!("We should be connected, but `am.i.mullvad` reported that it was not connected to Mullvad."); - errored = true; + self.checker.unsplit().await } - if am_i_mullvad(&rpc, true).await? { - log::error!( - "`mullvad-exclude curl {AM_I_MULLVAD_URL}` reported that it was connected to Mullvad." - ); - log::error!("`curl` does not appear to have been split correctly."); - errored = true; + /// Assert that traffic is flowing through the Mullvad tunnel and that no packets are leaked. + pub async fn assert_secure(&mut self) -> anyhow::Result<()> { + log::info!("checking that connection is secure"); + let status = self.check_connection().await?; + ensure!(status.am_i_mullvad); + ensure!(!status.leaked_tcp); + ensure!(!status.leaked_udp); + ensure!(!status.leaked_icmp); + + Ok(()) } - helpers::disconnect_and_wait(&mut mullvad_client).await?; + /// Assert that traffic is NOT flowing through the Mullvad tunnel and that packets ARE leaked. + pub async fn assert_insecure(&mut self) -> anyhow::Result<()> { + log::info!("checking that connection is not secure"); + let status = self.check_connection().await?; + ensure!(!status.am_i_mullvad); + ensure!(status.leaked_tcp); + ensure!(status.leaked_udp); + ensure!(status.leaked_icmp); - if am_i_mullvad(&rpc, false).await? { - log::error!("We should be disconnected, but `curl {AM_I_MULLVAD_URL}` reported that it was connected to Mullvad."); - log::error!("Host machine is probably connected to Mullvad. This may affect test results."); - errored = true; + Ok(()) } - if errored { - anyhow::bail!("test_split_tunnel failed, see log output for details."); + async fn check_connection(&mut self) -> anyhow::Result { + // Monitor all pakets going to LEAK_DESTINATION during the check. + let monitor = start_packet_monitor( + |packet| packet.destination.ip() == LEAK_DESTINATION.ip(), + MonitorOptions { + direction: Some(Direction::In), + ..MonitorOptions::default() + }, + ) + .await; + + // Write a newline to the connection checker to prompt it to perform the check. + self.checker + .rpc + .write_child_stdin(self.pid, "Say the line, Bart!\r\n".into()) + .await?; + + // The checker responds when the check is complete. + let line = self.read_stdout_line().await?; + + let monitor_result = monitor + .into_result() + .await + .map_err(|_e| anyhow!("Packet monitor unexpectedly stopped"))?; + + Ok(ConnectionStatus { + am_i_mullvad: parse_am_i_mullvad(line)?, + + leaked_tcp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Tcp), + + leaked_udp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Udp), + + leaked_icmp: (monitor_result.packets.iter()) + .any(|pkt| pkt.protocol == IpNextHeaderProtocols::Icmp), + }) } - Ok(()) + /// Try to a single line of output from the spawned process + async fn read_stdout_line(&mut self) -> anyhow::Result { + // Add a timeout to avoid waiting forever. + timeout(Duration::from_secs(8), async { + let mut line = String::new(); + + // tarpc doesn't support streams, so we poll the checker process in a loop instead + loop { + let Some(output) = self.checker.rpc.read_child_stdout(self.pid).await? else { + bail!("got EOF from connection checker process"); + }; + + if output.is_empty() { + sleep(Duration::from_millis(500)).await; + continue; + } + + line.push_str(&output); + + if line.contains('\n') { + log::info!("output from child process: {output:?}"); + return Ok(line); + } + } + }) + .await + .with_context(|| "Timeout reading stdout from connection checker")? + } } -/// Parse output from am-i-mullvad. Returns true if connected to Mullvad. -fn parse_am_i_mullvad(result: ExecResult) -> anyhow::Result { - let stdout = str::from_utf8(&result.stdout).expect("curl output is UTF-8"); +impl Drop for ConnCheckerHandle<'_> { + fn drop(&mut self) { + let rpc = self.checker.rpc.clone(); + let pid = self.pid; + + let Ok(runtime_handle) = tokio::runtime::Handle::try_current() else { + log::error!("ConnCheckerHandle dropped outside of a tokio runtime."); + return; + }; + + runtime_handle.spawn(async move { + // Make sure child process is stopped when this handle is dropped. + // Closing stdin does the trick. + let _ = rpc.close_child_stdin(pid).await; + }); + } +} - Ok(if stdout.contains("You are connected") { +/// Parse output from connection-checker. Returns true if connected to Mullvad. +fn parse_am_i_mullvad(result: String) -> anyhow::Result { + Ok(if result.contains("You are connected") { true - } else if stdout.contains("You are not connected") { + } else if result.contains("You are not connected") { false } else { - anyhow::bail!("Unexpected output from am-i-mullvad: {stdout:?}") + bail!("Unexpected output from connection-checker: {result:?}") }) } diff --git a/test/test-manager/src/tests/test_metadata.rs b/test/test-manager/src/tests/test_metadata.rs index 3e28a4380b6a..d4ffa9bfd029 100644 --- a/test/test-manager/src/tests/test_metadata.rs +++ b/test/test-manager/src/tests/test_metadata.rs @@ -5,7 +5,7 @@ use test_rpc::mullvad_daemon::MullvadClientVersion; pub struct TestMetadata { pub name: &'static str, pub command: &'static str, - pub target_os: Option, + pub targets: &'static [Os], pub mullvad_client_version: MullvadClientVersion, pub func: TestWrapperFunction, pub priority: Option, @@ -16,9 +16,7 @@ pub struct TestMetadata { impl TestMetadata { pub fn should_run_on_os(&self, os: Os) -> bool { - self.target_os - .map(|target_os| target_os == os) - .unwrap_or(true) + self.targets.is_empty() || self.targets.contains(&os) } } diff --git a/test/test-manager/src/vm/provision.rs b/test/test-manager/src/vm/provision.rs index 5f01e8f192b9..8667b6c1338b 100644 --- a/test/test-manager/src/vm/provision.rs +++ b/test/test-manager/src/vm/provision.rs @@ -106,6 +106,11 @@ fn blocking_ssh( ssh_send_file_path(&session, &source, temp_dir) .context("Failed to send test runner to remote")?; + // Transfer connection-checker + let source = local_runner_dir.join("connection-checker"); + ssh_send_file_path(&session, &source, temp_dir) + .context("Failed to send connection-checker to remote")?; + // Transfer app packages ssh_send_file_path(&session, &local_app_manifest.current_app_path, temp_dir) .context("Failed to send current app package to remote")?; diff --git a/test/test-manager/test_macro/src/lib.rs b/test/test-manager/test_macro/src/lib.rs index fdf7e5539cc9..7cb8407230eb 100644 --- a/test/test-manager/test_macro/src/lib.rs +++ b/test/test-manager/test_macro/src/lib.rs @@ -121,7 +121,7 @@ fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> Result Result Some(os), + let target = match lit_str.value().parse() { + Ok(os) => os, Err(e) => bail!(lit_str, "{e}"), + }; + + if targets.contains(&target) { + bail!(nv, "Duplicate target"); } + + targets.push(target); } else { bail!(nv, "unknown attribute"); } @@ -173,7 +175,7 @@ fn get_test_macro_parameters(attributes: &syn::AttributeArgs) -> Result proc_macro2::TokenStream { Some(priority) => quote! { Some(#priority) }, None => quote! { None }, }; - let target_os = match test_function.macro_parameters.target_os { - Some(Os::Linux) => quote! { Some(::test_rpc::meta::Os::Linux) }, - Some(Os::Macos) => quote! { Some(::test_rpc::meta::Os::Macos) }, - Some(Os::Windows) => quote! { Some(::test_rpc::meta::Os::Windows) }, - None => quote! { None }, - }; + let targets: proc_macro2::TokenStream = (test_function.macro_parameters.targets.iter()) + .map(|&os| match os { + Os::Linux => quote! { ::test_rpc::meta::Os::Linux, }, + Os::Macos => quote! { ::test_rpc::meta::Os::Macos, }, + Os::Windows => quote! { ::test_rpc::meta::Os::Windows, }, + }) + .collect(); + let should_cleanup = test_function.macro_parameters.cleanup; let always_run = test_function.macro_parameters.always_run; let must_succeed = test_function.macro_parameters.must_succeed; @@ -230,7 +234,7 @@ fn create_test(test_function: TestFunction) -> proc_macro2::TokenStream { inventory::submit!(crate::tests::test_metadata::TestMetadata { name: stringify!(#func_name), command: stringify!(#func_name), - target_os: #target_os, + targets: &[#targets], mullvad_client_version: #function_mullvad_version, func: #wrapper_closure, priority: #test_function_priority, @@ -252,7 +256,7 @@ struct MacroParameters { cleanup: bool, always_run: bool, must_succeed: bool, - target_os: Option, + targets: Vec, } enum MullvadClient { diff --git a/test/test-rpc/src/client.rs b/test/test-rpc/src/client.rs index b4fb67f5c069..324669de3fc6 100644 --- a/test/test-rpc/src/client.rs +++ b/test/test-rpc/src/client.rs @@ -351,4 +351,26 @@ impl ServiceClient { .make_device_json_old(tarpc::context::current()) .await? } + + pub async fn spawn(&self, opts: SpawnOpts) -> Result { + self.client.spawn(tarpc::context::current(), opts).await? + } + + pub async fn read_child_stdout(&self, pid: u32) -> Result, Error> { + self.client + .read_child_stdout(tarpc::context::current(), pid) + .await? + } + + pub async fn write_child_stdin(&self, pid: u32, data: String) -> Result<(), Error> { + self.client + .write_child_stdin(tarpc::context::current(), pid, data) + .await? + } + + pub async fn close_child_stdin(&self, pid: u32) -> Result<(), Error> { + self.client + .close_child_stdin(tarpc::context::current(), pid) + .await? + } } diff --git a/test/test-rpc/src/lib.rs b/test/test-rpc/src/lib.rs index d1515206015f..e0088a67b50b 100644 --- a/test/test-rpc/src/lib.rs +++ b/test/test-rpc/src/lib.rs @@ -57,6 +57,10 @@ pub enum Error { Timeout, #[error("TCP forward error")] TcpForward, + #[error("Unknown process ID: {0}")] + UnknownPid(u32), + #[error("{0}")] + Other(String), } /// Response from am.i.mullvad.net @@ -80,6 +84,27 @@ impl ExecResult { } } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct SpawnOpts { + pub path: String, + pub args: Vec, + pub env: BTreeMap, + pub attach_stdin: bool, + pub attach_stdout: bool, +} + +impl SpawnOpts { + pub fn new(path: impl Into) -> SpawnOpts { + SpawnOpts { + path: path.into(), + args: Default::default(), + env: Default::default(), + attach_stdin: Default::default(), + attach_stdout: Default::default(), + } + } +} + #[derive(Debug, Serialize, Deserialize)] pub enum AppTrace { Path(PathBuf), @@ -197,6 +222,28 @@ mod service { async fn reboot() -> Result<(), Error>; async fn make_device_json_old() -> Result<(), Error>; + + /// Spawn a child process and return the PID. + async fn spawn(opts: SpawnOpts) -> Result; + + /// Read from stdout of a process spawned through [Service::spawn]. + /// + /// Process must have been spawned with `attach_stdout`. + /// Returns `None` if process stdout is closed. + async fn read_child_stdout(pid: u32) -> Result, Error>; + + /// Write to stdin of a process spawned through [Service::spawn]. + /// + /// Process must have been spawned with `attach_stdin`. + async fn write_child_stdin(pid: u32, data: String) -> Result<(), Error>; + + /// Close stdin of a process spawned through [Service::spawn]. + /// + /// Process must have been spawned with `attach_stdin`. + async fn close_child_stdin(pid: u32) -> Result<(), Error>; + + /// Kill a process spawned through [Service::spawn]. + async fn kill_child(pid: u32) -> Result<(), Error>; } } diff --git a/test/test-runner/Cargo.toml b/test/test-runner/Cargo.toml index 8e2ae8cbf687..50f3ddda6a91 100644 --- a/test/test-runner/Cargo.toml +++ b/test/test-runner/Cargo.toml @@ -33,7 +33,7 @@ test-rpc = { path = "../test-rpc" } mullvad-paths = { path = "../../mullvad-paths" } talpid-platform-metadata = { path = "../../talpid-platform-metadata" } -socket2 = { version = "0.5", features = ["all"] } +socket2 = { version = "0.5.4", features = ["all"] } [target."cfg(target_os=\"windows\")".dependencies] talpid-windows = { path = "../../talpid-windows" } diff --git a/test/test-runner/src/main.rs b/test/test-runner/src/main.rs index 3511d78cec55..d864968bbee5 100644 --- a/test/test-runner/src/main.rs +++ b/test/test-runner/src/main.rs @@ -1,10 +1,14 @@ -use futures::{pin_mut, SinkExt, StreamExt}; +use futures::{pin_mut, select, select_biased, FutureExt, SinkExt, StreamExt}; use logging::LOGGER; use std::{ collections::{BTreeMap, HashMap}, net::{IpAddr, SocketAddr}, path::{Path, PathBuf}, + process::Stdio, + sync::Arc, + time::Duration, }; +use util::OnDrop; use tarpc::{context, server::Channel}; use test_rpc::{ @@ -12,12 +16,14 @@ use test_rpc::{ net::SockHandleId, package::Package, transport::GrpcForwarder, - AppTrace, Service, + AppTrace, Service, SpawnOpts, }; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - process::Command, - sync::broadcast::error::TryRecvError, + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + process::{ChildStdin, ChildStdout, Command}, + sync::{broadcast::error::TryRecvError, oneshot, Mutex}, + task, + time::sleep, }; use tokio_util::codec::{Decoder, LengthDelimitedCodec}; @@ -27,9 +33,23 @@ mod logging; mod net; mod package; mod sys; +mod util; -#[derive(Clone)] -pub struct TestServer(pub ()); +#[derive(Clone, Default)] +pub struct TestServer(Arc>); + +#[derive(Default)] +struct State { + spawned_procs: HashMap, +} + +struct SpawnedProcess { + stdout: Option, + stdin: Option, + + #[allow(dead_code)] + abort_handle: OnDrop, +} #[tarpc::server] impl Service for TestServer { @@ -319,6 +339,192 @@ impl Service for TestServer { async fn make_device_json_old(self, _: context::Context) -> Result<(), test_rpc::Error> { app::make_device_json_old().await } + + async fn spawn(self, _: context::Context, opts: SpawnOpts) -> Result { + let mut cmd = Command::new(&opts.path); + cmd.args(&opts.args); + + // Make sure that PATH is updated + // TODO: We currently do not need this on non-Windows + #[cfg(target_os = "windows")] + cmd.env("PATH", sys::get_system_path_var()?); + + cmd.envs(opts.env); + + if opts.attach_stdin { + cmd.stdin(Stdio::piped()); + } else { + cmd.stdin(Stdio::null()); + } + + if opts.attach_stdout { + cmd.stdout(Stdio::piped()); + } + + cmd.stderr(Stdio::piped()); + + let mut child = cmd.kill_on_drop(true).spawn().map_err(|error| { + log::error!("Failed to spawn {}: {error}", opts.path); + test_rpc::Error::Syscall + })?; + + let pid = child + .id() + .expect("Child hasn't been polled to completion yet"); + + log::info!("spawned {} (args={:?}) (pid={pid})", opts.path, opts.args); + + let (abort_tx, abort_rx) = oneshot::channel(); + let abort_handle = || { + let _ = abort_tx.send(()); + }; + + let spawned_process = SpawnedProcess { + stdout: child.stdout.take(), + stdin: child.stdin.take(), + abort_handle: OnDrop::new(Box::new(abort_handle)), + }; + + let mut state = self.0.lock().await; + state.spawned_procs.insert(pid, spawned_process); + drop(state); + + // spawn a task to log child stdout + if let Some(stderr) = child.stderr.take() { + task::spawn(async move { + let mut stderr = BufReader::new(stderr); + let mut line = String::new(); + loop { + match stderr.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim_end_matches(&['\r', '\n']); + log::info!("child stderr (pid={pid}): {trimmed}"); + line.clear(); + } + Err(e) => { + log::error!("failed to read child stderr (pid={pid}): {e}"); + break; + } + } + } + }); + } + + // spawn a task to monitor if the child exits + task::spawn(async move { + select! { + result = child.wait().fuse() => match result { + Err(e) => { + log::error!("failed to await child process (pid={pid}): {e}"); + } + Ok(status) => { + log::info!("child process (pid={pid}) exited with status: {status}"); + } + }, + + _ = abort_rx.fuse() => { + if let Err(e) = child.kill().await { + log::error!("failed to kill child process (pid={pid}): {e}"); + } + } + } + + let mut state = self.0.lock().await; + state.spawned_procs.remove(&pid); + }); + + Ok(pid) + } + + async fn read_child_stdout( + self, + _: context::Context, + pid: u32, + ) -> Result, test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .get_mut(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + let Some(stdout) = child.stdout.as_mut() else { + return Ok(None); + }; + + let mut buf = vec![0u8; 512]; + + let n = select_biased! { + result = stdout.read(&mut buf).fuse() => result + .map_err(|e| format!("Failed to read from child stdout: {e}")) + .map_err(test_rpc::Error::Other)?, + + _ = sleep(Duration::from_millis(500)).fuse() => return Ok(Some(String::new())), + }; + + // check for EOF + if n == 0 { + child.stdout = None; + return Ok(None); + } + + buf.truncate(n); + let output = String::from_utf8(buf) + .map_err(|_| test_rpc::Error::Other("Child wrote non UTF-8 to stdout".into()))?; + + Ok(Some(output)) + } + + async fn write_child_stdin( + self, + _: context::Context, + pid: u32, + data: String, + ) -> Result<(), test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .get_mut(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + let Some(stdin) = child.stdin.as_mut() else { + return Err(test_rpc::Error::Other("Child stdin is closed.".into())); + }; + + stdin + .write_all(data.as_bytes()) + .await + .map_err(|e| format!("Error writing to child stdin: {e}")) + .map_err(test_rpc::Error::Other)?; + + log::debug!("wrote {} bytes to pid {pid}", data.len()); + + Ok(()) + } + + async fn close_child_stdin(self, _: context::Context, pid: u32) -> Result<(), test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .get_mut(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + child.stdin = None; + + Ok(()) + } + + async fn kill_child(self, _: context::Context, pid: u32) -> Result<(), test_rpc::Error> { + let mut state = self.0.lock().await; + let child = state + .spawned_procs + .remove(&pid) + .ok_or(test_rpc::Error::UnknownPid(pid))?; + + drop(child); // I swear officer, it's not what you think! + + Ok(()) + } } fn get_pipe_status() -> ServiceStatus { @@ -364,7 +570,7 @@ async fn main() -> Result<(), Error> { )); let server = tarpc::server::BaseChannel::with_defaults(runner_transport); - server.execute(TestServer(()).serve()).await; + server.execute(TestServer::default().serve()).await; log::error!("Restarting server since it stopped"); } diff --git a/test/test-runner/src/util.rs b/test/test-runner/src/util.rs new file mode 100644 index 000000000000..03a334321412 --- /dev/null +++ b/test/test-runner/src/util.rs @@ -0,0 +1,23 @@ +/// Drop guard that executes the provided callback function when dropped. +pub struct OnDrop> +where + F: FnOnce() + Send, +{ + callback: Option, +} + +impl Drop for OnDrop { + fn drop(&mut self) { + if let Some(callback) = self.callback.take() { + callback(); + } + } +} + +impl OnDrop { + pub fn new(callback: F) -> Self { + Self { + callback: Some(callback), + } + } +}