@@ -12,6 +12,7 @@ use smol::{
};
use std::{
collections::HashMap,
+ future::Future,
sync::{
atomic::{self, AtomicI32},
Arc,
@@ -32,7 +33,7 @@ impl<Conn> RpcClient<Conn>
where
Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
- pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
+ pub fn new(conn: Conn, executor: Arc<Background>) -> Arc<Self> {
let response_channels = Arc::new(Mutex::new(HashMap::new()));
let (conn_rx, conn_tx) = smol::io::split(conn);
let (_drop_tx, drop_rx) = barrier::channel();
@@ -45,12 +46,12 @@ where
))
.detach();
- Self {
+ Arc::new(Self {
response_channels,
outgoing: Mutex::new(MessageStream::new(conn_tx)),
_drop_tx,
next_message_id: AtomicI32::new(0),
- }
+ })
}
async fn handle_incoming(
@@ -101,63 +102,75 @@ where
}
}
- pub async fn request<T: RequestMessage>(&self, req: T) -> Result<T::Response> {
- let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- let (tx, mut rx) = mpsc::channel(1);
- self.response_channels
- .lock()
- .await
- .insert(message_id, (tx, true));
- self.outgoing
- .lock()
- .await
- .write_message(&proto::FromClient {
- id: message_id,
- variant: Some(req.to_variant()),
- })
- .await?;
- let response = rx
- .recv()
- .await
- .expect("response channel was unexpectedly dropped");
- T::Response::from_variant(response)
- .ok_or_else(|| anyhow!("received response of the wrong t"))
+ pub fn request<T: RequestMessage>(
+ self: &Arc<Self>,
+ req: T,
+ ) -> impl Future<Output = Result<T::Response>> {
+ let this = self.clone();
+ async move {
+ let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+ let (tx, mut rx) = mpsc::channel(1);
+ this.response_channels
+ .lock()
+ .await
+ .insert(message_id, (tx, true));
+ this.outgoing
+ .lock()
+ .await
+ .write_message(&proto::FromClient {
+ id: message_id,
+ variant: Some(req.to_variant()),
+ })
+ .await?;
+ let response = rx
+ .recv()
+ .await
+ .expect("response channel was unexpectedly dropped");
+ T::Response::from_variant(response)
+ .ok_or_else(|| anyhow!("received response of the wrong t"))
+ }
}
- pub async fn send<T: SendMessage>(&self, message: T) -> Result<()> {
- let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- self.outgoing
- .lock()
- .await
- .write_message(&proto::FromClient {
- id: message_id,
- variant: Some(message.to_variant()),
- })
- .await?;
- Ok(())
+ pub fn send<T: SendMessage>(self: &Arc<Self>, message: T) -> impl Future<Output = Result<()>> {
+ let this = self.clone();
+ async move {
+ let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+ this.outgoing
+ .lock()
+ .await
+ .write_message(&proto::FromClient {
+ id: message_id,
+ variant: Some(message.to_variant()),
+ })
+ .await?;
+ Ok(())
+ }
}
- pub async fn subscribe<T: SubscribeMessage>(
- &self,
+ pub fn subscribe<T: SubscribeMessage>(
+ self: &Arc<Self>,
subscription: T,
- ) -> Result<impl Stream<Item = Result<T::Event>>> {
- let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- let (tx, rx) = mpsc::channel(256);
- self.response_channels
- .lock()
- .await
- .insert(message_id, (tx, false));
- self.outgoing
- .lock()
- .await
- .write_message(&proto::FromClient {
- id: message_id,
- variant: Some(subscription.to_variant()),
- })
- .await?;
- Ok(rx.map(|event| {
- T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
- }))
+ ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Event>>>> {
+ let this = self.clone();
+ async move {
+ let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+ let (tx, rx) = mpsc::channel(256);
+ this.response_channels
+ .lock()
+ .await
+ .insert(message_id, (tx, false));
+ this.outgoing
+ .lock()
+ .await
+ .write_message(&proto::FromClient {
+ id: message_id,
+ variant: Some(subscription.to_variant()),
+ })
+ .await?;
+ Ok(rx.map(|event| {
+ T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
+ }))
+ }
}
}