diff --git a/src/lib.rs b/src/lib.rs index f68300e..65fe7d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -243,7 +243,7 @@ pub trait Socket: Sized + Send { let backend = self.backend(); let endpoint = endpoint.try_into()?; - let result = match transport::connect(endpoint).await { + let result = match util::connect_forever(endpoint).await { Ok((socket, endpoint)) => match util::peer_connected(socket, backend).await { Ok(peer_id) => Ok((endpoint, peer_id)), Err(e) => Err(e), diff --git a/src/transport/ipc.rs b/src/transport/ipc.rs index a11d29f..14786cd 100644 --- a/src/transport/ipc.rs +++ b/src/transport/ipc.rs @@ -15,8 +15,8 @@ use crate::ZmqResult; use futures::{select, FutureExt}; use std::path::{Path, PathBuf}; -pub(crate) async fn connect(path: PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { - let raw_socket = UnixStream::connect(&path).await?; +pub(crate) async fn connect(path: &PathBuf) -> ZmqResult<(FramedIo, Endpoint)> { + let raw_socket = UnixStream::connect(path).await?; let peer_addr = raw_socket.peer_addr()?; let peer_addr = peer_addr.as_pathname().map(|a| a.to_owned()); diff --git a/src/transport/mod.rs b/src/transport/mod.rs index e9e77d3..0a6b164 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -24,10 +24,10 @@ macro_rules! do_if_enabled { /// /// # Panics /// Panics if the requested endpoint uses a transport type that isn't enabled -pub(crate) async fn connect(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> { +pub(crate) async fn connect(endpoint: &Endpoint) -> ZmqResult<(FramedIo, Endpoint)> { match endpoint { Endpoint::Tcp(_host, _port) => { - do_if_enabled!("tcp-transport", tcp::connect(_host, _port).await) + do_if_enabled!("tcp-transport", tcp::connect(_host, *_port).await) } Endpoint::Ipc(_path) => do_if_enabled!( "ipc-transport", diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index 0ae8618..8e56eae 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -14,7 +14,7 @@ use crate::ZmqResult; use futures::{select, FutureExt}; -pub(crate) async fn connect(host: Host, port: Port) -> ZmqResult<(FramedIo, Endpoint)> { +pub(crate) async fn connect(host: &Host, port: Port) -> ZmqResult<(FramedIo, Endpoint)> { let raw_socket = TcpStream::connect((host.to_string().as_str(), port)).await?; let peer_addr = raw_socket.peer_addr()?; diff --git a/src/util.rs b/src/util.rs index 0b9e040..f4dff1f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -5,6 +5,8 @@ use bytes::Bytes; use futures::stream::StreamExt; use futures::SinkExt; use futures_codec::FramedRead; +use num_traits::Pow; +use rand::Rng; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; use uuid::Uuid; @@ -178,6 +180,27 @@ pub(crate) async fn peer_connected( Ok(peer_id) } +pub(crate) async fn connect_forever(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> { + let mut try_num: u64 = 0; + loop { + match transport::connect(&endpoint).await { + Ok(res) => return Ok(res), + Err(ZmqError::Network(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => { + if try_num < 5 { + try_num += 1; + } + let delay = { + let mut rng = rand::thread_rng(); + std::f64::consts::E.pow(try_num as f64 / 3.0) + rng.gen_range(0.0f64, 0.1f64) + }; + async_rt::task::sleep(std::time::Duration::from_secs_f64(delay)).await; + continue; + } + Err(e) => return Err(e), + } + } +} + #[cfg(test)] pub(crate) mod tests { use super::*;