diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index 0fa69d3ab31ef2a9faf84924a6d81a3627f1e5d6..69d041104f92e902b47cae26efb61b1b072b918d 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -1,18 +1,17 @@ use anyhow::{anyhow, Result}; use futures::future::Either; use gpui::executor::Background; -use parking_lot::Mutex; use postage::{ - barrier, mpsc, oneshot, + barrier, mpsc, prelude::{Sink, Stream}, }; use smol::{ io::{ReadHalf, WriteHalf}, + lock::Mutex, prelude::{AsyncRead, AsyncWrite}, }; use std::{ collections::HashMap, - io, sync::{ atomic::{self, AtomicI32}, Arc, @@ -22,21 +21,20 @@ use zed_rpc::proto::{ self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage, }; -pub struct RpcClient { +pub struct RpcClient { response_channels: Arc, bool)>>>, - outgoing_tx: mpsc::Sender<(proto::FromClient, oneshot::Sender>)>, + outgoing: Mutex>>, next_message_id: AtomicI32, _drop_tx: barrier::Sender, } -impl RpcClient { - pub fn new(conn: Conn, executor: Arc) -> Self - where - Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, - { +impl RpcClient +where + Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + pub fn new(conn: Conn, executor: Arc) -> Self { let response_channels = Arc::new(Mutex::new(HashMap::new())); let (conn_rx, conn_tx) = smol::io::split(conn); - let (outgoing_tx, outgoing_rx) = mpsc::channel(32); let (_drop_tx, drop_rx) = barrier::channel(); executor @@ -47,27 +45,21 @@ impl RpcClient { )) .detach(); - executor - .spawn(Self::handle_outgoing(conn_tx, outgoing_rx)) - .detach(); - Self { response_channels, - outgoing_tx, + outgoing: Mutex::new(MessageStream::new(conn_tx)), _drop_tx, next_message_id: AtomicI32::new(0), } } - async fn handle_incoming( + async fn handle_incoming( conn: ReadHalf, mut drop_rx: barrier::Receiver, response_channels: Arc< Mutex, bool)>>, >, - ) where - Conn: AsyncRead + Unpin, - { + ) { let mut stream = MessageStream::new(conn); loop { let read_message = stream.read_message::(); @@ -78,11 +70,14 @@ impl RpcClient { Either::Left((Ok(incoming), _)) => { if let Some(variant) = incoming.variant { if let Some(request_id) = incoming.request_id { - let channel = response_channels.lock().remove(&request_id); + let channel = response_channels.lock().await.remove(&request_id); if let Some((mut tx, oneshot)) = channel { if tx.send(variant).await.is_ok() { if !oneshot { - response_channels.lock().insert(request_id, (tx, false)); + response_channels + .lock() + .await + .insert(request_id, (tx, false)); } } } else { @@ -104,36 +99,21 @@ impl RpcClient { } } - async fn handle_outgoing( - conn: WriteHalf, - mut outgoing_rx: mpsc::Receiver<(proto::FromClient, oneshot::Sender>)>, - ) where - Conn: AsyncWrite + Unpin, - { - let mut stream = MessageStream::new(conn); - while let Some((message, mut result_tx)) = outgoing_rx.recv().await { - let result = stream.write_message(&message).await; - result_tx.send(result).await.unwrap(); - } - } - pub async fn request(&self, req: T) -> Result { let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); - let (result_tx, mut result_rx) = oneshot::channel(); let (tx, mut rx) = mpsc::channel(1); - self.response_channels.lock().insert(message_id, (tx, true)); - self.outgoing_tx - .clone() - .send(( - proto::FromClient { - id: message_id, - variant: Some(req.to_variant()), - }, - result_tx, - )) + self.response_channels + .lock() + .await + .insert(message_id, (tx, true)); + self.outgoing + .lock() .await - .ok(); - result_rx.recv().await.unwrap()?; + .write_message(&proto::FromClient { + id: message_id, + variant: Some(req.to_variant()), + }) + .await?; let response = rx .recv() .await @@ -144,19 +124,14 @@ impl RpcClient { pub async fn send(&self, message: T) -> Result<()> { let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); - let (result_tx, mut result_rx) = oneshot::channel(); - self.outgoing_tx - .clone() - .send(( - proto::FromClient { - id: message_id, - variant: Some(message.to_variant()), - }, - result_tx, - )) + self.outgoing + .lock() .await - .ok(); - result_rx.recv().await.unwrap()?; + .write_message(&proto::FromClient { + id: message_id, + variant: Some(message.to_variant()), + }) + .await?; Ok(()) } @@ -165,23 +140,19 @@ impl RpcClient { subscription: T, ) -> Result>> { let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); - let (result_tx, mut result_rx) = oneshot::channel(); let (tx, rx) = mpsc::channel(256); self.response_channels .lock() + .await .insert(message_id, (tx, false)); - self.outgoing_tx - .clone() - .send(( - proto::FromClient { - id: message_id, - variant: Some(subscription.to_variant()), - }, - result_tx, - )) + self.outgoing + .lock() .await - .ok(); - result_rx.recv().await.unwrap()?; + .write_message(&proto::FromClient { + id: message_id, + variant: Some(subscription.to_variant()), + }) + .await?; Ok(rx.map(|event| { T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}")) }))