use std::future::Future; use std::io::IoSlice; use std::marker::Unpin; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use hyper::client::connect::{Connected, Connection}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::time::Sleep; use std::task::{Context, Poll}; use super::{RateLimiter, ShareableRateLimit}; type SharedRateLimit = Arc; pub type RateLimiterCallback = dyn Fn() -> (Option, Option) + Send; /// A rate limited stream using [RateLimiter] pub struct RateLimitedStream { read_limiter: Option, read_delay: Option>>, write_limiter: Option, write_delay: Option>>, update_limiter_cb: Option>, last_limiter_update: Instant, stream: S, } impl RateLimitedStream { /// Creates a new instance with reads and writes limited to the same `rate`. pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self { let now = Instant::now(); let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now); let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter)); let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now); let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter)); Self::with_limiter(stream, Some(read_limiter), Some(write_limiter)) } /// Creates a new instance with specified [`RateLimiter`s](RateLimiter) for reads and writes. pub fn with_limiter( stream: S, read_limiter: Option, write_limiter: Option, ) -> Self { Self { read_limiter, read_delay: None, write_limiter, write_delay: None, update_limiter_cb: None, last_limiter_update: Instant::now(), stream, } } /// Creates a new instance with limiter update callback. /// /// The function is called every minute to update/change the used limiters. /// /// Note: This function is called within an async context, so it /// should be fast and must not block. pub fn with_limiter_update_cb< F: Fn() -> (Option, Option) + Send + 'static, >( stream: S, update_limiter_cb: F, ) -> Self { let (read_limiter, write_limiter) = update_limiter_cb(); Self { read_limiter, read_delay: None, write_limiter, write_delay: None, update_limiter_cb: Some(Box::new(update_limiter_cb)), last_limiter_update: Instant::now(), stream, } } fn update_limiters(&mut self) { if let Some(ref update_limiter_cb) = self.update_limiter_cb { if self.last_limiter_update.elapsed().as_secs() >= 5 { self.last_limiter_update = Instant::now(); let (read_limiter, write_limiter) = update_limiter_cb(); self.read_limiter = read_limiter; self.write_limiter = write_limiter; } } } pub fn inner(&self) -> &S { &self.stream } pub fn inner_mut(&mut self) -> &mut S { &mut self.stream } } fn register_traffic(limiter: &(dyn ShareableRateLimit), count: usize) -> Option>> { const MIN_DELAY: Duration = Duration::from_millis(10); let now = Instant::now(); let delay = limiter.register_traffic(now, count as u64); if delay >= MIN_DELAY { let sleep = tokio::time::sleep(delay); Some(Box::pin(sleep)) } else { None } } fn delay_is_ready(delay: &mut Option>>, ctx: &mut Context<'_>) -> bool { match delay { Some(ref mut future) => future.as_mut().poll(ctx).is_ready(), None => true, } } impl AsyncWrite for RateLimitedStream { fn poll_write( self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let this = self.get_mut(); let is_ready = delay_is_ready(&mut this.write_delay, ctx); if !is_ready { return Poll::Pending; } this.write_delay = None; this.update_limiters(); let result = Pin::new(&mut this.stream).poll_write(ctx, buf); if let Some(ref mut limiter) = this.write_limiter { if let Poll::Ready(Ok(count)) = result { this.write_delay = register_traffic(limiter.as_ref(), count); } } result } fn is_write_vectored(&self) -> bool { self.stream.is_write_vectored() } fn poll_write_vectored( self: Pin<&mut Self>, ctx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { let this = self.get_mut(); let is_ready = delay_is_ready(&mut this.write_delay, ctx); if !is_ready { return Poll::Pending; } this.write_delay = None; this.update_limiters(); let result = Pin::new(&mut this.stream).poll_write_vectored(ctx, bufs); if let Some(ref limiter) = this.write_limiter { if let Poll::Ready(Ok(count)) = result { this.write_delay = register_traffic(limiter.as_ref(), count); } } result } fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); Pin::new(&mut this.stream).poll_flush(ctx) } fn poll_shutdown( self: Pin<&mut Self>, ctx: &mut Context<'_>, ) -> Poll> { let this = self.get_mut(); Pin::new(&mut this.stream).poll_shutdown(ctx) } } impl AsyncRead for RateLimitedStream { fn poll_read( self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); let is_ready = delay_is_ready(&mut this.read_delay, ctx); if !is_ready { return Poll::Pending; } this.read_delay = None; this.update_limiters(); let filled_len = buf.filled().len(); let result = Pin::new(&mut this.stream).poll_read(ctx, buf); if let Some(ref read_limiter) = this.read_limiter { if let Poll::Ready(Ok(())) = &result { let count = buf.filled().len() - filled_len; this.read_delay = register_traffic(read_limiter.as_ref(), count); } } result } } // we need this for the hyper http client impl Connection for RateLimitedStream { fn connected(&self) -> Connected { self.stream.connected() } }