Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(client): remove MaxSlots limit #1377

Merged
merged 2 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,6 @@ impl<L> HttpClientBuilder<L> {
self
}

/// Set max concurrent requests.
pub fn max_concurrent_requests(mut self, max: usize) -> Self {
self.max_concurrent_requests = max;
self
}

/// Force to use the rustls native certificate store.
///
/// Since multiple certificate stores can be optionally enabled, this option will
Expand Down Expand Up @@ -198,7 +192,6 @@ where
let Self {
max_request_size,
max_response_size,
max_concurrent_requests,
request_timeout,
certificate_store,
id_kind,
Expand All @@ -220,11 +213,7 @@ where
.build(target)
.map_err(|e| Error::Transport(e.into()))?;

Ok(HttpClient {
transport,
id_manager: Arc::new(RequestIdManager::new(max_concurrent_requests, id_kind)),
request_timeout,
})
Ok(HttpClient { transport, id_manager: Arc::new(RequestIdManager::new(id_kind)), request_timeout })
}
}

Expand Down Expand Up @@ -303,8 +292,7 @@ where
R: DeserializeOwned,
Params: ToRpcParams + Send,
{
let guard = self.id_manager.next_request_id()?;
let id = guard.inner();
let id = self.id_manager.next_request_id();
let params = params.to_rpc_params()?;

let request = RequestSer::borrowed(&id, &method, params.as_deref());
Expand Down Expand Up @@ -340,8 +328,8 @@ where
R: DeserializeOwned + fmt::Debug + 'a,
{
let batch = batch.build()?;
let guard = self.id_manager.next_request_id()?;
let id_range = generate_batch_id_range(&guard, batch.len() as u64)?;
let id = self.id_manager.next_request_id();
let id_range = generate_batch_id_range(id, batch.len() as u64)?;

let mut batch_request = Vec::with_capacity(batch.len());
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
Expand Down
15 changes: 7 additions & 8 deletions core/src/client/async_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ impl ClientBuilder {
to_back: to_back.clone(),
request_timeout: self.request_timeout,
error: ErrorFromBack::new(to_back, disconnect_reason),
id_manager: RequestIdManager::new(self.max_concurrent_requests, self.id_kind),
id_manager: RequestIdManager::new(self.id_kind),
max_log_length: self.max_log_length,
on_exit: Some(client_dropped_tx),
}
Expand Down Expand Up @@ -479,7 +479,7 @@ impl ClientT for Client {
Params: ToRpcParams + Send,
{
// NOTE: we use this to guard against max number of concurrent requests.
let _req_id = self.id_manager.next_request_id()?;
let _req_id = self.id_manager.next_request_id();
let params = params.to_rpc_params()?;
let notif = NotificationSer::borrowed(&method, params.as_deref());

Expand All @@ -505,8 +505,7 @@ impl ClientT for Client {
Params: ToRpcParams + Send,
{
let (send_back_tx, send_back_rx) = oneshot::channel();
let guard = self.id_manager.next_request_id()?;
let id = guard.inner();
let id = self.id_manager.next_request_id();

let params = params.to_rpc_params()?;
let raw =
Expand Down Expand Up @@ -540,8 +539,8 @@ impl ClientT for Client {
R: DeserializeOwned,
{
let batch = batch.build()?;
let guard = self.id_manager.next_request_id()?;
let id_range = generate_batch_id_range(&guard, batch.len() as u64)?;
let id = self.id_manager.next_request_id();
let id_range = generate_batch_id_range(id, batch.len() as u64)?;

let mut batches = Vec::with_capacity(batch.len());
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
Expand Down Expand Up @@ -621,8 +620,8 @@ impl SubscriptionClientT for Client {
return Err(RegisterMethodError::SubscriptionNameConflict(unsubscribe_method.to_owned()).into());
}

let guard = self.id_manager.next_request_two_ids()?;
let (id_sub, id_unsub) = guard.inner();
let id_sub = self.id_manager.next_request_id();
let id_unsub = self.id_manager.next_request_id();
let params = params.to_rpc_params()?;

let raw = serde_json::to_string(&RequestSer::borrowed(&id_sub, &subscribe_method, params.as_deref()))
Expand Down
3 changes: 0 additions & 3 deletions core/src/client/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ pub enum Error {
/// Request timeout
#[error("Request timeout")]
RequestTimeout,
/// Max number of request slots exceeded.
#[error("Max concurrent requests exceeded")]
MaxSlotsExceeded,
/// Custom error.
#[error("Custom error: {0}")]
Custom(String),
Expand Down
73 changes: 6 additions & 67 deletions core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,6 @@ impl<Notif> Drop for Subscription<Notif> {
#[derive(Debug)]
/// Keep track of request IDs.
pub struct RequestIdManager {
// Current pending requests.
current_pending: Arc<()>,
/// Max concurrent pending requests allowed.
max_concurrent_requests: usize,
/// Get the next request ID.
current_id: CurrentId,
/// Request ID type.
Expand All @@ -467,38 +463,15 @@ pub struct RequestIdManager {

impl RequestIdManager {
/// Create a new `RequestIdGuard` with the provided concurrency limit.
pub fn new(limit: usize, id_kind: IdKind) -> Self {
Self { current_pending: Arc::new(()), max_concurrent_requests: limit, current_id: CurrentId::new(), id_kind }
}

fn get_slot(&self) -> Result<Arc<()>, Error> {
// Strong count is 1 at start, so that's why we use `>` and not `>=`.
if Arc::strong_count(&self.current_pending) > self.max_concurrent_requests {
Err(Error::MaxSlotsExceeded)
} else {
Ok(self.current_pending.clone())
}
pub fn new(id_kind: IdKind) -> Self {
Self { current_id: CurrentId::new(), id_kind }
}

/// Attempts to get the next request ID.
///
/// Fails if request limit has been exceeded.
pub fn next_request_id(&self) -> Result<RequestIdGuard<Id<'static>>, Error> {
let rc = self.get_slot()?;
let id = self.id_kind.into_id(self.current_id.next());

Ok(RequestIdGuard { _rc: rc, id })
}

/// Attempts to get fetch two ids (used for subscriptions) but only
/// occupy one slot in the request guard.
///
/// Fails if request limit has been exceeded.
pub fn next_request_two_ids(&self) -> Result<RequestIdGuard<(Id<'static>, Id<'static>)>, Error> {
let rc = self.get_slot()?;
let id1 = self.id_kind.into_id(self.current_id.next());
let id2 = self.id_kind.into_id(self.current_id.next());
Ok(RequestIdGuard { _rc: rc, id: (id1, id2) })
pub fn next_request_id(&self) -> Id<'static> {
self.id_kind.into_id(self.current_id.next())
}

/// Get a handle to the `IdKind`.
Expand All @@ -507,21 +480,6 @@ impl RequestIdManager {
}
}

/// Reference counted request ID.
#[derive(Debug)]
pub struct RequestIdGuard<T: Clone> {
id: T,
/// Reference count decreased when dropped.
_rc: Arc<()>,
}

impl<T: Clone> RequestIdGuard<T> {
/// Get the actual ID or IDs.
pub fn inner(&self) -> T {
self.id.clone()
}
}

/// What certificate store to use
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
Expand Down Expand Up @@ -568,8 +526,8 @@ impl CurrentId {
}

/// Generate a range of IDs to be used in a batch request.
pub fn generate_batch_id_range(guard: &RequestIdGuard<Id>, len: u64) -> Result<Range<u64>, Error> {
let id_start = guard.inner().try_parse_inner_as_number()?;
pub fn generate_batch_id_range(id: Id, len: u64) -> Result<Range<u64>, Error> {
let id_start = id.try_parse_inner_as_number()?;
let id_end = id_start
.checked_add(len)
.ok_or_else(|| Error::Custom("BatchID range wrapped; restart the client or try again later".to_string()))?;
Expand Down Expand Up @@ -704,22 +662,3 @@ fn subscription_channel(max_buf_size: usize) -> (SubscriptionSender, Subscriptio

(SubscriptionSender { inner: tx, lagged: lagged_tx }, SubscriptionReceiver { inner: rx, lagged: lagged_rx })
}

#[cfg(test)]
mod tests {
use super::{IdKind, RequestIdManager};

#[test]
fn request_id_guard_works() {
let manager = RequestIdManager::new(2, IdKind::Number);
let _first = manager.next_request_id().unwrap();

{
let _second = manager.next_request_two_ids().unwrap();
assert!(manager.next_request_id().is_err());
// second dropped here.
}

assert!(manager.next_request_id().is_ok());
}
}
38 changes: 0 additions & 38 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,6 @@ async fn http_method_call_str_id_works() {
assert_eq!(&response, "hello");
}

#[tokio::test]
async fn http_concurrent_method_call_limits_works() {
init_logger();

let server_addr = server().await;
let uri = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(1).build(&uri).unwrap();

let (first, second) = tokio::join!(
client.request::<String, ArrayParams>("say_hello", rpc_params!()),
client.request::<String, ArrayParams>("say_hello", rpc_params![]),
);

assert!(first.is_ok());
assert!(matches!(second, Err(Error::MaxSlotsExceeded)));
}

#[tokio::test]
async fn ws_subscription_several_clients() {
init_logger();
Expand Down Expand Up @@ -418,27 +401,6 @@ async fn ws_making_more_requests_than_allowed_should_not_deadlock() {
}
}

#[tokio::test]
async fn http_making_more_requests_than_allowed_should_not_deadlock() {
init_logger();

let server_addr = server().await;
let server_url = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(2).build(&server_url).unwrap();
let client = Arc::new(client);

let mut requests = Vec::new();

for _ in 0..6 {
let c = client.clone();
requests.push(tokio::spawn(async move { c.request::<String, ArrayParams>("say_hello", rpc_params![]).await }));
}

for req in requests {
let _ = req.await.unwrap();
}
}

#[tokio::test]
async fn https_works() {
init_logger();
Expand Down
Loading