use std::future::Future; use std::io::{self, Error, ErrorKind}; use std::pin::Pin; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use std::sync::Mutex; use std::task::{Context, Poll}; use std::time::Duration; use futures::{Sink, Stream}; use stream_reconnect::ReconnectOptions; use stream_reconnect::{ReconnectStream, UnderlyingStream}; #[derive(Default)] pub struct DummyStream { poll_read_results: PollReadResults, } #[derive(Default, Clone)] struct DummyCtor { connect_outcomes: ConnectOutcomes, poll_read_results: PollReadResults, } type ConnectOutcomes = Arc>>; type PollReadResults = Arc>, Vec)>>>; struct DummyStreamConnector; impl UnderlyingStream, io::Error> for DummyStreamConnector { type Stream = DummyStream; #[cfg(not(feature = "not-send"))] fn establish(ctor: DummyCtor) -> Pin> + Send>> { let mut connect_attempt_outcome_results = ctor.connect_outcomes.lock().unwrap(); let should_succeed = connect_attempt_outcome_results.remove(0); if should_succeed { let dummy_io = DummyStream { poll_read_results: ctor.poll_read_results.clone(), }; Box::pin(async { Ok(dummy_io) }) } else { Box::pin(async { Err(io::Error::new(ErrorKind::NotConnected, "So unfortunate")) }) } } #[cfg(feature = "not-send")] fn establish(ctor: DummyCtor) -> Pin>>> { let mut connect_attempt_outcome_results = ctor.connect_outcomes.lock().unwrap(); let should_succeed = connect_attempt_outcome_results.remove(0); if should_succeed { let dummy_io = DummyStream { poll_read_results: ctor.poll_read_results.clone(), }; Box::pin(async { Ok(dummy_io) }) } else { Box::pin(async { Err(io::Error::new(ErrorKind::NotConnected, "So unfortunate")) }) } } fn is_write_disconnect_error(err: &Error) -> bool { use std::io::ErrorKind::*; matches!( err.kind(), NotFound | PermissionDenied | ConnectionRefused | ConnectionReset | ConnectionAborted | NotConnected | AddrInUse | AddrNotAvailable | BrokenPipe | AlreadyExists ) } fn exhaust_err() -> Error { io::Error::new( ErrorKind::NotConnected, "Disconnected. Connection attempts have been exhausted.", ) } } type ReconnectDummy = ReconnectStream, io::Error>; impl Stream for DummyStream { type Item = Vec; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let cloned = self.poll_read_results.clone(); let mut poll_read_results = cloned.lock().unwrap(); let (result, bytes) = poll_read_results.remove(0); if let Poll::Ready(Err(e)) = result { if e.kind() == io::ErrorKind::WouldBlock { cx.waker().wake_by_ref(); Poll::Pending } else { Poll::Ready(None) } } else { Poll::Ready(Some(bytes)) } } } impl Sink> for DummyStream { type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { unreachable!() } fn start_send(self: Pin<&mut Self>, _item: Vec) -> Result<(), Self::Error> { unreachable!() } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { unreachable!() } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } #[cfg(test)] pub mod instantiating { use super::*; #[tokio::test] async fn should_be_connected_if_initial_connect_succeeds() { let connect_outcomes = Arc::new(Mutex::new(vec![true])); let ctor = DummyCtor { connect_outcomes, ..DummyCtor::default() }; let dummy = ReconnectDummy::connect(ctor).await; assert!(dummy.is_ok()); } #[tokio::test] async fn should_be_disconnected_if_initial_connect_fails_with_fail_on_first_enabled() { let connect_outcomes = Arc::new(Mutex::new(vec![false, true])); let ctor = DummyCtor { connect_outcomes, ..DummyCtor::default() }; let dummy = ReconnectDummy::connect(ctor).await; assert!(dummy.is_err()); } #[tokio::test] async fn should_be_disconnected_if_all_initial_connects_fail() { let connect_outcomes = Arc::new(Mutex::new(vec![false, false])); let ctor = DummyCtor { connect_outcomes, ..DummyCtor::default() }; let disconnect_counter = Arc::new(AtomicU8::new(0)); let disconnect_clone = disconnect_counter.clone(); let options = ReconnectOptions::new() .with_retries_generator(|| vec![Duration::from_millis(100)]) .with_exit_if_first_connect_fails(false) .with_on_connect_fail_callback(move || { disconnect_clone.fetch_add(1, Ordering::Relaxed); }); let dummy = ReconnectDummy::connect_with_options(ctor, options).await; assert_eq!(disconnect_counter.load(Ordering::Relaxed), 2); assert!(dummy.is_err()); } #[tokio::test] async fn should_be_connected_if_initial_connect_fails_but_then_other_succeeds() { let connect_outcomes = Arc::new(Mutex::new(vec![false, true])); let ctor = DummyCtor { connect_outcomes, ..DummyCtor::default() }; let disconnect_counter = Arc::new(AtomicU8::new(0)); let disconnect_clone = disconnect_counter.clone(); let options = ReconnectOptions::new() .with_exit_if_first_connect_fails(false) .with_retries_generator(|| vec![Duration::from_millis(100)]) .with_on_connect_fail_callback(move || { disconnect_clone.fetch_add(1, Ordering::Relaxed); }); let dummy = ReconnectDummy::connect_with_options(ctor, options).await; assert_eq!(disconnect_counter.load(Ordering::Relaxed), 1); assert!(dummy.is_ok()); } } #[cfg(test)] mod already_connected { use std::str::from_utf8; use futures::stream::StreamExt; use super::*; #[tokio::test] async fn should_ignore_non_fatal_errors_and_continue_as_connected() { let connect_outcomes = Arc::new(Mutex::new(vec![true])); let poll_read_results = Arc::new(Mutex::new(vec![ ( Poll::Ready(Err(io::Error::new( io::ErrorKind::WouldBlock, "good old fashioned async io msg", ))), vec![], ), (Poll::Ready(Ok(())), b"yother".to_vec()), (Poll::Ready(Ok(())), b"e\n".to_vec()), ])); let ctor = DummyCtor { connect_outcomes, poll_read_results, }; let mut dummy = ReconnectDummy::connect(ctor).await.unwrap(); let mut buf = vec![]; buf.extend(dummy.next().await.unwrap()); buf.extend(dummy.next().await.unwrap()); let msg = from_utf8(&buf).unwrap(); assert_eq!(msg, "yothere\n"); } #[tokio::test] async fn should_be_able_to_recover_after_disconnect() { let connect_outcomes = Arc::new(Mutex::new(vec![true, false, true])); let poll_read_results = Arc::new(Mutex::new(vec![ ( Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionAborted, "fatal", ))), vec![], ), (Poll::Ready(Ok(())), b"e\n".to_vec()), ])); let ctor = DummyCtor { connect_outcomes, poll_read_results: poll_read_results.clone(), }; let disconnect_counter = Arc::new(AtomicU8::new(0)); let disconnect_clone = disconnect_counter.clone(); let options = ReconnectOptions::new() .with_on_disconnect_callback(move || { disconnect_clone.fetch_add(1, Ordering::Relaxed); }) .with_retries_generator(|| { vec![ Duration::from_millis(100), Duration::from_millis(100), Duration::from_millis(100), ] }); let mut dummy = ReconnectDummy::connect_with_options(ctor, options) .await .unwrap(); let msg = dummy.next().await.unwrap(); assert_eq!(msg, b"e\n".to_vec()); assert_eq!(disconnect_counter.load(Ordering::Relaxed), 1); } #[tokio::test] async fn should_give_up_when_all_attempts_exhausted() { let connect_outcomes = Arc::new(Mutex::new(vec![true, false, false, false])); let poll_read_results = Arc::new(Mutex::new(vec![ ( Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionAborted, "fatal", ))), vec![], ), (Poll::Ready(Ok(())), b"e\n".to_vec()), ( Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "eof"))), vec![], ), ])); let ctor = DummyCtor { connect_outcomes, poll_read_results: poll_read_results.clone(), }; let options = ReconnectOptions::new().with_retries_generator(|| { vec![ Duration::from_millis(100), Duration::from_millis(100), Duration::from_millis(100), ] }); let mut dummy = ReconnectDummy::connect_with_options(ctor, options) .await .unwrap(); let mut buf = vec![]; while let Some(msg) = dummy.next().await { buf.extend(msg); } assert!(buf.is_empty()); } }