Cargo.lock 🔗
@@ -4364,6 +4364,7 @@ dependencies = [
"rsa",
"serde 1.0.125",
"smol",
+ "tempdir",
]
[[package]]
Max Brunsfeld created
* Re-enable peer tests
* Enhance request/response unit test to exercise
peers interacting with each other end-to-end
Cargo.lock | 1
zed-rpc/Cargo.toml | 1
zed-rpc/src/peer.rs | 387 ++++++++++++++++++++++++++--------------------
3 files changed, 221 insertions(+), 168 deletions(-)
@@ -4364,6 +4364,7 @@ dependencies = [
"rsa",
"serde 1.0.125",
"smol",
+ "tempdir",
]
[[package]]
@@ -21,3 +21,4 @@ prost-build = { git="https://github.com/sfackler/prost", rev="082f3e65874fe91382
[dev-dependencies]
smol = "1.2.5"
+tempdir = "0.3.7"
@@ -29,7 +29,6 @@ struct Connection {
writer: Mutex<MessageStream<BoxedWriter>>,
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
next_message_id: AtomicU32,
- _close_barrier: barrier::Sender,
}
type MessageHandler = Box<
@@ -53,7 +52,7 @@ impl<T> TypedEnvelope<T> {
}
pub struct Peer {
- connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
+ connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32,
@@ -106,7 +105,7 @@ impl Peer {
pub async fn add_connection<Conn>(
self: &Arc<Self>,
conn: Conn,
- ) -> (ConnectionId, impl Future<Output = ()>)
+ ) -> (ConnectionId, impl Future<Output = Result<()>>)
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@@ -119,13 +118,12 @@ impl Peer {
writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
response_channels: Default::default(),
next_message_id: Default::default(),
- _close_barrier: close_tx,
});
self.connections
.write()
.await
- .insert(connection_id, connection.clone());
+ .insert(connection_id, (connection.clone(), close_tx));
let this = self.clone();
let handler_future = async move {
@@ -178,8 +176,9 @@ impl Peer {
}
Either::Left((Err(error), _)) => {
log::warn!("received invalid RPC message: {}", error);
+ Err(error)?;
}
- Either::Right(_) => break,
+ Either::Right(_) => return Ok(()),
}
}
};
@@ -199,13 +198,7 @@ impl Peer {
let this = self.clone();
let (tx, mut rx) = oneshot::channel();
async move {
- let connection = this
- .connections
- .read()
- .await
- .get(&connection_id)
- .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
- .clone();
+ let connection = this.connection(connection_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@@ -236,13 +229,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
- let connection = this
- .connections
- .read()
- .await
- .get(&connection_id)
- .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
- .clone();
+ let connection = this.connection(connection_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@@ -263,13 +250,7 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
- let connection = this
- .connections
- .read()
- .await
- .get(&request.connection_id)
- .ok_or_else(|| anyhow!("unknown connection: {}", request.connection_id.0))?
- .clone();
+ let connection = this.connection(request.connection_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@@ -282,146 +263,216 @@ impl Peer {
Ok(())
}
}
+
+ async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
+ Ok(self
+ .connections
+ .read()
+ .await
+ .get(&id)
+ .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
+ .0
+ .clone())
+ }
}
-// #[cfg(test)]
-// mod tests {
-// use super::*;
-// use smol::{
-// future::poll_once,
-// io::AsyncWriteExt,
-// net::unix::{UnixListener, UnixStream},
-// };
-// use std::{future::Future, io};
-// use tempdir::TempDir;
-
-// #[gpui::test]
-// async fn test_request_response(cx: gpui::TestAppContext) {
-// let executor = cx.read(|app| app.background_executor().clone());
-// let socket_dir_path = TempDir::new("request-response").unwrap();
-// let socket_path = socket_dir_path.path().join(".sock");
-// let listener = UnixListener::bind(&socket_path).unwrap();
-// let client_conn = UnixStream::connect(&socket_path).await.unwrap();
-// let (server_conn, _) = listener.accept().await.unwrap();
-
-// let mut server_stream = MessageStream::new(server_conn);
-// let client = Peer::new();
-// let (connection_id, handler) = client.add_connection(client_conn).await;
-// executor.spawn(handler).detach();
-
-// let client_req = client.request(
-// connection_id,
-// proto::Auth {
-// user_id: 42,
-// access_token: "token".to_string(),
-// },
-// );
-// smol::pin!(client_req);
-// let server_req = send_recv(&mut client_req, server_stream.read_message())
-// .await
-// .unwrap();
-// assert_eq!(
-// server_req.payload,
-// Some(proto::envelope::Payload::Auth(proto::Auth {
-// user_id: 42,
-// access_token: "token".to_string()
-// }))
-// );
-
-// // Respond to another request to ensure requests are properly matched up.
-// server_stream
-// .write_message(
-// &proto::AuthResponse {
-// credentials_valid: false,
-// }
-// .into_envelope(1000, Some(999)),
-// )
-// .await
-// .unwrap();
-// server_stream
-// .write_message(
-// &proto::AuthResponse {
-// credentials_valid: true,
-// }
-// .into_envelope(1001, Some(server_req.id)),
-// )
-// .await
-// .unwrap();
-// assert_eq!(
-// client_req.await.unwrap(),
-// proto::AuthResponse {
-// credentials_valid: true
-// }
-// );
-// }
-
-// #[gpui::test]
-// async fn test_disconnect(cx: gpui::TestAppContext) {
-// let executor = cx.read(|app| app.background_executor().clone());
-// let socket_dir_path = TempDir::new("drop-client").unwrap();
-// let socket_path = socket_dir_path.path().join(".sock");
-// let listener = UnixListener::bind(&socket_path).unwrap();
-// let client_conn = UnixStream::connect(&socket_path).await.unwrap();
-// let (mut server_conn, _) = listener.accept().await.unwrap();
-
-// let client = Peer::new();
-// let (connection_id, handler) = client.add_connection(client_conn).await;
-// executor.spawn(handler).detach();
-// client.disconnect(connection_id).await;
-
-// // Try sending an empty payload over and over, until the client is dropped and hangs up.
-// loop {
-// match server_conn.write(&[]).await {
-// Ok(_) => {}
-// Err(err) => {
-// if err.kind() == io::ErrorKind::BrokenPipe {
-// break;
-// }
-// }
-// }
-// }
-// }
-
-// #[gpui::test]
-// async fn test_io_error(cx: gpui::TestAppContext) {
-// let executor = cx.read(|app| app.background_executor().clone());
-// let socket_dir_path = TempDir::new("io-error").unwrap();
-// let socket_path = socket_dir_path.path().join(".sock");
-// let _listener = UnixListener::bind(&socket_path).unwrap();
-// let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
-// client_conn.close().await.unwrap();
-
-// let client = Peer::new();
-// let (connection_id, handler) = client.add_connection(client_conn).await;
-// executor.spawn(handler).detach();
-// let err = client
-// .request(
-// connection_id,
-// proto::Auth {
-// user_id: 42,
-// access_token: "token".to_string(),
-// },
-// )
-// .await
-// .unwrap_err();
-// assert_eq!(
-// err.downcast_ref::<io::Error>().unwrap().kind(),
-// io::ErrorKind::BrokenPipe
-// );
-// }
-
-// async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
-// where
-// S: Unpin + Future,
-// R: Future<Output = O>,
-// {
-// smol::pin!(receiver);
-// loop {
-// poll_once(&mut sender).await;
-// match poll_once(&mut receiver).await {
-// Some(message) => break message,
-// None => continue,
-// }
-// }
-// }
-// }
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use smol::{
+ io::AsyncWriteExt,
+ net::unix::{UnixListener, UnixStream},
+ };
+ use std::io;
+ use tempdir::TempDir;
+
+ #[test]
+ fn test_request_response() {
+ smol::block_on(async move {
+ // create socket
+ let socket_dir_path = TempDir::new("test-request-response").unwrap();
+ let socket_path = socket_dir_path.path().join("test.sock");
+ let listener = UnixListener::bind(&socket_path).unwrap();
+
+ // create 2 clients connected to 1 server
+ let server = Peer::new();
+ let client1 = Peer::new();
+ let client2 = Peer::new();
+ let (client1_conn_id, f1) = client1
+ .add_connection(UnixStream::connect(&socket_path).await.unwrap())
+ .await;
+ let (client2_conn_id, f2) = client2
+ .add_connection(UnixStream::connect(&socket_path).await.unwrap())
+ .await;
+ let (_, f3) = server
+ .add_connection(listener.accept().await.unwrap().0)
+ .await;
+ let (_, f4) = server
+ .add_connection(listener.accept().await.unwrap().0)
+ .await;
+ smol::spawn(f1).detach();
+ smol::spawn(f2).detach();
+ smol::spawn(f3).detach();
+ smol::spawn(f4).detach();
+
+ // define the expected requests and responses
+ let request1 = proto::OpenWorktree {
+ worktree_id: 101,
+ access_token: "first-worktree-access-token".to_string(),
+ };
+ let response1 = proto::OpenWorktreeResponse {
+ worktree: Some(proto::Worktree {
+ paths: vec![b"path/one".to_vec()],
+ }),
+ };
+ let request2 = proto::OpenWorktree {
+ worktree_id: 102,
+ access_token: "second-worktree-access-token".to_string(),
+ };
+ let response2 = proto::OpenWorktreeResponse {
+ worktree: Some(proto::Worktree {
+ paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
+ }),
+ };
+ let request3 = proto::OpenBuffer {
+ worktree_id: 102,
+ path: b"path/two".to_vec(),
+ };
+ let response3 = proto::OpenBufferResponse {
+ buffer: Some(proto::Buffer {
+ id: 1001,
+ path: b"path/two".to_vec(),
+ content: b"path/two content".to_vec(),
+ history: vec![],
+ }),
+ };
+ let request4 = proto::OpenBuffer {
+ worktree_id: 101,
+ path: b"path/one".to_vec(),
+ };
+ let response4 = proto::OpenBufferResponse {
+ buffer: Some(proto::Buffer {
+ id: 1002,
+ path: b"path/one".to_vec(),
+ content: b"path/one content".to_vec(),
+ history: vec![],
+ }),
+ };
+
+ // on the server, respond to two requests for each client
+ let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
+ let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
+ let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
+ smol::spawn({
+ let request1 = request1.clone();
+ let request2 = request2.clone();
+ let request3 = request3.clone();
+ let request4 = request4.clone();
+ let response1 = response1.clone();
+ let response2 = response2.clone();
+ let response3 = response3.clone();
+ let response4 = response4.clone();
+ async move {
+ let msg = open_worktree_rx.recv().await.unwrap();
+ assert_eq!(msg.payload, request1);
+ server.respond(msg, response1.clone()).await.unwrap();
+
+ let msg = open_worktree_rx.recv().await.unwrap();
+ assert_eq!(msg.payload, request2.clone());
+ server.respond(msg, response2.clone()).await.unwrap();
+
+ let msg = open_buffer_rx.recv().await.unwrap();
+ assert_eq!(msg.payload, request3.clone());
+ server.respond(msg, response3.clone()).await.unwrap();
+
+ let msg = open_buffer_rx.recv().await.unwrap();
+ assert_eq!(msg.payload, request4.clone());
+ server.respond(msg, response4.clone()).await.unwrap();
+
+ server_done_tx.send(()).await.unwrap();
+ }
+ })
+ .detach();
+
+ assert_eq!(
+ client1.request(client1_conn_id, request1).await.unwrap(),
+ response1
+ );
+ assert_eq!(
+ client2.request(client2_conn_id, request2).await.unwrap(),
+ response2
+ );
+ assert_eq!(
+ client2.request(client2_conn_id, request3).await.unwrap(),
+ response3
+ );
+ assert_eq!(
+ client1.request(client1_conn_id, request4).await.unwrap(),
+ response4
+ );
+
+ client1.disconnect(client1_conn_id).await;
+ client2.disconnect(client1_conn_id).await;
+
+ server_done_rx.recv().await.unwrap();
+ });
+ }
+
+ #[test]
+ fn test_disconnect() {
+ smol::block_on(async move {
+ let socket_dir_path = TempDir::new("drop-client").unwrap();
+ let socket_path = socket_dir_path.path().join(".sock");
+ let listener = UnixListener::bind(&socket_path).unwrap();
+ let client_conn = UnixStream::connect(&socket_path).await.unwrap();
+ let (mut server_conn, _) = listener.accept().await.unwrap();
+
+ let client = Peer::new();
+ let (connection_id, handler) = client.add_connection(client_conn).await;
+ smol::spawn(handler).detach();
+ client.disconnect(connection_id).await;
+
+ // Try sending an empty payload over and over, until the client is dropped and hangs up.
+ loop {
+ match server_conn.write(&[]).await {
+ Ok(_) => {}
+ Err(err) => {
+ if err.kind() == io::ErrorKind::BrokenPipe {
+ break;
+ }
+ }
+ }
+ }
+ });
+ }
+
+ #[test]
+ fn test_io_error() {
+ smol::block_on(async move {
+ let socket_dir_path = TempDir::new("io-error").unwrap();
+ let socket_path = socket_dir_path.path().join(".sock");
+ let _listener = UnixListener::bind(&socket_path).unwrap();
+ let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
+ client_conn.close().await.unwrap();
+
+ let client = Peer::new();
+ let (connection_id, handler) = client.add_connection(client_conn).await;
+ smol::spawn(handler).detach();
+
+ let err = client
+ .request(
+ connection_id,
+ proto::Auth {
+ user_id: 42,
+ access_token: "token".to_string(),
+ },
+ )
+ .await
+ .unwrap_err();
+ assert_eq!(
+ err.downcast_ref::<io::Error>().unwrap().kind(),
+ io::ErrorKind::BrokenPipe
+ );
+ });
+ }
+}