diff --git a/test/test-manager/src/tests/helpers.rs b/test/test-manager/src/tests/helpers.rs index d4924c1956af..58bf759c16cc 100644 --- a/test/test-manager/src/tests/helpers.rs +++ b/test/test-manager/src/tests/helpers.rs @@ -1199,8 +1199,13 @@ fn parse_am_i_mullvad(result: String) -> anyhow::Result { pub mod custom_lists { use super::*; - use mullvad_relay_selector::query::builder::RelayQueryBuilder; - use mullvad_types::custom_list::{CustomList, Id}; + use mullvad_relay_selector::query::{ + builder::RelayQueryBuilder, BridgeQuery, OpenVpnRelayQuery, WireguardRelayQuery, + }; + use mullvad_types::{ + custom_list::{CustomList, Id}, + relay_constraints::BridgeConstraints, + }; use std::sync::{LazyLock, Mutex}; // Expose all custom list variants as a shorthand. @@ -1317,27 +1322,34 @@ pub mod custom_lists { pub async fn set_default_location( mullvad_client: &mut MullvadProxyClient, ) -> anyhow::Result<()> { - let mut query = RelayQueryBuilder::new() - .location(DEFAULT_LIST) - .wireguard() - .multihop() - .entry(DEFAULT_LIST) - .build(); - - // The typestate query builder cannot express OpenVPN bridge locations while specifying - // wireguard options, like multihop. So we need to create a new query for bridge - // locations and insert the OpenVPN constraints into the existing query. - let openvpn_constraints = RelayQueryBuilder::new() - .openvpn() - .bridge() - .bridge_location(DEFAULT_LIST) - .build() - .into_openvpn_constraints(); - query.set_openvpn_constraints(openvpn_constraints)?; - - apply_settings_from_relay_query(mullvad_client, query).await?; - Ok(()) + let relay_query = get_custom_list_location_query(DEFAULT_LIST); + apply_settings_from_relay_query(mullvad_client, relay_query) + .await + .context("Failed to apply default custom list location") + } + + fn get_custom_list_location_query(custom_list: List) -> RelayQuery { + let mut query = RelayQueryBuilder::new().location(custom_list).build(); + let wireguard_constraints = WireguardRelayQuery { + entry_location: Constraint::Only(custom_list.into()), + ..Default::default() + }; + + let openvpn_constraints = OpenVpnRelayQuery { + bridge_settings: BridgeQuery::Normal(BridgeConstraints { + location: Constraint::Only(custom_list.into()), + ..Default::default() + }), + ..Default::default() + }; + + query.set_openvpn_constraints(openvpn_constraints).unwrap(); + query + .set_wireguard_constraints(wireguard_constraints) + .unwrap(); + query } + /// Dig out a custom list from the daemon settings based on the custom list's name. /// There should be an rpc for this. async fn find_custom_list( @@ -1351,4 +1363,30 @@ pub mod custom_lists { .find(|list| list.name == name) .ok_or(anyhow!("List '{name}' not found")) } + + #[cfg(test)] + mod tests { + use super::*; + use mullvad_types::Intersection; + use std::str::FromStr; + + #[tokio::test] + /// Test that the default location doesn't contradict other queries. + async fn test_set_default_location() { + // Add mock custom list ID to map + IDS.lock().unwrap().insert( + DEFAULT_LIST, + Id::from_str("1a428244-92c8-496b-9bdc-d59a9566eaca").unwrap(), + ); + let default_query = get_custom_list_location_query(DEFAULT_LIST); + + log::info!("{default_query:#?}"); + + let test_valid = |query| assert!(default_query.clone().intersection(query).is_some()); + test_valid(RelayQueryBuilder::new().openvpn().build()); + test_valid(RelayQueryBuilder::new().openvpn().bridge().build()); + test_valid(RelayQueryBuilder::new().wireguard().build()); + test_valid(RelayQueryBuilder::new().wireguard().multihop().build()); + } + } }