Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark set_print as unsafe #374

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/capable/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,12 @@ fn main() -> Result<()> {

let mut skel_builder = CapableSkelBuilder::default();
if opts.debug {
skel_builder.obj_builder.debug(true);
unsafe {
// SAFETY:
// no other thread is running which could cause undefined behaviour due
// to this call.
skel_builder.obj_builder.debug(true);
}
}

bump_memlock_rlimit()?;
Expand Down
7 changes: 6 additions & 1 deletion examples/runqslower/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ fn main() -> Result<()> {

let mut skel_builder = RunqslowerSkelBuilder::default();
if opts.verbose {
skel_builder.obj_builder.debug(true);
unsafe {
// SAFETY:
// no other thread is running which could cause undefined behaviour due
// to this call.
skel_builder.obj_builder.debug(true);
}
}

bump_memlock_rlimit()?;
Expand Down
7 changes: 6 additions & 1 deletion examples/tproxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ fn main() -> Result<()> {

let mut skel_builder = TproxySkelBuilder::default();
if opts.verbose {
skel_builder.obj_builder.debug(true);
unsafe {
// SAFETY:
// no other thread is running which could cause undefined behaviour due
// to this call.
skel_builder.obj_builder.debug(true);
}
}

// Set constants
Expand Down
16 changes: 10 additions & 6 deletions libbpf-rs/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ impl ObjectBuilder {

/// Option to print debug output to stderr.
///
/// # Safety
///
/// See [`set_print`](set_print#safety).
///
/// Also it is guarenteed that the only parts of the `ObjectBuilder` that interacts with libbpf
/// are the [`open_file`](Self::open_file) and [`open_memory`](Self::open_memory) methods.
///
/// Note: This function uses [`set_print`] internally and will overwrite any callbacks
/// currently in use.
pub fn debug(&mut self, dbg: bool) -> &mut Self {
if dbg {
set_print(Some((PrintLevel::Debug, |_, s| print!("{s}"))));
} else {
set_print(None);
}
pub unsafe fn debug(&mut self, dbg: bool) -> &mut Self {
let callback = dbg.then(|| (PrintLevel::Debug, (|_, s| print!("{s}")) as PrintCallback));
unsafe { set_print(callback) };
self
}

Expand Down
48 changes: 22 additions & 26 deletions libbpf-rs/src/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ extern "C" fn outer_print_cb(
0 // return value is ignored by libbpf
}

// allow main to reinforce the idea that set print should be used before any other interaction
// with the library.
#[allow(clippy::needless_doctest_main)]
/// Set a callback to receive log messages from libbpf, instead of printing them to stderr.
///
/// # Arguments
Expand All @@ -78,6 +81,12 @@ extern "C" fn outer_print_cb(
///
/// This overrides (and is overridden by) [`ObjectBuilder::debug`]
///
/// # Safety
///
/// This function mutates a global variable without syncronization inside libbpf. This means that
/// it is not thread safe and **cannot** be called concurrently with any other function
/// in this library.
///
/// # Examples
///
/// To pass all messages to the `log` crate:
Expand All @@ -93,26 +102,26 @@ extern "C" fn outer_print_cb(
/// }
/// }
///
/// set_print(Some((PrintLevel::Debug, print_to_log)));
/// fn main() {
/// unsafe {
/// // SAFETY: this is being setup before interacting with any other part of the library
/// set_print(Some((PrintLevel::Debug, print_to_log)))
/// };
/// }
/// ```
///
/// To disable printing completely:
///
/// ```
/// use libbpf_rs::set_print;
/// set_print(None);
/// ```
///
/// To temporarliy suppress output:
///
/// ```
/// use libbpf_rs::set_print;
///
/// let prev = set_print(None);
/// // do things quietly
/// set_print(prev);
/// fn main() {
/// unsafe {
/// // SAFETY: this is being setup before interacting with any other part of the library
/// set_print(None)
/// };
/// }
/// ```
pub fn set_print(
pub unsafe fn set_print(
mut callback: Option<(PrintLevel, PrintCallback)>,
) -> Option<(PrintLevel, PrintCallback)> {
let real_cb: libbpf_sys::libbpf_print_fn_t = callback.as_ref().and(Some(outer_print_cb));
Expand All @@ -122,19 +131,6 @@ pub fn set_print(
}

/// Return the current print callback and level.
///
/// # Examples
///
/// To temporarily suppress output:
///
/// ```
/// use libbpf_rs::{get_print, set_print};
///
/// let prev = get_print();
/// set_print(None);
/// // do things quietly
/// set_print(prev);
/// ```
pub fn get_print() -> Option<(PrintLevel, PrintCallback)> {
*PRINT_CB.lock().unwrap()
}
12 changes: 6 additions & 6 deletions libbpf-rs/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub fn open_test_object(filename: &str) -> OpenObject {
// cargo test -- --nocapture
//
// To get all the output
builder.debug(true);
unsafe { builder.debug(true) };
builder.open_file(obj_path).expect("failed to open object")
}

Expand Down Expand Up @@ -112,8 +112,8 @@ fn test_object_build_from_memory() {
#[test]
fn test_object_load_invalid() {
let empty_file = NamedTempFile::new().unwrap();
let _err = ObjectBuilder::default()
.debug(true)
let mut builder = ObjectBuilder::default();
let _err = unsafe { builder.debug(true) }
.open_file(empty_file.path())
.unwrap_err();
}
Expand Down Expand Up @@ -488,7 +488,7 @@ fn test_object_reuse_pined_map() {
// Reuse the pinned map
let obj_path = get_test_object_path("runqslower.bpf.o");
let mut builder = ObjectBuilder::default();
builder.debug(true);
unsafe { builder.debug(true) };
let mut open_obj = builder.open_file(obj_path).expect("failed to open object");

let start = open_obj.map_mut("start").expect("failed to find map");
Expand Down Expand Up @@ -1035,8 +1035,8 @@ fn test_object_link_files() {
let () = linker.link().unwrap();

// Check that we can load the resulting object file.
let _object = ObjectBuilder::default()
.debug(true)
let mut builder = ObjectBuilder::default();
let _object = unsafe { builder.debug(true) }
.open_file(output_file.path())
.unwrap();
}
Expand Down
20 changes: 14 additions & 6 deletions libbpf-rs/tests/test_print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ fn test_set_print() {
}
}

set_print(Some((PrintLevel::Debug, callback)));
unsafe {
set_print(Some((PrintLevel::Debug, callback)));
}
// expect_err requires that OpenObject implement Debug, which it does not.
let obj = ObjectBuilder::default().open_file("/dev/null");
assert!(obj.is_err(), "Successfully loaded /dev/null?");
Expand All @@ -46,11 +48,15 @@ fn test_set_restore_print() {
println!("two");
}

set_print(Some((PrintLevel::Warn, callback1)));
unsafe {
set_print(Some((PrintLevel::Warn, callback1)));
}
let prev = get_print();
assert_eq!(prev, Some((PrintLevel::Warn, callback1 as PrintCallback)));

set_print(Some((PrintLevel::Debug, callback2)));
unsafe {
set_print(Some((PrintLevel::Debug, callback2)));
}
let prev = get_print();
assert_eq!(prev, Some((PrintLevel::Debug, callback2 as PrintCallback)));
}
Expand All @@ -65,10 +71,12 @@ fn test_set_and_save_print() {
println!("two");
}

set_print(Some((PrintLevel::Warn, callback1)));
let prev = set_print(Some((PrintLevel::Debug, callback2)));
unsafe {
set_print(Some((PrintLevel::Warn, callback1)));
}
let prev = unsafe { set_print(Some((PrintLevel::Debug, callback2))) };
assert_eq!(prev, Some((PrintLevel::Warn, callback1 as PrintCallback)));

let prev = set_print(None);
let prev = unsafe { set_print(None) };
assert_eq!(prev, Some((PrintLevel::Debug, callback2 as PrintCallback)));
}