diff --git a/src/tools.rs b/src/tools.rs index 5e0a6c01..5f3fb897 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -20,6 +20,7 @@ use std::collections::HashMap; use serde_json::Value; +pub mod async_mutex; pub mod timer; pub mod wrapped_reader_stream; #[macro_use] diff --git a/src/tools/async_mutex.rs b/src/tools/async_mutex.rs new file mode 100644 index 00000000..12832146 --- /dev/null +++ b/src/tools/async_mutex.rs @@ -0,0 +1,40 @@ +use std::marker::PhantomData; + +use futures::Poll; +use futures::future::Future; +use tokio::sync::lock::Lock as TokioLock; +pub use tokio::sync::lock::LockGuard as AsyncLockGuard; + +pub struct AsyncMutex(TokioLock); + +unsafe impl Sync for AsyncMutex {} + +impl AsyncMutex { + pub fn new(value: T) -> Self { + Self(TokioLock::new(value)) + } + + // to allow any error type (we never error, so we have no error type of our own) + pub fn lock(&self) -> LockFuture { + LockFuture { + lock: self.0.clone(), + _error: PhantomData, + } + } +} + +/// Represents a lock to be held in the future: +pub struct LockFuture { + lock: TokioLock, + // We can't error and we don't want to enforce a specific error type either + _error: PhantomData, +} + +impl Future for LockFuture { + type Item = AsyncLockGuard; + type Error = E; + + fn poll(&mut self) -> Poll, E> { + Ok(self.lock.poll_lock()) + } +}