Skip to content

Commit

Permalink
Merge branch 'firewall-not-updated-when-api-override-is-enabled-550'
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Jan 11, 2024
2 parents 75eb89c + 8b0fd0d commit b5decd1
Show file tree
Hide file tree
Showing 11 changed files with 266 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ import kotlinx.parcelize.Parcelize
data class ApiEndpoint(
val address: InetSocketAddress,
val disableAddressCache: Boolean,
val disableTls: Boolean,
val forceDirectConnection: Boolean
val disableTls: Boolean
) : Parcelable
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ data class CustomApiEndpointConfiguration(
val hostname: String,
val port: Int,
val disableAddressCache: Boolean = true,
val disableTls: Boolean = false,
val forceDirectConnection: Boolean = true
val disableTls: Boolean = false
) : ApiEndpointConfiguration {
override fun apiEndpoint() =
ApiEndpoint(
address = InetSocketAddress(hostname, port),
disableAddressCache = disableAddressCache,
disableTls = disableTls,
forceDirectConnection = forceDirectConnection
disableTls = disableTls
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ abstract class MockApiTest {
InetAddress.getLocalHost().hostName,
port,
disableAddressCache = true,
disableTls = true,
forceDirectConnection = true
disableTls = true
)
}
}
2 changes: 1 addition & 1 deletion mullvad-api/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct AccountState {

impl AccessTokenStore {
pub(crate) fn new(service: RequestServiceHandle) -> Self {
let factory = rest::RequestFactory::new(&API.host, None);
let factory = rest::RequestFactory::new(API.host(), None);
let (tx, rx) = mpsc::unbounded();
tokio::spawn(Self::service_requests(rx, service, factory));
Self { tx }
Expand Down
4 changes: 2 additions & 2 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct AddressCache {
impl AddressCache {
/// Initialize cache using the hardcoded address, and write changes to `write_path`.
pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> {
Self::new_inner(API.addr, write_path)
Self::new_inner(API.address(), write_path)
}

/// Initialize cache using `read_path`, and write changes to `write_path`.
Expand All @@ -53,7 +53,7 @@ impl AddressCache {

/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(&API.host) {
if hostname.eq_ignore_ascii_case(API.host()) {
Some(self.get_address().await)
} else {
None
Expand Down
206 changes: 140 additions & 66 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,95 +103,169 @@ impl<T> Deref for LazyManual<T> {
/// A hostname and socketaddr to reach the Mullvad REST API over.
#[derive(Debug)]
pub struct ApiEndpoint {
pub host: String,
pub addr: SocketAddr,
/// An overriden API hostname. Initialized with the value of the environment
/// variable `MULLVAD_API_HOST` if it has been set.
///
/// Use the associated function [`Self::host`] to read this value with a
/// default fallback if `MULLVAD_API_HOST` was not set.
pub host: Option<String>,
/// An overriden API address. Initialized with the value of the environment
/// variable `MULLVAD_API_ADDR` if it has been set.
///
/// Use the associated function [`Self::address()`] to read this value with
/// a default fallback if `MULLVAD_API_ADDR` was not set.
///
/// # Note
///
/// If [`Self::address`] is populated with [`Some(SocketAddr)`], it should
/// always be respected when establishing API connections.
pub address: Option<SocketAddr>,
#[cfg(feature = "api-override")]
pub disable_address_cache: bool,
#[cfg(feature = "api-override")]
pub disable_tls: bool,
#[cfg(feature = "api-override")]
pub force_direct_connection: bool,
}

impl ApiEndpoint {
const API_HOST_DEFAULT: &'static str = "api.mullvad.net";
const API_IP_DEFAULT: IpAddr = IpAddr::V4(Ipv4Addr::new(45, 83, 223, 196));
const API_PORT_DEFAULT: u16 = 443;

const API_HOST_VAR: &'static str = "MULLVAD_API_HOST";
const API_ADDR_VAR: &'static str = "MULLVAD_API_ADDR";
const DISABLE_TLS_VAR: &'static str = "MULLVAD_API_DISABLE_TLS";

/// Returns the endpoint to connect to the API over.
///
/// # Panics
///
/// Panics if `MULLVAD_API_ADDR` has invalid contents or if only one of
/// `MULLVAD_API_ADDR` or `MULLVAD_API_HOST` has been set but not the other.
/// Panics if `MULLVAD_API_ADDR`, `MULLVAD_API_HOST` or
/// `MULLVAD_API_DISABLE_TLS` has invalid contents.
#[cfg(feature = "api-override")]
pub fn from_env_vars() -> ApiEndpoint {
const API_HOST_DEFAULT: &str = "api.mullvad.net";
const API_IP_DEFAULT: IpAddr = IpAddr::V4(Ipv4Addr::new(45, 83, 223, 196));
const API_PORT_DEFAULT: u16 = 443;

fn read_var(key: &'static str) -> Option<String> {
use std::env;
match env::var(key) {
Ok(v) => Some(v),
Err(env::VarError::NotPresent) => None,
Err(env::VarError::NotUnicode(_)) => panic!("{key} does not contain valid UTF-8"),
}
}

let host_var = read_var("MULLVAD_API_HOST");
let address_var = read_var("MULLVAD_API_ADDR");
let disable_tls_var = read_var("MULLVAD_API_DISABLE_TLS");
let host_var = Self::read_var(ApiEndpoint::API_HOST_VAR);
let address_var = Self::read_var(ApiEndpoint::API_ADDR_VAR);
let disable_tls_var = Self::read_var(ApiEndpoint::DISABLE_TLS_VAR);

#[cfg_attr(not(feature = "api-override"), allow(unused_mut))]
let mut api = ApiEndpoint {
host: API_HOST_DEFAULT.to_owned(),
addr: SocketAddr::new(API_IP_DEFAULT, API_PORT_DEFAULT),
#[cfg(feature = "api-override")]
disable_address_cache: false,
#[cfg(feature = "api-override")]
host: host_var.clone(),
address: None,
disable_address_cache: true,
disable_tls: false,
#[cfg(feature = "api-override")]
force_direct_connection: false,
};

#[cfg(feature = "api-override")]
{
use std::net::ToSocketAddrs;

if host_var.is_none() && address_var.is_none() {
if disable_tls_var.is_some() {
log::warn!("MULLVAD_API_DISABLE_TLS is ignored since MULLVAD_API_HOST and MULLVAD_API_ADDR are not set");
}
return api;
}

let scheme = if let Some(disable_tls_var) = disable_tls_var {
api.disable_tls = disable_tls_var != "0";
"http://"
} else {
"https://"
};

if let Some(user_host) = host_var {
api.host = user_host;
api.address = match address_var {
Some(user_addr) => {
let addr = user_addr.parse().unwrap_or_else(|_| {
panic!(
"{api_addr}={user_addr} is not a valid socketaddr",
api_addr = ApiEndpoint::API_ADDR_VAR,
)
});
Some(addr)
}
if let Some(user_addr) = address_var {
api.addr = user_addr
.parse()
.expect("MULLVAD_API_ADDR is not a valid socketaddr");
} else {
log::warn!("Resolving API IP from MULLVAD_API_HOST");
api.addr = format!("{}:{}", api.host, API_PORT_DEFAULT)
None => {
use std::net::ToSocketAddrs;
log::debug!(
"{api_addr} not found. Resolving API IP from {api_host}",
api_addr = ApiEndpoint::API_ADDR_VAR,
api_host = ApiEndpoint::API_HOST_VAR
);
format!("{}:{}", api.host(), ApiEndpoint::API_PORT_DEFAULT)
.to_socket_addrs()
.expect("failed to resolve API host")
.next()
.expect("API host yielded 0 addresses");
}
api.disable_address_cache = true;
api.force_direct_connection = true;
log::debug!("Overriding API. Using {} at {scheme}{}", api.host, api.addr);
};

if api.host.is_none() && api.address.is_none() {
if disable_tls_var.is_some() {
log::warn!(
"{disable_tls} is ignored since {api_host} and {api_addr} are not set",
disable_tls = ApiEndpoint::DISABLE_TLS_VAR,
api_host = ApiEndpoint::API_HOST_VAR,
api_addr = ApiEndpoint::API_ADDR_VAR,
);
}
} else {
api.disable_tls = disable_tls_var
.as_ref()
.map(|disable_tls| disable_tls != "0")
.unwrap_or(api.disable_tls);

log::debug!(
"Overriding API. Using {host} at {scheme}{addr}",
host = api.host(),
addr = api.address(),
scheme = if api.disable_tls {
"http://"
} else {
"https://"
}
);
}
#[cfg(not(feature = "api-override"))]
api
}

/// Returns the endpoint to connect to the API over.
///
/// # Panics
///
/// Panics if `MULLVAD_API_ADDR`, `MULLVAD_API_HOST` or
/// `MULLVAD_API_DISABLE_TLS` has invalid contents.
#[cfg(not(feature = "api-override"))]
pub fn from_env_vars() -> ApiEndpoint {
let host_var = Self::read_var(ApiEndpoint::API_HOST_VAR);
let address_var = Self::read_var(ApiEndpoint::API_ADDR_VAR);
let disable_tls_var = Self::read_var(ApiEndpoint::DISABLE_TLS_VAR);

if host_var.is_some() || address_var.is_some() || disable_tls_var.is_some() {
log::warn!("These variables are ignored in production builds: MULLVAD_API_HOST, MULLVAD_API_ADDR, MULLVAD_API_DISABLE_TLS");
log::warn!(
"These variables are ignored in production builds: {api_host}, {api_addr}, {disable_tls}",
api_host = ApiEndpoint::API_HOST_VAR,
api_addr = ApiEndpoint::API_ADDR_VAR,
disable_tls = ApiEndpoint::DISABLE_TLS_VAR
);
}

ApiEndpoint {
host: None,
address: None,
}
}

/// Read the [`Self::host`] value, falling back to
/// [`Self::API_HOST_DEFAULT`] as default value if it does not exist.
pub fn host(&self) -> &str {
self.host
.as_deref()
.unwrap_or(ApiEndpoint::API_HOST_DEFAULT)
}

/// Read the [`Self::address`] value, falling back to
/// [`Self::API_IP_DEFAULT`]:[`Self::API_PORT_DEFAULT`] as default if it
/// does not exist.
pub fn address(&self) -> SocketAddr {
self.address.unwrap_or(SocketAddr::new(
ApiEndpoint::API_IP_DEFAULT,
ApiEndpoint::API_PORT_DEFAULT,
))
}

/// Try to read the value of an environment variable. Returns `None` if the
/// environment variable has not been set.
///
/// # Panics
///
/// Panics if the environment variable was found, but it did not contain
/// valid unicode data.
fn read_var(key: &'static str) -> Option<String> {
use std::env;
match env::var(key) {
Ok(v) => Some(v),
Err(env::VarError::NotPresent) => None,
Err(env::VarError::NotUnicode(_)) => panic!("{key} does not contain valid UTF-8"),
}
api
}
}

Expand Down Expand Up @@ -314,14 +388,14 @@ impl Runtime {
) -> rest::MullvadRestHandle {
let service = self
.new_request_service(
Some(API.host.clone()),
Some(API.host().to_string()),
proxy_provider,
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
)
.await;
let token_store = access::AccessTokenStore::new(service.clone());
let factory = rest::RequestFactory::new(&API.host, Some(token_store));
let factory = rest::RequestFactory::new(API.host(), Some(token_store));

rest::MullvadRestHandle::new(
service,
Expand Down
28 changes: 6 additions & 22 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ use std::{
};
use talpid_types::ErrorExt;

#[cfg(feature = "api-override")]
use crate::API;

pub use hyper::StatusCode;

const USER_AGENT: &str = "mullvad-app";
Expand Down Expand Up @@ -147,14 +144,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
socket_bypass_tx.clone(),
);

#[cfg(feature = "api-override")]
let force_direct_connection = API.force_direct_connection;
#[cfg(not(feature = "api-override"))]
let force_direct_connection = false;

if force_direct_connection {
log::debug!("API proxies are disabled");
} else if let Some(config) = proxy_config_provider.next().await {
if let Some(config) = proxy_config_provider.next().await {
connector_handle.set_connection_mode(config);
}

Expand Down Expand Up @@ -185,17 +175,9 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
self.connector_handle.reset();
}
RequestCommand::NextApiConfig(completion_tx) => {
#[cfg(feature = "api-override")]
let force_direct_connection = API.force_direct_connection;
#[cfg(not(feature = "api-override"))]
let force_direct_connection = false;

if force_direct_connection {
log::debug!("Ignoring API connection mode");
} else if let Some(connection_mode) = self.proxy_config_provider.next().await {
if let Some(connection_mode) = self.proxy_config_provider.next().await {
self.connector_handle.set_connection_mode(connection_mode);
}

let _ = completion_tx.send(Ok(()));
}
}
Expand Down Expand Up @@ -632,8 +614,10 @@ impl MullvadRestHandle {
availability,
};
#[cfg(feature = "api-override")]
if API.disable_address_cache {
return handle;
{
if crate::API.disable_address_cache {
return handle;
}
}
handle.spawn_api_address_fetcher(address_cache);
handle
Expand Down
Loading

0 comments on commit b5decd1

Please sign in to comment.