@@ -1,7 +1,7 @@
use anyhow::{anyhow, Context, Result};
use gpui::{AsyncAppContext, MutableAppContext, Task};
use rpc_client::RpcClient;
-use std::{convert::TryFrom, net::Shutdown, time::Duration};
+use std::{convert::TryFrom, time::Duration};
use tiny_http::{Header, Response, Server};
use url::Url;
use util::SurfResultExt;
@@ -60,9 +60,7 @@ fn share_worktree(_: &(), cx: &mut MutableAppContext) {
// a TLS stream using `native-tls`.
let stream = smol::net::TcpStream::connect(rpc_address).await?;
- let mut rpc_client = RpcClient::new(stream, executor, |stream| {
- stream.shutdown(Shutdown::Read).ok();
- });
+ let mut rpc_client = RpcClient::new(stream, executor);
let auth_response = rpc_client
.request(proto::from_client::Auth {
@@ -5,60 +5,81 @@ use postage::{
oneshot,
prelude::{Sink, Stream},
};
-use smol::prelude::{AsyncRead, AsyncWrite};
+use smol::{
+ future::FutureExt,
+ io::WriteHalf,
+ prelude::{AsyncRead, AsyncWrite},
+};
use std::{collections::HashMap, sync::Arc};
use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage};
-pub struct RpcClient<Conn, ShutdownFn>
-where
- ShutdownFn: FnMut(&mut Conn),
-{
- stream: MessageStream<Conn>,
+pub struct RpcClient<Conn> {
+ stream: MessageStream<WriteHalf<Conn>>,
response_channels: Arc<Mutex<HashMap<i32, oneshot::Sender<proto::from_server::Variant>>>>,
next_message_id: i32,
- shutdown_fn: ShutdownFn,
+ _drop_tx: oneshot::Sender<()>,
}
-impl<Conn, ShutdownFn> RpcClient<Conn, ShutdownFn>
+impl<Conn> RpcClient<Conn>
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
- ShutdownFn: FnMut(&mut Conn),
{
- pub fn new(conn: Conn, executor: Arc<Background>, shutdown_fn: ShutdownFn) -> Self {
+ pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
+ let (conn_rx, conn_tx) = smol::io::split(conn);
+ let (drop_tx, mut drop_rx) = oneshot::channel();
let response_channels = Arc::new(Mutex::new(HashMap::new()));
-
- let result = Self {
+ let client = Self {
next_message_id: 0,
- stream: MessageStream::new(conn.clone()),
+ stream: MessageStream::new(conn_tx),
response_channels: response_channels.clone(),
- shutdown_fn,
+ _drop_tx: drop_tx,
};
executor
.spawn::<Result<()>, _>(async move {
- let mut stream = MessageStream::new(conn);
+ enum Message {
+ Message(proto::FromServer),
+ ClientDropped,
+ }
+
+ let mut stream = MessageStream::new(conn_rx);
+ let client_dropped = async move {
+ assert!(drop_rx.recv().await.is_none());
+ Ok(Message::ClientDropped) as Result<_>
+ };
+ smol::pin!(client_dropped);
loop {
- let message = stream.read_message::<proto::FromServer>().await?;
- if let Some(variant) = message.variant {
- if let Some(request_id) = message.request_id {
- let tx = response_channels.lock().remove(&request_id);
- if let Some(mut tx) = tx {
- tx.send(variant).await?;
+ let message = async {
+ Ok(Message::Message(
+ stream.read_message::<proto::FromServer>().await?,
+ ))
+ };
+
+ match message.race(&mut client_dropped).await? {
+ Message::Message(message) => {
+ if let Some(variant) = message.variant {
+ if let Some(request_id) = message.request_id {
+ let tx = response_channels.lock().remove(&request_id);
+ if let Some(mut tx) = tx {
+ tx.send(variant).await?;
+ } else {
+ log::warn!(
+ "received RPC response to unknown request id {}",
+ request_id
+ );
+ }
+ }
} else {
- log::warn!(
- "received RPC response to unknown request id {}",
- request_id
- );
+ log::warn!("received RPC message with no content");
}
}
- } else {
- log::warn!("received RPC message with no content");
+ Message::ClientDropped => break Ok(()),
}
}
})
.detach();
- result
+ client
}
pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
@@ -87,11 +108,103 @@ where
}
}
-impl<Conn, ShutdownFn> Drop for RpcClient<Conn, ShutdownFn>
-where
- ShutdownFn: FnMut(&mut Conn),
-{
- fn drop(&mut self) {
- (self.shutdown_fn)(self.stream.inner_mut())
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use smol::{
+ future::poll_once,
+ io::AsyncWriteExt,
+ net::unix::{UnixListener, UnixStream},
+ };
+ use std::{future::Future, io};
+ use tempdir::TempDir;
+
+ #[gpui::test]
+ async fn test_request_response(cx: gpui::TestAppContext) {
+ let executor = cx.read(|app| app.background_executor().clone());
+ let socket_dir_path = TempDir::new("request-response-socket").unwrap();
+ let socket_path = socket_dir_path.path().join(".sock");
+ let listener = UnixListener::bind(&socket_path).unwrap();
+ let client_conn = UnixStream::connect(&socket_path).await.unwrap();
+ let (server_conn, _) = listener.accept().await.unwrap();
+
+ let mut server_stream = MessageStream::new(server_conn);
+ let mut client = RpcClient::new(client_conn, executor.clone());
+
+ let client_req = client.request(proto::from_client::Auth {
+ user_id: 42,
+ access_token: "token".to_string(),
+ });
+ smol::pin!(client_req);
+ let server_req = send_recv(
+ &mut client_req,
+ server_stream.read_message::<proto::FromClient>(),
+ )
+ .await
+ .unwrap();
+ assert_eq!(
+ server_req.variant,
+ Some(proto::from_client::Variant::Auth(
+ proto::from_client::Auth {
+ user_id: 42,
+ access_token: "token".to_string()
+ }
+ ))
+ );
+
+ server_stream
+ .write_message(&proto::FromServer {
+ request_id: Some(server_req.id),
+ variant: Some(proto::from_server::Variant::AuthResponse(
+ proto::from_server::AuthResponse {
+ credentials_valid: true,
+ },
+ )),
+ })
+ .await
+ .unwrap();
+ assert_eq!(
+ client_req.await.unwrap(),
+ proto::from_server::AuthResponse {
+ credentials_valid: true
+ }
+ );
+ }
+
+ #[gpui::test]
+ async fn test_drop_client(cx: gpui::TestAppContext) {
+ let executor = cx.read(|app| app.background_executor().clone());
+ let socket_dir_path = TempDir::new("request-response-socket").unwrap();
+ let socket_path = socket_dir_path.path().join(".sock");
+ let listener = UnixListener::bind(&socket_path).unwrap();
+ let client_conn = UnixStream::connect(&socket_path).await.unwrap();
+ let (mut server_conn, _) = listener.accept().await.unwrap();
+
+ let client = RpcClient::new(client_conn, executor.clone());
+ drop(client);
+
+ // Try sending an empty payload over and over, until the client is dropped and hangs up.
+ let error = loop {
+ match server_conn.write(&[0]).await {
+ Ok(_) => continue,
+ Err(err) => break err,
+ }
+ };
+ assert_eq!(error.kind(), io::ErrorKind::BrokenPipe);
+ }
+
+ async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
+ where
+ S: Unpin + Future,
+ R: Future<Output = O>,
+ {
+ smol::pin!(receiver);
+ loop {
+ poll_once(&mut sender).await;
+ match poll_once(&mut receiver).await {
+ Some(message) => break message,
+ None => continue,
+ }
+ }
}
}