Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let delay queue be mocked #485

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,12 @@ impl<'a> ServiceGenerator<'a> {
#vis fn new<T>(config: ::tarpc::client::Config, transport: T)
-> ::tarpc::client::NewClient<
Self,
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
::tarpc::client::RequestDispatch<
#request_ident,
#response_ident,
T,
::tarpc::util::delay_queue::DelayQueue<u64>
>
>
where
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
Expand Down
50 changes: 39 additions & 11 deletions tarpc/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ pub mod stub;
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context, trace,
util::TimeUntil,
util::{
delay_queue::{DelayQueue, DelayQueueLike},
TimeUntil,
},
ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport,
};
use futures::{prelude::*, ready, stream::Fuse, task::*};
Expand Down Expand Up @@ -237,9 +240,27 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
pub fn new<Req, Resp, C>(
config: Config,
transport: C,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C, DelayQueue<u64>>>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
with_in_flight_requests(
config,
transport,
InFlightRequests::<_, DelayQueue<u64>>::default(),
)
}

/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
/// channel.
pub fn with_in_flight_requests<Req, Resp, C, Deadline>(
config: Config,
transport: C,
in_flight_requests: InFlightRequests<Result<Resp, RpcError>, Deadline>,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C, Deadline>>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
Deadline: DelayQueueLike<u64>,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
Expand All @@ -254,7 +275,7 @@ where
config,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: InFlightRequests::default(),
in_flight_requests,
pending_requests,
terminal_error: None,
},
Expand All @@ -266,7 +287,10 @@ where
#[must_use]
#[pin_project()]
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
pub struct RequestDispatch<Req, Resp, C, Deadline>
where
Deadline: DelayQueueLike<u64>,
{
/// Writes requests to the wire and reads responses off the wire.
#[pin]
transport: Fuse<C>,
Expand All @@ -275,7 +299,7 @@ pub struct RequestDispatch<Req, Resp, C> {
/// Requests that were dropped.
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
in_flight_requests: InFlightRequests<Result<Resp, RpcError>, Deadline>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// Produces errors that can be sent in response to any unprocessed requests at the time
Expand All @@ -285,13 +309,14 @@ pub struct RequestDispatch<Req, Resp, C> {
terminal_error: Option<ChannelError<dyn Any + Send + Sync + 'static>>,
}

impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
impl<Req, Resp, C, Deadline> RequestDispatch<Req, Resp, C, Deadline>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
Deadline: DelayQueueLike<u64>,
{
fn in_flight_requests<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
) -> &'a mut InFlightRequests<Result<Resp, RpcError>, Deadline> {
self.as_mut().project().in_flight_requests
}

Expand Down Expand Up @@ -636,9 +661,10 @@ where
}
}

impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
impl<Req, Resp, C, Deadline> Future for RequestDispatch<Req, Resp, C, Deadline>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
Deadline: DelayQueueLike<u64>,
{
type Output = Result<(), ChannelError<C::Error>>;

Expand Down Expand Up @@ -685,6 +711,7 @@ mod tests {
client::{in_flight_requests::InFlightRequests, Config},
context::{self, current},
transport::{self, channel::UnboundedChannel},
util::delay_queue::DelayQueue,
ChannelError, ClientMessage, Response,
};
use assert_matches::assert_matches;
Expand Down Expand Up @@ -960,14 +987,14 @@ mod tests {
fn set_up_always_err(
cause: TransportError,
) -> (
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>, DelayQueue<u64>>>>,
Channel<String, String>,
Context<'static>,
) {
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancellation, canceled_requests) = cancellations();
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
let dispatch = Box::pin(RequestDispatch::<String, String, _, _> {
transport: transport.fuse(),
pending_requests,
canceled_requests,
Expand Down Expand Up @@ -1051,6 +1078,7 @@ mod tests {
String,
String,
UnboundedChannel<Response<String>, ClientMessage<String>>,
DelayQueue<u64>,
>,
>,
>,
Expand All @@ -1063,7 +1091,7 @@ mod tests {
let (cancellation, canceled_requests) = cancellations();
let (client_channel, server_channel) = transport::channel::unbounded();

let dispatch = RequestDispatch::<String, String, _> {
let dispatch = RequestDispatch::<String, String, _, _> {
transport: client_channel.fuse(),
pending_requests,
canceled_requests,
Expand Down
29 changes: 19 additions & 10 deletions tarpc/src/client/in_flight_requests.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
use crate::{
context,
util::{Compact, TimeUntil},
util::{delay_queue::DelayQueueLike, Compact, TimeUntil},
};
use fnv::FnvHashMap;
use std::{
collections::hash_map,
fmt::Debug,
task::{Context, Poll},
};
use tokio::sync::oneshot;
use tokio_util::time::delay_queue::{self, DelayQueue};
use tracing::Span;

/// Requests already written to the wire that haven't yet received responses.
#[derive(Debug)]
pub struct InFlightRequests<Resp> {
request_data: FnvHashMap<u64, RequestData<Resp>>,
deadlines: DelayQueue<u64>,
pub struct InFlightRequests<Resp, Deadline>
where
Deadline: DelayQueueLike<u64>,
{
request_data: FnvHashMap<u64, RequestData<Resp, Deadline::Key>>,
deadlines: Deadline,
}

impl<Resp> Default for InFlightRequests<Resp> {
impl<Resp, Deadline> Default for InFlightRequests<Resp, Deadline>
where
Deadline: DelayQueueLike<u64> + Default,
{
fn default() -> Self {
Self {
request_data: Default::default(),
Expand All @@ -28,20 +34,23 @@ impl<Resp> Default for InFlightRequests<Resp> {
}

#[derive(Debug)]
struct RequestData<Res> {
struct RequestData<Res, Key> {
ctx: context::Context,
span: Span,
response_completion: oneshot::Sender<Res>,
/// The key to remove the timer for the request's deadline.
deadline_key: delay_queue::Key,
deadline_key: Key,
}

/// An error returned when an attempt is made to insert a request with an ID that is already in
/// use.
#[derive(Debug)]
pub struct AlreadyExistsError;

impl<Res> InFlightRequests<Res> {
impl<Res, Deadline> InFlightRequests<Res, Deadline>
where
Deadline: DelayQueueLike<u64>,
{
/// Returns the number of in-flight requests.
pub fn len(&self) -> usize {
self.request_data.len()
Expand Down Expand Up @@ -124,7 +133,7 @@ impl<Res> InFlightRequests<Res> {
expired_error: impl Fn() -> Res,
) -> Poll<Option<u64>> {
self.deadlines.poll_expired(cx).map(|expired| {
let request_id = expired?.into_inner();
let request_id = expired?;
if let Some(request_data) = self.request_data.remove(&request_id) {
let _entered = request_data.span.enter();
tracing::error!("DeadlineExceeded");
Expand Down
4 changes: 3 additions & 1 deletion tarpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ pub mod client;
pub mod context;
pub mod server;
pub mod transport;
pub(crate) mod util;

/// Utilities
pub mod util;

pub use crate::transport::sealed::Transport;

Expand Down
53 changes: 41 additions & 12 deletions tarpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

//! Provides a server that concurrently handles many connections sending multiplexed requests.

use crate::util::delay_queue::{DelayQueue, DelayQueueLike};
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context::{self, SpanExt},
Expand Down Expand Up @@ -58,11 +59,15 @@ impl Default for Config {

impl Config {
/// Returns a channel backed by `transport` and configured with `self`.
pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
pub fn channel<Req, Resp, T, Deadline>(
self,
transport: T,
) -> BaseChannelImpl<Req, Resp, T, Deadline>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
Deadline: DelayQueueLike<u64> + Default,
{
BaseChannel::new(self, transport)
BaseChannelImpl::new(self, transport)
}
}

Expand Down Expand Up @@ -138,7 +143,10 @@ where
/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
/// the corresponding in-flight requests and aborting their handlers).
#[pin_project]
pub struct BaseChannel<Req, Resp, T> {
pub struct BaseChannelImpl<Req, Resp, T, Deadline>
where
Deadline: DelayQueueLike<u64>,
{
config: Config,
/// Writes responses to the wire and reads requests off the wire.
#[pin]
Expand All @@ -149,19 +157,23 @@ pub struct BaseChannel<Req, Resp, T> {
/// Notifies `canceled_requests` when a request is canceled.
request_cancellation: RequestCancellation,
/// Holds data necessary to clean up in-flight requests.
in_flight_requests: InFlightRequests,
in_flight_requests: InFlightRequests<Deadline>,
/// Types the request and response.
ghost: PhantomData<(fn() -> Req, fn(Resp))>,
}

impl<Req, Resp, T> BaseChannel<Req, Resp, T>
///
pub type BaseChannel<Req, Resp, T> = BaseChannelImpl<Req, Resp, T, DelayQueue<u64>>;

impl<Req, Resp, T, Deadline> BaseChannelImpl<Req, Resp, T, Deadline>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
Deadline: DelayQueueLike<u64> + Default,
{
/// Creates a new channel backed by `transport` and configured with `config`.
pub fn new(config: Config, transport: T) -> Self {
let (request_cancellation, canceled_requests) = cancellations();
BaseChannel {
BaseChannelImpl {
config,
transport: transport.fuse(),
canceled_requests,
Expand All @@ -175,7 +187,13 @@ where
pub fn with_defaults(transport: T) -> Self {
Self::new(Config::default(), transport)
}
}

impl<Req, Resp, T, Deadline> BaseChannelImpl<Req, Resp, T, Deadline>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
Deadline: DelayQueueLike<u64>,
{
/// Returns the inner transport over which messages are sent and received.
pub fn get_ref(&self) -> &T {
self.transport.get_ref()
Expand All @@ -186,7 +204,9 @@ where
self.project().transport.get_pin_mut()
}

fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests {
fn in_flight_requests_mut<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut InFlightRequests<Deadline> {
self.as_mut().project().in_flight_requests
}

Expand Down Expand Up @@ -248,7 +268,10 @@ where
}
}

impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
impl<Req, Resp, T, Deadline> fmt::Debug for BaseChannelImpl<Req, Resp, T, Deadline>
where
Deadline: DelayQueueLike<u64>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BaseChannel")
}
Expand Down Expand Up @@ -418,9 +441,10 @@ where
}
}

impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
impl<Req, Resp, T, Deadline> Stream for BaseChannelImpl<Req, Resp, T, Deadline>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
Deadline: DelayQueueLike<u64>,
{
type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;

Expand Down Expand Up @@ -525,10 +549,11 @@ where
}
}

impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
impl<Req, Resp, T, Deadline> Sink<Response<Resp>> for BaseChannelImpl<Req, Resp, T, Deadline>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
T::Error: Error,
Deadline: DelayQueueLike<u64>,
{
type Error = ChannelError<T::Error>;

Expand Down Expand Up @@ -572,15 +597,19 @@ where
}
}

impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
impl<Req, Resp, T, Deadline> AsRef<T> for BaseChannelImpl<Req, Resp, T, Deadline>
where
Deadline: DelayQueueLike<u64>,
{
fn as_ref(&self) -> &T {
self.transport.get_ref()
}
}

impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
impl<Req, Resp, T, Deadline> Channel for BaseChannelImpl<Req, Resp, T, Deadline>
where
T: Transport<Response<Resp>, ClientMessage<Req>>,
Deadline: DelayQueueLike<u64>,
{
type Req = Req;
type Resp = Resp;
Expand Down
Loading