Skip to content

Commit

Permalink
Merge branch 'win-refactor-winapi' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 20, 2023
2 parents 12167d2 + 8a5ad4b commit c9f9f6b
Show file tree
Hide file tree
Showing 32 changed files with 387 additions and 407 deletions.
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ members = [
"talpid-time",
"talpid-tunnel",
"talpid-tunnel-config-client",
"talpid-windows-net",
"talpid-windows",
"talpid-wireguard",
"mullvad-management-interface",
"tunnel-obfuscation",
Expand Down
1 change: 1 addition & 0 deletions mullvad-daemon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ ctrlc = "3.0"
windows-service = "0.6.0"
winapi = { version = "0.3", features = ["winnt", "excpt"] }
dirs = "5.0.1"
talpid-windows = { path = "../talpid-windows" }

[target.'cfg(windows)'.dependencies.windows-sys]
workspace = true
Expand Down
96 changes: 6 additions & 90 deletions mullvad-daemon/src/exception_logging/win.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
use mullvad_paths::log_dir;
use std::{
borrow::Cow,
ffi::{c_char, c_void, CStr},
ffi::c_void,
fmt::Write,
fs, io, mem,
fs, io,
os::windows::io::AsRawHandle,
path::{Path, PathBuf},
ptr,
};
use talpid_types::ErrorExt;
use talpid_windows::process::{ModuleEntry, ProcessSnapshot};
use winapi::{
um::winnt::{CONTEXT_CONTROL, CONTEXT_INTEGER, CONTEXT_SEGMENTS},
vc::excpt::EXCEPTION_EXECUTE_HANDLER,
};
use windows_sys::Win32::{
Foundation::{CloseHandle, BOOL, ERROR_NO_MORE_FILES, HANDLE, INVALID_HANDLE_VALUE},
Foundation::{BOOL, HANDLE},
System::{
Diagnostics::{
Debug::{SetUnhandledExceptionFilter, CONTEXT, EXCEPTION_POINTERS, EXCEPTION_RECORD},
ToolHelp::{
CreateToolhelp32Snapshot, Module32First, Module32Next, MODULEENTRY32,
TH32CS_SNAPMODULE,
},
ToolHelp::TH32CS_SNAPMODULE,
},
Threading::{GetCurrentProcess, GetCurrentProcessId, GetCurrentThreadId},
},
Expand Down Expand Up @@ -291,7 +289,7 @@ fn get_context_info(context: &CONTEXT) -> String {
}

/// Return module info for the current process and given memory address.
fn find_address_module(address: *mut c_void) -> io::Result<Option<ModuleInfo>> {
fn find_address_module(address: *mut c_void) -> io::Result<Option<ModuleEntry>> {
let snap = ProcessSnapshot::new(TH32CS_SNAPMODULE, 0)?;

for module in snap.modules() {
Expand All @@ -306,85 +304,3 @@ fn find_address_module(address: *mut c_void) -> io::Result<Option<ModuleInfo>> {

Ok(None)
}

struct ModuleInfo {
name: String,
base_address: *const u8,
size: usize,
}

struct ProcessSnapshot {
handle: HANDLE,
}

impl ProcessSnapshot {
fn new(flags: u32, process_id: u32) -> io::Result<ProcessSnapshot> {
let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) };

if snap == INVALID_HANDLE_VALUE {
Err(io::Error::last_os_error())
} else {
Ok(ProcessSnapshot { handle: snap })
}
}

fn handle(&self) -> HANDLE {
self.handle
}

fn modules(&self) -> ProcessSnapshotModules<'_> {
let mut entry: MODULEENTRY32 = unsafe { mem::zeroed() };
entry.dwSize = mem::size_of::<MODULEENTRY32>() as u32;

ProcessSnapshotModules {
snapshot: self,
iter_started: false,
temp_entry: entry,
}
}
}

impl Drop for ProcessSnapshot {
fn drop(&mut self) {
unsafe {
CloseHandle(self.handle);
}
}
}

struct ProcessSnapshotModules<'a> {
snapshot: &'a ProcessSnapshot,
iter_started: bool,
temp_entry: MODULEENTRY32,
}

impl Iterator for ProcessSnapshotModules<'_> {
type Item = io::Result<ModuleInfo>;

fn next(&mut self) -> Option<io::Result<ModuleInfo>> {
if self.iter_started {
if unsafe { Module32Next(self.snapshot.handle(), &mut self.temp_entry) } == 0 {
let last_error = io::Error::last_os_error();

return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES {
None
} else {
Some(Err(last_error))
};
}
} else {
if unsafe { Module32First(self.snapshot.handle(), &mut self.temp_entry) } == 0 {
return Some(Err(io::Error::last_os_error()));
}
self.iter_started = true;
}

let cstr_ref = &self.temp_entry.szModule[0];
let cstr = unsafe { CStr::from_ptr(cstr_ref as *const u8 as *const c_char) };
Some(Ok(ModuleInfo {
name: cstr.to_string_lossy().into_owned(),
base_address: self.temp_entry.modBaseAddr,
size: self.temp_entry.modBaseSize as usize,
}))
}
}
2 changes: 1 addition & 1 deletion talpid-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ widestring = "1.0"
winreg = { version = "0.51", features = ["transactions"] }
memoffset = "0.6"
windows-service = "0.6.0"
talpid-windows-net = { path = "../talpid-windows-net" }
talpid-windows = { path = "../talpid-windows" }

