From 7c6636a3cd7f53c8dc72bf107b86851df8b2abdb Mon Sep 17 00:00:00 2001 From: Nicolas Kagami Date: Wed, 1 Jan 2025 15:07:37 -0300 Subject: [PATCH 1/4] Add test for multiple custom extensions --- openssl/src/ssl/test/mod.rs | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/openssl/src/ssl/test/mod.rs b/openssl/src/ssl/test/mod.rs index 282558f80..77c090d4b 100644 --- a/openssl/src/ssl/test/mod.rs +++ b/openssl/src/ssl/test/mod.rs @@ -1268,7 +1268,8 @@ fn no_version_overlap() { #[test] #[cfg(ossl111)] fn custom_extensions() { - static FOUND_EXTENSION: AtomicBool = AtomicBool::new(false); + static FOUND_EXTENSION_1: AtomicBool = AtomicBool::new(false); + static FOUND_EXTENSION_2: AtomicBool = AtomicBool::new(false); let mut server = Server::builder(); server @@ -1278,7 +1279,19 @@ fn custom_extensions() { ExtensionContext::CLIENT_HELLO, |_, _, _| -> Result, _> { unreachable!() }, |_, _, data, _| { - FOUND_EXTENSION.store(data == b"hello", Ordering::SeqCst); + FOUND_EXTENSION_1.store(data == b"hello", Ordering::SeqCst); + Ok(()) + }, + ) + .unwrap(); + server + .ctx() + .add_custom_ext( + 23456, + ExtensionContext::CLIENT_HELLO, + |_, _, _| -> Result, _> { unreachable!() }, + |_, _, data, _| { + FOUND_EXTENSION_2.store(data == b"another hello", Ordering::SeqCst); Ok(()) }, ) @@ -1296,10 +1309,20 @@ fn custom_extensions() { |_, _, _, _| unreachable!(), ) .unwrap(); + client + .ctx() + .add_custom_ext( + 23456, + ssl::ExtensionContext::CLIENT_HELLO, + |_, _, _| Ok(Some(b"another hello")), + |_, _, _, _| unreachable!(), + ) + .unwrap(); client.connect(); - assert!(FOUND_EXTENSION.load(Ordering::SeqCst)); + assert!(FOUND_EXTENSION_1.load(Ordering::SeqCst)); + assert!(FOUND_EXTENSION_2.load(Ordering::SeqCst)); } fn _check_kinds() { From 58b418ad4835641c8df6d08aa417b43f2a65f1be Mon Sep 17 00:00:00 2001 From: Nicolas Kagami Date: Wed, 1 Jan 2025 15:08:14 -0300 Subject: [PATCH 2/4] Store custom extension callbacks per extension type --- openssl/src/ssl/callbacks.rs | 12 +++++++---- openssl/src/ssl/mod.rs | 40 ++++++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/openssl/src/ssl/callbacks.rs b/openssl/src/ssl/callbacks.rs index f7e51a5d3..44163231f 100644 --- a/openssl/src/ssl/callbacks.rs +++ b/openssl/src/ssl/callbacks.rs @@ -562,7 +562,7 @@ pub struct CustomExtAddState(Option); #[cfg(ossl111)] pub extern "C" fn raw_custom_ext_add( ssl: *mut ffi::SSL, - _: c_uint, + ext_type: c_uint, context: c_uint, out: *mut *const c_uchar, outlen: *mut size_t, @@ -580,9 +580,11 @@ where { unsafe { let ssl = SslRef::from_ptr_mut(ssl); + let cb_key = + SslContext::get_custom_ext_cb_key(ext_type as u16, super::CustomExtCbType::Add); let callback = ssl .ssl_context() - .ex_data(SslContext::cached_ex_index::()) + .ex_data(SslContext::cached_custom_ext_ex_index(cb_key)) .expect("BUG: custom ext add callback missing") as *const F; let ectx = ExtensionContext::from_bits_truncate(context); let cert = if ectx.contains(ExtensionContext::TLS1_3_CERTIFICATE) { @@ -640,7 +642,7 @@ pub extern "C" fn raw_custom_ext_free( #[cfg(ossl111)] pub extern "C" fn raw_custom_ext_parse( ssl: *mut ffi::SSL, - _: c_uint, + ext_type: c_uint, context: c_uint, input: *const c_uchar, inlen: size_t, @@ -657,9 +659,11 @@ where { unsafe { let ssl = SslRef::from_ptr_mut(ssl); + let cb_key = + SslContext::get_custom_ext_cb_key(ext_type as u16, super::CustomExtCbType::Parse); let callback = ssl .ssl_context() - .ex_data(SslContext::cached_ex_index::()) + .ex_data(SslContext::cached_custom_ext_ex_index(cb_key)) .expect("BUG: custom ext parse callback missing") as *const F; let ectx = ExtensionContext::from_bits_truncate(context); #[allow(clippy::unnecessary_cast)] diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index f5a696ab5..5e7595c8e 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -554,7 +554,14 @@ impl NameType { } } +enum CustomExtCbType { + Add, + Parse, +} + static INDEXES: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); +static CUSTOM_EXT_INDEXES: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); static SSL_INDEXES: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); static SESSION_CTX_INDEX: OnceCell> = OnceCell::new(); @@ -1642,8 +1649,16 @@ impl SslContextBuilder { + Send, { let ret = unsafe { - self.set_ex_data(SslContext::cached_ex_index::(), add_cb); - self.set_ex_data(SslContext::cached_ex_index::(), parse_cb); + let add_cb_key = SslContext::get_custom_ext_cb_key(ext_type, CustomExtCbType::Add); + let parse_cb_key = SslContext::get_custom_ext_cb_key(ext_type, CustomExtCbType::Parse); + self.set_ex_data( + SslContext::cached_custom_ext_ex_index::(add_cb_key), + add_cb, + ); + self.set_ex_data( + SslContext::cached_custom_ext_ex_index::(parse_cb_key), + parse_cb, + ); ffi::SSL_CTX_add_custom_ext( self.as_ptr(), @@ -1827,6 +1842,27 @@ impl SslContext { } } + fn get_custom_ext_cb_key(ext_type: u16, cb_type: CustomExtCbType) -> u64 { + match cb_type { + CustomExtCbType::Add => ext_type as u64, + CustomExtCbType::Parse => ext_type as u64 | 0x8000_0000_0000_0000, + } + } + + fn cached_custom_ext_ex_index(key: u64) -> Index + where + T: 'static + Sync + Send, + { + unsafe { + let idx = *CUSTOM_EXT_INDEXES + .lock() + .unwrap_or_else(|e| e.into_inner()) + .entry(key) + .or_insert_with(|| SslContext::new_ex_index::().unwrap().as_raw()); + Index::from_raw(idx) + } + } + // FIXME should return a result? fn cached_ex_index() -> Index where From 2137c189ffec139c5fb7a7a962e3f25900d1dc71 Mon Sep 17 00:00:00 2001 From: Nicolas Kagami Date: Wed, 1 Jan 2025 16:24:20 -0300 Subject: [PATCH 3/4] Add ossl111 cfg flag --- openssl/src/ssl/mod.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 5e7595c8e..2d1f488c2 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -554,16 +554,18 @@ impl NameType { } } +#[cfg(ossl111)] enum CustomExtCbType { Add, Parse, } static INDEXES: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); -static CUSTOM_EXT_INDEXES: Lazy>> = - Lazy::new(|| Mutex::new(HashMap::new())); static SSL_INDEXES: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); static SESSION_CTX_INDEX: OnceCell> = OnceCell::new(); +#[cfg(ossl111)] +static CUSTOM_EXT_INDEXES: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); fn try_get_session_ctx_index() -> Result<&'static Index, ErrorStack> { SESSION_CTX_INDEX.get_or_try_init(Ssl::new_ex_index) @@ -1842,6 +1844,7 @@ impl SslContext { } } + #[cfg(ossl111)] fn get_custom_ext_cb_key(ext_type: u16, cb_type: CustomExtCbType) -> u64 { match cb_type { CustomExtCbType::Add => ext_type as u64, @@ -1849,6 +1852,7 @@ impl SslContext { } } + #[cfg(ossl111)] fn cached_custom_ext_ex_index(key: u64) -> Index where T: 'static + Sync + Send, From c5d0b8e6341f8be0037eb7b756c2ffd71c186588 Mon Sep 17 00:00:00 2001 From: Nicolas Kagami Date: Wed, 1 Jan 2025 23:19:20 -0300 Subject: [PATCH 4/4] Separate custom extensions callback conflict test --- openssl/src/ssl/test/mod.rs | 70 ++++++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/openssl/src/ssl/test/mod.rs b/openssl/src/ssl/test/mod.rs index 77c090d4b..4f54156f3 100644 --- a/openssl/src/ssl/test/mod.rs +++ b/openssl/src/ssl/test/mod.rs @@ -1268,8 +1268,7 @@ fn no_version_overlap() { #[test] #[cfg(ossl111)] fn custom_extensions() { - static FOUND_EXTENSION_1: AtomicBool = AtomicBool::new(false); - static FOUND_EXTENSION_2: AtomicBool = AtomicBool::new(false); + static FOUND_EXTENSION: AtomicBool = AtomicBool::new(false); let mut server = Server::builder(); server @@ -1279,19 +1278,7 @@ fn custom_extensions() { ExtensionContext::CLIENT_HELLO, |_, _, _| -> Result, _> { unreachable!() }, |_, _, data, _| { - FOUND_EXTENSION_1.store(data == b"hello", Ordering::SeqCst); - Ok(()) - }, - ) - .unwrap(); - server - .ctx() - .add_custom_ext( - 23456, - ExtensionContext::CLIENT_HELLO, - |_, _, _| -> Result, _> { unreachable!() }, - |_, _, data, _| { - FOUND_EXTENSION_2.store(data == b"another hello", Ordering::SeqCst); + FOUND_EXTENSION.store(data == b"hello", Ordering::SeqCst); Ok(()) }, ) @@ -1309,22 +1296,65 @@ fn custom_extensions() { |_, _, _, _| unreachable!(), ) .unwrap(); - client + + client.connect(); + + assert!(FOUND_EXTENSION.load(Ordering::SeqCst)); +} + +#[test] +#[cfg(ossl111)] +fn custom_extensions_with_same_callback_signature() { + static FOUND_EXTENSION_1: AtomicBool = AtomicBool::new(false); + static FOUND_EXTENSION_2: AtomicBool = AtomicBool::new(false); + + fn insert_custom_extension(builder: &mut SslContextBuilder, ext_type: u16, data: Vec) { + builder + .add_custom_ext( + ext_type, + ExtensionContext::CLIENT_HELLO, + move |_, _, _| Ok(Some(data.clone())), + |_, _, _, _| Ok(()), + ) + .expect("Failed to add custom extension"); + } + + let mut server = Server::builder(); + server + .ctx() + .add_custom_ext( + 12345, + ExtensionContext::CLIENT_HELLO, + |_, _, _| -> Result, _> { unreachable!() }, + move |_, _, data, _| { + FOUND_EXTENSION_1.store(data == b"some data".to_vec(), Ordering::SeqCst); + Ok(()) + }, + ) + .unwrap(); + server .ctx() .add_custom_ext( 23456, - ssl::ExtensionContext::CLIENT_HELLO, - |_, _, _| Ok(Some(b"another hello")), - |_, _, _, _| unreachable!(), + ExtensionContext::CLIENT_HELLO, + |_, _, _| -> Result, _> { unreachable!() }, + move |_, _, data, _| { + FOUND_EXTENSION_2.store(data == b"some other data".to_vec(), Ordering::SeqCst); + Ok(()) + }, ) .unwrap(); + let server = server.build(); + + let mut client = server.client(); + insert_custom_extension(client.ctx(), 12345, b"some data".to_vec()); + insert_custom_extension(client.ctx(), 23456, b"some other data".to_vec()); client.connect(); assert!(FOUND_EXTENSION_1.load(Ordering::SeqCst)); assert!(FOUND_EXTENSION_2.load(Ordering::SeqCst)); } - fn _check_kinds() { fn is_send() {} fn is_sync() {}