diff --git a/zed/src/rpc_client.rs b/zed/src/rpc_client.rs index 12ea77dae76ac73422d4ed0f806dcc6ede992563..b275ead6b00d1b2fd0a44508965bf8c78c91b969 100644 --- a/zed/src/rpc_client.rs +++ b/zed/src/rpc_client.rs @@ -1,13 +1,12 @@ use anyhow::{anyhow, Result}; use futures::future::Either; -use gpui::executor::Background; use postage::{ barrier, oneshot, prelude::{Sink, Stream}, }; use smol::{ io::BoxedWriter, - lock::Mutex, + lock::{Mutex, RwLock}, prelude::{AsyncRead, AsyncWrite}, }; use std::{ @@ -23,29 +22,27 @@ use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage}; #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct ConnectionId(u32); -pub struct RpcClient { - response_channels: Arc>>>, - outgoing: Arc>>>, +struct RpcConnection { + writer: Mutex>, + response_channels: Mutex>>, next_message_id: AtomicU32, + _close_barrier: barrier::Sender, +} + +pub struct RpcClient { + connections: Arc>>>, next_connection_id: AtomicU32, - _drop_tx: barrier::Sender, - drop_rx: barrier::Receiver, } impl RpcClient { pub fn new() -> Arc { - let (_drop_tx, drop_rx) = barrier::channel(); Arc::new(Self { - response_channels: Default::default(), - outgoing: Default::default(), - next_message_id: Default::default(), + connections: Default::default(), next_connection_id: Default::default(), - _drop_tx, - drop_rx, }) } - pub async fn connect(&self, conn: Conn, executor: Arc) -> ConnectionId + pub async fn add_connection(&self, conn: Conn) -> (ConnectionId, impl Future) where Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static, { @@ -53,52 +50,63 @@ impl RpcClient { self.next_connection_id .fetch_add(1, atomic::Ordering::SeqCst), ); + let (close_tx, mut close_rx) = barrier::channel(); let (conn_rx, conn_tx) = smol::io::split(conn); - let response_channels = self.response_channels.clone(); - let mut drop_rx = self.drop_rx.clone(); - let outgoing = self.outgoing.clone(); + let connections = self.connections.clone(); + let connection = Arc::new(RpcConnection { + writer: Mutex::new(MessageStream::new(Box::pin(conn_tx))), + response_channels: Default::default(), + next_message_id: Default::default(), + _close_barrier: close_tx, + }); + + connections + .write() + .await + .insert(connection_id, connection.clone()); - executor - .spawn(async move { - let dropped = drop_rx.recv(); - smol::pin!(dropped); + let handler_future = async move { + let closed = close_rx.recv(); + smol::pin!(closed); - let mut stream = MessageStream::new(conn_rx); - loop { - let read_message = stream.read_message(); - smol::pin!(read_message); + let mut stream = MessageStream::new(conn_rx); + loop { + let read_message = stream.read_message(); + smol::pin!(read_message); - match futures::future::select(read_message, &mut dropped).await { - Either::Left((Ok(incoming), _)) => { - if let Some(responding_to) = incoming.responding_to { - let channel = response_channels.lock().await.remove(&responding_to); - if let Some(mut tx) = channel { - tx.send(incoming).await.ok(); - } else { - log::warn!( - "received RPC response to unknown request {}", - responding_to - ); - } + match futures::future::select(read_message, &mut closed).await { + Either::Left((Ok(incoming), _)) => { + if let Some(responding_to) = incoming.responding_to { + let channel = connection + .response_channels + .lock() + .await + .remove(&responding_to); + if let Some(mut tx) = channel { + tx.send(incoming).await.ok(); } else { - // unprompted message from server + log::warn!( + "received RPC response to unknown request {}", + responding_to + ); } + } else { + // unprompted message from server } - Either::Left((Err(error), _)) => { - log::warn!("received invalid RPC message {:?}", error); - } - Either::Right(_) => break, } + Either::Left((Err(error), _)) => { + log::warn!("received invalid RPC message {:?}", error); + } + Either::Right(_) => break, } - }) - .detach(); + } + }; - outgoing - .lock() - .await - .insert(connection_id, MessageStream::new(Box::pin(conn_tx))); + (connection_id, handler_future) + } - connection_id + pub async fn disconnect(&self, connection_id: ConnectionId) { + self.connections.write().await.remove(&connection_id); } pub fn request( @@ -106,17 +114,27 @@ impl RpcClient { connection_id: ConnectionId, req: T, ) -> impl Future> { - let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); - let outgoing = self.outgoing.clone(); - let response_channels = self.response_channels.clone(); + let connections = self.connections.clone(); let (tx, mut rx) = oneshot::channel(); async move { - response_channels.lock().await.insert(message_id, tx); - outgoing - .lock() + let connection = connections + .read() .await - .get_mut(&connection_id) + .get(&connection_id) .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))? + .clone(); + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .response_channels + .lock() + .await + .insert(message_id, tx); + connection + .writer + .lock() + .await .write_message(&req.into_envelope(message_id, None)) .await?; let response = rx @@ -124,7 +142,7 @@ impl RpcClient { .await .expect("response channel was unexpectedly dropped"); T::Response::from_envelope(response) - .ok_or_else(|| anyhow!("received response of the wrong t")) + .ok_or_else(|| anyhow!("received response of the wrong type")) } } @@ -133,14 +151,21 @@ impl RpcClient { connection_id: ConnectionId, message: T, ) -> impl Future> { - let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst); - let outgoing = self.outgoing.clone(); + let connections = self.connections.clone(); async move { - outgoing - .lock() + let connection = connections + .read() .await - .get_mut(&connection_id) + .get(&connection_id) .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))? + .clone(); + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + connection + .writer + .lock() + .await .write_message(&message.into_envelope(message_id, None)) .await?; Ok(()) @@ -170,7 +195,8 @@ mod tests { let mut server_stream = MessageStream::new(server_conn); let client = RpcClient::new(); - let connection_id = client.connect(client_conn, executor.clone()).await; + let (connection_id, handler) = client.add_connection(client_conn).await; + executor.spawn(handler).detach(); let client_req = client.request( connection_id, @@ -219,7 +245,7 @@ mod tests { } #[gpui::test] - async fn test_drop_client(cx: gpui::TestAppContext) { + async fn test_disconnect(cx: gpui::TestAppContext) { let executor = cx.read(|app| app.background_executor().clone()); let socket_dir_path = TempDir::new("drop-client").unwrap(); let socket_path = socket_dir_path.path().join(".sock"); @@ -228,8 +254,9 @@ mod tests { let (mut server_conn, _) = listener.accept().await.unwrap(); let client = RpcClient::new(); - client.connect(client_conn, executor.clone()).await; - drop(client); + let (connection_id, handler) = client.add_connection(client_conn).await; + executor.spawn(handler).detach(); + client.disconnect(connection_id).await; // Try sending an empty payload over and over, until the client is dropped and hangs up. loop { @@ -254,10 +281,11 @@ mod tests { client_conn.close().await.unwrap(); let client = RpcClient::new(); - let conn_id = client.connect(client_conn, executor.clone()).await; + let (connection_id, handler) = client.add_connection(client_conn).await; + executor.spawn(handler).detach(); let err = client .request( - conn_id, + connection_id, proto::Auth { user_id: 42, access_token: "token".to_string(), diff --git a/zed/src/workspace.rs b/zed/src/workspace.rs index ef75aa702a9e26e070dae605b6791409cb6bda89..b1e2526f3ead16caf2f65f450b53f1a54374674f 100644 --- a/zed/src/workspace.rs +++ b/zed/src/workspace.rs @@ -671,7 +671,8 @@ impl Workspace { let stream = smol::net::TcpStream::connect(rpc_address).await?; let rpc_client = RpcClient::new(); - let connection_id = rpc_client.connect(stream, executor).await; + let (connection_id, handler) = rpc_client.add_connection(stream).await; + executor.spawn(handler).detach(); let auth_response = rpc_client .request(