[target.'cfg(windows)'.dependencies.windows-sys]
workspace = true
Expand Down
2 changes: 1 addition & 1 deletion talpid-core/src/dns/windows/iphlpapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{
ptr,
};
use talpid_types::win32_err;
use talpid_windows_net::{guid_from_luid, luid_from_alias};
use talpid_windows::net::{guid_from_luid, luid_from_alias};
use windows_sys::{
core::GUID,
s, w,
Expand Down
2 changes: 1 addition & 1 deletion talpid-core/src/dns/windows/netsh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
time::Duration,
};
use talpid_types::{net::IpVersion, ErrorExt};
use talpid_windows_net::{index_from_luid, luid_from_alias};
use talpid_windows::net::{index_from_luid, luid_from_alias};
use windows_sys::Win32::{
Foundation::{MAX_PATH, WAIT_OBJECT_0, WAIT_TIMEOUT},
System::{
Expand Down
2 changes: 1 addition & 1 deletion talpid-core/src/dns/windows/tcpip.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::dns::DnsMonitorT;
use std::{io, net::IpAddr};
use talpid_types::ErrorExt;
use talpid_windows_net::{guid_from_luid, luid_from_alias};
use talpid_windows::net::{guid_from_luid, luid_from_alias};
use windows_sys::{core::GUID, Win32::System::Com::StringFromGUID2};
use winreg::{
enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE},
Expand Down
2 changes: 1 addition & 1 deletion talpid-core/src/offline/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::{
time::Duration,
};
use talpid_types::ErrorExt;
use talpid_windows_net::AddressFamily;
use talpid_windows::net::AddressFamily;

#[derive(err_derive::Error, Debug)]
pub enum Error {
Expand Down
9 changes: 5 additions & 4 deletions talpid-core/src/split_tunnel/windows/driver.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::windows::{
get_device_path, get_process_creation_time, get_process_device_path, open_process, Event,
Overlapped, ProcessAccess, ProcessSnapshot,
get_device_path, get_process_creation_time, get_process_device_path, open_process,
ProcessAccess,
};
use bitflags::bitflags;
use memoffset::offset_of;
Expand All @@ -22,6 +22,7 @@ use std::{
time::Duration,
};
use talpid_types::ErrorExt;
use talpid_windows::{io::Overlapped, process::ProcessSnapshot, sync::Event};
use windows_sys::Win32::{
Foundation::{
ERROR_ACCESS_DENIED, ERROR_FILE_NOT_FOUND, ERROR_INVALID_PARAMETER, ERROR_IO_PENDING,
Expand Down Expand Up @@ -485,7 +486,7 @@ fn build_process_tree() -> io::Result<Vec<ProcessInfo>> {
let mut process_info = HashMap::new();

let snap = ProcessSnapshot::new(TH32CS_SNAPPROCESS, 0)?;
for entry in snap.entries() {
for entry in snap.processes() {
let entry = entry?;

let process = match open_process(ProcessAccess::QueryLimitedInformation, false, entry.pid) {
Expand Down Expand Up @@ -877,7 +878,7 @@ pub fn get_overlapped_result(
let event = overlapped.get_event().unwrap();

// SAFETY: This is a valid event object.
unsafe { wait_for_single_object(event.as_handle(), None) }?;
unsafe { wait_for_single_object(event.as_raw(), None) }?;

// SAFETY: The handle and overlapped object are valid.
let mut returned_bytes = 0u32;
Expand Down
31 changes: 17 additions & 14 deletions talpid-core/src/split_tunnel/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ use std::{
};
use talpid_routing::{get_best_default_route, CallbackHandle, EventType, RouteManagerHandle};
use talpid_types::{split_tunnel::ExcludedProcess, tunnel::ErrorStateCause, ErrorExt};
use talpid_windows_net::{get_ip_address_for_interface, AddressFamily};
use talpid_windows::{
io::Overlapped,
net::{get_ip_address_for_interface, AddressFamily},
sync::Event,
};
use windows_sys::Win32::Foundation::ERROR_OPERATION_ABORTED;

const DRIVER_EVENT_BUFFER_SIZE: usize = 2048;
Expand Down Expand Up @@ -69,7 +73,7 @@ pub enum Error {

/// Failed to obtain an IP address given a network interface LUID
#[error(display = "Failed to obtain IP address for interface LUID")]
LuidToIp(#[error(source)] talpid_windows_net::Error),
LuidToIp(#[error(source)] talpid_windows::net::Error),

/// Failed to set up callback for monitoring default route changes
#[error(display = "Failed to register default route change callback")]
Expand Down Expand Up @@ -105,7 +109,7 @@ pub struct SplitTunnel {
runtime: tokio::runtime::Handle,
request_tx: RequestTx,
event_thread: Option<std::thread::JoinHandle<()>>,
quit_event: Arc<windows::Event>,
quit_event: Arc<Event>,
excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
_route_change_callback: Option<CallbackHandle>,
daemon_tx: Weak<mpsc::UnboundedSender<TunnelCommand>>,
Expand Down Expand Up @@ -191,14 +195,13 @@ impl SplitTunnel {
fn spawn_event_listener(
handle: Arc<driver::DeviceHandle>,
excluded_processes: Arc<RwLock<HashMap<usize, ExcludedProcess>>>,
) -> Result<(std::thread::JoinHandle<()>, Arc<windows::Event>), Error> {
let mut event_overlapped = windows::Overlapped::new(Some(
windows::Event::new(true, false).map_err(Error::EventThreadError)?,
) -> Result<(std::thread::JoinHandle<()>, Arc<Event>), Error> {
let mut event_overlapped = Overlapped::new(Some(
Event::new(true, false).map_err(Error::EventThreadError)?,
))
.map_err(Error::EventThreadError)?;

let quit_event =
Arc::new(windows::Event::new(true, false).map_err(Error::EventThreadError)?);
let quit_event = Arc::new(Event::new(true, false).map_err(Error::EventThreadError)?);
let quit_event_copy = quit_event.clone();

let event_thread = std::thread::spawn(move || {
Expand Down Expand Up @@ -237,11 +240,11 @@ impl SplitTunnel {

fn fetch_next_event(
device: &Arc<driver::DeviceHandle>,
quit_event: &windows::Event,
overlapped: &mut windows::Overlapped,
quit_event: &Event,
overlapped: &mut Overlapped,
data_buffer: &mut Vec<u8>,
) -> io::Result<EventResult> {
if unsafe { driver::wait_for_single_object(quit_event.as_handle(), Some(Duration::ZERO)) }
if unsafe { driver::wait_for_single_object(quit_event.as_raw(), Some(Duration::ZERO)) }
.is_ok()
{
return Ok(EventResult::Quit);
Expand All @@ -268,8 +271,8 @@ impl SplitTunnel {
})?;

let event_objects = [
overlapped.get_event().unwrap().as_handle(),
quit_event.as_handle(),
overlapped.get_event().unwrap().as_raw(),
quit_event.as_raw(),
];

let signaled_object =
Expand All @@ -283,7 +286,7 @@ impl SplitTunnel {
},
)?;

if signaled_object == quit_event.as_handle() {
if signaled_object == quit_event.as_raw() {
// Quit event was signaled
return Ok(EventResult::Quit);
}
Expand Down
Loading

0 comments on commit c9f9f6b

Please sign in to comment.