diff --git a/examples/memfs/src/main.rs b/examples/memfs/src/main.rs index c6cbc74..c5df1d4 100644 --- a/examples/memfs/src/main.rs +++ b/examples/memfs/src/main.rs @@ -5,9 +5,9 @@ use std::{ sync::{Arc, Mutex}, }; use winfsp_wrs::{ - filetime_now, u16cstr, CleanupFlags, CreateFileInfo, CreateOptions, FileAccessRights, + filetime_now, u16cstr, u16str, CleanupFlags, CreateFileInfo, CreateOptions, FileAccessRights, FileAttributes, FileInfo, FileSystem, FileSystemContext, PSecurityDescriptor, Params, - SecurityDescriptor, U16CStr, U16CString, VolumeInfo, VolumeParams, WriteMode, NTSTATUS, + SecurityDescriptor, U16CStr, U16CString, U16Str, VolumeInfo, VolumeParams, WriteMode, NTSTATUS, STATUS_ACCESS_DENIED, STATUS_DIRECTORY_NOT_EMPTY, STATUS_END_OF_FILE, STATUS_MEDIA_WRITE_PROTECTED, STATUS_NOT_A_DIRECTORY, STATUS_OBJECT_NAME_COLLISION, STATUS_OBJECT_NAME_NOT_FOUND, @@ -203,7 +203,7 @@ impl MemFs { const MAX_FILE_SIZE: u64 = 16 * 1024 * 1024; const FILE_NODES: u64 = 1; - fn new(volume_label: &U16CStr, read_only: bool) -> Self { + fn new(volume_label: &U16Str, read_only: bool) -> Self { let root_path = PathBuf::from("/"); let mut entries = HashMap::new(); @@ -220,11 +220,14 @@ impl MemFs { Self { entries: Arc::new(Mutex::new(entries)), - volume_info: Arc::new(Mutex::new(VolumeInfo::new( - Self::MAX_FILE_NODES * Self::MAX_FILE_SIZE, - (Self::MAX_FILE_NODES - Self::FILE_NODES) * Self::MAX_FILE_SIZE, - volume_label, - ))), + volume_info: Arc::new(Mutex::new( + VolumeInfo::new( + Self::MAX_FILE_NODES * Self::MAX_FILE_SIZE, + (Self::MAX_FILE_NODES - Self::FILE_NODES) * Self::MAX_FILE_SIZE, + volume_label, + ) + .expect("volume label too long"), + )), read_only, root_path, } @@ -235,15 +238,17 @@ impl FileSystemContext for MemFs { type FileContext = Arc>; fn get_volume_info(&self) -> Result { - Ok(*self.volume_info.lock().unwrap()) + Ok(self.volume_info.lock().unwrap().clone()) } - fn set_volume_label(&self, volume_label: &U16CStr) -> Result<(), NTSTATUS> { - self.volume_info - .lock() - .unwrap() - .set_volume_label(volume_label); - Ok(()) + fn set_volume_label(&self, volume_label: &U16CStr) -> Result { + let mut guard = self.volume_info.lock().unwrap(); + + guard + .set_volume_label(volume_label.as_ustr()) + .expect("volume label size already checked"); + + Ok(guard.clone()) } fn get_security_by_name( @@ -748,7 +753,7 @@ fn create_memory_file_system(mountpoint: &U16CStr) -> FileSystem { FileSystem::new( params, Some(mountpoint), - MemFs::new(u16cstr!("memfs"), false), + MemFs::new(u16str!("memfs"), false), ) .unwrap() } diff --git a/examples/memfs/tests/mod.rs b/examples/memfs/tests/mod.rs index ae452b3..a481ab8 100644 --- a/examples/memfs/tests/mod.rs +++ b/examples/memfs/tests/mod.rs @@ -4,6 +4,8 @@ use std::{ time::Duration, }; +use winfsp_wrs::{u16str, VolumeInfo}; + #[test] fn winfsp_tests() { let mut fs = Command::new("cargo") @@ -18,8 +20,8 @@ fn winfsp_tests() { std::thread::sleep(Duration::from_millis(100)) } - let exe = - std::env::var("WINFSP_TEST_EXE").expect("specify the path of winfsp_tests in TEST_EXE"); + let exe = std::env::var("WINFSP_TEST_EXE") + .expect("specify the path of winfsp_tests with `WINFSP_TEST_EXE` env var"); let mut tests = Command::new(exe) .args([ @@ -67,3 +69,24 @@ fn init_is_idempotent() { fs.kill().unwrap(); } + +#[test] +fn too_long_volume_label() { + let too_long = u16str!("012345678901234567890123456789123"); + assert_eq!(too_long.len(), 33); // Sanity check + let max_size = u16str!("01234567890123456789012345678912"); + assert_eq!(max_size.len(), 32); // Sanity check + + VolumeInfo::new(0, 0, &too_long).unwrap_err(); + + let mut vi = VolumeInfo::new(0, 0, &max_size).unwrap(); + assert_eq!(vi.volume_label(), max_size,); + + vi.set_volume_label(&too_long).unwrap_err(); + + vi.set_volume_label(&max_size).unwrap(); + + let small = u16str!("abc"); + vi.set_volume_label(&small).unwrap(); + assert_eq!(vi.volume_label(), &small,); +} diff --git a/examples/minimal/src/main.rs b/examples/minimal/src/main.rs index 7e74359..39b86cd 100644 --- a/examples/minimal/src/main.rs +++ b/examples/minimal/src/main.rs @@ -1,9 +1,9 @@ use std::sync::Arc; use winfsp_wrs::{ - filetime_now, u16cstr, CreateOptions, FileAccessRights, FileAttributes, FileInfo, FileSystem, - FileSystemContext, PSecurityDescriptor, Params, SecurityDescriptor, U16CStr, U16CString, - VolumeInfo, VolumeParams, NTSTATUS, + filetime_now, u16cstr, u16str, CreateOptions, FileAccessRights, FileAttributes, FileInfo, + FileSystem, FileSystemContext, PSecurityDescriptor, Params, SecurityDescriptor, U16CStr, + U16CString, U16Str, VolumeInfo, VolumeParams, NTSTATUS, }; #[derive(Debug, Clone)] @@ -23,7 +23,7 @@ impl MemFs { const MAX_FILE_SIZE: u64 = 16 * 1024 * 1024; const FILE_NODES: u64 = 1; - fn new(volume_label: &U16CStr) -> Self { + fn new(volume_label: &U16Str) -> Self { let now = filetime_now(); let mut info = FileInfo::default(); @@ -35,7 +35,8 @@ impl MemFs { Self::MAX_FILE_NODES * Self::MAX_FILE_SIZE, (Self::MAX_FILE_NODES - Self::FILE_NODES) * Self::MAX_FILE_SIZE, volume_label, - ), + ) + .expect("volume label too long"), file_context: Context { info, security_descriptor: SecurityDescriptor::from_wstr(u16cstr!( @@ -76,7 +77,7 @@ impl FileSystemContext for MemFs { } fn get_volume_info(&self) -> Result { - Ok(self.volume_info) + Ok(self.volume_info.clone()) } fn read_directory( @@ -102,7 +103,7 @@ fn create_memory_file_system(mountpoint: &U16CStr) -> FileSystem { ..Default::default() }; - FileSystem::new(params, Some(mountpoint), MemFs::new(u16cstr!("memfs"))).unwrap() + FileSystem::new(params, Some(mountpoint), MemFs::new(u16str!("memfs"))).unwrap() } fn main() { diff --git a/winfsp_wrs/src/callback.rs b/winfsp_wrs/src/callback.rs index 77b066f..3a46225 100644 --- a/winfsp_wrs/src/callback.rs +++ b/winfsp_wrs/src/callback.rs @@ -84,7 +84,7 @@ pub trait FileSystemContext { fn get_volume_info(&self) -> Result; /// Set volume label. - fn set_volume_label(&self, _volume_label: &U16CStr) -> Result<(), NTSTATUS> { + fn set_volume_label(&self, _volume_label: &U16CStr) -> Result { Err(STATUS_NOT_IMPLEMENTED) } @@ -380,7 +380,10 @@ impl Interface { match C::set_volume_label(fs, U16CStr::from_ptr_str(volume_label)) { Err(e) => e, - Ok(()) => Self::get_volume_info_ext::(file_system, volume_info), + Ok(vi) => { + *volume_info = vi.0; + STATUS_SUCCESS + } } } diff --git a/winfsp_wrs/src/file_system.rs b/winfsp_wrs/src/file_system.rs index 1ba2d24..73fe92d 100644 --- a/winfsp_wrs/src/file_system.rs +++ b/winfsp_wrs/src/file_system.rs @@ -354,6 +354,9 @@ impl FileSystem { let device_name = params.volume_params.device_path(); let res = FspFileSystemCreate( + // `device_name` contains const data, so this `cast_mut` is a bit scary ! + // However, it is only a limitation in the type system (we need to cast + // to `PWSTR`): in practice this parameter in never modified. device_name.as_ptr().cast_mut(), ¶ms.volume_params.0, interface, diff --git a/winfsp_wrs/src/info.rs b/winfsp_wrs/src/info.rs index 223b228..4e66d5f 100644 --- a/winfsp_wrs/src/info.rs +++ b/winfsp_wrs/src/info.rs @@ -1,4 +1,4 @@ -use widestring::U16CStr; +use widestring::{U16CStr, U16Str}; use crate::{ ext::{FSP_FSCTL_DIR_INFO, FSP_FSCTL_FILE_INFO, FSP_FSCTL_VOLUME_INFO}, @@ -117,43 +117,73 @@ impl FileInfo { } } -#[derive(Debug, Default, Copy, Clone)] +#[derive(Debug, Default, Clone)] pub struct VolumeInfo(pub(crate) FSP_FSCTL_VOLUME_INFO); -impl VolumeInfo { - const VOLUME_LABEL_MAX_LEN: usize = 31; +#[derive(Debug)] +pub struct VolumeLabelNameTooLong; - pub fn new(total_size: u64, free_size: u64, volume_label: &U16CStr) -> Self { - assert!(volume_label.len() <= Self::VOLUME_LABEL_MAX_LEN); +impl VolumeInfo { + // Max len correspond to the entire `FSP_FSCTL_VOLUME_INFO.VolumeLabel` buffer given + // there should be no null-terminator (`FSP_FSCTL_VOLUME_INFO.VolumeLabelLength` is + // used instead). + const VOLUME_LABEL_MAX_LEN: usize = 32; + + pub fn new( + total_size: u64, + free_size: u64, + volume_label: &U16Str, + ) -> Result { + if volume_label.len() > Self::VOLUME_LABEL_MAX_LEN { + return Err(VolumeLabelNameTooLong); + } - let mut vl = [0; Self::VOLUME_LABEL_MAX_LEN + 1]; + let mut vl = [0; Self::VOLUME_LABEL_MAX_LEN]; vl[..volume_label.len()].copy_from_slice(volume_label.as_slice()); - Self(FSP_FSCTL_VOLUME_INFO { + Ok(Self(FSP_FSCTL_VOLUME_INFO { TotalSize: total_size, FreeSize: free_size, + // It is unintuitive, but the length is in bytes, not in u16s VolumeLabelLength: (volume_label.len() * std::mem::size_of::()) as u16, VolumeLabel: vl, - }) + })) } pub fn total_size(&self) -> u64 { self.0.TotalSize } + pub fn set_total_size(&mut self, size: u64) { + self.0.TotalSize = size; + } + pub fn free_size(&self) -> u64 { self.0.FreeSize } - pub fn volume_label(&self) -> &U16CStr { - U16CStr::from_slice(&self.0.VolumeLabel[..self.0.VolumeLabelLength as usize]).unwrap() + pub fn set_free_size(&mut self, size: u64) { + self.0.FreeSize = size; } - pub fn set_volume_label(&mut self, volume_label: &U16CStr) { - assert!(volume_label.len() <= Self::VOLUME_LABEL_MAX_LEN); + pub fn volume_label(&self) -> &U16Str { + let len_in_u16s = self.0.VolumeLabelLength as usize / std::mem::size_of::(); + U16Str::from_slice(&self.0.VolumeLabel[..len_in_u16s]) + } + + pub fn set_volume_label( + &mut self, + volume_label: &U16Str, + ) -> Result<(), VolumeLabelNameTooLong> { + if volume_label.len() > Self::VOLUME_LABEL_MAX_LEN { + return Err(VolumeLabelNameTooLong); + } - self.0.VolumeLabelLength = volume_label.len() as u16; + // It is unintuitive, but the length is in bytes, not in u16s + self.0.VolumeLabelLength = (volume_label.len() * std::mem::size_of::()) as u16; self.0.VolumeLabel[..volume_label.len()].copy_from_slice(volume_label.as_slice()); + + Ok(()) } }