Skip to content

Commit

Permalink
Improve callback API
Browse files Browse the repository at this point in the history
  • Loading branch information
touilleMan committed Mar 15, 2024
1 parent 8f9b655 commit 1408dc8
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 149 deletions.
83 changes: 50 additions & 33 deletions examples/memfs/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use std::{
sync::{Arc, Mutex},
};
use winfsp_wrs::{
filetime_now, u16cstr, u16str, CleanupFlags, CreateFileInfo, CreateOptions, FileAccessRights,
FileAttributes, FileInfo, FileSystem, FileSystemContext, PSecurityDescriptor, Params,
SecurityDescriptor, U16CStr, U16CString, U16Str, VolumeInfo, VolumeParams, WriteMode, NTSTATUS,
STATUS_ACCESS_DENIED, STATUS_DIRECTORY_NOT_EMPTY, STATUS_END_OF_FILE,
filetime_now, u16cstr, u16str, CleanupFlags, CreateFileInfo, CreateOptions, DirInfo,
FileAccessRights, FileAttributes, FileInfo, FileSystem, FileSystemContext, PSecurityDescriptor,
Params, SecurityDescriptor, U16CStr, U16CString, U16Str, VolumeInfo, VolumeParams, WriteMode,
NTSTATUS, STATUS_ACCESS_DENIED, STATUS_DIRECTORY_NOT_EMPTY, STATUS_END_OF_FILE,
STATUS_MEDIA_WRITE_PROTECTED, STATUS_NOT_A_DIRECTORY, STATUS_OBJECT_NAME_COLLISION,
STATUS_OBJECT_NAME_NOT_FOUND,
};
Expand Down Expand Up @@ -285,7 +285,7 @@ impl FileSystemContext for MemFs {
security_descriptor: SecurityDescriptor,
_buffer: &[u8],
_extra_buffer_is_reparse_point: bool,
) -> Result<Self::FileContext, NTSTATUS> {
) -> Result<(Self::FileContext, FileInfo), NTSTATUS> {
if self.read_only {
return Err(STATUS_MEDIA_WRITE_PROTECTED);
}
Expand All @@ -299,7 +299,7 @@ impl FileSystemContext for MemFs {
return Err(STATUS_OBJECT_NAME_COLLISION);
}

let file_obj = Arc::new(Mutex::new(
let file_context = Arc::new(Mutex::new(
if create_file_info
.create_options
.is(CreateOptions::FILE_DIRECTORY_FILE)
Expand All @@ -319,21 +319,27 @@ impl FileSystemContext for MemFs {
},
));

entries.insert(file_name, file_obj.clone());
entries.insert(file_name, file_context.clone());

Ok(file_obj)
let file_info = self.get_file_info(file_context.clone())?;

Ok((file_context, file_info))
}

fn open(
&self,
file_name: &U16CStr,
_create_options: CreateOptions,
_granted_access: FileAccessRights,
) -> Result<Self::FileContext, NTSTATUS> {
) -> Result<(Self::FileContext, FileInfo), NTSTATUS> {
let file_name = PathBuf::from(file_name.to_os_string());

match self.entries.lock().unwrap().get(&file_name) {
Some(entry) => Ok(entry.clone()),
Some(entry) => {
let file_context = entry.clone();
let file_info = self.get_file_info(file_context.clone())?;
Ok((file_context, file_info))
}
None => Err(STATUS_OBJECT_NAME_NOT_FOUND),
}
}
Expand All @@ -345,7 +351,7 @@ impl FileSystemContext for MemFs {
replace_file_attributes: bool,
allocation_size: u64,
_buffer: &[u8],
) -> Result<(), NTSTATUS> {
) -> Result<FileInfo, NTSTATUS> {
if self.read_only {
return Err(STATUS_MEDIA_WRITE_PROTECTED);
}
Expand All @@ -369,11 +375,11 @@ impl FileSystemContext for MemFs {
file_obj.info.set_last_access_time(now);
file_obj.info.set_last_write_time(now);
file_obj.info.set_change_time(now);

Ok(())
} else {
unreachable!()
}

self.get_file_info(file_context)
}

fn cleanup(
Expand All @@ -396,9 +402,9 @@ impl FileSystemContext for MemFs {

// Set archive bit
if flags.is(CleanupFlags::SET_ARCHIVE_BIT) {
file_obj.info.set_file_attributes(
FileAttributes::ARCHIVE | file_obj.info.file_attributes(),
);
file_obj
.info
.set_file_attributes(FileAttributes::ARCHIVE | file_obj.info.file_attributes());
}

let now = filetime_now();
Expand Down Expand Up @@ -455,29 +461,32 @@ impl FileSystemContext for MemFs {
&self,
file_context: Self::FileContext,
buffer: &[u8],
offset: u64,
mode: WriteMode,
) -> Result<usize, NTSTATUS> {
) -> Result<(usize, FileInfo), NTSTATUS> {
if self.read_only {
return Err(STATUS_MEDIA_WRITE_PROTECTED);
}

if let Obj::File(file_obj) = file_context.lock().unwrap().deref_mut() {
let written = if let Obj::File(file_obj) = file_context.lock().unwrap().deref_mut() {
match mode {
WriteMode::Constrained => Ok(file_obj.constrained_write(buffer, offset as usize)),
WriteMode::Normal => Ok(file_obj.write(buffer, offset as usize)),
WriteMode::StartEOF => {
WriteMode::Normal { offset } => file_obj.write(buffer, offset as usize),
WriteMode::ConstrainedIO { offset } => {
file_obj.constrained_write(buffer, offset as usize)
}
WriteMode::WriteToEOF => {
let offset = file_obj.info.file_size();
Ok(file_obj.write(buffer, offset as usize))
file_obj.write(buffer, offset as usize)
}
}
} else {
unreachable!()
}
};

Ok((written, self.get_file_info(file_context)?))
}

fn flush(&self, _file_context: Self::FileContext) -> Result<(), NTSTATUS> {
Ok(())
fn flush(&self, file_context: Self::FileContext) -> Result<FileInfo, NTSTATUS> {
self.get_file_info(file_context)
}

fn get_file_info(&self, file_context: Self::FileContext) -> Result<FileInfo, NTSTATUS> {
Expand All @@ -495,7 +504,7 @@ impl FileSystemContext for MemFs {
last_access_time: u64,
last_write_time: u64,
change_time: u64,
) -> Result<(), NTSTATUS> {
) -> Result<FileInfo, NTSTATUS> {
if self.read_only {
return Err(STATUS_MEDIA_WRITE_PROTECTED);
}
Expand Down Expand Up @@ -537,15 +546,15 @@ impl FileSystemContext for MemFs {
}
}

Ok(())
self.get_file_info(file_context)
}

fn set_file_size(
&self,
file_context: Self::FileContext,
new_size: u64,
set_allocation_size: bool,
) -> Result<(), NTSTATUS> {
) -> Result<FileInfo, NTSTATUS> {
if self.read_only {
return Err(STATUS_MEDIA_WRITE_PROTECTED);
}
Expand All @@ -563,7 +572,7 @@ impl FileSystemContext for MemFs {
}
}

Ok(())
self.get_file_info(file_context)
}

fn rename(
Expand Down Expand Up @@ -655,7 +664,8 @@ impl FileSystemContext for MemFs {
&self,
file_context: Self::FileContext,
marker: Option<&U16CStr>,
) -> Result<Vec<(U16CString, FileInfo)>, NTSTATUS> {
mut add_dir_info: impl FnMut(DirInfo) -> bool,
) -> Result<(), NTSTATUS> {
let entries = self.entries.lock().unwrap();

match file_context.lock().unwrap().deref() {
Expand Down Expand Up @@ -683,7 +693,7 @@ impl FileSystemContext for MemFs {
res_entries.push((
U16CString::from_os_str(entry_path.file_name().unwrap()).unwrap(),
FileInfo::from(entry_obj.deref()),
))
));
}

res_entries.sort_by(|x, y| y.0.cmp(&x.0));
Expand All @@ -697,7 +707,14 @@ impl FileSystemContext for MemFs {

res_entries.reverse();

Ok(res_entries)
for (file_name, file_info) in res_entries {
let dir_info = DirInfo::new(file_info, &file_name);
if !add_dir_info(dir_info) {
break;
}
}

Ok(())
}
}
}
Expand Down
17 changes: 10 additions & 7 deletions examples/minimal/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::Arc;

use winfsp_wrs::{
filetime_now, u16cstr, u16str, CreateOptions, FileAccessRights, FileAttributes, FileInfo,
FileSystem, FileSystemContext, PSecurityDescriptor, Params, SecurityDescriptor, U16CStr,
U16CString, U16Str, VolumeInfo, VolumeParams, NTSTATUS,
filetime_now, u16cstr, u16str, CreateOptions, DirInfo, FileAccessRights, FileAttributes,
FileInfo, FileSystem, FileSystemContext, PSecurityDescriptor, Params, SecurityDescriptor,
U16CStr, U16Str, VolumeInfo, VolumeParams, NTSTATUS,
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -68,8 +68,10 @@ impl FileSystemContext for MemFs {
_file_name: &U16CStr,
_create_options: CreateOptions,
_granted_access: FileAccessRights,
) -> Result<Self::FileContext, NTSTATUS> {
Ok(Arc::new(self.file_context.clone()))
) -> Result<(Self::FileContext, FileInfo), NTSTATUS> {
let file_context = Arc::new(self.file_context.clone());
let file_info = self.file_context.info;
Ok((file_context, file_info))
}

fn get_file_info(&self, _file_context: Self::FileContext) -> Result<FileInfo, NTSTATUS> {
Expand All @@ -84,8 +86,9 @@ impl FileSystemContext for MemFs {
&self,
_file_context: Self::FileContext,
_marker: Option<&U16CStr>,
) -> Result<Vec<(U16CString, FileInfo)>, NTSTATUS> {
Ok(vec![])
_add_dir_info: impl FnMut(DirInfo) -> bool,
) -> Result<(), NTSTATUS> {
Ok(())
}
}

Expand Down
Loading

0 comments on commit 1408dc8

Please sign in to comment.