@@ -1,13 +1,12 @@
use anyhow::{anyhow, Result};
use futures::future::Either;
-use gpui::executor::Background;
use postage::{
barrier, oneshot,
prelude::{Sink, Stream},
};
use smol::{
io::BoxedWriter,
- lock::Mutex,
+ lock::{Mutex, RwLock},
prelude::{AsyncRead, AsyncWrite},
};
use std::{
@@ -23,29 +22,27 @@ use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct ConnectionId(u32);
-pub struct RpcClient {
- response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
- outgoing: Arc<Mutex<HashMap<ConnectionId, MessageStream<BoxedWriter>>>>,
+struct RpcConnection {
+ writer: Mutex<MessageStream<BoxedWriter>>,
+ response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
next_message_id: AtomicU32,
+ _close_barrier: barrier::Sender,
+}
+
+pub struct RpcClient {
+ connections: Arc<RwLock<HashMap<ConnectionId, Arc<RpcConnection>>>>,
next_connection_id: AtomicU32,
- _drop_tx: barrier::Sender,
- drop_rx: barrier::Receiver,
}
impl RpcClient {
pub fn new() -> Arc<Self> {
- let (_drop_tx, drop_rx) = barrier::channel();
Arc::new(Self {
- response_channels: Default::default(),
- outgoing: Default::default(),
- next_message_id: Default::default(),
+ connections: Default::default(),
next_connection_id: Default::default(),
- _drop_tx,
- drop_rx,
})
}
- pub async fn connect<Conn>(&self, conn: Conn, executor: Arc<Background>) -> ConnectionId
+ pub async fn add_connection<Conn>(&self, conn: Conn) -> (ConnectionId, impl Future<Output = ()>)
where
Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@@ -53,52 +50,63 @@ impl RpcClient {
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
+ let (close_tx, mut close_rx) = barrier::channel();
let (conn_rx, conn_tx) = smol::io::split(conn);
- let response_channels = self.response_channels.clone();
- let mut drop_rx = self.drop_rx.clone();
- let outgoing = self.outgoing.clone();
+ let connections = self.connections.clone();
+ let connection = Arc::new(RpcConnection {
+ writer: Mutex::new(MessageStream::new(Box::pin(conn_tx))),
+ response_channels: Default::default(),
+ next_message_id: Default::default(),
+ _close_barrier: close_tx,
+ });
+
+ connections
+ .write()
+ .await
+ .insert(connection_id, connection.clone());
- executor
- .spawn(async move {
- let dropped = drop_rx.recv();
- smol::pin!(dropped);
+ let handler_future = async move {
+ let closed = close_rx.recv();
+ smol::pin!(closed);
- let mut stream = MessageStream::new(conn_rx);
- loop {
- let read_message = stream.read_message();
- smol::pin!(read_message);
+ let mut stream = MessageStream::new(conn_rx);
+ loop {
+ let read_message = stream.read_message();
+ smol::pin!(read_message);
- match futures::future::select(read_message, &mut dropped).await {
- Either::Left((Ok(incoming), _)) => {
- if let Some(responding_to) = incoming.responding_to {
- let channel = response_channels.lock().await.remove(&responding_to);
- if let Some(mut tx) = channel {
- tx.send(incoming).await.ok();
- } else {
- log::warn!(
- "received RPC response to unknown request {}",
- responding_to
- );
- }
+ match futures::future::select(read_message, &mut closed).await {
+ Either::Left((Ok(incoming), _)) => {
+ if let Some(responding_to) = incoming.responding_to {
+ let channel = connection
+ .response_channels
+ .lock()
+ .await
+ .remove(&responding_to);
+ if let Some(mut tx) = channel {
+ tx.send(incoming).await.ok();
} else {
- // unprompted message from server
+ log::warn!(
+ "received RPC response to unknown request {}",
+ responding_to
+ );
}
+ } else {
+ // unprompted message from server
}
- Either::Left((Err(error), _)) => {
- log::warn!("received invalid RPC message {:?}", error);
- }
- Either::Right(_) => break,
}
+ Either::Left((Err(error), _)) => {
+ log::warn!("received invalid RPC message {:?}", error);
+ }
+ Either::Right(_) => break,
}
- })
- .detach();
+ }
+ };
- outgoing
- .lock()
- .await
- .insert(connection_id, MessageStream::new(Box::pin(conn_tx)));
+ (connection_id, handler_future)
+ }
- connection_id
+ pub async fn disconnect(&self, connection_id: ConnectionId) {
+ self.connections.write().await.remove(&connection_id);
}
pub fn request<T: RequestMessage>(
@@ -106,17 +114,27 @@ impl RpcClient {
connection_id: ConnectionId,
req: T,
) -> impl Future<Output = Result<T::Response>> {
- let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- let outgoing = self.outgoing.clone();
- let response_channels = self.response_channels.clone();
+ let connections = self.connections.clone();
let (tx, mut rx) = oneshot::channel();
async move {
- response_channels.lock().await.insert(message_id, tx);
- outgoing
- .lock()
+ let connection = connections
+ .read()
.await
- .get_mut(&connection_id)
+ .get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
+ .clone();
+ let message_id = connection
+ .next_message_id
+ .fetch_add(1, atomic::Ordering::SeqCst);
+ connection
+ .response_channels
+ .lock()
+ .await
+ .insert(message_id, tx);
+ connection
+ .writer
+ .lock()
+ .await
.write_message(&req.into_envelope(message_id, None))
.await?;
let response = rx
@@ -124,7 +142,7 @@ impl RpcClient {
.await
.expect("response channel was unexpectedly dropped");
T::Response::from_envelope(response)
- .ok_or_else(|| anyhow!("received response of the wrong t"))
+ .ok_or_else(|| anyhow!("received response of the wrong type"))
}
}
@@ -133,14 +151,21 @@ impl RpcClient {
connection_id: ConnectionId,
message: T,
) -> impl Future<Output = Result<()>> {
- let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- let outgoing = self.outgoing.clone();
+ let connections = self.connections.clone();
async move {
- outgoing
- .lock()
+ let connection = connections
+ .read()
.await
- .get_mut(&connection_id)
+ .get(&connection_id)
.ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
+ .clone();
+ let message_id = connection
+ .next_message_id
+ .fetch_add(1, atomic::Ordering::SeqCst);
+ connection
+ .writer
+ .lock()
+ .await
.write_message(&message.into_envelope(message_id, None))
.await?;
Ok(())
@@ -170,7 +195,8 @@ mod tests {
let mut server_stream = MessageStream::new(server_conn);
let client = RpcClient::new();
- let connection_id = client.connect(client_conn, executor.clone()).await;
+ let (connection_id, handler) = client.add_connection(client_conn).await;
+ executor.spawn(handler).detach();
let client_req = client.request(
connection_id,
@@ -219,7 +245,7 @@ mod tests {
}
#[gpui::test]
- async fn test_drop_client(cx: gpui::TestAppContext) {
+ async fn test_disconnect(cx: gpui::TestAppContext) {
let executor = cx.read(|app| app.background_executor().clone());
let socket_dir_path = TempDir::new("drop-client").unwrap();
let socket_path = socket_dir_path.path().join(".sock");
@@ -228,8 +254,9 @@ mod tests {
let (mut server_conn, _) = listener.accept().await.unwrap();
let client = RpcClient::new();
- client.connect(client_conn, executor.clone()).await;
- drop(client);
+ let (connection_id, handler) = client.add_connection(client_conn).await;
+ executor.spawn(handler).detach();
+ client.disconnect(connection_id).await;
// Try sending an empty payload over and over, until the client is dropped and hangs up.
loop {
@@ -254,10 +281,11 @@ mod tests {
client_conn.close().await.unwrap();
let client = RpcClient::new();
- let conn_id = client.connect(client_conn, executor.clone()).await;
+ let (connection_id, handler) = client.add_connection(client_conn).await;
+ executor.spawn(handler).detach();
let err = client
.request(
- conn_id,
+ connection_id,
proto::Auth {
user_id: 42,
access_token: "token".to_string(),