From 098958d0ad975c6c42b210a4a2b7ca8f319cf71a Mon Sep 17 00:00:00 2001 From: Davide Baldo Date: Mon, 23 Dec 2024 17:34:47 +0100 Subject: [PATCH] feat(rust): added zeroize on `LocalMessage` --- Cargo.lock | 1 + .../rust/ockam/ockam_api/src/echoer.rs | 7 +- .../ockam_api/src/proxy_vault/protocol.rs | 27 +-- .../ockam/ockam_api/tests/common/session.rs | 2 +- .../rust/ockam/ockam_core/Cargo.toml | 1 + .../ockam/ockam_core/src/cbor/cow_bytes.rs | 15 ++ .../rust/ockam/ockam_core/src/lib.rs | 2 + .../rust/ockam/ockam_core/src/message.rs | 5 +- .../src/routing/message/local_message.rs | 28 ++- .../rust/ockam/ockam_core/src/zeroize.rs | 176 ++++++++++++++++++ .../src/secure_channel/decryptor.rs | 8 +- .../src/secure_channel/encryptor_worker.rs | 31 ++- .../handshake/handshake_worker.rs | 2 +- .../src/secure_channel/message.rs | 6 +- .../ockam_node/src/context/send_message.rs | 41 +++- .../src/transport_message.rs | 12 +- .../ockam_transport_tcp/src/workers/sender.rs | 2 +- .../src/messages/routing_message.rs | 12 +- .../src/puncture/puncture/receiver.rs | 3 +- .../src/puncture/puncture/sender.rs | 2 +- .../ockam_transport_udp/src/workers/sender.rs | 2 +- 21 files changed, 322 insertions(+), 63 deletions(-) create mode 100644 implementations/rust/ockam/ockam_core/src/zeroize.rs diff --git a/Cargo.lock b/Cargo.lock index 19e798dcad1..4d2fa786c82 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4904,6 +4904,7 @@ dependencies = [ "tracing-opentelemetry", "tracing-subscriber", "utcnow", + "zeroize", ] [[package]] diff --git a/implementations/rust/ockam/ockam_api/src/echoer.rs b/implementations/rust/ockam/ockam_api/src/echoer.rs index 1f33e3f1665..bd76bf2ac26 100644 --- a/implementations/rust/ockam/ockam_api/src/echoer.rs +++ b/implementations/rust/ockam/ockam_api/src/echoer.rs @@ -13,7 +13,10 @@ impl Worker for Echoer { async fn handle_message(&mut self, ctx: &mut Context, msg: Routed) -> Result<()> { log::debug!(src = %msg.src_addr(), from = %msg.sender()?, to = %msg.return_route().next()?, "echoing back"); let msg = msg.into_local_message(); - ctx.send(msg.return_route, NeutralMessage::from(msg.payload)) - .await + ctx.send( + msg.return_route, + NeutralMessage::from(msg.payload.discard_zeroize()), + ) + .await } } diff --git a/implementations/rust/ockam/ockam_api/src/proxy_vault/protocol.rs b/implementations/rust/ockam/ockam_api/src/proxy_vault/protocol.rs index 0d4ef12da09..09316951c21 100644 --- a/implementations/rust/ockam/ockam_api/src/proxy_vault/protocol.rs +++ b/implementations/rust/ockam/ockam_api/src/proxy_vault/protocol.rs @@ -4,10 +4,11 @@ use minicbor::{CborLen, Decode, Encode}; use ockam::identity::{utils, TimestampInSeconds, Vault}; use ockam_core::errcode::{Kind, Origin}; use ockam_core::{ - async_trait, cbor_encode_preallocate, route, Address, NeutralMessage, Route, Routed, Worker, + async_trait, cbor_encode_preallocate, route, Address, NeutralMessage, OnDrop, Route, Routed, + Worker, }; use ockam_multiaddr::MultiAddr; -use ockam_node::Context; +use ockam_node::{Context, MessageSendReceiveOptions}; use std::sync::Arc; use tokio::sync::Mutex as AsyncMutex; @@ -245,11 +246,13 @@ impl SpecificClient { let response: NeutralMessage = self .client .context - .send_and_receive( + .send_and_receive_extended( route![route, self.destination.clone()], NeutralMessage::from(encoded), + MessageSendReceiveOptions::new().with_on_drop(OnDrop::Zeroize), ) - .await?; + .await? + .into_body()?; Ok(minicbor::decode::(&response.into_vec())?) } @@ -258,7 +261,7 @@ impl SpecificClient { mod vault_for_signing { use crate::proxy_vault::protocol::{ProxyError, SpecificClient}; use minicbor::{CborLen, Decode, Encode}; - use ockam_core::{async_trait, cbor_encode_preallocate}; + use ockam_core::{async_trait, cbor_encode_preallocate, MaybeZeroizeOnDrop}; use ockam_vault::{ Signature, SigningKeyType, SigningSecretKeyHandle, VaultForSigning, VerifyingPublicKey, }; @@ -296,7 +299,7 @@ mod vault_for_signing { pub(super) async fn handle_request( vault: &dyn VaultForSigning, - request: Vec, + request: MaybeZeroizeOnDrop>, ) -> ockam_core::Result> { let request: Request = minicbor::decode(&request)?; let response = match request { @@ -480,7 +483,7 @@ mod vault_for_signing { pub mod vault_for_secure_channels { use crate::proxy_vault::protocol::{ProxyError, SpecificClient}; use minicbor::{CborLen, Decode, Encode}; - use ockam_core::{async_trait, cbor_encode_preallocate}; + use ockam_core::{async_trait, cbor_encode_preallocate, MaybeZeroizeOnDrop}; use ockam_vault::{ AeadSecretKeyHandle, HKDFNumberOfOutputs, HashOutput, HkdfOutput, SecretBufferHandle, VaultForSecureChannels, X25519PublicKey, X25519SecretKeyHandle, @@ -488,7 +491,7 @@ pub mod vault_for_secure_channels { pub(super) async fn handle_request( vault: &dyn VaultForSecureChannels, - request: Vec, + request: MaybeZeroizeOnDrop>, ) -> ockam_core::Result> { let request: Request = minicbor::decode(&request)?; let response = match request { @@ -1084,12 +1087,12 @@ pub mod vault_for_secure_channels { pub mod vault_for_verify_signatures { use crate::proxy_vault::protocol::{ProxyError, SpecificClient}; use minicbor::{CborLen, Decode, Encode}; - use ockam_core::{async_trait, cbor_encode_preallocate}; + use ockam_core::{async_trait, cbor_encode_preallocate, MaybeZeroizeOnDrop}; use ockam_vault::{Sha256Output, Signature, VaultForVerifyingSignatures, VerifyingPublicKey}; pub(super) async fn handle_request( vault: &dyn VaultForVerifyingSignatures, - request: Vec, + request: MaybeZeroizeOnDrop>, ) -> ockam_core::Result> { let request: Request = minicbor::decode(&request)?; let response = match request { @@ -1179,12 +1182,12 @@ pub mod vault_for_verify_signatures { pub mod vault_for_encryption_at_rest { use crate::proxy_vault::protocol::{ProxyError, SpecificClient}; use minicbor::{CborLen, Decode, Encode}; - use ockam_core::{async_trait, cbor_encode_preallocate}; + use ockam_core::{async_trait, cbor_encode_preallocate, MaybeZeroizeOnDrop}; use ockam_vault::{AeadSecretKeyHandle, VaultForEncryptionAtRest}; pub(super) async fn handle_request( vault: &dyn VaultForEncryptionAtRest, - request: Vec, + request: MaybeZeroizeOnDrop>, ) -> ockam_core::Result> { let request: Request = minicbor::decode(&request)?; let response = match request { diff --git a/implementations/rust/ockam/ockam_api/tests/common/session.rs b/implementations/rust/ockam/ockam_api/tests/common/session.rs index f33debc0df9..c1a9e323092 100644 --- a/implementations/rust/ockam/ockam_api/tests/common/session.rs +++ b/implementations/rust/ockam/ockam_api/tests/common/session.rs @@ -53,7 +53,7 @@ impl Worker for MockEchoer { ctx.send( msg.return_route().clone(), - NeutralMessage::from(msg.into_payload()), + NeutralMessage::from(msg.into_payload().discard_zeroize()), ) .await?; info!("Echo message back"); diff --git a/implementations/rust/ockam/ockam_core/Cargo.toml b/implementations/rust/ockam/ockam_core/Cargo.toml index 17d11801368..6cede60e614 100644 --- a/implementations/rust/ockam/ockam_core/Cargo.toml +++ b/implementations/rust/ockam/ockam_core/Cargo.toml @@ -98,6 +98,7 @@ tracing-opentelemetry = { version = "0.27.0", optional = true } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"], optional = true } # Wasn't tested on no_std utcnow = { version = "0.2.5", default-features = false, features = ["fallback"], optional = true } +zeroize = { version = "1.8", default-features = false } [dev-dependencies] cddl-cat = { version = "0.6.2" } diff --git a/implementations/rust/ockam/ockam_core/src/cbor/cow_bytes.rs b/implementations/rust/ockam/ockam_core/src/cbor/cow_bytes.rs index d2efb7ee8e8..52ec4913291 100644 --- a/implementations/rust/ockam/ockam_core/src/cbor/cow_bytes.rs +++ b/implementations/rust/ockam/ockam_core/src/cbor/cow_bytes.rs @@ -4,6 +4,7 @@ use crate::compat::vec::Vec; use core::ops::Deref; use minicbor::{CborLen, Decode, Encode}; use serde::{Deserialize, Serialize}; +use zeroize::Zeroize; /// A new type around `Cow<'_, [u8]>` that borrows from input. /// @@ -79,3 +80,17 @@ impl<'a> Deref for CowBytes<'a> { &self.0 } } + +impl Default for CowBytes<'_> { + fn default() -> Self { + CowBytes(Cow::Borrowed(&[])) + } +} + +impl Zeroize for CowBytes<'_> { + fn zeroize(&mut self) { + if !self.is_borrowed() { + self.0.to_mut().zeroize(); + } + } +} diff --git a/implementations/rust/ockam/ockam_core/src/lib.rs b/implementations/rust/ockam/ockam_core/src/lib.rs index a72fec734a6..c33954fa6c9 100644 --- a/implementations/rust/ockam/ockam_core/src/lib.rs +++ b/implementations/rust/ockam/ockam_core/src/lib.rs @@ -86,6 +86,7 @@ mod processor; mod routing; mod uint; mod worker; +mod zeroize; pub use access_control::*; pub use cbor::*; @@ -96,6 +97,7 @@ pub use processor::*; pub use routing::*; pub use uint::*; pub use worker::*; +pub use zeroize::*; #[cfg(all(not(feature = "std"), feature = "alloc"))] #[doc(hidden)] diff --git a/implementations/rust/ockam/ockam_core/src/message.rs b/implementations/rust/ockam/ockam_core/src/message.rs index ad08f0800e6..6334589b33c 100644 --- a/implementations/rust/ockam/ockam_core/src/message.rs +++ b/implementations/rust/ockam/ockam_core/src/message.rs @@ -1,3 +1,4 @@ +use crate::zeroize::MaybeZeroizeOnDrop; use crate::{ compat::{ string::{String, ToString}, @@ -244,7 +245,7 @@ impl Routed { /// Consume the message wrapper and return the original message. #[inline] pub fn into_body(self) -> Result { - M::decode(&self.into_payload()) + M::decode(self.payload()) } /// Consume the message wrapper and return the underlying local message. @@ -267,7 +268,7 @@ impl Routed { /// Consume the message wrapper and return the underlying transport message's binary payload. #[inline] - pub fn into_payload(self) -> Vec { + pub fn into_payload(self) -> MaybeZeroizeOnDrop> { self.local_msg.into_payload() } } diff --git a/implementations/rust/ockam/ockam_core/src/routing/message/local_message.rs b/implementations/rust/ockam/ockam_core/src/routing/message/local_message.rs index d387a74c589..55a89facdb3 100644 --- a/implementations/rust/ockam/ockam_core/src/routing/message/local_message.rs +++ b/implementations/rust/ockam/ockam_core/src/routing/message/local_message.rs @@ -1,7 +1,8 @@ #[cfg(feature = "std")] use crate::OpenTelemetryContext; -use crate::{compat::vec::Vec, route, Address, Message, Route, TransportMessage}; +use crate::{compat::vec::Vec, route, Address, Message, OnDrop, Route, TransportMessage}; +use crate::zeroize::MaybeZeroizeOnDrop; use crate::{LocalInfo, Result}; use cfg_if::cfg_if; use serde::{Deserialize, Serialize}; @@ -46,7 +47,7 @@ pub struct LocalMessage { /// Return message route. This field must be populated by routers handling this message along the way. pub return_route: Route, /// The message payload. - pub payload: Vec, + pub payload: MaybeZeroizeOnDrop>, /// Local information added by workers to give additional context to the message /// independently of its payload. For example this can be used to store the identifier that /// was used to encrypt the payload @@ -142,7 +143,7 @@ impl LocalMessage { } /// Return the message payload - pub fn into_payload(self) -> Vec { + pub fn into_payload(self) -> MaybeZeroizeOnDrop> { self.payload } @@ -153,12 +154,12 @@ impl LocalMessage { /// Return a mutable reference to the message payload pub fn payload_mut(&mut self) -> &mut Vec { - &mut self.payload + self.payload.as_mut() } /// Set the message payload pub fn set_payload(mut self, payload: Vec) -> Self { - self.payload = payload; + self.payload = MaybeZeroizeOnDrop::new(payload, OnDrop::NoZeroize); self } @@ -213,7 +214,7 @@ impl LocalMessage { 1, self.onward_route, self.return_route, - self.payload, + self.payload.discard_zeroize(), None, ); @@ -285,7 +286,7 @@ impl LocalMessage { LocalMessage { onward_route, return_route, - payload, + payload: MaybeZeroizeOnDrop::new(payload, OnDrop::NoZeroize), local_info, #[cfg(feature = "std")] tracing_context: OpenTelemetryContext::current(), @@ -316,7 +317,18 @@ impl LocalMessage { /// Specify the payload for the message pub fn with_payload(self, payload: Vec) -> Self { - Self { payload, ..self } + Self { + payload: MaybeZeroizeOnDrop::new(payload, OnDrop::NoZeroize), + ..self + } + } + + /// Specify the payload for the message with zeroization + pub fn with_payload_on_drop(self, payload: Vec, on_drop: OnDrop) -> Self { + Self { + payload: MaybeZeroizeOnDrop::new(payload, on_drop), + ..self + } } /// Specify the local information for the message diff --git a/implementations/rust/ockam/ockam_core/src/zeroize.rs b/implementations/rust/ockam/ockam_core/src/zeroize.rs new file mode 100644 index 00000000000..2ad7d844942 --- /dev/null +++ b/implementations/rust/ockam/ockam_core/src/zeroize.rs @@ -0,0 +1,176 @@ +use core::fmt::{Debug, Display, Formatter}; +use core::hash::{Hash, Hasher}; +use core::ops::Deref; +use minicbor::encode::{Error, Write}; +use minicbor::{CborLen, Decode, Encode, Encoder}; +use std::ops::DerefMut; +use zeroize::Zeroize; + +/// OnDrop is an enum to specify whether to zeroize the inner value when dropped. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Encode, Decode, CborLen)] +#[rustfmt::skip] +pub enum OnDrop { + /// Do not zeroize the inner value when dropped + #[n(0)] NoZeroize, + /// Zeroize the inner value when dropped + #[n(1)] Zeroize, +} + +/// MaybeZeroizeOnDrop will zeroize the inner value when dropped when zeroize_on_drop is true. +pub struct MaybeZeroizeOnDrop { + target: T, + on_drop: OnDrop, +} + +impl MaybeZeroizeOnDrop { + /// Create a new MaybeZeroizeOnDrop with the given target and zeroize_on_drop flag. + pub fn new(target: T, on_drop: OnDrop) -> Self { + Self { target, on_drop } + } + + /// Gets on_drop + pub fn on_drop(&self) -> OnDrop { + self.on_drop + } + + /// Sets on_drop + pub fn set_zeroize(&mut self, on_drop: OnDrop) { + self.on_drop = on_drop; + } +} + +impl MaybeZeroizeOnDrop { + /// Return the inner value regardless of the zeroize_on_drop flag. + /// The caller has the responsibility to ensure that the inner value is zeroized when necessary. + pub fn discard_zeroize(mut self) -> T { + std::mem::take(&mut self.target) + } +} + +impl Deref for MaybeZeroizeOnDrop { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.target + } +} + +impl DerefMut for MaybeZeroizeOnDrop { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.target + } +} + +impl Drop for MaybeZeroizeOnDrop { + fn drop(&mut self) { + if let OnDrop::Zeroize = self.on_drop { + self.target.zeroize(); + } + } +} + +impl Zeroize for MaybeZeroizeOnDrop { + fn zeroize(&mut self) { + self.target.zeroize(); + } +} + +impl Default for MaybeZeroizeOnDrop { + fn default() -> Self { + Self { + target: T::default(), + on_drop: OnDrop::NoZeroize, + } + } +} + +impl Debug for MaybeZeroizeOnDrop { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MaybeZeroizeOnDrop") + .field("target", &self.target) + .field("on_drop", &self.on_drop) + .finish() + } +} + +impl Display for MaybeZeroizeOnDrop { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.target.fmt(f) + } +} + +impl Clone for MaybeZeroizeOnDrop { + fn clone(&self) -> Self { + Self { + target: self.target.clone(), + on_drop: self.on_drop, + } + } +} + +impl PartialEq for MaybeZeroizeOnDrop { + fn eq(&self, other: &Self) -> bool { + self.target.eq(&other.target) + } +} + +impl Eq for MaybeZeroizeOnDrop {} + +impl PartialOrd for MaybeZeroizeOnDrop { + fn partial_cmp(&self, other: &Self) -> Option { + self.target.partial_cmp(&other.target) + } +} + +impl Ord for MaybeZeroizeOnDrop { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.target.cmp(&other.target) + } +} + +impl Hash for MaybeZeroizeOnDrop { + fn hash(&self, state: &mut H) { + self.target.hash(state) + } +} + +// serde serialize/deserialize impls +impl serde::Serialize for MaybeZeroizeOnDrop { + fn serialize(&self, serializer: S) -> Result { + self.target.serialize(serializer) + } +} + +impl<'de, T: Zeroize + serde::Deserialize<'de>> serde::Deserialize<'de> for MaybeZeroizeOnDrop { + fn deserialize>(deserializer: D) -> Result { + T::deserialize(deserializer).map(|target| Self { + target, + on_drop: OnDrop::NoZeroize, + }) + } +} + +impl> Encode for MaybeZeroizeOnDrop { + fn encode(&self, e: &mut Encoder, ctx: &mut C) -> Result<(), Error> { + self.target.encode(e, ctx) + } + + fn is_nil(&self) -> bool { + self.target.is_nil() + } +} + +impl<'b, C, T: Zeroize + Decode<'b, C>> Decode<'b, C> for MaybeZeroizeOnDrop { + fn decode(d: &mut minicbor::Decoder<'b>, ctx: &mut C) -> Result { + T::decode(d, ctx).map(|target| Self { + target, + on_drop: OnDrop::NoZeroize, + }) + } +} + +impl> CborLen for MaybeZeroizeOnDrop { + fn cbor_len(&self, ctx: &mut C) -> usize { + self.target.cbor_len(ctx) + } +} diff --git a/implementations/rust/ockam/ockam_identity/src/secure_channel/decryptor.rs b/implementations/rust/ockam/ockam_identity/src/secure_channel/decryptor.rs index e50b1e91646..857a1208795 100644 --- a/implementations/rust/ockam/ockam_identity/src/secure_channel/decryptor.rs +++ b/implementations/rust/ockam/ockam_identity/src/secure_channel/decryptor.rs @@ -1,6 +1,6 @@ use core::sync::atomic::Ordering; use ockam_core::compat::sync::Arc; -use ockam_core::{route, Any, Result, Route, Routed, SecureChannelLocalInfo}; +use ockam_core::{route, Any, OnDrop, Result, Route, Routed, SecureChannelLocalInfo}; use ockam_core::{Decodable, LocalMessage}; use ockam_node::Context; @@ -21,6 +21,7 @@ use ockam_core::errcode::{Kind, Origin}; use ockam_vault::{AeadSecretKeyHandle, VaultForSecureChannels}; use tracing::{debug, info, trace, warn}; use tracing_attributes::instrument; +use zeroize::Zeroize; pub(crate) struct DecryptorHandler { //for debug purposes only @@ -144,7 +145,7 @@ impl DecryptorHandler { let msg = LocalMessage::new() .with_onward_route(msg.onward_route) .with_return_route(return_route) - .with_payload(msg.payload.to_vec()) + .with_payload_on_drop(msg.payload.to_vec(), msg.on_drop) .with_local_info(local_info); match ctx @@ -217,6 +218,8 @@ impl DecryptorHandler { // Decode raw payload binary let mut payload = msg.payload; + // it might contain sensitive data, so we zeroize it in *case of an error* + payload.set_zeroize(OnDrop::Zeroize); // Decrypt the binary let (decrypted_payload, nonce) = self.decryptor.decrypt(payload.as_mut_slice()).await?; @@ -224,6 +227,7 @@ impl DecryptorHandler { match decrypted_msg.message { SecureChannelMessage::Payload(decrypted_msg) => { + payload.set_zeroize(decrypted_msg.on_drop); self.handle_payload(ctx, decrypted_msg, nonce, encrypted_msg_return_route) .await? } diff --git a/implementations/rust/ockam/ockam_identity/src/secure_channel/encryptor_worker.rs b/implementations/rust/ockam/ockam_identity/src/secure_channel/encryptor_worker.rs index 66ac4a425f6..1123fffd86a 100644 --- a/implementations/rust/ockam/ockam_identity/src/secure_channel/encryptor_worker.rs +++ b/implementations/rust/ockam/ockam_identity/src/secure_channel/encryptor_worker.rs @@ -1,17 +1,17 @@ use core::sync::atomic::{AtomicBool, Ordering}; -use tracing::{debug, error, info, warn}; -use tracing_attributes::instrument; - use ockam_core::compat::boxed::Box; use ockam_core::compat::sync::{Arc, RwLock}; use ockam_core::compat::vec::Vec; use ockam_core::errcode::{Kind, Origin}; use ockam_core::{ - async_trait, route, CowBytes, Decodable, Error, LocalMessage, NeutralMessage, Route, + async_trait, route, CowBytes, Decodable, Error, LocalMessage, MaybeZeroizeOnDrop, + NeutralMessage, OnDrop, Route, }; use ockam_core::{Any, Result, Routed, Worker}; use ockam_node::Context; +use tracing::{debug, error, info, warn}; +use tracing_attributes::instrument; use crate::models::CredentialAndPurposeKey; use crate::secure_channel::addresses::Addresses; @@ -98,13 +98,21 @@ impl EncryptorWorker { &mut self, ctx: &Context, msg: SecureChannelPaddedMessage<'static>, + on_drop: OnDrop, ) -> Result> { let expected_len = minicbor::len(&msg); - let mut destination = vec![0u8; NOISE_NONCE_LEN + expected_len + AES_GCM_TAGSIZE]; + let mut destination = MaybeZeroizeOnDrop::new( + vec![0u8; NOISE_NONCE_LEN + expected_len + AES_GCM_TAGSIZE], + on_drop, + ); minicbor::encode(&msg, &mut destination[NOISE_NONCE_LEN..])?; match self.encryptor.encrypt(&mut destination).await { - Ok(()) => Ok(destination), + Ok(()) => { + // the content of the destination is now encrypted, + // and we can safely return it as `Vec` + Ok(destination.discard_zeroize()) + } // If encryption failed, that means we have some internal error, // and we may be in an invalid state, it's better to stop the Worker Err(err) => { @@ -203,21 +211,24 @@ impl EncryptorWorker { let msg = msg.into_local_message(); let mut onward_route = msg.onward_route; let return_route = msg.return_route; + let on_drop = msg.payload.on_drop(); + let payload = + MaybeZeroizeOnDrop::new(CowBytes::from(msg.payload.discard_zeroize()), on_drop); // Remove our address let _ = onward_route.step(); - let payload = CowBytes::from(msg.payload); let msg = PlaintextPayloadMessage { onward_route, return_route, payload, + on_drop, }; let msg = SecureChannelMessage::Payload(msg); let msg = Self::add_padding(msg); - let payload = self.encrypt(ctx, msg).await?; + let payload = self.encrypt(ctx, msg, on_drop).await?; let remote_route = self.shared_state.remote_route.read().unwrap().route.clone(); // Decryptor doesn't need the return_route since it has `self.remote_route` as well @@ -288,7 +299,7 @@ impl EncryptorWorker { let msg = SecureChannelMessage::RefreshCredentials(msg); let msg = Self::add_padding(msg); - let msg = self.encrypt(ctx, msg).await?; + let msg = self.encrypt(ctx, msg, OnDrop::NoZeroize).await?; info!( "Sending credentials refresh for {}", @@ -314,7 +325,7 @@ impl EncryptorWorker { let msg = Self::add_padding(msg); // Encrypt the message - let msg = self.encrypt(ctx, msg).await?; + let msg = self.encrypt(ctx, msg, OnDrop::NoZeroize).await?; let remote_route = self.shared_state.remote_route.read().unwrap().route.clone(); // Send the message to the decryptor on the other side diff --git a/implementations/rust/ockam/ockam_identity/src/secure_channel/handshake/handshake_worker.rs b/implementations/rust/ockam/ockam_identity/src/secure_channel/handshake/handshake_worker.rs index 40c8b3383a2..0a6865ff2c9 100644 --- a/implementations/rust/ockam/ockam_identity/src/secure_channel/handshake/handshake_worker.rs +++ b/implementations/rust/ockam/ockam_identity/src/secure_channel/handshake/handshake_worker.rs @@ -272,7 +272,7 @@ impl HandshakeWorker { .state_machine .as_mut() .ok_or(IdentityError::HandshakeInternalError)? - .on_event(ReceivedMessage(payload)) + .on_event(ReceivedMessage(payload.discard_zeroize())) .await? { // set the remote route by taking the most up to date message return route diff --git a/implementations/rust/ockam/ockam_identity/src/secure_channel/message.rs b/implementations/rust/ockam/ockam_identity/src/secure_channel/message.rs index 9a58e0b08ad..6ede1285040 100644 --- a/implementations/rust/ockam/ockam_identity/src/secure_channel/message.rs +++ b/implementations/rust/ockam/ockam_identity/src/secure_channel/message.rs @@ -1,7 +1,7 @@ use crate::models::{ChangeHistory, CredentialAndPurposeKey}; use minicbor::{CborLen, Decode, Encode}; use ockam_core::compat::vec::Vec; -use ockam_core::{CowBytes, Route}; +use ockam_core::{CowBytes, MaybeZeroizeOnDrop, OnDrop, Route}; /// Secure Channel Message format. #[derive(Debug, Encode, Decode, CborLen, Clone)] @@ -34,7 +34,9 @@ pub struct PlaintextPayloadMessage<'a> { /// Return route of the message. #[n(1)] pub return_route: Route, /// Untyped binary payload. - #[b(2)] pub payload: CowBytes<'a>, + #[b(2)] pub payload: MaybeZeroizeOnDrop>, + /// Whether to Zeroize the payload on drop. + #[n(3)] pub on_drop: OnDrop, } /// Secure Channel Message format. diff --git a/implementations/rust/ockam/ockam_node/src/context/send_message.rs b/implementations/rust/ockam/ockam_node/src/context/send_message.rs index 5170f891c8b..f8f66733e5a 100644 --- a/implementations/rust/ockam/ockam_node/src/context/send_message.rs +++ b/implementations/rust/ockam/ockam_node/src/context/send_message.rs @@ -8,14 +8,15 @@ use ockam_core::compat::{sync::Arc, vec::Vec}; use ockam_core::{ errcode::{Kind, Origin}, route, Address, AllOutgoingAccessControl, AllowAll, AllowOnwardAddress, Error, - IncomingAccessControl, LocalMessage, Mailboxes, Message, OutgoingAccessControl, RelayMessage, - Result, Route, Routed, + IncomingAccessControl, LocalMessage, Mailboxes, Message, OnDrop, OutgoingAccessControl, + RelayMessage, Result, Route, Routed, }; use ockam_core::{LocalInfo, Mailbox}; /// Full set of options to `send_and_receive_extended` function pub struct MessageSendReceiveOptions { message_wait: MessageWait, + on_drop: OnDrop, incoming_access_control: Option>, outgoing_access_control: Option>, } @@ -31,6 +32,7 @@ impl MessageSendReceiveOptions { pub fn new() -> Self { Self { message_wait: MessageWait::Timeout(DEFAULT_TIMEOUT), + on_drop: OnDrop::NoZeroize, incoming_access_control: None, outgoing_access_control: None, } @@ -65,6 +67,12 @@ impl MessageSendReceiveOptions { self.outgoing_access_control = Some(outgoing_access_control); self } + + /// Set on drop behavior + pub fn with_on_drop(mut self, on_drop: OnDrop) -> Self { + self.on_drop = on_drop; + self + } } impl Context { @@ -150,7 +158,9 @@ impl Context { #[cfg(feature = "std")] child_ctx.set_tracing_context(self.tracing_context()); - child_ctx.send(route, msg).await?; + child_ctx + .send_from_address_impl(route, msg, self.address(), vec![], options.on_drop) + .await?; child_ctx .receive_extended::( MessageReceiveOptions::new().with_message_wait(options.message_wait), @@ -226,8 +236,14 @@ impl Context { R: Into, M: Message + Send + 'static, { - self.send_from_address_impl(route.into(), msg, self.address(), local_info) - .await + self.send_from_address_impl( + route.into(), + msg, + self.address(), + local_info, + OnDrop::NoZeroize, + ) + .await } /// Send a message to an address or via a fully-qualified route @@ -253,8 +269,14 @@ impl Context { R: Into, M: Message + Send + 'static, { - self.send_from_address_impl(route.into(), msg, sending_address, Vec::new()) - .await + self.send_from_address_impl( + route.into(), + msg, + sending_address, + Vec::new(), + OnDrop::NoZeroize, + ) + .await } async fn send_from_address_impl( @@ -263,6 +285,7 @@ impl Context { msg: M, sending_address: Address, local_info: Vec, + on_drop: OnDrop, ) -> Result<()> where M: Message + Send + 'static, @@ -305,13 +328,13 @@ impl Context { .with_tracing_context(self.tracing_context().update()) .with_onward_route(route) .with_return_route(route![sending_address.clone()]) - .with_payload(payload) + .with_payload_on_drop(payload, on_drop) .with_local_info(local_info); } else { let local_msg = LocalMessage::new() .with_onward_route(route) .with_return_route(route![sending_address.clone()]) - .with_payload(payload) + .with_payload_on_drop(payload, on_drop) .with_local_info(local_info); } } diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/transport_message.rs b/implementations/rust/ockam/ockam_transport_tcp/src/transport_message.rs index 10ab6f94963..2d86d61effe 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/transport_message.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/transport_message.rs @@ -64,12 +64,14 @@ impl From> for LocalMessage { } } -impl From for TcpTransportMessage<'_> { - fn from(value: LocalMessage) -> Self { +impl TryFrom for TcpTransportMessage<'_> { + type Error = ockam_core::Error; + + fn try_from(value: LocalMessage) -> Result { let transport_message = Self::new( value.onward_route, value.return_route, - CowBytes::from(value.payload), + CowBytes::from(value.payload.discard_zeroize()), None, ); @@ -77,9 +79,9 @@ impl From for TcpTransportMessage<'_> { if #[cfg(feature = "std")] { // make sure to pass the latest tracing context let new_tracing_context = LocalMessage::start_new_tracing_context(value.tracing_context.update(), "TcpTransportMessage"); - transport_message.with_tracing_context(new_tracing_context) + Ok(transport_message.with_tracing_context(new_tracing_context)) } else { - transport_message + Ok(transport_message) } } } diff --git a/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs b/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs index 9f5cbb0e9a1..62bbc4758f1 100644 --- a/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs +++ b/implementations/rust/ockam/ockam_transport_tcp/src/workers/sender.rs @@ -120,7 +120,7 @@ impl TcpSendWorker { fn serialize_message(&mut self, local_message: LocalMessage) -> Result<()> { // Create a message buffer with prepended length - let transport_message = TcpTransportMessage::from(local_message); + let transport_message = TcpTransportMessage::try_from(local_message)?; let expected_payload_len = minicbor::len(&transport_message); diff --git a/implementations/rust/ockam/ockam_transport_udp/src/messages/routing_message.rs b/implementations/rust/ockam/ockam_transport_udp/src/messages/routing_message.rs index e6f9edb16f5..f55e25418dc 100644 --- a/implementations/rust/ockam/ockam_transport_udp/src/messages/routing_message.rs +++ b/implementations/rust/ockam/ockam_transport_udp/src/messages/routing_message.rs @@ -75,12 +75,14 @@ impl From> for LocalMessage { } } -impl From for UdpRoutingMessage<'_> { - fn from(value: LocalMessage) -> Self { +impl TryFrom for UdpRoutingMessage<'_> { + type Error = ockam_core::Error; + + fn try_from(value: LocalMessage) -> Result { let routing_message = Self::new( value.onward_route, value.return_route, - CowBytes::from(value.payload), + CowBytes::from(value.payload.discard_zeroize()), None, ); @@ -88,9 +90,9 @@ impl From for UdpRoutingMessage<'_> { if #[cfg(feature = "std")] { // make sure to pass the latest tracing context let new_tracing_context = LocalMessage::start_new_tracing_context(value.tracing_context.update(), "UdpRoutingMessage"); - routing_message.with_tracing_context(new_tracing_context) + Ok(routing_message.with_tracing_context(new_tracing_context)) } else { - routing_message + Ok(routing_message) } } } diff --git a/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/receiver.rs b/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/receiver.rs index a49cd0bf8ac..baa87568194 100644 --- a/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/receiver.rs +++ b/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/receiver.rs @@ -294,7 +294,8 @@ impl Worker for UdpPunctureReceiverWorker { if &addr == self.addresses.remote_address() { let msg = msg.into_local_message(); let return_route = msg.return_route; - self.handle_peer(ctx, msg.payload, &return_route).await?; + self.handle_peer(ctx, msg.payload.discard_zeroize(), &return_route) + .await?; } else if &addr == self.addresses.heartbeat_address() { self.handle_heartbeat(ctx).await?; } else { diff --git a/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/sender.rs b/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/sender.rs index b17e88cb56f..ecc38761fef 100644 --- a/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/sender.rs +++ b/implementations/rust/ockam/ockam_transport_udp/src/puncture/puncture/sender.rs @@ -38,7 +38,7 @@ impl UdpPunctureSenderWorker { let wrapped_payload = PunctureMessage::Payload { onward_route, return_route, - payload: msg.payload, + payload: msg.payload.discard_zeroize(), }; let msg = LocalMessage::new() diff --git a/implementations/rust/ockam/ockam_transport_udp/src/workers/sender.rs b/implementations/rust/ockam/ockam_transport_udp/src/workers/sender.rs index ad696e2e5ee..4e9c8a0f1ba 100644 --- a/implementations/rust/ockam/ockam_transport_udp/src/workers/sender.rs +++ b/implementations/rust/ockam/ockam_transport_udp/src/workers/sender.rs @@ -117,7 +117,7 @@ struct TransportMessagesIterator { impl TransportMessagesIterator { fn new(current_routing_number: RoutingNumber, local_message: LocalMessage) -> Result { - let routing_message = UdpRoutingMessage::from(local_message); + let routing_message = UdpRoutingMessage::try_from(local_message)?; let routing_message = ockam_core::cbor_encode_preallocate(routing_message)?;