diff --git a/examples/forwarder_device.rs b/examples/forwarder_device.rs new file mode 100644 index 0000000..84cdaf2 --- /dev/null +++ b/examples/forwarder_device.rs @@ -0,0 +1,27 @@ +mod async_helpers; +use std::error::Error; +use zeromq::prelude::*; + +#[async_helpers::main] +async fn main() -> Result<(), Box> { + println!("Start forwarder"); + let mut frontend = zeromq::SubSocket::new(); + frontend.bind("tcp://127.0.0.1:30001").await?; + + let mut backend = zeromq::PubSocket::new(); + backend.bind("tcp://127.0.0.1:30002").await?; + + frontend.subscribe("").await?; + + let forward = async move { + loop { + let message = frontend.recv().await.unwrap(); + println!("passing message: {:?}", message); + backend.send(message).await.unwrap(); + } + }; + + forward.await; + + Ok(()) +} diff --git a/src/backend.rs b/src/backend.rs index 80bb0fa..9d64c8a 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -4,6 +4,7 @@ use crate::util::PeerIdentity; use crate::{ MultiPeerBackend, SocketBackend, SocketEvent, SocketOptions, SocketType, ZmqError, ZmqResult, }; +use async_trait::async_trait; use crossbeam::queue::SegQueue; use dashmap::DashMap; use futures::channel::mpsc; @@ -97,8 +98,9 @@ impl SocketBackend for GenericSocketBackend { } } +#[async_trait] impl MultiPeerBackend for GenericSocketBackend { - fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { + async fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { let (recv_queue, send_queue) = io.into_parts(); self.peers.insert(peer_id.clone(), Peer { send_queue }); self.round_robin.push(peer_id.clone()); diff --git a/src/lib.rs b/src/lib.rs index fecb572..09f26f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -143,11 +143,12 @@ impl Default for SocketOptions { } } +#[async_trait] pub trait MultiPeerBackend: SocketBackend { /// This should not be public.. /// Find a better way of doing this - fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo); + async fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo); fn peer_disconnected(&self, peer_id: &PeerIdentity); } diff --git a/src/pub.rs b/src/pub.rs index d4b194f..5d240c6 100644 --- a/src/pub.rs +++ b/src/pub.rs @@ -99,8 +99,9 @@ impl SocketBackend for PubSocketBackend { } } +#[async_trait] impl MultiPeerBackend for PubSocketBackend { - fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { + async fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { let (mut recv_queue, send_queue) = io.into_parts(); // TODO provide handling for recv_queue let (sender, stop_receiver) = oneshot::channel(); diff --git a/src/rep.rs b/src/rep.rs index 0531b38..a8796d6 100644 --- a/src/rep.rs +++ b/src/rep.rs @@ -72,8 +72,9 @@ impl Socket for RepSocket { } } +#[async_trait] impl MultiPeerBackend for RepSocketBackend { - fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { + async fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { let (recv_queue, send_queue) = io.into_parts(); self.peers.insert( diff --git a/src/req.rs b/src/req.rs index e87e430..1382880 100644 --- a/src/req.rs +++ b/src/req.rs @@ -126,8 +126,9 @@ impl Socket for ReqSocket { } } +#[async_trait] impl MultiPeerBackend for ReqSocketBackend { - fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { + async fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { let (recv_queue, send_queue) = io.into_parts(); self.peers.insert( peer_id.clone(), diff --git a/src/sub.rs b/src/sub.rs index d1540d6..c4e8fb4 100644 --- a/src/sub.rs +++ b/src/sub.rs @@ -8,17 +8,114 @@ use crate::{ MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketOptions, SocketRecv, SocketType, }; -use crate::backend::GenericSocketBackend; +use crate::backend::Peer; use crate::fair_queue::FairQueue; +use crate::fair_queue::QueueInner; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; +use crossbeam::queue::SegQueue; +use dashmap::DashMap; use futures::channel::mpsc; use futures::{SinkExt, StreamExt}; -use std::collections::HashMap; +use parking_lot::Mutex; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; +pub enum SubBackendMsgType { + UNSUBSCRIBE = 0, + SUBSCRIBE = 1, +} + +pub(crate) struct SubSocketBackend { + pub(crate) peers: DashMap, + fair_queue_inner: Option>>>, + pub(crate) round_robin: SegQueue, + socket_type: SocketType, + socket_options: SocketOptions, + pub(crate) socket_monitor: Mutex>>, + subs: Mutex>, +} + +impl SubSocketBackend { + pub(crate) fn with_options( + fair_queue_inner: Option>>>, + socket_type: SocketType, + options: SocketOptions, + ) -> Self { + Self { + peers: DashMap::new(), + fair_queue_inner, + round_robin: SegQueue::new(), + socket_type, + socket_options: options, + socket_monitor: Mutex::new(None), + subs: Mutex::new(HashSet::new()), + } + } + + pub fn create_subs_message(subscription: &str, msg_type: SubBackendMsgType) -> ZmqMessage { + let mut buf = BytesMut::with_capacity(subscription.len() + 1); + buf.put_u8(msg_type as u8); + buf.extend_from_slice(subscription.as_bytes()); + + buf.freeze().into() + } +} + +impl SocketBackend for SubSocketBackend { + fn socket_type(&self) -> SocketType { + self.socket_type + } + + fn socket_options(&self) -> &SocketOptions { + &self.socket_options + } + + fn shutdown(&self) { + self.peers.clear(); + } + + fn monitor(&self) -> &Mutex>> { + &self.socket_monitor + } +} + +#[async_trait] +impl MultiPeerBackend for SubSocketBackend { + async fn peer_connected(self: Arc, peer_id: &PeerIdentity, io: FramedIo) { + let (recv_queue, mut send_queue) = io.into_parts(); + + let subs_msgs: Vec = self + .subs + .lock() + .iter() + .map(|x| SubSocketBackend::create_subs_message(x, SubBackendMsgType::SUBSCRIBE)) + .collect(); + + for message in subs_msgs.iter() { + send_queue + .send(Message::Message(message.clone())) + .await + .unwrap(); + } + + self.peers.insert(peer_id.clone(), Peer { send_queue }); + self.round_robin.push(peer_id.clone()); + match &self.fair_queue_inner { + None => {} + Some(inner) => { + inner.lock().insert(peer_id.clone(), recv_queue); + } + }; + } + + fn peer_disconnected(&self, peer_id: &PeerIdentity) { + self.peers.remove(peer_id); + } +} + pub struct SubSocket { - backend: Arc, + backend: Arc, fair_queue: FairQueue, binds: HashMap, } @@ -31,24 +128,24 @@ impl Drop for SubSocket { impl SubSocket { pub async fn subscribe(&mut self, subscription: &str) -> ZmqResult<()> { - let mut buf = BytesMut::with_capacity(subscription.len() + 1); - buf.put_u8(1); - buf.extend_from_slice(subscription.as_bytes()); - // let message = format!("\0x1{}", subscription); - let message: ZmqMessage = ZmqMessage::from(buf.freeze()); - for mut peer in self.backend.peers.iter_mut() { - peer.send_queue - .send(Message::Message(message.clone())) - .await?; - } - Ok(()) + self.backend.subs.lock().insert(subscription.to_string()); + self.process_subs(subscription, SubBackendMsgType::SUBSCRIBE) + .await } pub async fn unsubscribe(&mut self, subscription: &str) -> ZmqResult<()> { - let mut buf = BytesMut::with_capacity(subscription.len() + 1); - buf.put_u8(0); - buf.extend_from_slice(subscription.as_bytes()); - let message = ZmqMessage::from(buf.freeze()); + self.backend.subs.lock().remove(subscription); + self.process_subs(subscription, SubBackendMsgType::UNSUBSCRIBE) + .await + } + + async fn process_subs( + &mut self, + subscription: &str, + msg_type: SubBackendMsgType, + ) -> ZmqResult<()> { + let message: ZmqMessage = SubSocketBackend::create_subs_message(subscription, msg_type); + for mut peer in self.backend.peers.iter_mut() { peer.send_queue .send(Message::Message(message.clone())) @@ -63,7 +160,7 @@ impl Socket for SubSocket { fn with_options(options: SocketOptions) -> Self { let fair_queue = FairQueue::new(true); Self { - backend: Arc::new(GenericSocketBackend::with_options( + backend: Arc::new(SubSocketBackend::with_options( Some(fair_queue.inner()), SocketType::SUB, options, diff --git a/src/util.rs b/src/util.rs index 5a4976c..a5c3a3c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -191,7 +191,7 @@ pub(crate) async fn peer_connected( props = Some(connect_ops); } let peer_id = ready_exchange(&mut raw_socket, backend.socket_type(), props).await?; - backend.peer_connected(&peer_id, raw_socket); + backend.peer_connected(&peer_id, raw_socket).await; Ok(peer_id) }