Skip to content

Commit

Permalink
feat(rust): reduce the api versions to the supported range
Browse files Browse the repository at this point in the history
  • Loading branch information
davide-baldo committed Nov 29, 2024
1 parent f906dc6 commit 079c376
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 383 deletions.
37 changes: 2 additions & 35 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 0 additions & 7 deletions NOTICE.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,13 @@ This file contains attributions for any 3rd-party open source code used in this
| ctrlc | MIT, Apache-2.0 | https://crates.io/crates/ctrlc |
| curve25519-dalek | BSD-3-Clause | https://crates.io/crates/curve25519-dalek |
| curve25519-dalek-derive | MIT, Apache-2.0 | https://crates.io/crates/curve25519-dalek-derive |
| darling | MIT | https://crates.io/crates/darling |
| darling_core | MIT | https://crates.io/crates/darling_core |
| darling_macro | MIT | https://crates.io/crates/darling_macro |
| dashmap | MIT | https://crates.io/crates/dashmap |
| data-encoding | MIT | https://crates.io/crates/data-encoding |
| dbus | Apache-2.0, MIT | https://crates.io/crates/dbus |
| dbus-tokio | Apache-2.0, MIT | https://crates.io/crates/dbus-tokio |
| delegate | MIT, Apache-2.0 | https://crates.io/crates/delegate |
| der | Apache-2.0, MIT | https://crates.io/crates/der |
| deranged | MIT, Apache-2.0 | https://crates.io/crates/deranged |
| derive_builder | MIT, Apache-2.0 | https://crates.io/crates/derive_builder |
| derive_builder_core | MIT, Apache-2.0 | https://crates.io/crates/derive_builder_core |
| derive_builder_macro | MIT, Apache-2.0 | https://crates.io/crates/derive_builder_macro |
| dialoguer | MIT | https://crates.io/crates/dialoguer |
| digest | MIT, Apache-2.0 | https://crates.io/crates/digest |
| displaydoc | MIT, Apache-2.0 | https://crates.io/crates/displaydoc |
Expand Down Expand Up @@ -272,7 +266,6 @@ This file contains attributions for any 3rd-party open source code used in this
| icu_properties_data | Unicode-3.0 | https://crates.io/crates/icu_properties_data |
| icu_provider | Unicode-3.0 | https://crates.io/crates/icu_provider |
| icu_provider_macros | Unicode-3.0 | https://crates.io/crates/icu_provider_macros |
| ident_case | MIT, Apache-2.0 | https://crates.io/crates/ident_case |
| idna | MIT, Apache-2.0 | https://crates.io/crates/idna |
| idna_adapter | Apache-2.0, MIT | https://crates.io/crates/idna_adapter |
| image | MIT, Apache-2.0 | https://crates.io/crates/image |
Expand Down
2 changes: 1 addition & 1 deletion implementations/rust/ockam/ockam_api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jaq-core = "1"
jaq-interpret = "1"
jaq-parse = "1"
jaq-std = "1"
kafka-protocol = "0.10"
kafka-protocol = "0.13"
log = "0.4"
miette = { version = "7.2.0", features = ["fancy-no-backtrace"] }
minicbor = { version = "0.25.1", default-features = false, features = ["alloc", "derive"] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use bytes::{Bytes, BytesMut};
use kafka_protocol::messages::fetch_request::FetchRequest;
use kafka_protocol::messages::produce_request::{PartitionProduceData, ProduceRequest};
use kafka_protocol::messages::request_header::RequestHeader;
use kafka_protocol::messages::{ApiKey, TopicName};
use kafka_protocol::messages::{ApiKey, ApiVersionsRequest, TopicName};
use kafka_protocol::protocol::buf::ByteBuf;
use kafka_protocol::protocol::Decodable;
use kafka_protocol::protocol::{Decodable, Message};
use kafka_protocol::records::{
Compression, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions,
};
Expand Down Expand Up @@ -67,13 +67,7 @@ impl KafkaMessageRequestInterceptor for InletInterceptorImpl {
match api_key {
ApiKey::ApiVersionsKey => {
debug!("api versions request: {:?}", header);
self.request_map.lock().unwrap().insert(
header.correlation_id,
RequestInfo {
request_api_key: api_key,
request_api_version: header.request_api_version,
},
);
return self.handle_api_version_request(&mut buffer, &header).await;
}

ApiKey::ProduceKey => {
Expand Down Expand Up @@ -116,6 +110,40 @@ impl KafkaMessageRequestInterceptor for InletInterceptorImpl {
}

impl InletInterceptorImpl {
async fn handle_api_version_request(
&self,
buffer: &mut Bytes,
header: &RequestHeader,
) -> Result<BytesMut, InterceptError> {
let request: ApiVersionsRequest = decode_body(buffer, header.request_api_version)?;
const MAX_SUPPORTED_VERSION: i16 = ApiVersionsRequest::VERSIONS.max;

let request_api_version = if header.request_api_version > MAX_SUPPORTED_VERSION {
warn!("api versions request with version > {MAX_SUPPORTED_VERSION} not supported, downgrading request to {MAX_SUPPORTED_VERSION}");
MAX_SUPPORTED_VERSION
} else {
header.request_api_version
};

self.request_map.lock().unwrap().insert(
header.correlation_id,
RequestInfo {
request_api_key: ApiKey::ApiVersionsKey,
request_api_version,
},
);

let mut header = header.clone();
header.request_api_version = request_api_version;

encode_request(
&header,
&request,
request_api_version,
ApiKey::ApiVersionsKey,
)
}

async fn handle_fetch_request(
&self,
context: &mut Context,
Expand Down Expand Up @@ -178,12 +206,15 @@ impl InletInterceptorImpl {
// the content can be set in multiple topics and partitions in a single message
// for each we wrap the content and add the secure channel identifier of
// the encrypted content
for (topic_name, topic) in request.topic_data.iter_mut() {
for topic in request.topic_data.iter_mut() {
for data in &mut topic.partition_data {
if let Some(content) = data.records.take() {
let mut content = BytesMut::from(content.as_ref());
let mut records = RecordBatchDecoder::decode(&mut content)
.map_err(|_| InterceptError::InvalidData)?;
let mut records = RecordBatchDecoder::decode(
&mut content,
None::<fn(&mut Bytes, Compression) -> Result<BytesMut, _>>,
)
.map_err(|_| InterceptError::InvalidData)?;

for record in records.iter_mut() {
if let Some(record_value) = record.value.take() {
Expand All @@ -192,13 +223,13 @@ impl InletInterceptorImpl {
// valid JSON map
self.encrypt_specific_fields(
context,
topic_name,
&topic.name,
data,
&record_value,
)
.await?
} else {
self.encrypt_whole_record(context, topic_name, data, record_value)
self.encrypt_whole_record(context, &topic.name, data, record_value)
.await?
};
record.value = Some(buffer.into());
Expand All @@ -213,6 +244,7 @@ impl InletInterceptorImpl {
version: 2,
compression: Compression::None,
},
None::<fn(&mut BytesMut, &mut BytesMut, Compression) -> Result<(), _>>,
)
.map_err(|_| InterceptError::InvalidData)?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use crate::kafka::protocol_aware::{
use crate::kafka::KafkaInletController;
use bytes::{Bytes, BytesMut};
use kafka_protocol::messages::{
ApiKey, ApiVersionsResponse, FetchResponse, FindCoordinatorResponse, MetadataResponse,
ResponseHeader,
ApiKey, ApiVersionsResponse, CreateTopicsResponse, FetchResponse, FindCoordinatorResponse,
ListOffsetsResponse, MetadataResponse, ProduceResponse, ResponseHeader,
};
use kafka_protocol::protocol::buf::ByteBuf;
use kafka_protocol::protocol::{Decodable, StrBytes};
use kafka_protocol::protocol::{Decodable, Message, StrBytes};
use kafka_protocol::records::{
Compression, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions,
};
Expand All @@ -31,12 +31,7 @@ impl KafkaMessageResponseInterceptor for InletInterceptorImpl {
// we can/need to decode only mapped requests
let correlation_id = buffer.peek_bytes(0..4).try_get_i32()?;

let result = self
.request_map
.lock()
.unwrap()
.get(&correlation_id)
.cloned();
let result = self.request_map.lock().unwrap().remove(&correlation_id);

if let Some(request_info) = result {
let result = ResponseHeader::decode(
Expand Down Expand Up @@ -64,10 +59,9 @@ impl KafkaMessageResponseInterceptor for InletInterceptorImpl {

match request_info.request_api_key {
ApiKey::ApiVersionsKey => {
let response: ApiVersionsResponse =
decode_body(&mut buffer, request_info.request_api_version)?;
debug!("api versions response header: {:?}", header);
debug!("api versions response: {:#?}", response);
return self
.handle_api_versions_response(&mut buffer, request_info, &header)
.await;
}

ApiKey::FetchKey => {
Expand Down Expand Up @@ -116,6 +110,75 @@ impl KafkaMessageResponseInterceptor for InletInterceptorImpl {
}

impl InletInterceptorImpl {
async fn handle_api_versions_response(
&self,
buffer: &mut Bytes,
request_info: RequestInfo,
header: &ResponseHeader,
) -> Result<BytesMut, InterceptError> {
let mut response: ApiVersionsResponse =
decode_body(buffer, request_info.request_api_version)?;
debug!("api versions response header: {:?}", header);
debug!("api versions response: {:#?}", response);

// We must ensure that every message is fully encrypted and never leaves the
// client unencrypted.
// To do that, we __can't allow unknown/unparsable request/response__ since
// it might be a new API to produce or consume messages.
// To avoid breakage every time a client or server is updated, we reduce the
// version of the protocol to the supported version for each api.

for api_version in response.api_keys.iter_mut() {
let result = ApiKey::try_from(api_version.api_key);
let api_key = match result {
Ok(api_key) => api_key,
Err(_) => {
warn!("unknown api key: {}", api_version.api_key);
return Err(InterceptError::InvalidData);
}
};

// Request and responses share the same api version range.
let ockam_supported_range = match api_key {
ApiKey::ProduceKey => ProduceResponse::VERSIONS,
ApiKey::FetchKey => FetchResponse::VERSIONS,
ApiKey::ListOffsetsKey => ListOffsetsResponse::VERSIONS,
ApiKey::MetadataKey => MetadataResponse::VERSIONS,
ApiKey::ApiVersionsKey => ApiVersionsResponse::VERSIONS,
ApiKey::CreateTopicsKey => CreateTopicsResponse::VERSIONS,
ApiKey::FindCoordinatorKey => FindCoordinatorResponse::VERSIONS,
_ => {
// we only need to check the APIs that we actually use
continue;
}
};

if ockam_supported_range.min <= api_version.min_version
&& ockam_supported_range.max >= api_version.max_version
{
continue;
}

info!(
"reducing api version range for api {api_key:?} from ({min_server},{max_server}) to ({min_ockam},{max_ockam})",
min_server = api_version.min_version,
max_server = api_version.max_version,
min_ockam = ockam_supported_range.min,
max_ockam = ockam_supported_range.max,
);

api_version.min_version = ockam_supported_range.min;
api_version.max_version = ockam_supported_range.max;
}

encode_response(
header,
&response,
request_info.request_api_version,
ApiKey::ApiVersionsKey,
)
}

// for metadata we want to replace broker address and port
// to dedicated tcp inlet ports
async fn handle_metadata_response(
Expand All @@ -131,9 +194,13 @@ impl InletInterceptorImpl {
// we need to keep a map of topic uuid to topic name since fetch
// operations only use uuid
if request_info.request_api_version >= 10 {
for (topic_name, topic) in &response.topics {
for topic in &response.topics {
let topic_id = topic.topic_id.to_string();
let topic_name = topic_name.to_string();
let topic_name = if let Some(name) = &topic.name {
name.to_string()
} else {
continue;
};

trace!("metadata adding to map: {topic_id} => {topic_name}");
self.uuid_to_name
Expand All @@ -145,19 +212,19 @@ impl InletInterceptorImpl {

trace!("metadata response before: {:?}", &response);

for (broker_id, info) in response.brokers.iter_mut() {
for broker in response.brokers.iter_mut() {
let inlet_address = inlet_map
.assert_inlet_for_broker(context, broker_id.0)
.assert_inlet_for_broker(context, broker.node_id.0)
.await?;

trace!(
"inlet_address: {} for broker {}",
&inlet_address,
broker_id.0
broker.node_id.0
);

info.host = StrBytes::from_string(inlet_address.hostname());
info.port = inlet_address.port() as i32;
broker.host = StrBytes::from_string(inlet_address.hostname());
broker.port = inlet_address.port() as i32;
}
trace!("metadata response after: {:?}", &response);

Expand Down Expand Up @@ -225,8 +292,11 @@ impl InletInterceptorImpl {
for partition in response.partitions.iter_mut() {
if let Some(content) = partition.records.take() {
let mut content = BytesMut::from(content.as_ref());
let mut records = RecordBatchDecoder::decode(&mut content)
.map_err(|_| InterceptError::InvalidData)?;
let mut records = RecordBatchDecoder::decode(
&mut content,
None::<fn(&mut Bytes, Compression) -> Result<BytesMut, _>>,
)
.map_err(|_| InterceptError::InvalidData)?;

for record in records.iter_mut() {
if let Some(record_value) = record.value.take() {
Expand All @@ -247,6 +317,7 @@ impl InletInterceptorImpl {
version: 2,
compression: Compression::None,
},
None::<fn(&mut BytesMut, &mut BytesMut, Compression) -> Result<(), _>>,
)
.map_err(|_| InterceptError::InvalidData)?;
partition.records = Some(encoded.freeze());
Expand Down
Loading

0 comments on commit 079c376

Please sign in to comment.