diff --git a/Cargo.lock b/Cargo.lock index a452e05c12c..dd1f95f73ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2038,7 +2038,6 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", "syn 2.0.87", ] @@ -2161,37 +2160,6 @@ dependencies = [ "powerfmt", ] -[[package]] -name = "derive_builder" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn 2.0.87", -] - -[[package]] -name = "derive_builder_macro" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" -dependencies = [ - "derive_builder_core", - "syn 2.0.87", -] - [[package]] name = "deunicode" version = "1.6.0" @@ -3849,15 +3817,14 @@ dependencies = [ [[package]] name = "kafka-protocol" -version = "0.10.2" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2568bb6ab2e96399bbce79e33ddaf01f5db0827a182bb2cb5f95d6a9eae0ea68" +checksum = "d1edaf2fc3ecebe689bbc4fd97a6921cacd4cd09df8ebeda348a8e23c9fd48d4" dependencies = [ "anyhow", "bytes 1.8.0", "crc", "crc32c", - "derive_builder", "flate2", "indexmap 2.6.0", "lz4", diff --git a/NOTICE.md b/NOTICE.md index 6e4e8082aba..072ebb3bde9 100644 --- a/NOTICE.md +++ b/NOTICE.md @@ -159,9 +159,6 @@ This file contains attributions for any 3rd-party open source code used in this | ctrlc | MIT, Apache-2.0 | https://crates.io/crates/ctrlc | | curve25519-dalek | BSD-3-Clause | https://crates.io/crates/curve25519-dalek | | curve25519-dalek-derive | MIT, Apache-2.0 | https://crates.io/crates/curve25519-dalek-derive | -| darling | MIT | https://crates.io/crates/darling | -| darling_core | MIT | https://crates.io/crates/darling_core | -| darling_macro | MIT | https://crates.io/crates/darling_macro | | dashmap | MIT | https://crates.io/crates/dashmap | | data-encoding | MIT | https://crates.io/crates/data-encoding | | dbus | Apache-2.0, MIT | https://crates.io/crates/dbus | @@ -169,9 +166,6 @@ This file contains attributions for any 3rd-party open source code used in this | delegate | MIT, Apache-2.0 | https://crates.io/crates/delegate | | der | Apache-2.0, MIT | https://crates.io/crates/der | | deranged | MIT, Apache-2.0 | https://crates.io/crates/deranged | -| derive_builder | MIT, Apache-2.0 | https://crates.io/crates/derive_builder | -| derive_builder_core | MIT, Apache-2.0 | https://crates.io/crates/derive_builder_core | -| derive_builder_macro | MIT, Apache-2.0 | https://crates.io/crates/derive_builder_macro | | dialoguer | MIT | https://crates.io/crates/dialoguer | | digest | MIT, Apache-2.0 | https://crates.io/crates/digest | | displaydoc | MIT, Apache-2.0 | https://crates.io/crates/displaydoc | @@ -272,7 +266,6 @@ This file contains attributions for any 3rd-party open source code used in this | icu_properties_data | Unicode-3.0 | https://crates.io/crates/icu_properties_data | | icu_provider | Unicode-3.0 | https://crates.io/crates/icu_provider | | icu_provider_macros | Unicode-3.0 | https://crates.io/crates/icu_provider_macros | -| ident_case | MIT, Apache-2.0 | https://crates.io/crates/ident_case | | idna | MIT, Apache-2.0 | https://crates.io/crates/idna | | idna_adapter | Apache-2.0, MIT | https://crates.io/crates/idna_adapter | | image | MIT, Apache-2.0 | https://crates.io/crates/image | diff --git a/implementations/rust/ockam/ockam_api/Cargo.toml b/implementations/rust/ockam/ockam_api/Cargo.toml index 142bb1d7055..263e3bdba26 100644 --- a/implementations/rust/ockam/ockam_api/Cargo.toml +++ b/implementations/rust/ockam/ockam_api/Cargo.toml @@ -73,7 +73,7 @@ jaq-core = "1" jaq-interpret = "1" jaq-parse = "1" jaq-std = "1" -kafka-protocol = "0.10" +kafka-protocol = "0.13" log = "0.4" miette = { version = "7.2.0", features = ["fancy-no-backtrace"] } minicbor = { version = "0.25.1", default-features = false, features = ["alloc", "derive"] } diff --git a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/request.rs b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/request.rs index 3d37616c65b..e12667732a9 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/request.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/request.rs @@ -6,9 +6,9 @@ use bytes::{Bytes, BytesMut}; use kafka_protocol::messages::fetch_request::FetchRequest; use kafka_protocol::messages::produce_request::{PartitionProduceData, ProduceRequest}; use kafka_protocol::messages::request_header::RequestHeader; -use kafka_protocol::messages::{ApiKey, TopicName}; +use kafka_protocol::messages::{ApiKey, ApiVersionsRequest, TopicName}; use kafka_protocol::protocol::buf::ByteBuf; -use kafka_protocol::protocol::Decodable; +use kafka_protocol::protocol::{Decodable, Message}; use kafka_protocol::records::{ Compression, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions, }; @@ -67,13 +67,7 @@ impl KafkaMessageRequestInterceptor for InletInterceptorImpl { match api_key { ApiKey::ApiVersionsKey => { debug!("api versions request: {:?}", header); - self.request_map.lock().unwrap().insert( - header.correlation_id, - RequestInfo { - request_api_key: api_key, - request_api_version: header.request_api_version, - }, - ); + return self.handle_api_version_request(&mut buffer, &header).await; } ApiKey::ProduceKey => { @@ -116,6 +110,40 @@ impl KafkaMessageRequestInterceptor for InletInterceptorImpl { } impl InletInterceptorImpl { + async fn handle_api_version_request( + &self, + buffer: &mut Bytes, + header: &RequestHeader, + ) -> Result { + let request: ApiVersionsRequest = decode_body(buffer, header.request_api_version)?; + const MAX_SUPPORTED_VERSION: i16 = ApiVersionsRequest::VERSIONS.max; + + let request_api_version = if header.request_api_version > MAX_SUPPORTED_VERSION { + warn!("api versions request with version > {MAX_SUPPORTED_VERSION} not supported, downgrading request to {MAX_SUPPORTED_VERSION}"); + MAX_SUPPORTED_VERSION + } else { + header.request_api_version + }; + + self.request_map.lock().unwrap().insert( + header.correlation_id, + RequestInfo { + request_api_key: ApiKey::ApiVersionsKey, + request_api_version, + }, + ); + + let mut header = header.clone(); + header.request_api_version = request_api_version; + + encode_request( + &header, + &request, + request_api_version, + ApiKey::ApiVersionsKey, + ) + } + async fn handle_fetch_request( &self, context: &mut Context, @@ -178,12 +206,15 @@ impl InletInterceptorImpl { // the content can be set in multiple topics and partitions in a single message // for each we wrap the content and add the secure channel identifier of // the encrypted content - for (topic_name, topic) in request.topic_data.iter_mut() { + for topic in request.topic_data.iter_mut() { for data in &mut topic.partition_data { if let Some(content) = data.records.take() { let mut content = BytesMut::from(content.as_ref()); - let mut records = RecordBatchDecoder::decode(&mut content) - .map_err(|_| InterceptError::InvalidData)?; + let mut records = RecordBatchDecoder::decode( + &mut content, + None:: Result>, + ) + .map_err(|_| InterceptError::InvalidData)?; for record in records.iter_mut() { if let Some(record_value) = record.value.take() { @@ -192,13 +223,13 @@ impl InletInterceptorImpl { // valid JSON map self.encrypt_specific_fields( context, - topic_name, + &topic.name, data, &record_value, ) .await? } else { - self.encrypt_whole_record(context, topic_name, data, record_value) + self.encrypt_whole_record(context, &topic.name, data, record_value) .await? }; record.value = Some(buffer.into()); @@ -213,6 +244,7 @@ impl InletInterceptorImpl { version: 2, compression: Compression::None, }, + None:: Result<(), _>>, ) .map_err(|_| InterceptError::InvalidData)?; diff --git a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/response.rs b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/response.rs index a9a5c5d35af..c0b221cf503 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/response.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/response.rs @@ -6,11 +6,11 @@ use crate::kafka::protocol_aware::{ use crate::kafka::KafkaInletController; use bytes::{Bytes, BytesMut}; use kafka_protocol::messages::{ - ApiKey, ApiVersionsResponse, FetchResponse, FindCoordinatorResponse, MetadataResponse, - ResponseHeader, + ApiKey, ApiVersionsResponse, CreateTopicsResponse, FetchResponse, FindCoordinatorResponse, + ListOffsetsResponse, MetadataResponse, ProduceResponse, ResponseHeader, }; use kafka_protocol::protocol::buf::ByteBuf; -use kafka_protocol::protocol::{Decodable, StrBytes}; +use kafka_protocol::protocol::{Decodable, Message, StrBytes}; use kafka_protocol::records::{ Compression, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions, }; @@ -31,12 +31,7 @@ impl KafkaMessageResponseInterceptor for InletInterceptorImpl { // we can/need to decode only mapped requests let correlation_id = buffer.peek_bytes(0..4).try_get_i32()?; - let result = self - .request_map - .lock() - .unwrap() - .get(&correlation_id) - .cloned(); + let result = self.request_map.lock().unwrap().remove(&correlation_id); if let Some(request_info) = result { let result = ResponseHeader::decode( @@ -64,10 +59,9 @@ impl KafkaMessageResponseInterceptor for InletInterceptorImpl { match request_info.request_api_key { ApiKey::ApiVersionsKey => { - let response: ApiVersionsResponse = - decode_body(&mut buffer, request_info.request_api_version)?; - debug!("api versions response header: {:?}", header); - debug!("api versions response: {:#?}", response); + return self + .handle_api_versions_response(&mut buffer, request_info, &header) + .await; } ApiKey::FetchKey => { @@ -116,6 +110,75 @@ impl KafkaMessageResponseInterceptor for InletInterceptorImpl { } impl InletInterceptorImpl { + async fn handle_api_versions_response( + &self, + buffer: &mut Bytes, + request_info: RequestInfo, + header: &ResponseHeader, + ) -> Result { + let mut response: ApiVersionsResponse = + decode_body(buffer, request_info.request_api_version)?; + debug!("api versions response header: {:?}", header); + debug!("api versions response: {:#?}", response); + + // We must ensure that every message is fully encrypted and never leaves the + // client unencrypted. + // To do that, we __can't allow unknown/unparsable request/response__ since + // it might be a new API to produce or consume messages. + // To avoid breakage every time a client or server is updated, we reduce the + // version of the protocol to the supported version for each api. + + for api_version in response.api_keys.iter_mut() { + let result = ApiKey::try_from(api_version.api_key); + let api_key = match result { + Ok(api_key) => api_key, + Err(_) => { + warn!("unknown api key: {}", api_version.api_key); + return Err(InterceptError::InvalidData); + } + }; + + // Request and responses share the same api version range. + let ockam_supported_range = match api_key { + ApiKey::ProduceKey => ProduceResponse::VERSIONS, + ApiKey::FetchKey => FetchResponse::VERSIONS, + ApiKey::ListOffsetsKey => ListOffsetsResponse::VERSIONS, + ApiKey::MetadataKey => MetadataResponse::VERSIONS, + ApiKey::ApiVersionsKey => ApiVersionsResponse::VERSIONS, + ApiKey::CreateTopicsKey => CreateTopicsResponse::VERSIONS, + ApiKey::FindCoordinatorKey => FindCoordinatorResponse::VERSIONS, + _ => { + // we only need to check the APIs that we actually use + continue; + } + }; + + if ockam_supported_range.min <= api_version.min_version + && ockam_supported_range.max >= api_version.max_version + { + continue; + } + + info!( + "reducing api version range for api {api_key:?} from ({min_server},{max_server}) to ({min_ockam},{max_ockam})", + min_server = api_version.min_version, + max_server = api_version.max_version, + min_ockam = ockam_supported_range.min, + max_ockam = ockam_supported_range.max, + ); + + api_version.min_version = ockam_supported_range.min; + api_version.max_version = ockam_supported_range.max; + } + + encode_response( + header, + &response, + request_info.request_api_version, + ApiKey::ApiVersionsKey, + ) + } + // for metadata we want to replace broker address and port // to dedicated tcp inlet ports async fn handle_metadata_response( @@ -131,9 +194,13 @@ impl InletInterceptorImpl { // we need to keep a map of topic uuid to topic name since fetch // operations only use uuid if request_info.request_api_version >= 10 { - for (topic_name, topic) in &response.topics { + for topic in &response.topics { let topic_id = topic.topic_id.to_string(); - let topic_name = topic_name.to_string(); + let topic_name = if let Some(name) = &topic.name { + name.to_string() + } else { + continue; + }; trace!("metadata adding to map: {topic_id} => {topic_name}"); self.uuid_to_name @@ -145,19 +212,19 @@ impl InletInterceptorImpl { trace!("metadata response before: {:?}", &response); - for (broker_id, info) in response.brokers.iter_mut() { + for broker in response.brokers.iter_mut() { let inlet_address = inlet_map - .assert_inlet_for_broker(context, broker_id.0) + .assert_inlet_for_broker(context, broker.node_id.0) .await?; trace!( "inlet_address: {} for broker {}", &inlet_address, - broker_id.0 + broker.node_id.0 ); - info.host = StrBytes::from_string(inlet_address.hostname()); - info.port = inlet_address.port() as i32; + broker.host = StrBytes::from_string(inlet_address.hostname()); + broker.port = inlet_address.port() as i32; } trace!("metadata response after: {:?}", &response); @@ -225,8 +292,11 @@ impl InletInterceptorImpl { for partition in response.partitions.iter_mut() { if let Some(content) = partition.records.take() { let mut content = BytesMut::from(content.as_ref()); - let mut records = RecordBatchDecoder::decode(&mut content) - .map_err(|_| InterceptError::InvalidData)?; + let mut records = RecordBatchDecoder::decode( + &mut content, + None:: Result>, + ) + .map_err(|_| InterceptError::InvalidData)?; for record in records.iter_mut() { if let Some(record_value) = record.value.take() { @@ -247,6 +317,7 @@ impl InletInterceptorImpl { version: 2, compression: Compression::None, }, + None:: Result<(), _>>, ) .map_err(|_| InterceptError::InvalidData)?; partition.records = Some(encoded.freeze()); diff --git a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/tests.rs b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/tests.rs index 505319e3751..b7db0eb0119 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/tests.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/inlet/tests.rs @@ -4,15 +4,13 @@ use crate::kafka::protocol_aware::{ utils, KafkaEncryptedContent, KafkaMessageRequestInterceptor, KafkaMessageResponseInterceptor, }; use crate::kafka::KafkaInletController; -use bytes::BytesMut; -use indexmap::IndexMap; +use bytes::{Bytes, BytesMut}; use kafka_protocol::messages::fetch_response::{FetchableTopicResponse, PartitionData}; use kafka_protocol::messages::produce_request::{PartitionProduceData, TopicProduceData}; -use kafka_protocol::messages::ApiKey::ProduceKey; use kafka_protocol::messages::{ ApiKey, FetchResponse, ProduceRequest, RequestHeader, ResponseHeader, TopicName, }; -use kafka_protocol::protocol::{Builder, Decodable, StrBytes}; +use kafka_protocol::protocol::{Decodable, StrBytes}; use kafka_protocol::records::{ Compression, Record, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions, TimestampType, }; @@ -66,14 +64,11 @@ impl KafkaKeyExchangeController for MockKafkaKeyExchangeController { const TEST_KAFKA_API_VERSION: i16 = 13; pub fn create_kafka_produce_request(content: &[u8]) -> BytesMut { - let header = RequestHeader::builder() - .request_api_key(ApiKey::ProduceKey as i16) - .request_api_version(TEST_KAFKA_API_VERSION) - .correlation_id(1) - .client_id(Some(StrBytes::from_static_str("my-client-id"))) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(); + let header = RequestHeader::default() + .with_request_api_key(ApiKey::ProduceKey as i16) + .with_request_api_version(TEST_KAFKA_API_VERSION) + .with_correlation_id(1) + .with_client_id(Some(StrBytes::from_static_str("my-client-id"))); let mut encoded = BytesMut::new(); RecordBatchEncoder::encode( @@ -97,31 +92,16 @@ pub fn create_kafka_produce_request(content: &[u8]) -> BytesMut { version: 2, compression: Compression::None, }, + None:: Result<(), _>>, ) .unwrap(); - let mut topic_data = IndexMap::new(); - topic_data.insert( - TopicName::from(StrBytes::from_static_str("topic-name")), - TopicProduceData::builder() - .partition_data(vec![PartitionProduceData::builder() - .index(1) - .records(Some(encoded.freeze())) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), - ); - let request = ProduceRequest::builder() - .transactional_id(None) - .acks(0) - .timeout_ms(0) - .topic_data(topic_data) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(); + let topic_data = vec![TopicProduceData::default() + .with_name(TopicName::from(StrBytes::from_static_str("topic-name"))) + .with_partition_data(vec![PartitionProduceData::default() + .with_index(1) + .with_records(Some(encoded.freeze()))])]; + let request = ProduceRequest::default().with_topic_data(topic_data); utils::encode_request( &header, @@ -133,11 +113,7 @@ pub fn create_kafka_produce_request(content: &[u8]) -> BytesMut { } pub fn create_kafka_fetch_response(content: &[u8]) -> BytesMut { - let header = ResponseHeader::builder() - .correlation_id(1) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(); + let header = ResponseHeader::default().with_correlation_id(1); let mut encoded = BytesMut::new(); RecordBatchEncoder::encode( @@ -161,37 +137,16 @@ pub fn create_kafka_fetch_response(content: &[u8]) -> BytesMut { version: 2, compression: Compression::None, }, + None:: Result<(), _>>, ) .unwrap(); - let response = FetchResponse::builder() - .throttle_time_ms(Default::default()) - .error_code(Default::default()) - .session_id(Default::default()) - .responses(vec![FetchableTopicResponse::builder() - .topic(TopicName::from(StrBytes::from_static_str("topic-name"))) - .topic_id(Default::default()) - .partitions(vec![PartitionData::builder() - .partition_index(1) - .error_code(Default::default()) - .high_watermark(Default::default()) - .last_stable_offset(Default::default()) - .log_start_offset(Default::default()) - .diverging_epoch(Default::default()) - .current_leader(Default::default()) - .snapshot_id(Default::default()) - .aborted_transactions(Default::default()) - .preferred_read_replica(Default::default()) - .records(Some(encoded.freeze())) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(); + let response = FetchResponse::default().with_responses(vec![FetchableTopicResponse::default() + .with_topic(TopicName::from(StrBytes::from_static_str("topic-name"))) + .with_topic_id(Default::default()) + .with_partitions(vec![PartitionData::default() + .with_partition_index(1) + .with_records(Some(encoded.freeze()))])]); utils::encode_response(&header, &response, TEST_KAFKA_API_VERSION, ApiKey::FetchKey).unwrap() } @@ -200,7 +155,7 @@ pub fn parse_produce_request(content: &[u8]) -> ProduceRequest { let mut buffer = BytesMut::from(content); let _header = RequestHeader::decode( &mut buffer, - ProduceKey.request_header_version(TEST_KAFKA_API_VERSION), + ApiKey::ProduceKey.request_header_version(TEST_KAFKA_API_VERSION), ) .unwrap(); utils::decode_body(&mut buffer, TEST_KAFKA_API_VERSION).unwrap() @@ -290,10 +245,9 @@ pub async fn json_encrypt_specific_fields(context: &mut Context) -> ockam::Resul let request = parse_produce_request(&encrypted_response); let topic_data = request.topic_data.first().unwrap(); - assert_eq!("topic-name", topic_data.0 .0.as_str()); + assert_eq!("topic-name", topic_data.name.as_str()); let mut batch_content = topic_data - .1 .partition_data .first() .cloned() @@ -301,7 +255,11 @@ pub async fn json_encrypt_specific_fields(context: &mut Context) -> ockam::Resul .records .unwrap(); - let records = RecordBatchDecoder::decode(&mut batch_content).unwrap(); + let records = RecordBatchDecoder::decode( + &mut batch_content, + None:: Result>, + ) + .unwrap(); let record = records.first().unwrap(); let record_content = record.value.clone().unwrap(); @@ -378,7 +336,11 @@ pub async fn json_decrypt_specific_fields(context: &mut Context) -> ockam::Resul .first() .unwrap(); let mut records = partition_data.records.clone().unwrap(); - let records = RecordBatchDecoder::decode(&mut records).unwrap(); + let records = RecordBatchDecoder::decode( + &mut records, + None:: Result>, + ) + .unwrap(); let record = records.first().unwrap(); let value = diff --git a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/outlet/response.rs b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/outlet/response.rs index 78e16228c72..27b4661b422 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/outlet/response.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/outlet/response.rs @@ -58,11 +58,11 @@ impl KafkaMessageResponseInterceptor for OutletInterceptorImpl { let response: MetadataResponse = decode_body(&mut buffer, request_info.request_api_version)?; - for (broker_id, metadata) in response.brokers { - let address = format!("{}:{}", metadata.host.as_str(), metadata.port); + for broker in response.brokers { + let address = format!("{}:{}", broker.host.as_str(), broker.port); let outlet_address = self .outlet_controller - .assert_outlet_for_broker(context, broker_id.0, address) + .assert_outlet_for_broker(context, broker.node_id.0, address) .await?; // allow the interceptor to reach the outlet diff --git a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs index e0c538c951e..afdd9b25173 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs @@ -14,7 +14,7 @@ mod test { use kafka_protocol::messages::BrokerId; use kafka_protocol::messages::{ApiVersionsRequest, MetadataRequest, MetadataResponse}; use kafka_protocol::messages::{ApiVersionsResponse, RequestHeader, ResponseHeader}; - use kafka_protocol::protocol::{Builder, StrBytes}; + use kafka_protocol::protocol::StrBytes; use ockam_abac::{Action, Env, Resource, ResourceType}; use ockam_core::route; use ockam_multiaddr::MultiAddr; @@ -82,20 +82,13 @@ mod test { .intercept_request( context, encode_request( - &RequestHeader::builder() - .request_api_version(api_version) - .correlation_id(correlation_id) - .request_api_key(ApiKey::ApiVersionsKey as i16) - .unknown_tagged_fields(Default::default()) - .client_id(None) - .build() - .unwrap(), - &ApiVersionsRequest::builder() - .client_software_name(StrBytes::from_static_str("mr. software")) - .client_software_version(StrBytes::from_static_str("1.0.0")) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), + &RequestHeader::default() + .with_request_api_version(api_version) + .with_correlation_id(correlation_id) + .with_request_api_key(ApiKey::ApiVersionsKey as i16), + &ApiVersionsRequest::default() + .with_client_software_name(StrBytes::from_static_str("mr. software")) + .with_client_software_version(StrBytes::from_static_str("1.0.0")), api_version, ApiKey::ApiVersionsKey, ) @@ -111,21 +104,8 @@ mod test { .intercept_response( context, encode_response( - &ResponseHeader::builder() - .correlation_id(correlation_id) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), - &ApiVersionsResponse::builder() - .error_code(0) - .api_keys(Default::default()) - .throttle_time_ms(0) - .supported_features(Default::default()) - .finalized_features_epoch(0) - .finalized_features(Default::default()) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), + &ResponseHeader::default().with_correlation_id(correlation_id), + &ApiVersionsResponse::default(), api_version, ApiKey::ApiVersionsKey, ) @@ -145,22 +125,11 @@ mod test { .intercept_request( context, encode_request( - &RequestHeader::builder() - .request_api_version(api_version) - .correlation_id(correlation_id) - .request_api_key(ApiKey::MetadataKey as i16) - .unknown_tagged_fields(Default::default()) - .client_id(None) - .build() - .unwrap(), - &MetadataRequest::builder() - .topics(None) - .allow_auto_topic_creation(true) - .include_cluster_authorized_operations(false) - .include_topic_authorized_operations(false) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), + &RequestHeader::default() + .with_request_api_version(api_version) + .with_correlation_id(correlation_id) + .with_request_api_key(ApiKey::MetadataKey as i16), + &MetadataRequest::default(), api_version, ApiKey::MetadataKey, ) @@ -176,21 +145,8 @@ mod test { .intercept_response( context, encode_response( - &ResponseHeader::builder() - .correlation_id(correlation_id) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), - &MetadataResponse::builder() - .throttle_time_ms(0) - .brokers(Default::default()) - .cluster_id(None) - .controller_id(BrokerId::from(0_i32)) - .cluster_authorized_operations(-2147483648) - .topics(Default::default()) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), + &ResponseHeader::default().with_correlation_id(correlation_id), + &MetadataResponse::default().with_controller_id(BrokerId::from(0_i32)), api_version, ApiKey::MetadataKey, ) diff --git a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/utils.rs b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/utils.rs index 3accde9dde3..1d1c090ace4 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/utils.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/utils.rs @@ -11,8 +11,9 @@ where { let response = match T::decode(buffer, api_version) { Ok(response) => response, - Err(_) => { - warn!("cannot decode kafka message"); + Err(error) => { + warn!("cannot decode kafka message, closing connection"); + debug!("error: {:?}", error); return Err(InterceptError::InvalidData); } }; diff --git a/implementations/rust/ockam/ockam_api/src/kafka/tests/integration_test.rs b/implementations/rust/ockam/ockam_api/src/kafka/tests/integration_test.rs index 4b3fead2383..016fc7f062d 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/tests/integration_test.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/tests/integration_test.rs @@ -1,17 +1,14 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; -use bytes::{Buf, BufMut, BytesMut}; -use indexmap::IndexMap; +use bytes::{Buf, BufMut, Bytes, BytesMut}; use kafka_protocol::messages::produce_request::{PartitionProduceData, TopicProduceData}; use kafka_protocol::messages::{ fetch_request::{FetchPartition, FetchTopic}, fetch_response::FetchableTopicResponse, fetch_response::PartitionData, - ApiKey, BrokerId, FetchRequest, FetchResponse, ProduceRequest, RequestHeader, ResponseHeader, - TopicName, + ApiKey, FetchRequest, FetchResponse, ProduceRequest, RequestHeader, ResponseHeader, TopicName, }; -use kafka_protocol::protocol::Builder; use kafka_protocol::protocol::Decodable as KafkaDecodable; use kafka_protocol::protocol::Encodable as KafkaEncodable; use kafka_protocol::protocol::StrBytes; @@ -191,11 +188,8 @@ async fn producer__flow_with_mock_kafka__content_encryption_and_decryption( let encrypted_body = request .topic_data - .iter() - .next() - .as_ref() + .first() .unwrap() - .1 .partition_data .first() .unwrap() @@ -204,7 +198,11 @@ async fn producer__flow_with_mock_kafka__content_encryption_and_decryption( .unwrap(); let mut encrypted_body = BytesMut::from(encrypted_body.as_ref()); - let records = RecordBatchDecoder::decode(&mut encrypted_body).unwrap(); + let records = RecordBatchDecoder::decode( + &mut encrypted_body, + None:: Result>, + ) + .unwrap(); // verify the message has been encrypted assert_ne!( @@ -246,7 +244,11 @@ async fn producer__flow_with_mock_kafka__content_encryption_and_decryption( .unwrap(); let mut plain_content = BytesMut::from(plain_content.as_ref()); - let records = RecordBatchDecoder::decode(&mut plain_content).unwrap(); + let records = RecordBatchDecoder::decode( + &mut plain_content, + None:: Result>, + ) + .unwrap(); assert_eq!( records.first().as_ref().unwrap().value.as_ref().unwrap(), @@ -275,14 +277,11 @@ async fn simulate_kafka_producer_and_read_request( } async fn send_kafka_produce_request(stream: &mut TcpStream) { - let header = RequestHeader::builder() - .request_api_key(ApiKey::ProduceKey as i16) - .request_api_version(TEST_KAFKA_API_VERSION) - .correlation_id(1) - .client_id(Some(StrBytes::from_static_str("my-client-id"))) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(); + let header = RequestHeader::default() + .with_request_api_key(ApiKey::ProduceKey as i16) + .with_request_api_version(TEST_KAFKA_API_VERSION) + .with_correlation_id(1) + .with_client_id(Some(StrBytes::from_static_str("my-client-id"))); let mut encoded = BytesMut::new(); RecordBatchEncoder::encode( @@ -306,31 +305,16 @@ async fn send_kafka_produce_request(stream: &mut TcpStream) { version: 2, compression: Compression::None, }, + None:: Result<(), _>>, ) .unwrap(); - let mut topic_data = IndexMap::new(); - topic_data.insert( - TopicName::from(StrBytes::from_static_str("my-topic-name")), - TopicProduceData::builder() - .partition_data(vec![PartitionProduceData::builder() - .index(1) - .records(Some(encoded.freeze())) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), - ); - let request = ProduceRequest::builder() - .transactional_id(None) - .acks(0) - .timeout_ms(0) - .topic_data(topic_data) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(); + let topic_data = vec![TopicProduceData::default() + .with_name(TopicName::from(StrBytes::from_static_str("my-topic-name"))) + .with_partition_data(vec![PartitionProduceData::default() + .with_index(1) + .with_records(Some(encoded.freeze()))])]; + let request = ProduceRequest::default().with_topic_data(topic_data); send_kafka_request(stream, header, request, ApiKey::ProduceKey).await; } @@ -385,10 +369,9 @@ async fn send_kafka_fetch_response( stream: S, producer_request: &ProduceRequest, ) { - let topic_name = TopicName::from(StrBytes::from_static_str("my-topic-name")); let producer_content = producer_request .topic_data - .get(&topic_name) + .first() .unwrap() .partition_data .first() @@ -396,41 +379,15 @@ async fn send_kafka_fetch_response( .records .clone(); + let topic_name = TopicName::from(StrBytes::from_static_str("my-topic-name")); send_kafka_response( stream, - ResponseHeader::builder() - .correlation_id(1) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), - FetchResponse::builder() - .throttle_time_ms(Default::default()) - .error_code(Default::default()) - .session_id(Default::default()) - .responses(vec![FetchableTopicResponse::builder() - .topic(topic_name) - .topic_id(Default::default()) - .partitions(vec![PartitionData::builder() - .partition_index(1) - .error_code(Default::default()) - .high_watermark(Default::default()) - .last_stable_offset(Default::default()) - .log_start_offset(Default::default()) - .diverging_epoch(Default::default()) - .current_leader(Default::default()) - .snapshot_id(Default::default()) - .aborted_transactions(Default::default()) - .preferred_read_replica(Default::default()) - .records(producer_content) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), + ResponseHeader::default().with_correlation_id(1), + FetchResponse::default().with_responses(vec![FetchableTopicResponse::default() + .with_topic(topic_name) + .with_partitions(vec![PartitionData::default() + .with_partition_index(1) + .with_records(producer_content)])]), ApiKey::FetchKey, ) .await; @@ -439,44 +396,15 @@ async fn send_kafka_fetch_response( async fn send_kafka_fetch_request(stream: &mut TcpStream) { send_kafka_request( stream, - RequestHeader::builder() - .request_api_key(ApiKey::FetchKey as i16) - .request_api_version(TEST_KAFKA_API_VERSION) - .correlation_id(1) - .client_id(Some(StrBytes::from_static_str("my-client-id"))) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), - FetchRequest::builder() - .cluster_id(None) - .replica_id(BrokerId::default()) - .max_wait_ms(0) - .min_bytes(0) - .max_bytes(0) - .isolation_level(0) - .session_id(0) - .session_epoch(0) - .topics(vec![FetchTopic::builder() - .topic(TopicName::from(StrBytes::from_static_str("my-topic-name"))) - .topic_id(Uuid::from_slice(b"my-topic-name___").unwrap()) - .partitions(vec![FetchPartition::builder() - .partition(1) - .current_leader_epoch(0) - .fetch_offset(0) - .last_fetched_epoch(0) - .log_start_offset(0) - .partition_max_bytes(0) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap()]) - .forgotten_topics_data(Default::default()) - .rack_id(Default::default()) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), + RequestHeader::default() + .with_request_api_key(ApiKey::FetchKey as i16) + .with_request_api_version(TEST_KAFKA_API_VERSION) + .with_correlation_id(1) + .with_client_id(Some(StrBytes::from_static_str("my-client-id"))), + FetchRequest::default().with_topics(vec![FetchTopic::default() + .with_topic(TopicName::from(StrBytes::from_static_str("my-topic-name"))) + .with_topic_id(Uuid::from_slice(b"my-topic-name___").unwrap()) + .with_partitions(vec![FetchPartition::default().with_partition(1)])]), ApiKey::FetchKey, ) .await; diff --git a/implementations/rust/ockam/ockam_api/src/kafka/tests/interceptor_test.rs b/implementations/rust/ockam/ockam_api/src/kafka/tests/interceptor_test.rs index 606200a9f51..06a1c315f8f 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/tests/interceptor_test.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/tests/interceptor_test.rs @@ -3,12 +3,10 @@ use std::str::FromStr; use std::time::Duration; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use kafka_protocol::messages::metadata_request::MetadataRequestBuilder; use kafka_protocol::messages::metadata_response::MetadataResponseBroker; use kafka_protocol::messages::{ ApiKey, BrokerId, MetadataRequest, MetadataResponse, RequestHeader, ResponseHeader, }; -use kafka_protocol::protocol::Builder; use kafka_protocol::protocol::Decodable; use kafka_protocol::protocol::Encodable as KafkaEncodable; use kafka_protocol::protocol::StrBytes; @@ -167,14 +165,7 @@ async fn kafka_portal_worker__bigger_than_limit_kafka_message__error( encode( &mut request_buffer, create_request_header(ApiKey::MetadataKey), - MetadataRequestBuilder::default() - .topics(Default::default()) - .include_cluster_authorized_operations(Default::default()) - .include_topic_authorized_operations(Default::default()) - .allow_auto_topic_creation(Default::default()) - .unknown_tagged_fields(insanely_huge_tag) - .build() - .unwrap(), + MetadataRequest::default().with_unknown_tagged_fields(insanely_huge_tag), ); let huge_payload = request_buffer.as_ref(); @@ -221,14 +212,7 @@ async fn kafka_portal_worker__almost_over_limit_than_limit_kafka_message__two_ka encode( &mut huge_outgoing_request, create_request_header(ApiKey::MetadataKey), - MetadataRequestBuilder::default() - .topics(Default::default()) - .include_cluster_authorized_operations(Default::default()) - .include_topic_authorized_operations(Default::default()) - .allow_auto_topic_creation(Default::default()) - .unknown_tagged_fields(insanely_huge_tag.clone()) - .build() - .unwrap(), + MetadataRequest::default().with_unknown_tagged_fields(insanely_huge_tag.clone()), ); let receiver = TcpPayloadReceiver { @@ -366,14 +350,11 @@ where } fn create_request_header(api_key: ApiKey) -> RequestHeader { - RequestHeader::builder() - .request_api_key(api_key as i16) - .request_api_version(TEST_KAFKA_API_VERSION) - .correlation_id(1) - .client_id(Some(StrBytes::from_static_str("my-client-id"))) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap() + RequestHeader::default() + .with_request_api_key(api_key as i16) + .with_request_api_version(TEST_KAFKA_API_VERSION) + .with_correlation_id(1) + .with_client_id(Some(StrBytes::from_static_str("my-client-id"))) } #[allow(non_snake_case)] @@ -477,31 +458,13 @@ async fn kafka_portal_worker__metadata_exchange__response_changed( let mut response_buffer = BytesMut::new(); { - let response_header = ResponseHeader::builder() - .correlation_id(1) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(); - - let metadata_response = MetadataResponse::builder() - .throttle_time_ms(Default::default()) - .cluster_id(Default::default()) - .cluster_authorized_operations(-2147483648) - .unknown_tagged_fields(Default::default()) - .controller_id(BrokerId::from(1)) - .topics(Default::default()) - .brokers(indexmap::IndexMap::from_iter(vec![( - BrokerId(1), - MetadataResponseBroker::builder() - .host(StrBytes::from_static_str("bad.remote.host.example.com")) - .port(1234) - .rack(Default::default()) - .unknown_tagged_fields(Default::default()) - .build() - .unwrap(), - )])) - .build() - .unwrap(); + let response_header = ResponseHeader::default().with_correlation_id(1); + let metadata_response = MetadataResponse::default() + .with_controller_id(BrokerId::from(1)) + .with_brokers(vec![MetadataResponseBroker::default() + .with_node_id(BrokerId::from(1)) + .with_host(StrBytes::from_static_str("bad.remote.host.example.com")) + .with_port(1234)]); let size = response_header .compute_size(TEST_KAFKA_API_VERSION) @@ -541,7 +504,8 @@ async fn kafka_portal_worker__metadata_exchange__response_changed( let response = MetadataResponse::decode(&mut buffer_received, TEST_KAFKA_API_VERSION).unwrap(); assert_eq!(1, response.brokers.len()); - let broker = response.brokers.get(&BrokerId::from(1)).unwrap(); + let broker = response.brokers.first().unwrap(); + assert_eq!(1, broker.node_id.0); assert_eq!("127.0.0.1", &broker.host.to_string()); assert_eq!(0, broker.port);