diff --git a/src/_internal_test_exports/fuzz.rs b/src/_internal_test_exports/fuzz.rs index 9efb9274..a24c64aa 100644 --- a/src/_internal_test_exports/fuzz.rs +++ b/src/_internal_test_exports/fuzz.rs @@ -4,8 +4,7 @@ use std::time::Duration; use std::time::Instant; use crate::change::{SdpAnswer, SdpOffer}; -use crate::crypto::KeyingMaterial; -use crate::crypto::SrtpProfile; +use crate::crypto::{CryptoProviderId, KeyingMaterial, SrtpProfile}; use crate::format::Codec; use crate::packet::{DepacketizingBuffer, RtpMeta}; use crate::rtp_::{Frequency, MediaTime, RtpHeader}; @@ -49,12 +48,14 @@ pub fn rtp_header(data: &[u8]) -> Option<()> { #[cfg(feature = "_internal_test_exports")] pub fn rtp_packet(data: &[u8]) -> Option<()> { use crate::Session; + let crypto_provider = CryptoProviderId::default().into(); let mut rng = Rng::new(data); let config = random_config(&mut rng)?; let mut session = Session::new(&config); session.set_keying_material( + crypto_provider, KeyingMaterial::new(rng.slice(16)?.to_vec()), SrtpProfile::PassThrough, rng.bool()?, diff --git a/src/crypto/dtls.rs b/src/crypto/dtls.rs index 0d15f521..3a732d51 100644 --- a/src/crypto/dtls.rs +++ b/src/crypto/dtls.rs @@ -2,19 +2,46 @@ use std::collections::VecDeque; use std::fmt; +use std::panic::UnwindSafe; use std::time::Instant; use crate::net::DatagramSend; -use super::{CryptoError, Fingerprint, KeyingMaterial, SrtpProfile}; +use super::{ + CryptoError, CryptoProvider, CryptoProviderId, Fingerprint, KeyingMaterial, SrtpProfile, +}; -// libWebRTC says "WebRTC" here when doing OpenSSL, for BoringSSL they seem -// to generate a random 8 characters. -// https://webrtc.googlesource.com/src/+/1568f1b1330f94494197696fe235094e6293b258/rtc_base/rtc_certificate_generator.cc#27 -// -// Pion also sets this to "WebRTC", maybe for compatibility reasons. -// https://github.com/pion/webrtc/blob/eed2bb2d3b9f204f9de1cd7e1046ca5d652778d2/constants.go#L31 -pub const DTLS_CERT_IDENTITY: &str = "WebRTC"; +pub(crate) trait DtlsIdentity: fmt::Debug { + fn fingerprint(&self) -> Fingerprint; + fn create_context(&self) -> Result, CryptoError>; + fn crypto_provider(&self) -> CryptoProvider; + fn boxed_clone(&self) -> Box; +} + +pub(crate) trait DtlsContext: UnwindSafe + Send + Sync { + // Returns the crypto context. + fn crypto_provider(&self) -> CryptoProvider; + + // Returns the local certificate fingerprint. + fn local_fingerprint(&self) -> Fingerprint; + + // DTLS session management + fn set_active(&mut self, active: bool) -> (); + fn is_active(&self) -> Option; + fn is_connected(&self) -> bool; + fn handle_handshake( + &mut self, + out_events: &mut VecDeque, + ) -> Result; + fn handle_receive( + &mut self, + datagram: &[u8], + out_events: &mut VecDeque, + ) -> Result<(), CryptoError>; + fn poll_datagram(&mut self) -> Option; + fn poll_timeout(&mut self, now: Instant) -> Option; + fn handle_input(&mut self, data: &[u8]) -> Result<(), CryptoError>; +} /// Events arising from a [`Dtls`] instance. pub enum DtlsEvent { @@ -34,190 +61,49 @@ pub enum DtlsEvent { } /// Certificate used for DTLS. -#[derive(Clone)] -pub struct DtlsCert(DtlsCertInner); +pub struct DtlsCert(Box); -#[derive(Debug, Clone)] -enum DtlsCertInner { - #[cfg(feature = "openssl")] - OpenSsl(super::ossl::OsslDtlsCert), - #[cfg(feature = "wincrypto")] - WinCrypto(super::wincrypto::WinCryptoDtlsCert), +impl Clone for DtlsCert { + fn clone(&self) -> Self { + Self(self.0.boxed_clone()) + } } impl DtlsCert { + /// Create a new DtlsCert using the given provider. + pub fn new(crypto_provider_id: CryptoProviderId) -> Self { + let crypto_provider: CryptoProvider = crypto_provider_id.into(); + DtlsCert(crypto_provider.create_dtls_identity()) + } + #[cfg(feature = "openssl")] /// Create a new OpenSSL variant of the certificate. pub fn new_openssl() -> Self { - let cert = super::ossl::OsslDtlsCert::new(); - DtlsCert(DtlsCertInner::OpenSsl(cert)) - } - - #[cfg(feature = "wincrypto")] - /// Create a new Windows Crypto variant of the certificate. - pub fn new_wincrypto() -> Self { - let cert = super::wincrypto::WinCryptoDtlsCert::new(); - DtlsCert(DtlsCertInner::WinCrypto(cert)) + Self::new(super::CryptoProviderId::default()) } /// Creates a fingerprint for this certificate. /// /// Fingerprints are used to verify a remote peer's certificate. pub fn fingerprint(&self) -> Fingerprint { - match &self.0 { - #[cfg(feature = "openssl")] - DtlsCertInner::OpenSsl(v) => v.fingerprint(), - #[cfg(feature = "wincrypto")] - DtlsCertInner::WinCrypto(v) => v.fingerprint(), - _ => unreachable!(), - } - } - - pub(crate) fn create_dtls_impl(&self) -> Result { - match &self.0 { - #[cfg(feature = "openssl")] - DtlsCertInner::OpenSsl(c) => Ok(DtlsImpl::OpenSsl(super::ossl::OsslDtlsImpl::new( - c.clone(), - )?)), - #[cfg(feature = "wincrypto")] - DtlsCertInner::WinCrypto(c) => Ok(DtlsImpl::WinCrypto( - super::wincrypto::WinCryptoDtls::new(c.clone())?, - )), - _ => unreachable!(), - } - } -} - -impl fmt::Debug for DtlsCert { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.0 { - #[cfg(feature = "openssl")] - DtlsCertInner::OpenSsl(c) => c.fmt(f), - #[cfg(feature = "wincrypto")] - DtlsCertInner::WinCrypto(c) => c.fmt(f), - _ => unreachable!(), - } + self.0.fingerprint() } -} -pub trait DtlsInner: Sized { - /// Set whether this instance is active or passive. + /// Creates a DTLS context using this certificate as the identity. /// - /// i.e. initiating the client hello or not. This must be called - /// exactly once before starting to handshake (I/O). - fn set_active(&mut self, active: bool); - - /// Handle the handshake. Once this succeeds, it becomes a no-op. - fn handle_handshake(&mut self, o: &mut VecDeque) -> Result; - - /// If set_active, returns what was set. - fn is_active(&self) -> Option; - - /// Handles an incoming DTLS datagrams. - fn handle_receive(&mut self, m: &[u8], o: &mut VecDeque) -> Result<(), CryptoError>; - - /// Poll for the next datagram to send. - fn poll_datagram(&mut self) -> Option; - - /// Poll for next timeout. This is only used during DTLS handshake. - fn poll_timeout(&mut self, now: Instant) -> Option; - - /// Handling incoming data to be sent as DTLS datagrams. - fn handle_input(&mut self, data: &[u8]) -> Result<(), CryptoError>; - - /// Whether the DTLS connection is established. - fn is_connected(&self) -> bool; -} - -pub enum DtlsImpl { - #[cfg(feature = "openssl")] - OpenSsl(super::ossl::OsslDtlsImpl), - #[cfg(feature = "wincrypto")] - WinCrypto(super::wincrypto::WinCryptoDtls), -} - -impl DtlsImpl { - pub fn set_active(&mut self, active: bool) { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.set_active(active), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.set_active(active), - _ => unreachable!(), - } + /// Multiple contexts may be created using the same identity. + pub(crate) fn create_context(&self) -> Result, CryptoError> { + self.0.create_context() } - pub fn handle_handshake(&mut self, o: &mut VecDeque) -> Result { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.handle_handshake(o), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.handle_handshake(o), - _ => unreachable!(), - } - } - - pub fn is_active(&self) -> Option { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.is_active(), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.is_active(), - _ => unreachable!(), - } - } - - pub fn handle_receive( - &mut self, - m: &[u8], - o: &mut VecDeque, - ) -> Result<(), CryptoError> { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.handle_receive(m, o), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.handle_receive(m, o), - _ => unreachable!(), - } - } - - pub fn poll_datagram(&mut self) -> Option { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.poll_datagram(), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.poll_datagram(), - _ => unreachable!(), - } - } - - pub fn poll_timeout(&mut self, now: Instant) -> Option { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.poll_timeout(now), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.poll_timeout(now), - _ => unreachable!(), - } - } - - pub fn handle_input(&mut self, data: &[u8]) -> Result<(), CryptoError> { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.handle_input(data), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.handle_input(data), - _ => unreachable!(), - } + /// Obtains the CryptoProvider that this Cert was built with. + pub(crate) fn crypto_provider(&self) -> CryptoProvider { + self.0.crypto_provider() } +} - pub fn is_connected(&self) -> bool { - match self { - #[cfg(feature = "openssl")] - DtlsImpl::OpenSsl(i) => i.is_connected(), - #[cfg(feature = "wincrypto")] - DtlsImpl::WinCrypto(i) => i.is_connected(), - _ => unreachable!(), - } +impl fmt::Debug for DtlsCert { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) } } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 7bca0c67..53b98c1a 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -7,10 +7,11 @@ use thiserror::Error; mod ossl; #[cfg(feature = "wincrypto")] -pub mod wincrypto; +mod wincrypto; mod dtls; -pub use dtls::{DtlsCert, DtlsEvent, DtlsImpl}; +pub use dtls::{DtlsCert, DtlsEvent}; +pub(crate) use dtls::{DtlsContext, DtlsIdentity}; mod finger; pub use finger::Fingerprint; @@ -19,56 +20,11 @@ mod keying; pub use keying::KeyingMaterial; mod srtp; -pub use srtp::{aead_aes_128_gcm, aes_128_cm_sha1_80, new_aead_aes_128_gcm}; -pub use srtp::{new_aes_128_cm_sha1_80, srtp_aes_128_ecb_round, SrtpProfile}; +pub use srtp::{aead_aes_128_gcm, aes_128_cm_sha1_80, SrtpProfile}; -#[cfg(all(feature = "openssl", feature = "wincrypto"))] -compile_error!("features `openssl` and `wincrypto` are mutually exclusive"); #[cfg(not(any(feature = "openssl", feature = "wincrypto")))] compile_error!("either `openssl` or `wincrypto` must be enabled"); -/// SHA1 HMAC as used for STUN and older SRTP. -/// If sha1 feature is enabled, it uses `rust-crypto` crate. -#[cfg(feature = "sha1")] -pub fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { - use hmac::Hmac; - use hmac::Mac; - use sha1::Sha1; - - let mut hmac = Hmac::::new_from_slice(key).expect("hmac to normalize size to 20"); - - for payload in payloads { - hmac.update(payload); - } - - hmac.finalize().into_bytes().into() -} - -/// If openssl is enabled and sha1 is not, it uses `openssl` crate. -#[cfg(all(feature = "openssl", not(feature = "sha1")))] -pub fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { - use openssl::hash::MessageDigest; - use openssl::pkey::PKey; - use openssl::sign::Signer; - - let key = PKey::hmac(key).expect("valid hmac key"); - let mut signer = Signer::new(MessageDigest::sha1(), &key).expect("valid signer"); - - for payload in payloads { - signer.update(payload).expect("signer update"); - } - - let mut hash = [0u8; 20]; - signer.sign(&mut hash).expect("sign to array"); - hash -} - -/// If wincrypto is enabled and sha1 is not, it uses `wincrypto` crate. -#[cfg(all(feature = "wincrypto", not(feature = "sha1")))] -pub fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { - wincrypto::sha1_hmac(key, payloads) -} - /// Errors that can arise in DTLS. #[derive(Debug, Error)] pub enum CryptoError { @@ -86,3 +42,84 @@ pub enum CryptoError { #[error("{0}")] Io(#[from] io::Error), } + +/// An ID specifying which Crypto implementation to use. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum CryptoProviderId { + #[cfg(feature = "openssl")] + /// Use OpenSSL + OpenSsl, + #[cfg(all(feature = "openssl", feature = "sha1"))] + /// Use OpenSSL for most ciphers, but use sha1 crate for SHA1 hashes. + OpenSslWithSha1Crate, + #[cfg(feature = "wincrypto")] + /// Use Windows Cryptography APIs + WinCrypto, +} + +impl Default for CryptoProviderId { + #[allow(unreachable_code)] + fn default() -> Self { + #[cfg(all(feature = "openssl", feature = "sha1"))] + return CryptoProviderId::OpenSslWithSha1Crate; + #[cfg(feature = "openssl")] + return CryptoProviderId::OpenSsl; + panic!("No default for CryptoProviderId!") + } +} + +impl From for CryptoProvider { + fn from(value: CryptoProviderId) -> Self { + match value { + #[cfg(feature = "openssl")] + CryptoProviderId::OpenSsl => ossl::create_crypto_provider(), + #[cfg(all(feature = "openssl", feature = "sha1"))] + CryptoProviderId::OpenSslWithSha1Crate => ossl::sha1_crate::create_crypto_provider(), + #[cfg(feature = "wincrypto")] + CryptoProviderId::WinCrypto => wincrypto::create_crypto_provider(), + } + } +} + +/// RTP/SRTP ciphers and hashes +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct CryptoProvider { + pub(crate) crypto_provider_id: CryptoProviderId, + pub(super) create_dtls_identity_impl: fn(CryptoProvider) -> Box, + pub(super) create_aes_128_cm_sha1_80_cipher_impl: + fn(&aes_128_cm_sha1_80::AesKey, bool) -> Box, + pub(super) create_aead_aes_128_gcm_cipher_impl: + fn(&aead_aes_128_gcm::AeadKey, bool) -> Box, + pub(super) srtp_aes_128_ecb_round_impl: fn(&[u8], &[u8], &mut [u8]) -> (), + pub(super) sha1_hmac_impl: fn(&[u8], &[&[u8]]) -> [u8; 20], +} + +impl CryptoProvider { + pub(super) fn create_dtls_identity(&self) -> Box { + (self.create_dtls_identity_impl)(*self) + } + + pub(crate) fn create_aes_128_cm_sha1_80_cipher( + &self, + key: &aes_128_cm_sha1_80::AesKey, + encrypt: bool, + ) -> Box { + (self.create_aes_128_cm_sha1_80_cipher_impl)(key, encrypt) + } + + pub(crate) fn create_aead_aes_128_gcm_cipher( + &self, + key: &aead_aes_128_gcm::AeadKey, + encrypt: bool, + ) -> Box { + (self.create_aead_aes_128_gcm_cipher_impl)(key, encrypt) + } + + pub(crate) fn srtp_aes_128_ecb_round(&self, key: &[u8], input: &[u8], output: &mut [u8]) { + (self.srtp_aes_128_ecb_round_impl)(key, input, output) + } + + pub(crate) fn sha1_hmac(&self, key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { + (self.sha1_hmac_impl)(key, payloads) + } +} diff --git a/src/crypto/ossl/cert.rs b/src/crypto/ossl/cert.rs index 115cb897..be4f9586 100644 --- a/src/crypto/ossl/cert.rs +++ b/src/crypto/ossl/cert.rs @@ -8,29 +8,41 @@ use openssl::pkey::{PKey, Private}; use openssl::rsa::Rsa; use openssl::x509::{X509Name, X509}; -use crate::crypto::dtls::DTLS_CERT_IDENTITY; -use crate::crypto::Fingerprint; +use crate::crypto::dtls::{DtlsContext, DtlsIdentity}; +use crate::crypto::{CryptoProvider, Fingerprint}; +use super::dtls::DtlsContextImpl; use super::CryptoError; const RSA_F4: u32 = 0x10001; +// libWebRTC says "WebRTC" here when doing OpenSSL, for BoringSSL they seem +// to generate a random 8 characters. +// https://webrtc.googlesource.com/src/+/1568f1b1330f94494197696fe235094e6293b258/rtc_base/rtc_certificate_generator.cc#27 +// +// Pion also sets this to "WebRTC", maybe for compatibility reasons. +// https://github.com/pion/webrtc/blob/eed2bb2d3b9f204f9de1cd7e1046ca5d652778d2/constants.go#L31 +pub const DTLS_CERT_IDENTITY: &str = "WebRTC"; + +pub(super) fn create_dtls_identity_impl(crypto_ctx: CryptoProvider) -> Box { + let identity = + DtlsIdentityImpl::create_self_signed(crypto_ctx).expect("self-signed cert expected"); + Box::new(identity) +} + /// Certificate used for DTLS. #[derive(Debug, Clone)] -pub struct OsslDtlsCert { +pub struct DtlsIdentityImpl { + crypto_provider: CryptoProvider, pub(crate) pkey: PKey, pub(crate) x509: X509, } -impl OsslDtlsCert { +impl DtlsIdentityImpl { /// Creates a new (self signed) DTLS certificate. - pub fn new() -> Self { - Self::self_signed().expect("create dtls cert") - } - // The libWebRTC code we try to match is at: // https://webrtc.googlesource.com/src/+/1568f1b1330f94494197696fe235094e6293b258/rtc_base/openssl_certificate.cc#58 - fn self_signed() -> Result { + fn create_self_signed(crypto_provider: CryptoProvider) -> Result { let f4 = BigNum::from_u32(RSA_F4).unwrap(); let key = Rsa::generate_with_e(2048, &f4)?; let pkey = PKey::from_rsa(key)?; @@ -75,14 +87,20 @@ impl OsslDtlsCert { x509b.sign(&pkey, MessageDigest::sha1())?; let x509 = x509b.build(); - Ok(OsslDtlsCert { pkey, x509 }) + Ok(DtlsIdentityImpl { + crypto_provider, + pkey, + x509, + }) } +} +impl DtlsIdentity for DtlsIdentityImpl { /// Produce a (public) fingerprint of the cert. /// /// This is sent via SDP to the other peer to lock down the DTLS /// to this specific certificate. - pub fn fingerprint(&self) -> Fingerprint { + fn fingerprint(&self) -> Fingerprint { let digest: &[u8] = &self .x509 .digest(MessageDigest::sha256()) @@ -93,6 +111,18 @@ impl OsslDtlsCert { bytes: digest.to_vec(), } } + + fn create_context(&self) -> Result, CryptoError> { + Ok(Box::new(DtlsContextImpl::new(self.clone())?)) + } + + fn boxed_clone(&self) -> Box { + Box::new(self.clone()) + } + + fn crypto_provider(&self) -> CryptoProvider { + self.crypto_provider + } } // TODO: Refactor away this use of System::now, to instead go via InstantExt diff --git a/src/crypto/ossl/dtls.rs b/src/crypto/ossl/dtls.rs index d67f8678..61daa186 100644 --- a/src/crypto/ossl/dtls.rs +++ b/src/crypto/ossl/dtls.rs @@ -6,11 +6,11 @@ use openssl::ec::EcKey; use openssl::nid::Nid; use openssl::ssl::{Ssl, SslContext, SslContextBuilder, SslMethod, SslOptions, SslVerifyMode}; -use crate::crypto::dtls::DtlsInner; +use crate::crypto::dtls::{DtlsContext, DtlsIdentity}; use crate::crypto::{DtlsEvent, SrtpProfile}; use crate::io::{DATAGRAM_MTU, DATAGRAM_MTU_WARN}; -use super::cert::OsslDtlsCert; +use super::cert::DtlsIdentityImpl; use super::io_buf::IoBuffer; use super::stream::TlsStream; use super::CryptoError; @@ -18,9 +18,9 @@ use super::CryptoError; const DTLS_CIPHERS: &str = "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH"; const DTLS_EC_CURVE: Nid = Nid::X9_62_PRIME256V1; -pub struct OsslDtlsImpl { +pub struct DtlsContextImpl { /// Certificate for the DTLS session. - _cert: OsslDtlsCert, + cert: DtlsIdentityImpl, /// Context belongs together with Fingerprint. /// @@ -32,19 +32,27 @@ pub struct OsslDtlsImpl { tls: TlsStream, } -impl OsslDtlsImpl { - pub fn new(cert: OsslDtlsCert) -> Result { +impl DtlsContextImpl { + pub fn new(cert: DtlsIdentityImpl) -> Result { let context = dtls_create_ctx(&cert)?; let ssl = dtls_ssl_create(&context)?; - Ok(OsslDtlsImpl { - _cert: cert, + Ok(DtlsContextImpl { + cert, _context: context, tls: TlsStream::new(ssl, IoBuffer::default()), }) } } -impl DtlsInner for OsslDtlsImpl { +impl DtlsContext for DtlsContextImpl { + fn crypto_provider(&self) -> crate::crypto::CryptoProvider { + self.cert.crypto_provider() + } + + fn local_fingerprint(&self) -> crate::crypto::Fingerprint { + self.cert.fingerprint() + } + fn set_active(&mut self, active: bool) { self.tls.set_active(active); } @@ -129,7 +137,7 @@ impl DtlsInner for OsslDtlsImpl { } } -pub fn dtls_create_ctx(cert: &OsslDtlsCert) -> Result { +pub fn dtls_create_ctx(cert: &DtlsIdentityImpl) -> Result { // TODO: Technically we want to disallow DTLS < 1.2, but that requires // us to use this commented out unsafe. We depend on browsers disallowing // it instead. diff --git a/src/crypto/ossl/mod.rs b/src/crypto/ossl/mod.rs index a3fb1899..57bd5516 100644 --- a/src/crypto/ossl/mod.rs +++ b/src/crypto/ossl/mod.rs @@ -1,18 +1,13 @@ //! OpenSSL implementation of cryptographic functions. -use super::{CryptoError, SrtpProfile}; +use super::{CryptoError, CryptoProvider, CryptoProviderId, SrtpProfile}; mod cert; -pub use cert::OsslDtlsCert; - -mod io_buf; -mod stream; - mod dtls; -pub use dtls::OsslDtlsImpl; - +mod io_buf; +mod sha1; mod srtp; -pub use srtp::OsslSrtpCryptoImpl; +mod stream; impl SrtpProfile { /// What this profile is called in OpenSSL parlance. @@ -25,3 +20,43 @@ impl SrtpProfile { } } } + +pub(crate) fn create_crypto_provider() -> CryptoProvider { + CryptoProvider { + crypto_provider_id: CryptoProviderId::OpenSsl, + create_dtls_identity_impl: cert::create_dtls_identity_impl, + create_aes_128_cm_sha1_80_cipher_impl: srtp::Aes128CmSha1_80Impl::new, + create_aead_aes_128_gcm_cipher_impl: srtp::AeadAes128GcmImpl::new, + srtp_aes_128_ecb_round_impl: srtp::srtp_aes_128_ecb_round, + sha1_hmac_impl: sha1::sha1_hmac, + } +} + +#[cfg(feature = "sha1")] +pub(super) mod sha1_crate { + use super::{cert, srtp, CryptoProvider, CryptoProviderId}; + use hmac::Hmac; + use hmac::Mac; + use sha1::Sha1; + + pub(super) fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { + let mut hmac = Hmac::::new_from_slice(key).expect("hmac to normalize size to 20"); + + for payload in payloads { + hmac.update(payload); + } + + hmac.finalize().into_bytes().into() + } + + pub(crate) fn create_crypto_provider() -> CryptoProvider { + CryptoProvider { + crypto_provider_id: CryptoProviderId::OpenSslWithSha1Crate, + create_dtls_identity_impl: cert::create_dtls_identity_impl, + create_aes_128_cm_sha1_80_cipher_impl: srtp::Aes128CmSha1_80Impl::new, + create_aead_aes_128_gcm_cipher_impl: srtp::AeadAes128GcmImpl::new, + srtp_aes_128_ecb_round_impl: srtp::srtp_aes_128_ecb_round, + sha1_hmac_impl: sha1_hmac, + } + } +} diff --git a/src/crypto/ossl/sha1.rs b/src/crypto/ossl/sha1.rs new file mode 100644 index 00000000..ad5c7e41 --- /dev/null +++ b/src/crypto/ossl/sha1.rs @@ -0,0 +1,16 @@ +use openssl::hash::MessageDigest; +use openssl::pkey::PKey; +use openssl::sign::Signer; + +pub(super) fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { + let key = PKey::hmac(key).expect("valid hmac key"); + let mut signer = Signer::new(MessageDigest::sha1(), &key).expect("valid signer"); + + for payload in payloads { + signer.update(payload).expect("signer update"); + } + + let mut hash = [0u8; 20]; + signer.sign(&mut hash).expect("sign to array"); + hash +} diff --git a/src/crypto/ossl/srtp.rs b/src/crypto/ossl/srtp.rs index 073f68d6..0be8c280 100644 --- a/src/crypto/ossl/srtp.rs +++ b/src/crypto/ossl/srtp.rs @@ -2,35 +2,27 @@ use openssl::cipher; use openssl::cipher_ctx::CipherCtx; use openssl::symm::{Cipher, Crypter, Mode}; -use crate::crypto::srtp::SrtpCryptoImpl; use crate::crypto::srtp::{aead_aes_128_gcm, aes_128_cm_sha1_80}; use crate::crypto::CryptoError; -pub struct OsslSrtpCryptoImpl; +pub(super) fn srtp_aes_128_ecb_round(key: &[u8], input: &[u8], output: &mut [u8]) { + let mut aes = + Crypter::new(Cipher::aes_128_ecb(), Mode::Encrypt, key, None).expect("AES deriver"); -impl SrtpCryptoImpl for OsslSrtpCryptoImpl { - type Aes128CmSha1_80 = OsslAes128CmSha1_80; - type AeadAes128Gcm = OsslAeadAes128Gcm; + // Run AES + let count = aes.update(input, output).expect("AES update"); + let rest = aes.finalize(&mut output[count..]).expect("AES finalize"); - fn srtp_aes_128_ecb_round(key: &[u8], input: &[u8], output: &mut [u8]) { - let mut aes = - Crypter::new(Cipher::aes_128_ecb(), Mode::Encrypt, key, None).expect("AES deriver"); - - // Run AES - let count = aes.update(input, output).expect("AES update"); - let rest = aes.finalize(&mut output[count..]).expect("AES finalize"); - - assert_eq!(count + rest, 16 + 16); // input len + block size - } + assert_eq!(count + rest, 16 + 16); // input len + block size } -pub struct OsslAes128CmSha1_80(CipherCtx); +pub(super) struct Aes128CmSha1_80Impl(CipherCtx); -impl aes_128_cm_sha1_80::CipherCtx for OsslAes128CmSha1_80 { - fn new(key: aes_128_cm_sha1_80::AesKey, encrypt: bool) -> Self - where - Self: Sized, - { +impl Aes128CmSha1_80Impl { + pub(super) fn new( + key: &aes_128_cm_sha1_80::AesKey, + encrypt: bool, + ) -> Box { let t = cipher::Cipher::aes_128_ctr(); let mut ctx = CipherCtx::new().expect("a reusable cipher context"); @@ -42,9 +34,11 @@ impl aes_128_cm_sha1_80::CipherCtx for OsslAes128CmSha1_80 { .expect("enc init"); } - OsslAes128CmSha1_80(ctx) + Box::new(Aes128CmSha1_80Impl(ctx)) } +} +impl aes_128_cm_sha1_80::CipherCtx for Aes128CmSha1_80Impl { fn encrypt( &mut self, iv: &aes_128_cm_sha1_80::RtpIv, @@ -70,10 +64,13 @@ impl aes_128_cm_sha1_80::CipherCtx for OsslAes128CmSha1_80 { } } -pub struct OsslAeadAes128Gcm(CipherCtx); +pub(super) struct AeadAes128GcmImpl(CipherCtx); -impl aead_aes_128_gcm::CipherCtx for OsslAeadAes128Gcm { - fn new(key: aead_aes_128_gcm::AeadKey, encrypt: bool) -> Self +impl AeadAes128GcmImpl { + pub(super) fn new( + key: &aead_aes_128_gcm::AeadKey, + encrypt: bool, + ) -> Box where Self: Sized, { @@ -81,19 +78,21 @@ impl aead_aes_128_gcm::CipherCtx for OsslAeadAes128Gcm { let mut ctx = CipherCtx::new().expect("a reusable cipher context"); if encrypt { - ctx.encrypt_init(Some(t), Some(&key), None) + ctx.encrypt_init(Some(t), Some(key), None) .expect("enc init"); ctx.set_iv_length(aead_aes_128_gcm::IV_LEN) .expect("IV length"); ctx.set_padding(false); } else { - ctx.decrypt_init(Some(t), Some(&key), None) + ctx.decrypt_init(Some(t), Some(key), None) .expect("dec init"); } - OsslAeadAes128Gcm(ctx) + Box::new(AeadAes128GcmImpl(ctx)) } +} +impl aead_aes_128_gcm::CipherCtx for AeadAes128GcmImpl { fn encrypt( &mut self, iv: &[u8; aead_aes_128_gcm::IV_LEN], diff --git a/src/crypto/srtp.rs b/src/crypto/srtp.rs index 8545aa27..b66ac7e2 100644 --- a/src/crypto/srtp.rs +++ b/src/crypto/srtp.rs @@ -1,8 +1,5 @@ use std::fmt; -use self::aead_aes_128_gcm::AeadKey; -use self::aes_128_cm_sha1_80::AesKey; - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SrtpProfile { #[cfg(feature = "_internal_test_exports")] @@ -32,95 +29,10 @@ impl SrtpProfile { } } -// TODO: Can we avoice dynamic dispatch in this signature? The parameters are: -// 1. As few "touch points" beteen rtp/srtp.rs and here as possible. -// 2. Clear contract towards the actual impl. -// 3. Choice of impl passed all the way from RtcConfig. -#[allow(unused)] -pub fn new_aes_128_cm_sha1_80( - key: AesKey, - encrypt: bool, -) -> Box { - #[cfg(feature = "openssl")] - { - let ctx = super::ossl::OsslSrtpCryptoImpl::new_aes_128_cm_sha1_80(key, encrypt); - Box::new(ctx) - } - #[cfg(feature = "wincrypto")] - { - let ctx = super::wincrypto::WinCryptoSrtpCryptoImpl::new_aes_128_cm_sha1_80(key, encrypt); - Box::new(ctx) - } - #[cfg(not(any(feature = "openssl", feature = "wincrypto")))] - { - panic!("No SRTP implementation. Enable openssl feature"); - } -} - -// TODO: Can we avoice dynamic dispatch in this signature? The parameters are: -// 1. As few "touch points" beteen rtp/srtp.rs and here as possible. -// 2. Clear contract towards the actual impl. -// 3. Choice of impl passed all the way from RtcConfig. -#[allow(unused)] -pub fn new_aead_aes_128_gcm(key: AeadKey, encrypt: bool) -> Box { - /// TODO: The exact mechanism for passing which crypto to use from - /// RtcConfig to here. We're not going to instantiate openssl - /// automatically. - #[cfg(feature = "openssl")] - { - let ctx = super::ossl::OsslSrtpCryptoImpl::new_aead_aes_128_gcm(key, encrypt); - Box::new(ctx) - } - #[cfg(feature = "wincrypto")] - { - let ctx = super::wincrypto::WinCryptoSrtpCryptoImpl::new_aead_aes_128_gcm(key, encrypt); - Box::new(ctx) - } - #[cfg(not(any(feature = "openssl", feature = "wincrypto")))] - { - panic!("No SRTP implementation. Enable openssl feature"); - } -} - -#[allow(unused)] - -pub fn srtp_aes_128_ecb_round(key: &[u8], input: &[u8], output: &mut [u8]) { - /// TODO: The exact mechanism for passing which crypto to use from - /// RtcConfig to here. We're not going to instantiate openssl - /// automatically. - #[cfg(feature = "openssl")] - { - super::ossl::OsslSrtpCryptoImpl::srtp_aes_128_ecb_round(key, input, output) - } - #[cfg(feature = "wincrypto")] - { - super::wincrypto::WinCryptoSrtpCryptoImpl::srtp_aes_128_ecb_round(key, input, output) - } - #[cfg(not(any(feature = "openssl", feature = "wincrypto")))] - { - panic!("No SRTP implementation. Enable openssl feature"); - } -} - -pub trait SrtpCryptoImpl { - type Aes128CmSha1_80: aes_128_cm_sha1_80::CipherCtx; - type AeadAes128Gcm: aead_aes_128_gcm::CipherCtx; - - fn new_aes_128_cm_sha1_80(key: AesKey, encrypt: bool) -> Self::Aes128CmSha1_80 { - ::new(key, encrypt) - } - - fn new_aead_aes_128_gcm(key: AeadKey, encrypt: bool) -> Self::AeadAes128Gcm { - ::new(key, encrypt) - } - - fn srtp_aes_128_ecb_round(key: &[u8], input: &[u8], output: &mut [u8]); -} - pub mod aes_128_cm_sha1_80 { use std::panic::UnwindSafe; - use crate::crypto::CryptoError; + use crate::crypto::{CryptoError, CryptoProvider}; pub const KEY_LEN: usize = 16; pub const SALT_LEN: usize = 14; @@ -131,10 +43,6 @@ pub mod aes_128_cm_sha1_80 { pub type RtpIv = [u8; 16]; pub trait CipherCtx: UnwindSafe + Send + Sync { - fn new(key: AesKey, encrypt: bool) -> Self - where - Self: Sized; - fn encrypt( &mut self, iv: &RtpIv, @@ -150,15 +58,27 @@ pub mod aes_128_cm_sha1_80 { ) -> Result<(), CryptoError>; } - pub fn rtp_hmac(key: &[u8], buf: &mut [u8], srtp_index: u64, hmac_start: usize) { + pub fn rtp_hmac( + ctx: &CryptoProvider, + key: &[u8], + buf: &mut [u8], + srtp_index: u64, + hmac_start: usize, + ) { let roc = (srtp_index >> 16) as u32; - let tag = crate::crypto::sha1_hmac(key, &[&buf[..hmac_start], &roc.to_be_bytes()]); + let tag = ctx.sha1_hmac(key, &[&buf[..hmac_start], &roc.to_be_bytes()]); buf[hmac_start..(hmac_start + HMAC_TAG_LEN)].copy_from_slice(&tag[0..HMAC_TAG_LEN]); } - pub fn rtp_verify(key: &[u8], buf: &[u8], srtp_index: u64, cmp: &[u8]) -> bool { + pub fn rtp_verify( + ctx: &CryptoProvider, + key: &[u8], + buf: &[u8], + srtp_index: u64, + cmp: &[u8], + ) -> bool { let roc = (srtp_index >> 16) as u32; - let tag = crate::crypto::sha1_hmac(key, &[buf, &roc.to_be_bytes()]); + let tag = ctx.sha1_hmac(key, &[buf, &roc.to_be_bytes()]); &tag[0..HMAC_TAG_LEN] == cmp } @@ -176,14 +96,14 @@ pub mod aes_128_cm_sha1_80 { iv } - pub fn rtcp_hmac(key: &[u8], buf: &mut [u8], hmac_index: usize) { - let tag = crate::crypto::sha1_hmac(key, &[&buf[0..hmac_index]]); + pub fn rtcp_hmac(ctx: &CryptoProvider, key: &[u8], buf: &mut [u8], hmac_index: usize) { + let tag = ctx.sha1_hmac(key, &[&buf[0..hmac_index]]); buf[hmac_index..(hmac_index + HMAC_TAG_LEN)].copy_from_slice(&tag[0..HMAC_TAG_LEN]); } - pub fn rtcp_verify(key: &[u8], buf: &[u8], cmp: &[u8]) -> bool { - let tag = crate::crypto::sha1_hmac(key, &[buf]); + pub fn rtcp_verify(ctx: &CryptoProvider, key: &[u8], buf: &[u8], cmp: &[u8]) -> bool { + let tag = ctx.sha1_hmac(key, &[buf]); &tag[0..HMAC_TAG_LEN] == cmp } @@ -204,10 +124,6 @@ pub mod aead_aes_128_gcm { pub type RtpIv = [u8; SALT_LEN]; pub trait CipherCtx: UnwindSafe + Send + Sync { - fn new(key: AeadKey, encrypt: bool) -> Self - where - Self: Sized; - fn encrypt( &mut self, iv: &[u8; IV_LEN], diff --git a/src/crypto/wincrypto/cert.rs b/src/crypto/wincrypto/cert.rs index 88a91469..b78263a4 100644 --- a/src/crypto/wincrypto/cert.rs +++ b/src/crypto/wincrypto/cert.rs @@ -1,24 +1,43 @@ -use crate::crypto::dtls::DTLS_CERT_IDENTITY; -use crate::crypto::Fingerprint; use std::sync::Arc; use str0m_wincrypto::WinCryptoError; +use crate::crypto::dtls::{DtlsContext, DtlsIdentity}; +use crate::crypto::{CryptoError, CryptoProvider, Fingerprint}; + +use super::dtls::DtlsContextImpl; + +pub(super) fn create_dtls_identity_impl(crypto_ctx: CryptoProvider) -> Box { + let certificate = Arc::new( + str0m_wincrypto::Certificate::new_self_signed("CN=WebRTC") + .expect("Failed to create self-signed certificate"), + ); + Box::new(DtlsIdentityImpl { + certificate, + crypto_ctx, + }) +} + #[derive(Clone, Debug)] -pub struct WinCryptoDtlsCert { - pub(crate) certificate: Arc, +pub(super) struct DtlsIdentityImpl { + crypto_ctx: CryptoProvider, + pub(super) certificate: Arc, } -impl WinCryptoDtlsCert { - pub fn new() -> Self { - let certificate = Arc::new( - str0m_wincrypto::Certificate::new_self_signed(&format!("CN={}", DTLS_CERT_IDENTITY)) - .expect("Failed to create self-signed certificate"), - ); - Self { certificate } +impl DtlsIdentity for DtlsIdentityImpl { + fn fingerprint(&self) -> Fingerprint { + create_fingerprint(&self.certificate).expect("Failed to calculate fingerprint") + } + + fn create_context(&self) -> Result, CryptoError> { + Ok(DtlsContextImpl::new(self)?) } - pub fn fingerprint(&self) -> Fingerprint { - create_fingerprint(&self.certificate).expect("Failed to calculate fingerprint") + fn boxed_clone(&self) -> Box { + Box::new(self.clone()) + } + + fn crypto_provider(&self) -> CryptoProvider { + self.crypto_ctx } } diff --git a/src/crypto/wincrypto/dtls.rs b/src/crypto/wincrypto/dtls.rs index ef6fd1a5..cdda6d5c 100644 --- a/src/crypto/wincrypto/dtls.rs +++ b/src/crypto/wincrypto/dtls.rs @@ -1,35 +1,49 @@ -use std::collections::VecDeque; -use std::time::Instant; +use std::{collections::VecDeque, time::Instant}; -use crate::crypto::dtls::DtlsInner; -use crate::crypto::CryptoError; -use crate::crypto::DtlsEvent; -use crate::crypto::{KeyingMaterial, SrtpProfile}; +use crate::crypto::dtls::{DtlsContext, DtlsIdentity}; +use crate::crypto::{ + CryptoError, CryptoProvider, DtlsEvent, Fingerprint, KeyingMaterial, SrtpProfile, +}; use crate::io::DATAGRAM_MTU_WARN; -use super::cert::{create_sha256_fingerprint, WinCryptoDtlsCert}; +use super::cert::{create_sha256_fingerprint, DtlsIdentityImpl}; -pub struct WinCryptoDtls(str0m_wincrypto::Dtls); +pub(super) struct DtlsContextImpl { + crypto_provider: CryptoProvider, + dtls: str0m_wincrypto::Dtls, + fingerprint: Fingerprint, +} -impl WinCryptoDtls { - pub fn new(cert: WinCryptoDtlsCert) -> Result { - Ok(WinCryptoDtls(str0m_wincrypto::Dtls::new( - cert.certificate.clone(), - )?)) +impl DtlsContextImpl { + pub(super) fn new(cert: &DtlsIdentityImpl) -> Result, super::CryptoError> { + let fingerprint = cert.fingerprint(); + Ok(Box::new(DtlsContextImpl { + crypto_provider: cert.crypto_provider(), + dtls: str0m_wincrypto::Dtls::new(cert.certificate.clone())?, + fingerprint, + })) } } -impl DtlsInner for WinCryptoDtls { +impl DtlsContext for DtlsContextImpl { + fn crypto_provider(&self) -> CryptoProvider { + self.crypto_provider + } + + fn local_fingerprint(&self) -> Fingerprint { + self.fingerprint.clone() + } + fn set_active(&mut self, active: bool) { - self.0.set_as_client(active).expect("Set client failed"); + self.dtls.set_as_client(active).expect("Set client failed"); } fn is_active(&self) -> Option { - self.0.is_client() + self.dtls.is_client() } fn is_connected(&self) -> bool { - self.0.is_connected() + self.dtls.is_connected() } fn handle_receive( @@ -37,7 +51,7 @@ impl DtlsInner for WinCryptoDtls { datagram: &[u8], output_events: &mut VecDeque, ) -> Result<(), CryptoError> { - transform_dtls_event(self.0.handle_receive(Some(datagram))?, output_events); + transform_dtls_event(self.dtls.handle_receive(Some(datagram))?, output_events); Ok(()) } @@ -48,13 +62,13 @@ impl DtlsInner for WinCryptoDtls { if self.is_connected() || self.is_active().is_none() { return Ok(false); } - transform_dtls_event(self.0.handle_receive(None)?, output_events); - Ok(!self.0.is_connected()) + transform_dtls_event(self.dtls.handle_receive(None)?, output_events); + Ok(!self.dtls.is_connected()) } // This is DATA sent from client over SCTP/DTLS fn handle_input(&mut self, data: &[u8]) -> Result<(), CryptoError> { - match self.0.send_data(data) { + match self.dtls.send_data(data) { Ok(true) => Ok(()), Ok(false) => Err(std::io::Error::new( std::io::ErrorKind::WouldBlock, @@ -66,7 +80,7 @@ impl DtlsInner for WinCryptoDtls { } fn poll_datagram(&mut self) -> Option { - let datagram: Option = self.0.pull_datagram().map(|v| v.into()); + let datagram: Option = self.dtls.pull_datagram().map(|v| v.into()); if let Some(datagram) = &datagram { if datagram.len() > DATAGRAM_MTU_WARN { warn!("DTLS above MTU {}: {}", DATAGRAM_MTU_WARN, datagram.len()); @@ -77,7 +91,7 @@ impl DtlsInner for WinCryptoDtls { } fn poll_timeout(&mut self, now: Instant) -> Option { - self.0.next_timeout(now) + self.dtls.next_timeout(now) } } diff --git a/src/crypto/wincrypto/mod.rs b/src/crypto/wincrypto/mod.rs index 40dc2cff..0def237d 100644 --- a/src/crypto/wincrypto/mod.rs +++ b/src/crypto/wincrypto/mod.rs @@ -1,18 +1,21 @@ //! Windows SChannel + CNG implementation of cryptographic functions. -use super::CryptoError; +use super::{CryptoError, CryptoProvider, CryptoProviderId}; mod cert; -pub use cert::WinCryptoDtlsCert; - mod dtls; -pub use dtls::WinCryptoDtls; - +mod sha1; mod srtp; -pub use srtp::WinCryptoSrtpCryptoImpl; -mod sha1; -#[allow(unused_imports)] // If 'sha1' feature is enabled this is not used. -pub use sha1::sha1_hmac; +pub(crate) fn create_crypto_provider() -> CryptoProvider { + CryptoProvider { + crypto_provider_id: CryptoProviderId::WinCrypto, + create_dtls_identity_impl: cert::create_dtls_identity_impl, + create_aes_128_cm_sha1_80_cipher_impl: srtp::Aes128CmSha1_80Impl::new, + create_aead_aes_128_gcm_cipher_impl: srtp::AeadAes128GcmImpl::new, + srtp_aes_128_ecb_round_impl: srtp::srtp_aes_128_ecb_round, + sha1_hmac_impl: sha1::sha1_hmac, + } +} pub use str0m_wincrypto::WinCryptoError; diff --git a/src/crypto/wincrypto/sha1.rs b/src/crypto/wincrypto/sha1.rs index 85a32514..82aa5526 100644 --- a/src/crypto/wincrypto/sha1.rs +++ b/src/crypto/wincrypto/sha1.rs @@ -1,4 +1,4 @@ -pub fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { +pub(super) fn sha1_hmac(key: &[u8], payloads: &[&[u8]]) -> [u8; 20] { match str0m_wincrypto::sha1_hmac(key, payloads) { Ok(hash) => hash, Err(e) => panic!("sha1_hmac failed in WinCrypto: {e}"), diff --git a/src/crypto/wincrypto/srtp.rs b/src/crypto/wincrypto/srtp.rs index b0aafad8..cf494907 100644 --- a/src/crypto/wincrypto/srtp.rs +++ b/src/crypto/wincrypto/srtp.rs @@ -1,42 +1,36 @@ -use crate::crypto::srtp::SrtpCryptoImpl; -use crate::crypto::srtp::{aead_aes_128_gcm, aes_128_cm_sha1_80}; -use crate::crypto::CryptoError; use str0m_wincrypto::{ - srtp_aead_aes_128_gcm_decrypt, srtp_aead_aes_128_gcm_encrypt, srtp_aes_128_cm, - srtp_aes_128_ecb_round, SrtpKey, + srtp_aead_aes_128_gcm_decrypt, srtp_aead_aes_128_gcm_encrypt, srtp_aes_128_cm, SrtpKey, }; -pub struct WinCryptoSrtpCryptoImpl; - -impl SrtpCryptoImpl for WinCryptoSrtpCryptoImpl { - type Aes128CmSha1_80 = WinCryptoAes128CmSha1_80; - type AeadAes128Gcm = WinCryptoAeadAes128Gcm; +use crate::crypto::srtp::{aead_aes_128_gcm, aes_128_cm_sha1_80}; +use crate::crypto::CryptoError; - fn srtp_aes_128_ecb_round(key: &[u8], input: &[u8], output: &mut [u8]) { - let key = SrtpKey::create_aes_ecb_key(key).expect("AES key"); - let count = srtp_aes_128_ecb_round(&key, input, output).expect("AES encrypt"); - assert_eq!(count, 16 + 16); // block size - } +pub(super) fn srtp_aes_128_ecb_round(key: &[u8], input: &[u8], output: &mut [u8]) { + let key = SrtpKey::new_aes_ecb_key(key).expect("AES key"); + let count = str0m_wincrypto::srtp_aes_128_ecb_round(&key, input, output).expect("AES encrypt"); + assert_eq!(count, 16 + 16); // block size } -pub struct WinCryptoAes128CmSha1_80 { +pub(super) struct Aes128CmSha1_80Impl { key: SrtpKey, } -impl aes_128_cm_sha1_80::CipherCtx for WinCryptoAes128CmSha1_80 { +impl Aes128CmSha1_80Impl { /// Create a new context for AES-128-CM-SHA1-80 encryption/decryption. /// /// The encrypt flag is ignored, since the same operation is used for both encryption and /// decryption. - fn new(key: aes_128_cm_sha1_80::AesKey, _encrypt: bool) -> Self - where - Self: Sized, - { - Self { - key: SrtpKey::create_aes_ctr_key(&key).expect("generate sym key"), - } + pub(super) fn new( + key: &aes_128_cm_sha1_80::AesKey, + _encrypt: bool, + ) -> Box { + Box::new(Self { + key: SrtpKey::new_aes_ctr_key(key).expect("generate sym key"), + }) } +} +impl aes_128_cm_sha1_80::CipherCtx for Aes128CmSha1_80Impl { fn encrypt( &mut self, iv: &aes_128_cm_sha1_80::RtpIv, @@ -58,24 +52,26 @@ impl aes_128_cm_sha1_80::CipherCtx for WinCryptoAes128CmSha1_80 { } } -pub struct WinCryptoAeadAes128Gcm { +pub(super) struct AeadAes128GcmImpl { key: SrtpKey, } -impl aead_aes_128_gcm::CipherCtx for WinCryptoAeadAes128Gcm { +impl AeadAes128GcmImpl { /// Create a new context for AES-128-GCM encryption/decryption. /// /// The encrypt flag is ignored, since it is not needed and the same /// key can be used for both encryption and decryption. - fn new(key: aead_aes_128_gcm::AeadKey, _encrypt: bool) -> Self - where - Self: Sized, - { - Self { - key: SrtpKey::create_aes_gcm_key(&key).expect("generate sym key"), - } + pub(super) fn new( + key: &aead_aes_128_gcm::AeadKey, + _encrypt: bool, + ) -> Box { + Box::new(Self { + key: SrtpKey::new_aes_gcm_key(key).expect("generate sym key"), + }) } +} +impl aead_aes_128_gcm::CipherCtx for AeadAes128GcmImpl { fn encrypt( &mut self, iv: &[u8; aead_aes_128_gcm::IV_LEN], diff --git a/src/dtls.rs b/src/dtls.rs index 21619a43..0bbe24a0 100644 --- a/src/dtls.rs +++ b/src/dtls.rs @@ -3,7 +3,7 @@ use std::time::Instant; use std::{fmt, io}; use thiserror::Error; -use crate::crypto::{CryptoError, DtlsImpl, Fingerprint}; +use crate::crypto::{CryptoError, DtlsContext, Fingerprint}; pub use crate::crypto::{DtlsCert, DtlsEvent}; use crate::net::DatagramSend; @@ -19,7 +19,7 @@ pub enum DtlsError { /// Some error from Windows Crypto layer (used for DTLS). #[error("{0}")] #[cfg(feature = "wincrypto")] - WinCrypto(#[from] crate::crypto::wincrypto::WinCryptoError), + WinCrypto(#[from] str0m_wincrypto::WinCryptoError), /// Other IO errors. #[error("{0}")] @@ -51,7 +51,7 @@ impl From for DtlsError { /// Encapsulation of DTLS. pub struct Dtls { - dtls_impl: DtlsImpl, + dtls_ctx: Box, /// The fingerprint of the certificate. fingerprint: Fingerprint, @@ -69,11 +69,11 @@ impl Dtls { /// `active` indicates whether this side should initiate the handshake or not. /// This in turn is governed by the `a=setup` SDP attribute. pub fn new(cert: DtlsCert) -> Result { - let dtls_impl = cert.create_dtls_impl()?; + let dtls_impl = cert.create_context()?; let fingerprint = cert.fingerprint(); Ok(Self { - dtls_impl, + dtls_ctx: dtls_impl, fingerprint, remote_fingerprint: None, events: VecDeque::new(), @@ -92,12 +92,12 @@ impl Dtls { /// i.e. initiating the client hello or not. This must be called /// exactly once before starting to handshake (I/O). pub fn set_active(&mut self, active: bool) { - self.dtls_impl.set_active(active) + self.dtls_ctx.set_active(active) } /// If set_active, returns what was set. pub fn is_active(&self) -> Option { - self.dtls_impl.is_active() + self.dtls_ctx.is_active() } /// The local fingerprint. @@ -114,12 +114,12 @@ impl Dtls { /// Poll for the next datagram to send. pub fn poll_datagram(&mut self) -> Option { - self.dtls_impl.poll_datagram() + self.dtls_ctx.poll_datagram() } /// Poll for a timeout. pub fn poll_timeout(&mut self, now: Instant) -> Option { - self.dtls_impl.poll_timeout(now) + self.dtls_ctx.poll_timeout(now) } /// Poll for an event. @@ -133,17 +133,17 @@ impl Dtls { /// Handling incoming data to be sent as DTLS datagrams. pub fn handle_input(&mut self, data: &[u8]) -> Result<(), DtlsError> { - Ok(self.dtls_impl.handle_input(data)?) + Ok(self.dtls_ctx.handle_input(data)?) } /// Handles an incoming DTLS datagrams. pub fn handle_receive(&mut self, message: &[u8]) -> Result<(), DtlsError> { - if self.dtls_impl.is_active().is_none() { + if self.dtls_ctx.is_active().is_none() { debug!("Ignoring DTLS datagram prior to DTLS start"); return Ok(()); } - Ok(self.dtls_impl.handle_receive(message, &mut self.events)?) + Ok(self.dtls_ctx.handle_receive(message, &mut self.events)?) } /// Handle handshaking. @@ -151,7 +151,7 @@ impl Dtls { /// Once handshaken, this becomes a noop. pub fn handle_handshake(&mut self) -> Result { let len_before = self.events.len(); - let result = self.dtls_impl.handle_handshake(&mut self.events)?; + let result = self.dtls_ctx.handle_handshake(&mut self.events)?; if self.remote_fingerprint.is_none() && self.events.len() > len_before { for ev in &self.events { @@ -165,7 +165,7 @@ impl Dtls { } pub(crate) fn is_connected(&self) -> bool { - self.dtls_impl.is_connected() + self.dtls_ctx.is_connected() } } diff --git a/src/ice/agent.rs b/src/ice/agent.rs index 5f20bcdf..ddcca10e 100644 --- a/src/ice/agent.rs +++ b/src/ice/agent.rs @@ -4,6 +4,7 @@ use std::time::{Duration, Instant}; use serde::{Deserialize, Serialize}; +use crate::crypto::{CryptoProvider, CryptoProviderId}; use crate::io::{Id, StunClass, StunMethod, StunTiming, DATAGRAM_MTU_WARN}; use crate::io::{Protocol, StunPacket}; use crate::io::{StunMessage, TransId}; @@ -20,6 +21,8 @@ use super::pair::{CandidatePair, CheckState, PairId}; /// each one. #[derive(Debug)] pub struct IceAgent { + crypto_provider: CryptoProvider, + /// Last time handle_timeout run (paced by timing_advance). /// /// This drives the state forward. @@ -251,12 +254,16 @@ impl IceAgent { /// Create a new [`IceAgent`] with randomly generated credentials. #[allow(unused)] pub fn new() -> Self { - Self::with_local_credentials(IceCreds::new()) + Self::with_local_credentials(CryptoProviderId::default().into(), IceCreds::new()) } /// Create a new [`IceAgent`] with a specific set of credentials. - pub fn with_local_credentials(local_credentials: IceCreds) -> Self { + pub fn with_local_credentials( + crypto_provider: CryptoProvider, + local_credentials: IceCreds, + ) -> Self { IceAgent { + crypto_provider, last_now: None, ice_lite: false, max_candidate_pairs: None, @@ -917,7 +924,7 @@ impl IceAgent { let do_integrity_check = |is_request: bool| -> bool { let (_, password) = self.stun_credentials(is_request); - let integrity_passed = message.check_integrity(&password); + let integrity_passed = message.check_integrity(&self.crypto_provider, &password); // The integrity is always the last thing we check if integrity_passed { @@ -1436,7 +1443,7 @@ impl IceAgent { let mut buf = vec![0_u8; DATAGRAM_MTU]; let n = reply - .to_bytes(&password, &mut buf) + .to_bytes(&self.crypto_provider, &password, &mut buf) .expect("IO error writing STUN reply"); buf.truncate(n); @@ -1483,7 +1490,7 @@ impl IceAgent { let mut buf = vec![0_u8; DATAGRAM_MTU]; let n = binding - .to_bytes(&password, &mut buf) + .to_bytes(&self.crypto_provider, &password, &mut buf) .expect("IO error writing STUN reply"); buf.truncate(n); @@ -2061,10 +2068,18 @@ mod test { let payload = Vec::from(agent.poll_transmit().unwrap().contents); let stun_message = StunMessage::parse(&payload).unwrap(); - let valid_reply = - make_authenticated_stun_reply(stun_message.trans_id(), ipv4_4(), &remote_creds.pass); - let fake_reply = - make_authenticated_stun_reply(TransId::new(), ipv4_4(), &remote_creds.pass); + let valid_reply = make_authenticated_stun_reply( + &agent.crypto_provider, + stun_message.trans_id(), + ipv4_4(), + &remote_creds.pass, + ); + let fake_reply = make_authenticated_stun_reply( + &agent.crypto_provider, + TransId::new(), + ipv4_4(), + &remote_creds.pass, + ); assert!(!agent.accepts_message(&StunMessage::parse(&fake_reply).unwrap())); assert!(agent.accepts_message(&StunMessage::parse(&valid_reply).unwrap())); @@ -2082,6 +2097,7 @@ mod test { agent.add_remote_candidate(remote_candidate); let serialized_req = make_serialized_binding_request( + &agent.crypto_provider, &agent.local_credentials, &remote_creds, !agent.controlling(), @@ -2120,8 +2136,13 @@ mod test { let remote_creds = IceCreds::new(); agent.set_remote_credentials(remote_creds.clone()); - let request = - make_serialized_binding_request(&agent.local_credentials, &remote_creds, false, 0); + let request = make_serialized_binding_request( + &agent.crypto_provider, + &agent.local_credentials, + &remote_creds, + false, + 0, + ); agent.handle_packet( Instant::now(), @@ -2137,6 +2158,7 @@ mod test { } fn make_serialized_binding_request( + crypto_provider: &CryptoProvider, local_creds: &IceCreds, remote_creds: &IceCreds, controlling: bool, @@ -2145,21 +2167,30 @@ mod test { let username = format!("{}:{}", local_creds.ufrag, remote_creds.ufrag); let binding_req = StunMessage::binding_request(&username, TransId::new(), controlling, 0, prio, false); - serialize_stun_msg(binding_req, &local_creds.pass) + serialize_stun_msg(crypto_provider, binding_req, &local_creds.pass) } - fn make_authenticated_stun_reply(tx_id: TransId, addr: SocketAddr, password: &str) -> Vec { + fn make_authenticated_stun_reply( + crypto_provider: &CryptoProvider, + tx_id: TransId, + addr: SocketAddr, + password: &str, + ) -> Vec { let reply = StunMessage::reply(tx_id, addr); - serialize_stun_msg(reply, password) + serialize_stun_msg(crypto_provider, reply, password) } /// Serializing will calculate a message integrity for it. You can then re-parse to get a message /// that contains that correct integrity value. - fn serialize_stun_msg(msg: StunMessage<'_>, password: &str) -> Vec { + fn serialize_stun_msg( + crypto_provider: &CryptoProvider, + msg: StunMessage<'_>, + password: &str, + ) -> Vec { let mut buf = vec![0_u8; DATAGRAM_MTU]; let n = msg - .to_bytes(password, &mut buf) + .to_bytes(crypto_provider, password, &mut buf) .expect("IO error writing STUN message"); buf.truncate(n); diff --git a/src/io/stun.rs b/src/io/stun.rs index 2014ddf0..abe94e83 100644 --- a/src/io/stun.rs +++ b/src/io/stun.rs @@ -4,6 +4,7 @@ use std::net::IpAddr; use std::net::SocketAddr; use std::time::Duration; +use crate::crypto::CryptoProvider; use crc::{Crc, CRC_32_ISO_HDLC}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -268,9 +269,9 @@ impl<'a> StunMessage<'a> { /// Verify the integrity of this message against the provided password. #[must_use] - pub(crate) fn check_integrity(&self, password: &str) -> bool { + pub(crate) fn check_integrity(&self, ctx: &CryptoProvider, password: &str) -> bool { if let Some(integ) = self.attrs.message_integrity { - let comp = crate::crypto::sha1_hmac( + let comp = ctx.sha1_hmac( password.as_bytes(), &[ &self.integrity[..2], @@ -288,7 +289,12 @@ impl<'a> StunMessage<'a> { /// Serialize this message into the provided buffer, returning the final length of the message. /// /// The provided password is used to authenticate the message. - pub(crate) fn to_bytes(self, password: &str, buf: &mut [u8]) -> Result { + pub(crate) fn to_bytes( + self, + ctx: &CryptoProvider, + password: &str, + buf: &mut [u8], + ) -> Result { const MSG_HEADER_LEN: usize = 20; const MSG_INTEGRITY_LEN: usize = 20; const FPRINT_LEN: usize = 4; @@ -331,7 +337,7 @@ impl<'a> StunMessage<'a> { let buf = buf.into_inner(); // Compute and fill in message integrity - let hmac = crate::crypto::sha1_hmac( + let hmac = ctx.sha1_hmac( password.as_bytes(), &[&buf[0..(integrity_value_offset - ATTR_TLV_LENGTH)]], ); @@ -817,6 +823,8 @@ impl<'a> fmt::Debug for StunMessage<'a> { #[cfg(test)] mod test { + use crate::CryptoProviderId; + use super::*; use std::net::SocketAddrV4; use systemstat::Ipv4Addr; @@ -833,10 +841,11 @@ mod test { 0xaa, 0xf9, 0x83, 0x9c, 0xa0, 0x76, 0xc6, 0xd5, 0x80, 0x28, 0x00, 0x04, 0x36, 0x0e, 0x21, 0x9f, ]; + let crypto_provider = CryptoProviderId::default().into(); let packet = PACKET.to_vec(); let message = StunMessage::parse(&packet).unwrap(); - assert!(message.check_integrity("xJcE9AQAR7kczUDVOXRUCl")); + assert!(message.check_integrity(&crypto_provider, "xJcE9AQAR7kczUDVOXRUCl")); } #[test] diff --git a/src/lib.rs b/src/lib.rs index ff28480b..88cd4f92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -596,7 +596,8 @@ use thiserror::Error; use util::InstantExt; mod crypto; -use crypto::Fingerprint; +pub use crypto::CryptoProviderId; +use crypto::{CryptoProvider, Fingerprint}; mod dtls; use dtls::DtlsCert; @@ -838,6 +839,7 @@ pub enum RtcError { /// ``` pub struct Rtc { alive: bool, + crypto_provider: CryptoProvider, ice: IceAgent, dtls: Dtls, sctp: RtcSctp, @@ -1098,31 +1100,28 @@ impl Rtc { pub(crate) fn new_from_config(config: RtcConfig) -> Self { let session = Session::new(&config); + let (crypto_provider, dtls_cert) = match config.crypto_config { + RtcCryptoConfig::None => { + panic!("Crypto provider must be set in RtcConfig"); + } + RtcCryptoConfig::CryptoProvider(crypto_provider) => ( + crypto_provider, + DtlsCert::new(crypto_provider.crypto_provider_id), + ), + RtcCryptoConfig::DtlsCert(ref dtls_cert) => { + (dtls_cert.crypto_provider(), dtls_cert.clone()) + } + }; + let local_creds = config.local_ice_credentials.unwrap_or_else(IceCreds::new); - let mut ice = IceAgent::with_local_credentials(local_creds); + let mut ice = IceAgent::with_local_credentials(crypto_provider, local_creds); if config.ice_lite { ice.set_ice_lite(config.ice_lite); } - let dtls_cert = if let Some(c) = config.dtls_cert { - c - } else { - #[cfg(feature = "openssl")] - { - DtlsCert::new_openssl() - } - #[cfg(feature = "wincrypto")] - { - DtlsCert::new_wincrypto() - } - #[cfg(not(any(feature = "openssl", feature = "wincrypto")))] - { - panic!("No DTLS implementation. Enable crypto feature"); - } - }; - Rtc { alive: true, + crypto_provider, ice, dtls: Dtls::new(dtls_cert).expect("DTLS to init without problem"), session, @@ -1435,7 +1434,12 @@ impl Rtc { srtp_profile ); let active = self.dtls.is_active().expect("DTLS must be inited by now"); - self.session.set_keying_material(mat, srtp_profile, active); + self.session.set_keying_material( + self.crypto_provider, + mat, + srtp_profile, + active, + ); } DtlsEvent::RemoteFingerprint(v1) => { debug!("DTLS verify remote fingerprint"); @@ -1813,6 +1817,25 @@ impl Rtc { } } +#[derive(Debug, Clone)] +enum RtcCryptoConfig { + None, + CryptoProvider(CryptoProvider), + DtlsCert(DtlsCert), +} + +impl Default for RtcCryptoConfig { + fn default() -> Self { + if cfg!(feature = "openssl") { + // When OpenSSL is enabled, we default to OpenSSL, this is for legacy + // compatibility. + RtcCryptoConfig::CryptoProvider(CryptoProviderId::default().into()) + } else { + RtcCryptoConfig::None + } + } +} + /// Customized config for creating an [`Rtc`] instance. /// /// ``` @@ -1827,7 +1850,7 @@ impl Rtc { #[derive(Debug, Clone)] pub struct RtcConfig { local_ice_credentials: Option, - dtls_cert: Option, + crypto_config: RtcCryptoConfig, fingerprint_verification: bool, ice_lite: bool, codec_config: CodecConfig, @@ -1869,6 +1892,32 @@ impl RtcConfig { self } + /// Get the ID of the configured Crypto Provider, if set. + /// + /// Returns [`None`] if no DTLS certificate is set. In such cases, + /// the certificate will be created on build and you can use the + /// direct API on an [`Rtc`] instance to obtain the local + /// DTLS fingerprint. + /// + /// ``` + /// # use str0m::RtcConfig; + /// let fingerprint = RtcConfig::default() + /// .build() + /// .direct_api() + /// .local_dtls_fingerprint(); + /// ``` + pub fn crypto_provider_id(&self) -> Option { + match self.crypto_config { + RtcCryptoConfig::None => None, + RtcCryptoConfig::CryptoProvider(ref crypto_provider) => { + Some(crypto_provider.crypto_provider_id) + } + RtcCryptoConfig::DtlsCert(ref dtls_cert) => { + Some(dtls_cert.crypto_provider().crypto_provider_id) + } + } + } + /// Get the configured DTLS certificate, if set. /// /// Returns [`None`] if no DTLS certificate is set. In such cases, @@ -1884,7 +1933,20 @@ impl RtcConfig { /// .local_dtls_fingerprint(); /// ``` pub fn dtls_cert(&self) -> Option<&DtlsCert> { - self.dtls_cert.as_ref() + if let RtcCryptoConfig::DtlsCert(ref cert) = self.crypto_config { + Some(cert) + } else { + None + } + } + + /// Set the Crypto Provider to use. + /// + /// Either this, or the DTLS certificate should be set but not both, as the DTLS + /// certificate implies a provider. + pub fn set_crypto_provider_id(mut self, crypto_provider_id: CryptoProviderId) -> Self { + self.crypto_config = RtcCryptoConfig::CryptoProvider(crypto_provider_id.into()); + self } /// Set the DTLS certificate for secure communication. @@ -1899,8 +1961,11 @@ impl RtcConfig { /// /// let rtc_config = RtcConfig::default() /// .set_dtls_cert(dtls_cert); + /// + /// Either this, or the DTLS certificate should be set but not both, as the DTLS + /// certificate implies a provider. pub fn set_dtls_cert(mut self, dtls_cert: DtlsCert) -> Self { - self.dtls_cert = Some(dtls_cert); + self.crypto_config = RtcCryptoConfig::DtlsCert(dtls_cert); self } @@ -2305,7 +2370,7 @@ impl Default for RtcConfig { fn default() -> Self { Self { local_ice_credentials: None, - dtls_cert: None, + crypto_config: RtcCryptoConfig::default(), fingerprint_verification: true, ice_lite: false, codec_config: CodecConfig::new_with_defaults(), diff --git a/src/rtp/mod.rs b/src/rtp/mod.rs index e9099e00..7ae5c31e 100644 --- a/src/rtp/mod.rs +++ b/src/rtp/mod.rs @@ -46,7 +46,7 @@ pub enum RtpError { /// Some error from Windows Crypto layer (used for SRTP). #[error("{0}")] #[cfg(feature = "wincrypto")] - WinCrypto(#[from] crate::crypto::wincrypto::WinCryptoError), + WinCrypto(#[from] str0m_wincrypto::WinCryptoError), /// Other IO errors. #[error("{0}")] diff --git a/src/rtp/srtp.rs b/src/rtp/srtp.rs index ccd085e1..cf318e22 100644 --- a/src/rtp/srtp.rs +++ b/src/rtp/srtp.rs @@ -1,7 +1,8 @@ use std::fmt; -use crate::crypto::{self, new_aead_aes_128_gcm, new_aes_128_cm_sha1_80, KeyingMaterial}; -use crate::crypto::{aead_aes_128_gcm, aes_128_cm_sha1_80, SrtpProfile}; +use crate::crypto::{ + aead_aes_128_gcm, aes_128_cm_sha1_80, CryptoProvider, KeyingMaterial, SrtpProfile, +}; use super::header::RtpHeader; @@ -30,10 +31,16 @@ pub const SRTP_OVERHEAD: usize = MAX_TAG_LEN; impl SrtpContext { /// Create an SRTP context for the relevant profile using the provided keying material. - pub fn new(profile: SrtpProfile, mat: &KeyingMaterial, left: bool) -> Self { + pub fn new( + crypto_provider: CryptoProvider, + profile: SrtpProfile, + mat: &KeyingMaterial, + left: bool, + ) -> Self { match profile { #[cfg(feature = "_internal_test_exports")] SrtpProfile::PassThrough => SrtpContext { + crypto_provider, rtp: Derived::PassThrough, rtcp: Derived::PassThrough, srtcp_index: 0, @@ -41,11 +48,12 @@ impl SrtpContext { SrtpProfile::Aes128CmSha1_80 => { use aes_128_cm_sha1_80::{KEY_LEN, SALT_LEN}; - let key = SrtpKey::::new(mat, left); + let key = SrtpKey::::new(crypto_provider, mat, left); - let (rtp, rtcp) = Derived::aes_128_cm_sha1_80(&key); + let (rtp, rtcp) = Derived::aes_128_cm_sha1_80(&crypto_provider, &key); SrtpContext { + crypto_provider, rtp, rtcp, srtcp_index: 0, @@ -54,11 +62,12 @@ impl SrtpContext { SrtpProfile::AeadAes128Gcm => { use aead_aes_128_gcm::{KEY_LEN, SALT_LEN}; - let key = SrtpKey::::new(mat, left); + let key = SrtpKey::::new(crypto_provider, mat, left); - let (rtp, rtcp) = Derived::aead_aes_128_gcm(&key); + let (rtp, rtcp) = Derived::aead_aes_128_gcm(&crypto_provider, &key); SrtpContext { + crypto_provider, rtp, rtcp, srtcp_index: 0, @@ -69,6 +78,7 @@ impl SrtpContext { #[cfg(test)] fn new_aead_aes_128_gcm( + crypto_provider: CryptoProvider, rtp_key: [u8; aead_aes_128_gcm::KEY_LEN], rtp_salt: [u8; aead_aes_128_gcm::SALT_LEN], rtcp_key: [u8; aead_aes_128_gcm::KEY_LEN], @@ -76,15 +86,16 @@ impl SrtpContext { srtcp_index: u32, ) -> Self { Self { + crypto_provider, rtp: Derived::AeadAes128Gcm { salt: rtp_salt, - enc: new_aead_aes_128_gcm(rtp_key, true), - dec: new_aead_aes_128_gcm(rtp_key, false), + enc: crypto_provider.create_aead_aes_128_gcm_cipher(&rtp_key, true), + dec: crypto_provider.create_aead_aes_128_gcm_cipher(&rtp_key, false), }, rtcp: Derived::AeadAes128Gcm { salt: rtcp_salt, - enc: new_aead_aes_128_gcm(rtcp_key, true), - dec: new_aead_aes_128_gcm(rtcp_key, false), + enc: crypto_provider.create_aead_aes_128_gcm_cipher(&rtcp_key, true), + dec: crypto_provider.create_aead_aes_128_gcm_cipher(&rtcp_key, false), }, srtcp_index, } @@ -93,6 +104,8 @@ impl SrtpContext { #[derive(Debug)] pub struct SrtpContext { + /// Crypto context from where we call ciphers. + crypto_provider: CryptoProvider, /// Encryption/decryption derived from srtp_key for RTP. rtp: Derived, /// Encryption/decryption derived from srtp_key for RTCP. @@ -157,7 +170,13 @@ impl SrtpContext { output[..hlen].copy_from_slice(&buf[..hlen]); let hmac_start = buf.len(); - aes_128_cm_sha1_80::rtp_hmac(key, &mut output, srtp_index, hmac_start); + aes_128_cm_sha1_80::rtp_hmac( + &self.crypto_provider, + key, + &mut output, + srtp_index, + hmac_start, + ); output } @@ -199,6 +218,7 @@ impl SrtpContext { let hmac_start = buf.len() - HMAC_TAG_LEN; if !aes_128_cm_sha1_80::rtp_verify( + &self.crypto_provider, key, &buf[..hmac_start], srtp_index, @@ -295,7 +315,7 @@ impl SrtpContext { to[0..4].copy_from_slice(&e_and_si.to_be_bytes()); let hmac_index = output.len() - HMAC_TAG_LEN; - aes_128_cm_sha1_80::rtcp_hmac(key, &mut output, hmac_index); + aes_128_cm_sha1_80::rtcp_hmac(&self.crypto_provider, key, &mut output, hmac_index); output } @@ -346,7 +366,12 @@ impl SrtpContext { let hmac_start = buf.len() - HMAC_TAG_LEN; - if !aes_128_cm_sha1_80::rtcp_verify(key, &buf[..hmac_start], &buf[hmac_start..]) { + if !aes_128_cm_sha1_80::rtcp_verify( + &self.crypto_provider, + key, + &buf[..hmac_start], + &buf[hmac_start..], + ) { trace!("unprotect_rtcp hmac verify fail"); return None; } @@ -470,12 +495,13 @@ impl SrtpContext { /// SrtpKeys created from DTLS SrtpKeyMaterial. #[derive(Debug)] struct SrtpKey { + crypto_provider: CryptoProvider, master: [u8; ML], salt: [u8; SL], } impl SrtpKey { - pub fn new(mat: &KeyingMaterial, left: bool) -> Self { + pub fn new(crypto_provider: CryptoProvider, mat: &KeyingMaterial, left: bool) -> Self { // layout in SrtpKeyMaterial is [key_input, key_output, salt_input, salt_output] // Invariant @@ -493,7 +519,11 @@ impl SrtpKey { master[0..ML].copy_from_slice(&mat[o0..(o0 + ML)]); salt[0..SL].copy_from_slice(&mat[(ML + ML + o1)..(ML + ML + o1 + SL)]); - SrtpKey { master, salt } + SrtpKey { + crypto_provider, + master, + salt, + } } fn derive(&self, label: u8, out: &mut [u8]) { @@ -521,7 +551,8 @@ impl SrtpKey { input[14..].copy_from_slice(&round.to_be_bytes()[..]); // default key derivation function, which uses AES-128 in Counter Mode - crypto::srtp_aes_128_ecb_round(&self.master, &input[..], &mut buf[..]); + self.crypto_provider + .srtp_aes_128_ecb_round(&self.master, &input[..], &mut buf[..]); // Copy to output. Even if we get 32 bytes of output with AES 128 ECB, we // only use the first 16. That matches the tests in the RFC. @@ -557,6 +588,7 @@ enum Derived { impl Derived { fn aes_128_cm_sha1_80( + ctx: &CryptoProvider, srtp_key: &SrtpKey<{ aes_128_cm_sha1_80::KEY_LEN }, { aes_128_cm_sha1_80::SALT_LEN }>, ) -> (Self, Self) { use aes_128_cm_sha1_80::*; @@ -594,21 +626,22 @@ impl Derived { let rtp = Derived::Aes128CmSha1_80 { key: rtp_hmac, salt: rtp_salt, - enc: new_aes_128_cm_sha1_80(rtp_aes, true), - dec: new_aes_128_cm_sha1_80(rtp_aes, false), + enc: ctx.create_aes_128_cm_sha1_80_cipher(&rtp_aes, true), + dec: ctx.create_aes_128_cm_sha1_80_cipher(&rtp_aes, false), }; let rtcp = Derived::Aes128CmSha1_80 { key: rtcp_hmac, salt: rtcp_salt, - enc: new_aes_128_cm_sha1_80(rtcp_aes, true), - dec: new_aes_128_cm_sha1_80(rtcp_aes, false), + enc: ctx.create_aes_128_cm_sha1_80_cipher(&rtcp_aes, true), + dec: ctx.create_aes_128_cm_sha1_80_cipher(&rtcp_aes, false), }; (rtp, rtcp) } fn aead_aes_128_gcm( + crypto_provider: &CryptoProvider, srtp_key: &SrtpKey<{ aead_aes_128_gcm::KEY_LEN }, { aead_aes_128_gcm::SALT_LEN }>, ) -> (Derived, Derived) { use aead_aes_128_gcm::*; @@ -631,14 +664,14 @@ impl Derived { let rtp = Derived::AeadAes128Gcm { salt: rtp_salt, - enc: new_aead_aes_128_gcm(rtp_aes, true), - dec: new_aead_aes_128_gcm(rtp_aes, false), + enc: crypto_provider.create_aead_aes_128_gcm_cipher(&rtp_aes, true), + dec: crypto_provider.create_aead_aes_128_gcm_cipher(&rtp_aes, false), }; let rtcp = Derived::AeadAes128Gcm { salt: rtcp_salt, - enc: new_aead_aes_128_gcm(rtcp_aes, true), - dec: new_aead_aes_128_gcm(rtcp_aes, false), + enc: crypto_provider.create_aead_aes_128_gcm_cipher(&rtcp_aes, true), + dec: crypto_provider.create_aead_aes_128_gcm_cipher(&rtcp_aes, false), }; (rtp, rtcp) @@ -676,9 +709,12 @@ fn error_details(header: &RtpHeader, srtp_index: u64) -> String { #[cfg(test)] mod test { use super::*; + use crate::crypto::CryptoProviderId; #[test] fn derive_key() { + let crypto_provider = CryptoProviderId::default().into(); + // https://tools.ietf.org/html/rfc3711#appendix-B.3 // // Key Derivation Test Vectors. @@ -693,7 +729,11 @@ mod test { 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6, ]; - let sk = SrtpKey { master, salt }; + let sk = SrtpKey { + crypto_provider, + master, + salt, + }; // aes crypto key let mut out = [0_u8; 16]; @@ -770,8 +810,14 @@ mod test { #[test] fn unprotect_rtcp() { + let crypto_provider = CryptoProviderId::default().into(); let key_mat = KeyingMaterial::new(MAT.to_vec()); - let mut ctx_rx = SrtpContext::new(SrtpProfile::Aes128CmSha1_80, &key_mat, true); + let mut ctx_rx = SrtpContext::new( + crypto_provider, + SrtpProfile::Aes128CmSha1_80, + &key_mat, + true, + ); ctx_rx.srtcp_index = 1; let decrypted = ctx_rx.unprotect_rtcp(SRTCP).unwrap(); @@ -989,6 +1035,7 @@ mod test { fn make_rtp_context() -> SrtpContext { SrtpContext::new_aead_aes_128_gcm( + CryptoProviderId::default().into(), rfc7714::KEY, rfc7714::SALT, rfc7714::KEY, @@ -999,6 +1046,7 @@ mod test { fn make_rtcp_context() -> SrtpContext { SrtpContext::new_aead_aes_128_gcm( + CryptoProviderId::default().into(), rfc7714::KEY, rfc7714::SALT, rfc7714::KEY, diff --git a/src/session.rs b/src/session.rs index d99aa8f0..ce9151de 100644 --- a/src/session.rs +++ b/src/session.rs @@ -2,8 +2,8 @@ use std::collections::{HashMap, VecDeque}; use std::time::{Duration, Instant}; use crate::bwe::BweKind; -use crate::crypto::KeyingMaterial; use crate::crypto::SrtpProfile; +use crate::crypto::{CryptoProvider, KeyingMaterial}; use crate::format::CodecConfig; use crate::format::PayloadParams; use crate::io::{DatagramSend, DATAGRAM_MTU, DATAGRAM_MTU_WARN}; @@ -191,6 +191,7 @@ impl Session { pub fn set_keying_material( &mut self, + crypto_provider: CryptoProvider, mat: KeyingMaterial, srtp_profile: SrtpProfile, active: bool, @@ -200,8 +201,8 @@ impl Session { // hand side of the key material to derive input/output. let left = active; - self.srtp_rx = Some(SrtpContext::new(srtp_profile, &mat, !left)); - self.srtp_tx = Some(SrtpContext::new(srtp_profile, &mat, left)); + self.srtp_rx = Some(SrtpContext::new(crypto_provider, srtp_profile, &mat, !left)); + self.srtp_tx = Some(SrtpContext::new(crypto_provider, srtp_profile, &mat, left)); } pub fn handle_timeout(&mut self, now: Instant) -> Result<(), RtcError> { diff --git a/wincrypto/src/srtp.rs b/wincrypto/src/srtp.rs index e59e4c87..0a18bdd1 100644 --- a/wincrypto/src/srtp.rs +++ b/wincrypto/src/srtp.rs @@ -23,13 +23,13 @@ unsafe impl Sync for SrtpKey {} impl SrtpKey { /// Creates a key from the given data for operating AES in Counter (CTR/CM) mode. - pub fn create_aes_ctr_key(key: &[u8]) -> Result { + pub fn new_aes_ctr_key(key: &[u8]) -> Result { // CTR mode is build on top of ECB mode, so we use the same key. - Self::create_aes_ecb_key(key) + Self::new_aes_ecb_key(key) } /// Creates a key from the given data for operating AES in ECB mode. - pub fn create_aes_ecb_key(key: &[u8]) -> Result { + pub fn new_aes_ecb_key(key: &[u8]) -> Result { let mut key_handle = BCRYPT_KEY_HANDLE::default(); // SAFETY: The key and key_handle will exist before and after this call. unsafe { @@ -45,7 +45,7 @@ impl SrtpKey { } /// Creates a key from the given data for operating AES in GCM mode. - pub fn create_aes_gcm_key(key: &[u8]) -> Result { + pub fn new_aes_gcm_key(key: &[u8]) -> Result { let mut key_handle = BCRYPT_KEY_HANDLE::default(); // SAFETY: The key and key_handle will exist before and after this call. unsafe { @@ -271,7 +271,7 @@ mod test { #[test] fn test_srtp_aes_128_ecb_round_test_vec_1() { let key = - SrtpKey::create_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); + SrtpKey::new_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); let mut out = [0u8; 32]; srtp_aes_128_ecb_round( &key, @@ -285,7 +285,7 @@ mod test { #[test] fn test_srtp_aes_128_ecb_round_test_vec_2() { let key = - SrtpKey::create_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); + SrtpKey::new_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); let mut out = [0u8; 32]; srtp_aes_128_ecb_round( &key, @@ -299,7 +299,7 @@ mod test { #[test] fn test_srtp_aes_128_ecb_round_test_vec_3() { let key = - SrtpKey::create_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); + SrtpKey::new_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); let mut out = [0u8; 32]; srtp_aes_128_ecb_round( &key, @@ -313,7 +313,7 @@ mod test { #[test] fn test_srtp_aes_128_ecb_round_test_vec_4() { let key = - SrtpKey::create_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); + SrtpKey::new_aes_ecb_key(&hex_to_vec("2b7e151628aed2a6abf7158809cf4f3c")).unwrap(); let mut out = [0u8; 32]; srtp_aes_128_ecb_round( &key,