@@ -1350,6 +1350,7 @@ checksum = "da9052a1a50244d8d5aa9bf55cbc2fb6f357c86cc52e46c62ed390a7180cf150"
dependencies = [
"futures-channel",
"futures-core",
+ "futures-executor",
"futures-io",
"futures-sink",
"futures-task",
@@ -1372,6 +1373,17 @@ version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79e5145dde8da7d1b3892dad07a9c98fc04bc39892b1ecc9692cf53e2b780a65"
+[[package]]
+name = "futures-executor"
+version = "0.3.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e9e59fdc009a4b3096bf94f740a0f2424c082521f20a9b08c5c07c48d90fd9b9"
+dependencies = [
+ "futures-core",
+ "futures-task",
+ "futures-util",
+]
+
[[package]]
name = "futures-io"
version = "0.3.12"
@@ -1423,6 +1435,7 @@ version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "632a8cd0f2a4b3fdea1657f08bde063848c3bd00f9bbf6e256b8be78802e624b"
dependencies = [
+ "futures-channel",
"futures-core",
"futures-io",
"futures-macro",
@@ -4304,7 +4317,7 @@ dependencies = [
"easy-parallel",
"env_logger",
"fsevent",
- "futures-core",
+ "futures",
"gpui",
"http-auth-basic",
"ignore",
@@ -1,106 +1,112 @@
use anyhow::{anyhow, Result};
+use futures::FutureExt;
use gpui::executor::Background;
use parking_lot::Mutex;
use postage::{
- mpsc, oneshot,
+ mpsc,
prelude::{Sink, Stream},
};
-use smol::{
- future::FutureExt,
- io::WriteHalf,
- prelude::{AsyncRead, AsyncWrite},
+use smol::prelude::{AsyncRead, AsyncWrite};
+use std::{
+ collections::HashMap,
+ io,
+ sync::{
+ atomic::{self, AtomicI32},
+ Arc,
+ },
};
-use std::{collections::HashMap, sync::Arc};
use zed_rpc::proto::{
self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
};
-pub struct RpcClient<Conn> {
- stream: MessageStream<WriteHalf<Conn>>,
+pub struct RpcClient {
response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
- next_message_id: i32,
- _drop_tx: oneshot::Sender<()>,
+ outgoing_tx: mpsc::Sender<proto::FromClient>,
+ next_message_id: AtomicI32,
}
-impl<Conn> RpcClient<Conn>
-where
- Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
-{
- pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
- let (conn_rx, conn_tx) = smol::io::split(conn);
- let (drop_tx, mut drop_rx) = oneshot::channel();
+impl RpcClient {
+ pub fn new<Conn>(conn: Conn, executor: Arc<Background>) -> Self
+ where
+ Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
+ {
let response_channels = Arc::new(Mutex::new(HashMap::new()));
- let client = Self {
- next_message_id: 0,
- stream: MessageStream::new(conn_tx),
- response_channels: response_channels.clone(),
- _drop_tx: drop_tx,
- };
+ let (outgoing_tx, mut outgoing_rx) = mpsc::channel(32);
- executor
- .spawn::<Result<()>, _>(async move {
- enum Message {
- Message(proto::FromServer),
- ClientDropped,
- }
+ {
+ let response_channels = response_channels.clone();
+ executor
+ .spawn(async move {
+ let (conn_rx, conn_tx) = smol::io::split(conn);
+ let mut stream_tx = MessageStream::new(conn_tx);
+ let mut stream_rx = MessageStream::new(conn_rx);
+ loop {
+ futures::select! {
+ incoming = stream_rx.read_message::<proto::FromServer>().fuse() => {
+ Self::handle_incoming(incoming, &response_channels).await;
+ }
+ outgoing = outgoing_rx.recv().fuse() => {
+ if let Some(outgoing) = outgoing {
+ stream_tx.write_message(&outgoing).await;
+ } else {
+ break;
+ }
+ }
+ }
+ }
+ })
+ .detach();
+ }
- let mut stream = MessageStream::new(conn_rx);
- let client_dropped = async move {
- assert!(drop_rx.recv().await.is_none());
- Ok(Message::ClientDropped) as Result<_>
- };
- smol::pin!(client_dropped);
- loop {
- let message = async {
- Ok(Message::Message(
- stream.read_message::<proto::FromServer>().await?,
- ))
- };
+ Self {
+ response_channels,
+ outgoing_tx,
+ next_message_id: AtomicI32::new(0),
+ }
+ }
- match message.race(&mut client_dropped).await? {
- Message::Message(message) => {
- if let Some(variant) = message.variant {
- if let Some(request_id) = message.request_id {
- let channel = response_channels.lock().remove(&request_id);
- if let Some((mut tx, oneshot)) = channel {
- if tx.send(variant).await.is_ok() {
- if !oneshot {
- response_channels
- .lock()
- .insert(request_id, (tx, false));
- }
- }
- } else {
- log::warn!(
- "received RPC response to unknown request id {}",
- request_id
- );
- }
+ async fn handle_incoming(
+ incoming: io::Result<proto::FromServer>,
+ response_channels: &Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>,
+ ) {
+ match incoming {
+ Ok(incoming) => {
+ if let Some(variant) = incoming.variant {
+ if let Some(request_id) = incoming.request_id {
+ let channel = response_channels.lock().remove(&request_id);
+ if let Some((mut tx, oneshot)) = channel {
+ if tx.send(variant).await.is_ok() {
+ if !oneshot {
+ response_channels.lock().insert(request_id, (tx, false));
}
- } else {
- log::warn!("received RPC message with no content");
}
+ } else {
+ log::warn!(
+ "received RPC response to unknown request id {}",
+ request_id
+ );
}
- Message::ClientDropped => break Ok(()),
}
+ } else {
+ log::warn!("received RPC message with no content");
}
- })
- .detach();
-
- client
+ }
+ Err(error) => log::warn!("invalid incoming RPC message {:?}", error),
+ }
}
- pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
- let message_id = self.next_message_id;
- self.next_message_id += 1;
+ pub async fn request<T: RequestMessage>(&self, req: T) -> Result<T::Response> {
+ let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, mut rx) = mpsc::channel(1);
self.response_channels.lock().insert(message_id, (tx, true));
- self.stream
- .write_message(&proto::FromClient {
+ self.outgoing_tx
+ .clone()
+ .send(proto::FromClient {
id: message_id,
variant: Some(req.to_variant()),
})
- .await?;
+ .await
+ .unwrap();
let response = rx
.recv()
.await
@@ -109,15 +115,16 @@ where
.ok_or_else(|| anyhow!("received response of the wrong t"))
}
- pub async fn send<T: SendMessage>(&mut self, message: T) -> Result<()> {
- let message_id = self.next_message_id;
- self.next_message_id += 1;
- self.stream
- .write_message(&proto::FromClient {
+ pub async fn send<T: SendMessage>(&self, message: T) -> Result<()> {
+ let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+ self.outgoing_tx
+ .clone()
+ .send(proto::FromClient {
id: message_id,
variant: Some(message.to_variant()),
})
- .await?;
+ .await
+ .unwrap();
Ok(())
}
@@ -125,19 +132,19 @@ where
&mut self,
subscription: T,
) -> Result<impl Stream<Item = Result<T::Event>>> {
- let message_id = self.next_message_id;
- self.next_message_id += 1;
+ let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
let (tx, rx) = mpsc::channel(256);
self.response_channels
.lock()
.insert(message_id, (tx, false));
- self.stream
- .write_message(&proto::FromClient {
+ self.outgoing_tx
+ .clone()
+ .send(proto::FromClient {
id: message_id,
variant: Some(subscription.to_variant()),
})
- .await?;
-
+ .await
+ .unwrap();
Ok(rx.map(|event| {
T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
}))
@@ -165,7 +172,7 @@ mod tests {
let (server_conn, _) = listener.accept().await.unwrap();
let mut server_stream = MessageStream::new(server_conn);
- let mut client = RpcClient::new(client_conn, executor.clone());
+ let client = RpcClient::new(client_conn, executor.clone());
let client_req = client.request(proto::from_client::Auth {
user_id: 42,
@@ -8,7 +8,6 @@ use crate::{
worktree::{FileHandle, Worktree, WorktreeHandle},
AppState,
};
-use futures_core::Future;
use gpui::{
color::rgbu, elements::*, json::to_string_pretty, keymap::Binding, AnyViewHandle, AppContext,
ClipboardItem, Entity, ModelHandle, MutableAppContext, PathPromptOptions, PromptLevel, Task,
@@ -19,10 +18,10 @@ pub use pane::*;
pub use pane_group::*;
use postage::watch;
use smol::prelude::*;
-use std::{collections::HashMap, path::PathBuf};
use std::{
- collections::{hash_map::Entry, HashSet},
- path::Path,
+ collections::{hash_map::Entry, HashMap, HashSet},
+ future::Future,
+ path::{Path, PathBuf},
sync::Arc,
};