@@ -6,7 +6,7 @@ use postage::{
prelude::{Sink, Stream},
};
use smol::{
- io::{BoxedWriter, ReadHalf},
+ io::BoxedWriter,
lock::Mutex,
prelude::{AsyncRead, AsyncWrite},
};
@@ -20,89 +20,103 @@ use std::{
};
use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub struct ConnectionId(u32);
+
pub struct RpcClient {
response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
- outgoing: Mutex<MessageStream<BoxedWriter>>,
+ outgoing: Arc<Mutex<HashMap<ConnectionId, MessageStream<BoxedWriter>>>>,
next_message_id: AtomicU32,
+ next_connection_id: AtomicU32,
_drop_tx: barrier::Sender,
+ drop_rx: barrier::Receiver,
}
impl RpcClient {
- pub fn new<Conn>(conn: Conn, executor: Arc<Background>) -> Arc<Self>
- where
- Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
- {
- let response_channels = Arc::new(Mutex::new(HashMap::new()));
- let (conn_rx, conn_tx) = smol::io::split(conn);
+ pub fn new() -> Arc<Self> {
let (_drop_tx, drop_rx) = barrier::channel();
-
- executor
- .spawn(Self::handle_incoming(
- conn_rx,
- drop_rx,
- response_channels.clone(),
- ))
- .detach();
-
Arc::new(Self {
- response_channels,
- outgoing: Mutex::new(MessageStream::new(Box::pin(conn_tx))),
+ response_channels: Default::default(),
+ outgoing: Default::default(),
+ next_message_id: Default::default(),
+ next_connection_id: Default::default(),
_drop_tx,
- next_message_id: AtomicU32::new(0),
+ drop_rx,
})
}
- async fn handle_incoming<Conn>(
- conn: ReadHalf<Conn>,
- mut drop_rx: barrier::Receiver,
- response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
- ) where
- Conn: AsyncRead + Unpin,
+ pub async fn connect<Conn>(&self, conn: Conn, executor: Arc<Background>) -> ConnectionId
+ where
+ Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
- let dropped = drop_rx.recv();
- smol::pin!(dropped);
+ let connection_id = ConnectionId(
+ self.next_connection_id
+ .fetch_add(1, atomic::Ordering::SeqCst),
+ );
+ let (conn_rx, conn_tx) = smol::io::split(conn);
+ let response_channels = self.response_channels.clone();
+ let mut drop_rx = self.drop_rx.clone();
+ let outgoing = self.outgoing.clone();
- let mut stream = MessageStream::new(conn);
- loop {
- let read_message = stream.read_message();
- smol::pin!(read_message);
+ executor
+ .spawn(async move {
+ let dropped = drop_rx.recv();
+ smol::pin!(dropped);
+
+ let mut stream = MessageStream::new(conn_rx);
+ loop {
+ let read_message = stream.read_message();
+ smol::pin!(read_message);
- match futures::future::select(read_message, &mut dropped).await {
- Either::Left((Ok(incoming), _)) => {
- if let Some(responding_to) = incoming.responding_to {
- let channel = response_channels.lock().await.remove(&responding_to);
- if let Some(mut tx) = channel {
- tx.send(incoming).await.ok();
- } else {
- log::warn!(
- "received RPC response to unknown request {}",
- responding_to
- );
+ match futures::future::select(read_message, &mut dropped).await {
+ Either::Left((Ok(incoming), _)) => {
+ if let Some(responding_to) = incoming.responding_to {
+ let channel = response_channels.lock().await.remove(&responding_to);
+ if let Some(mut tx) = channel {
+ tx.send(incoming).await.ok();
+ } else {
+ log::warn!(
+ "received RPC response to unknown request {}",
+ responding_to
+ );
+ }
+ } else {
+ // unprompted message from server
+ }
}
- } else {
- // unprompted message from server
+ Either::Left((Err(error), _)) => {
+ log::warn!("received invalid RPC message {:?}", error);
+ }
+ Either::Right(_) => break,
}
}
- Either::Left((Err(error), _)) => {
- log::warn!("received invalid RPC message {:?}", error);
- }
- Either::Right(_) => break,
- }
- }
+ })
+ .detach();
+
+ outgoing
+ .lock()
+ .await
+ .insert(connection_id, MessageStream::new(Box::pin(conn_tx)));
+
+ connection_id
}
pub fn request<T: RequestMessage>(
- self: &Arc<Self>,
+ &self,
+ connection_id: ConnectionId,
req: T,
) -> impl Future<Output = Result<T::Response>> {
- let this = self.clone();
+ let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+ let outgoing = self.outgoing.clone();
+ let response_channels = self.response_channels.clone();
+ let (tx, mut rx) = oneshot::channel();
async move {
- let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- let (tx, mut rx) = oneshot::channel();
- this.response_channels.lock().await.insert(message_id, tx);
- this.outgoing
+ response_channels.lock().await.insert(message_id, tx);
+ outgoing
.lock()
.await
+ .get_mut(&connection_id)
+ .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
.write_message(&req.into_envelope(message_id, None))
.await?;
let response = rx
@@ -115,15 +129,18 @@ impl RpcClient {
}
pub fn send<T: EnvelopedMessage>(
- self: &Arc<Self>,
+ &self,
+ connection_id: ConnectionId,
message: T,
) -> impl Future<Output = Result<()>> {
- let this = self.clone();
+ let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+ let outgoing = self.outgoing.clone();
async move {
- let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- this.outgoing
+ outgoing
.lock()
.await
+ .get_mut(&connection_id)
+ .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
.write_message(&message.into_envelope(message_id, None))
.await?;
Ok(())
@@ -152,12 +169,16 @@ mod tests {
let (server_conn, _) = listener.accept().await.unwrap();
let mut server_stream = MessageStream::new(server_conn);
- let client = RpcClient::new(client_conn, executor.clone());
+ let client = RpcClient::new();
+ let connection_id = client.connect(client_conn, executor.clone()).await;
- let client_req = client.request(proto::Auth {
- user_id: 42,
- access_token: "token".to_string(),
- });
+ let client_req = client.request(
+ connection_id,
+ proto::Auth {
+ user_id: 42,
+ access_token: "token".to_string(),
+ },
+ );
smol::pin!(client_req);
let server_req = send_recv(&mut client_req, server_stream.read_message())
.await
@@ -206,7 +227,8 @@ mod tests {
let client_conn = UnixStream::connect(&socket_path).await.unwrap();
let (mut server_conn, _) = listener.accept().await.unwrap();
- let client = RpcClient::new(client_conn, executor.clone());
+ let client = RpcClient::new();
+ client.connect(client_conn, executor.clone()).await;
drop(client);
// Try sending an empty payload over and over, until the client is dropped and hangs up.
@@ -231,12 +253,16 @@ mod tests {
let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
client_conn.close().await.unwrap();
- let client = RpcClient::new(client_conn, executor.clone());
+ let client = RpcClient::new();
+ let conn_id = client.connect(client_conn, executor.clone()).await;
let err = client
- .request(proto::Auth {
- user_id: 42,
- access_token: "token".to_string(),
- })
+ .request(
+ conn_id,
+ proto::Auth {
+ user_id: 42,
+ access_token: "token".to_string(),
+ },
+ )
.await
.unwrap_err();
assert_eq!(
@@ -670,13 +670,17 @@ impl Workspace {
// a TLS stream using `native-tls`.
let stream = smol::net::TcpStream::connect(rpc_address).await?;
- let rpc_client = RpcClient::new(stream, executor);
+ let rpc_client = RpcClient::new();
+ let connection_id = rpc_client.connect(stream, executor).await;
let auth_response = rpc_client
- .request(proto::Auth {
- user_id: user_id.parse::<u64>()?,
- access_token,
- })
+ .request(
+ connection_id,
+ proto::Auth {
+ user_id: user_id.parse::<u64>()?,
+ access_token,
+ },
+ )
.await?;
if !auth_response.credentials_valid {
Err(anyhow!("failed to authenticate with RPC server"))?;
@@ -684,7 +688,9 @@ impl Workspace {
let share_task = this.update(&mut cx, |this, cx| {
let worktree = this.worktrees.iter().next()?;
- Some(worktree.update(cx, |worktree, cx| worktree.share(rpc_client, cx)))
+ Some(worktree.update(cx, |worktree, cx| {
+ worktree.share(rpc_client, connection_id, cx)
+ }))
});
if let Some(share_task) = share_task {
@@ -4,7 +4,7 @@ mod ignore;
use crate::{
editor::{History, Rope},
- rpc_client::RpcClient,
+ rpc_client::{ConnectionId, RpcClient},
sum_tree::{self, Cursor, Edit, SumTree},
util::Bias,
};
@@ -229,6 +229,7 @@ impl Worktree {
pub fn share(
&mut self,
client: Arc<RpcClient>,
+ connection_id: ConnectionId,
cx: &mut ModelContext<Self>,
) -> Task<anyhow::Result<()>> {
self.rpc_client = Some(client.clone());
@@ -245,9 +246,12 @@ impl Worktree {
.await;
let share_response = client
- .request(proto::ShareWorktree {
- worktree: Some(proto::Worktree { paths }),
- })
+ .request(
+ connection_id,
+ proto::ShareWorktree {
+ worktree: Some(proto::Worktree { paths }),
+ },
+ )
.await?;
log::info!("sharing worktree {:?}", share_response);