Restructure Peer to handle connections' messages in order

Max Brunsfeld created

Change summary

Cargo.lock               |   1 
gpui/src/app.rs          |  16 +
zed-rpc/Cargo.toml       |   1 
zed-rpc/src/peer.rs      | 554 ++++++++++++++++++++++++++---------------
zed/src/editor.rs        |   3 
zed/src/editor/buffer.rs |   1 
zed/src/file_finder.rs   |  14 
zed/src/lib.rs           |   4 
zed/src/main.rs          |  14 
zed/src/menus.rs         |   5 
zed/src/rpc.rs           |  55 ++-
zed/src/test.rs          |   8 
zed/src/workspace.rs     |  48 ++-
zed/src/worktree.rs      |  31 +
14 files changed, 480 insertions(+), 275 deletions(-)

Detailed changes

Cargo.lock πŸ”—

@@ -4507,6 +4507,7 @@ dependencies = [
  "base64 0.13.0",
  "futures",
  "log",
+ "parking_lot",
  "postage",
  "prost 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
  "prost-build",

gpui/src/app.rs πŸ”—

@@ -104,8 +104,11 @@ pub enum MenuItem<'a> {
 #[derive(Clone)]
 pub struct App(Rc<RefCell<MutableAppContext>>);
 
+#[derive(Clone)]
 pub struct AsyncAppContext(Rc<RefCell<MutableAppContext>>);
 
+pub struct BackgroundAppContext(*const RefCell<MutableAppContext>);
+
 #[derive(Clone)]
 pub struct TestAppContext {
     cx: Rc<RefCell<MutableAppContext>>,
@@ -409,6 +412,15 @@ impl TestAppContext {
 }
 
 impl AsyncAppContext {
+    pub fn spawn<F, Fut, T>(&self, f: F) -> Task<T>
+    where
+        F: FnOnce(AsyncAppContext) -> Fut,
+        Fut: 'static + Future<Output = T>,
+        T: 'static,
+    {
+        self.0.borrow().foreground.spawn(f(self.clone()))
+    }
+
     pub fn read<T, F: FnOnce(&AppContext) -> T>(&mut self, callback: F) -> T {
         callback(self.0.borrow().as_ref())
     }
@@ -433,6 +445,10 @@ impl AsyncAppContext {
         self.0.borrow().platform()
     }
 
+    pub fn foreground(&self) -> Rc<executor::Foreground> {
+        self.0.borrow().foreground.clone()
+    }
+
     pub fn background(&self) -> Arc<executor::Background> {
         self.0.borrow().cx.background.clone()
     }

zed-rpc/Cargo.toml πŸ”—

@@ -14,6 +14,7 @@ async-tungstenite = "0.14"
 base64 = "0.13"
 futures = "0.3"
 log = "0.4"
+parking_lot = "0.11.1"
 postage = {version = "0.4.1", features = ["futures-traits"]}
 prost = "0.7"
 rand = "0.8"

zed-rpc/src/peer.rs πŸ”—

@@ -3,7 +3,7 @@ use anyhow::{anyhow, Context, Result};
 use async_lock::{Mutex, RwLock};
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 use futures::{
-    future::BoxFuture,
+    future::{BoxFuture, LocalBoxFuture},
     stream::{SplitSink, SplitStream},
     FutureExt, StreamExt,
 };
@@ -30,9 +30,14 @@ pub struct ConnectionId(pub u32);
 pub struct PeerId(pub u32);
 
 type MessageHandler = Box<
-    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
+    dyn Send
+        + Sync
+        + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
 >;
 
+type ForegroundMessageHandler =
+    Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
+
 pub struct Receipt<T> {
     sender_id: ConnectionId,
     message_id: u32,
@@ -63,10 +68,15 @@ impl<T: RequestMessage> TypedEnvelope<T> {
     }
 }
 
+pub type Router = RouterInternal<MessageHandler>;
+pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
+pub struct RouterInternal<H> {
+    message_handlers: Vec<H>,
+    handler_types: HashSet<TypeId>,
+}
+
 pub struct Peer {
     connections: RwLock<HashMap<ConnectionId, Connection>>,
-    message_handlers: RwLock<Vec<MessageHandler>>,
-    handler_types: Mutex<HashSet<TypeId>>,
     next_connection_id: AtomicU32,
 }
 
@@ -74,73 +84,37 @@ pub struct Peer {
 struct Connection {
     outgoing_tx: mpsc::Sender<proto::Envelope>,
     next_message_id: Arc<AtomicU32>,
-    response_channels: ResponseChannels,
+    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 }
 
-pub struct ConnectionHandler<W, R> {
-    peer: Arc<Peer>,
+pub struct IOHandler<W, R> {
     connection_id: ConnectionId,
-    response_channels: ResponseChannels,
+    incoming_tx: mpsc::Sender<proto::Envelope>,
     outgoing_rx: mpsc::Receiver<proto::Envelope>,
     writer: MessageStream<W>,
     reader: MessageStream<R>,
 }
 
-type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
-
 impl Peer {
     pub fn new() -> Arc<Self> {
         Arc::new(Self {
             connections: Default::default(),
-            message_handlers: Default::default(),
-            handler_types: Default::default(),
             next_connection_id: Default::default(),
         })
     }
 
-    pub async fn add_message_handler<T: EnvelopedMessage>(
-        &self,
-    ) -> mpsc::Receiver<TypedEnvelope<T>> {
-        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
-            panic!("duplicate handler type");
-        }
-
-        let (tx, rx) = mpsc::channel(256);
-        self.message_handlers
-            .write()
-            .await
-            .push(Box::new(move |envelope, connection_id| {
-                if envelope.as_ref().map_or(false, T::matches_envelope) {
-                    let envelope = Option::take(envelope).unwrap();
-                    let mut tx = tx.clone();
-                    Some(
-                        async move {
-                            tx.send(TypedEnvelope {
-                                sender_id: connection_id,
-                                original_sender_id: envelope.original_sender_id.map(PeerId),
-                                message_id: envelope.id,
-                                payload: T::from_envelope(envelope).unwrap(),
-                            })
-                            .await
-                            .is_err()
-                        }
-                        .boxed(),
-                    )
-                } else {
-                    None
-                }
-            }));
-        rx
-    }
-
-    pub async fn add_connection<Conn>(
+    pub async fn add_connection<Conn, H, Fut>(
         self: &Arc<Self>,
         conn: Conn,
+        router: Arc<RouterInternal<H>>,
     ) -> (
         ConnectionId,
-        ConnectionHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
+        IOHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
+        impl Future<Output = anyhow::Result<()>>,
     )
     where
+        H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
+        Fut: Future<Output = ()>,
         Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
             + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
             + Unpin,
@@ -150,25 +124,45 @@ impl Peer {
             self.next_connection_id
                 .fetch_add(1, atomic::Ordering::SeqCst),
         );
+        let (incoming_tx, mut incoming_rx) = mpsc::channel(64);
         let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
         let connection = Connection {
             outgoing_tx,
             next_message_id: Default::default(),
             response_channels: Default::default(),
         };
-        let handler = ConnectionHandler {
-            peer: self.clone(),
+        let handle_io = IOHandler {
             connection_id,
-            response_channels: connection.response_channels.clone(),
             outgoing_rx,
+            incoming_tx,
             writer: MessageStream::new(tx),
             reader: MessageStream::new(rx),
         };
+
+        let response_channels = connection.response_channels.clone();
+        let handle_messages = async move {
+            while let Some(message) = incoming_rx.recv().await {
+                if let Some(responding_to) = message.responding_to {
+                    let channel = response_channels.lock().await.remove(&responding_to);
+                    if let Some(mut tx) = channel {
+                        tx.send(message).await.ok();
+                    } else {
+                        log::warn!("received RPC response to unknown request {}", responding_to);
+                    }
+                } else {
+                    router.handle(connection_id, message).await;
+                }
+            }
+            response_channels.lock().await.clear();
+            Ok(())
+        };
+
         self.connections
             .write()
             .await
             .insert(connection_id, connection);
-        (connection_id, handler)
+
+        (connection_id, handle_io, handle_messages)
     }
 
     pub async fn disconnect(&self, connection_id: ConnectionId) {
@@ -177,8 +171,6 @@ impl Peer {
 
     pub async fn reset(&self) {
         self.connections.write().await.clear();
-        self.handler_types.lock().await.clear();
-        self.message_handlers.write().await.clear();
     }
 
     pub fn request<T: RequestMessage>(
@@ -302,7 +294,115 @@ impl Peer {
     }
 }
 
-impl<W, R> ConnectionHandler<W, R>
+impl<H, Fut> RouterInternal<H>
+where
+    H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
+    Fut: Future<Output = ()>,
+{
+    pub fn new() -> Self {
+        Self {
+            message_handlers: Default::default(),
+            handler_types: Default::default(),
+        }
+    }
+
+    async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
+        let mut envelope = Some(message);
+        for handler in self.message_handlers.iter() {
+            if let Some(future) = handler(&mut envelope, connection_id) {
+                future.await;
+                return;
+            }
+        }
+        log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
+    }
+}
+
+impl Router {
+    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
+    where
+        T: EnvelopedMessage,
+        Fut: 'static + Send + Future<Output = Result<()>>,
+        F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
+    {
+        if !self.handler_types.insert(TypeId::of::<T>()) {
+            panic!("duplicate handler type");
+        }
+
+        self.message_handlers
+            .push(Box::new(move |envelope, connection_id| {
+                if envelope.as_ref().map_or(false, T::matches_envelope) {
+                    let envelope = Option::take(envelope).unwrap();
+                    let message_id = envelope.id;
+                    let future = handler(TypedEnvelope {
+                        sender_id: connection_id,
+                        original_sender_id: envelope.original_sender_id.map(PeerId),
+                        message_id,
+                        payload: T::from_envelope(envelope).unwrap(),
+                    });
+                    Some(
+                        async move {
+                            if let Err(error) = future.await {
+                                log::error!(
+                                    "error handling message {} {}: {:?}",
+                                    T::NAME,
+                                    message_id,
+                                    error
+                                );
+                            }
+                        }
+                        .boxed(),
+                    )
+                } else {
+                    None
+                }
+            }));
+    }
+}
+
+impl ForegroundRouter {
+    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
+    where
+        T: EnvelopedMessage,
+        Fut: 'static + Future<Output = Result<()>>,
+        F: 'static + Fn(TypedEnvelope<T>) -> Fut,
+    {
+        if !self.handler_types.insert(TypeId::of::<T>()) {
+            panic!("duplicate handler type");
+        }
+
+        self.message_handlers
+            .push(Box::new(move |envelope, connection_id| {
+                if envelope.as_ref().map_or(false, T::matches_envelope) {
+                    let envelope = Option::take(envelope).unwrap();
+                    let message_id = envelope.id;
+                    let future = handler(TypedEnvelope {
+                        sender_id: connection_id,
+                        original_sender_id: envelope.original_sender_id.map(PeerId),
+                        message_id,
+                        payload: T::from_envelope(envelope).unwrap(),
+                    });
+                    Some(
+                        async move {
+                            if let Err(error) = future.await {
+                                log::error!(
+                                    "error handling message {} {}: {:?}",
+                                    T::NAME,
+                                    message_id,
+                                    error
+                                );
+                            }
+                        }
+                        .boxed_local(),
+                    )
+                } else {
+                    None
+                }
+            }));
+    }
+}
+
+impl<W, R> IOHandler<W, R>
 where
     W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
     R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
@@ -315,18 +415,18 @@ where
                 futures::select_biased! {
                     incoming = read_message => match incoming {
                         Ok(incoming) => {
-                            Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
+                            if self.incoming_tx.send(incoming).await.is_err() {
+                                return Ok(());
+                            }
                             break;
                         }
                         Err(error) => {
-                            self.response_channels.lock().await.clear();
                             Err(error).context("received invalid RPC message")?;
                         }
                     },
                     outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
                         Some(outgoing) => {
                             if let Err(result) = self.writer.write_message(&outgoing).await {
-                                self.response_channels.lock().await.clear();
                                 Err(result).context("failed to write RPC message")?;
                             }
                         }
@@ -350,41 +450,6 @@ where
             payload,
         })
     }
-
-    async fn handle_incoming_message(
-        message: proto::Envelope,
-        peer: &Arc<Peer>,
-        connection_id: ConnectionId,
-        response_channels: &ResponseChannels,
-    ) {
-        if let Some(responding_to) = message.responding_to {
-            let channel = response_channels.lock().await.remove(&responding_to);
-            if let Some(mut tx) = channel {
-                tx.send(message).await.ok();
-            } else {
-                log::warn!("received RPC response to unknown request {}", responding_to);
-            }
-        } else {
-            let mut envelope = Some(message);
-            let mut handler_index = None;
-            let mut handler_was_dropped = false;
-            for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
-                if let Some(future) = handler(&mut envelope, connection_id) {
-                    handler_was_dropped = future.await;
-                    handler_index = Some(i);
-                    break;
-                }
-            }
-
-            if let Some(handler_index) = handler_index {
-                if handler_was_dropped {
-                    drop(peer.message_handlers.write().await.remove(handler_index));
-                }
-            } else {
-                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
-            }
-        }
-    }
 }
 
 impl<T> Clone for Receipt<T> {
@@ -415,7 +480,6 @@ impl fmt::Display for PeerId {
 mod tests {
     use super::*;
     use crate::test;
-    use postage::oneshot;
 
     #[test]
     fn test_request_response() {
@@ -425,127 +489,185 @@ mod tests {
             let client1 = Peer::new();
             let client2 = Peer::new();
 
+            let mut router = Router::new();
+            router.add_message_handler({
+                let server = server.clone();
+                move |envelope: TypedEnvelope<proto::Auth>| {
+                    let server = server.clone();
+                    async move {
+                        let receipt = envelope.receipt();
+                        let message = envelope.payload;
+                        server
+                            .respond(
+                                receipt,
+                                match message.user_id {
+                                    1 => {
+                                        assert_eq!(message.access_token, "access-token-1");
+                                        proto::AuthResponse {
+                                            credentials_valid: true,
+                                        }
+                                    }
+                                    2 => {
+                                        assert_eq!(message.access_token, "access-token-2");
+                                        proto::AuthResponse {
+                                            credentials_valid: false,
+                                        }
+                                    }
+                                    _ => {
+                                        panic!("unexpected user id {}", message.user_id);
+                                    }
+                                },
+                            )
+                            .await
+                    }
+                }
+            });
+
+            router.add_message_handler({
+                let server = server.clone();
+                move |envelope: TypedEnvelope<proto::OpenBuffer>| {
+                    let server = server.clone();
+                    async move {
+                        let receipt = envelope.receipt();
+                        let message = envelope.payload;
+                        server
+                            .respond(
+                                receipt,
+                                match message.path.as_str() {
+                                    "path/one" => {
+                                        assert_eq!(message.worktree_id, 1);
+                                        proto::OpenBufferResponse {
+                                            buffer: Some(proto::Buffer {
+                                                id: 101,
+                                                content: "path/one content".to_string(),
+                                                history: vec![],
+                                                selections: vec![],
+                                            }),
+                                        }
+                                    }
+                                    "path/two" => {
+                                        assert_eq!(message.worktree_id, 2);
+                                        proto::OpenBufferResponse {
+                                            buffer: Some(proto::Buffer {
+                                                id: 102,
+                                                content: "path/two content".to_string(),
+                                                history: vec![],
+                                                selections: vec![],
+                                            }),
+                                        }
+                                    }
+                                    _ => {
+                                        panic!("unexpected path {}", message.path);
+                                    }
+                                },
+                            )
+                            .await
+                    }
+                }
+            });
+            let router = Arc::new(router);
+
             let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
-            let (client1_conn_id, task1) = client1.add_connection(client1_to_server_conn).await;
-            let (_, task2) = server.add_connection(server_to_client_1_conn).await;
+            let (client1_conn_id, io_task1, msg_task1) = client1
+                .add_connection(client1_to_server_conn, router.clone())
+                .await;
+            let (_, io_task2, msg_task2) = server
+                .add_connection(server_to_client_1_conn, router.clone())
+                .await;
 
             let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
-            let (client2_conn_id, task3) = client2.add_connection(client2_to_server_conn).await;
-            let (_, task4) = server.add_connection(server_to_client_2_conn).await;
-
-            smol::spawn(task1.run()).detach();
-            smol::spawn(task2.run()).detach();
-            smol::spawn(task3.run()).detach();
-            smol::spawn(task4.run()).detach();
-
-            // define the expected requests and responses
-            let request1 = proto::Auth {
-                user_id: 1,
-                access_token: "token-1".to_string(),
-            };
-            let response1 = proto::AuthResponse {
-                credentials_valid: true,
-            };
-            let request2 = proto::Auth {
-                user_id: 2,
-                access_token: "token-2".to_string(),
-            };
-            let response2 = proto::AuthResponse {
-                credentials_valid: false,
-            };
-            let request3 = proto::OpenBuffer {
-                worktree_id: 1,
-                path: "path/two".to_string(),
-            };
-            let response3 = proto::OpenBufferResponse {
-                buffer: Some(proto::Buffer {
-                    id: 2,
-                    content: "path/two content".to_string(),
-                    history: vec![],
-                    selections: vec![],
-                }),
-            };
-            let request4 = proto::OpenBuffer {
-                worktree_id: 2,
-                path: "path/one".to_string(),
-            };
-            let response4 = proto::OpenBufferResponse {
-                buffer: Some(proto::Buffer {
-                    id: 1,
-                    content: "path/one content".to_string(),
-                    history: vec![],
-                    selections: vec![],
-                }),
-            };
-
-            // on the server, respond to two requests for each client
-            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
-            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
-            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
-            smol::spawn({
-                let request1 = request1.clone();
-                let request2 = request2.clone();
-                let request3 = request3.clone();
-                let request4 = request4.clone();
-                let response1 = response1.clone();
-                let response2 = response2.clone();
-                let response3 = response3.clone();
-                let response4 = response4.clone();
-                async move {
-                    let msg = auth_rx.recv().await.unwrap();
-                    assert_eq!(msg.payload, request1);
-                    server
-                        .respond(msg.receipt(), response1.clone())
-                        .await
-                        .unwrap();
-
-                    let msg = auth_rx.recv().await.unwrap();
-                    assert_eq!(msg.payload, request2.clone());
-                    server
-                        .respond(msg.receipt(), response2.clone())
-                        .await
-                        .unwrap();
-
-                    let msg = open_buffer_rx.recv().await.unwrap();
-                    assert_eq!(msg.payload, request3.clone());
-                    server
-                        .respond(msg.receipt(), response3.clone())
-                        .await
-                        .unwrap();
-
-                    let msg = open_buffer_rx.recv().await.unwrap();
-                    assert_eq!(msg.payload, request4.clone());
-                    server
-                        .respond(msg.receipt(), response4.clone())
-                        .await
-                        .unwrap();
-
-                    server_done_tx.send(()).await.unwrap();
-                }
-            })
-            .detach();
+            let (client2_conn_id, io_task3, msg_task3) = client2
+                .add_connection(client2_to_server_conn, router.clone())
+                .await;
+            let (_, io_task4, msg_task4) = server
+                .add_connection(server_to_client_2_conn, router.clone())
+                .await;
+
+            smol::spawn(io_task1.run()).detach();
+            smol::spawn(io_task2.run()).detach();
+            smol::spawn(io_task3.run()).detach();
+            smol::spawn(io_task4.run()).detach();
+            smol::spawn(msg_task1).detach();
+            smol::spawn(msg_task2).detach();
+            smol::spawn(msg_task3).detach();
+            smol::spawn(msg_task4).detach();
 
             assert_eq!(
-                client1.request(client1_conn_id, request1).await.unwrap(),
-                response1
+                client1
+                    .request(
+                        client1_conn_id,
+                        proto::Auth {
+                            user_id: 1,
+                            access_token: "access-token-1".to_string(),
+                        },
+                    )
+                    .await
+                    .unwrap(),
+                proto::AuthResponse {
+                    credentials_valid: true,
+                }
             );
+
             assert_eq!(
-                client2.request(client2_conn_id, request2).await.unwrap(),
-                response2
+                client2
+                    .request(
+                        client2_conn_id,
+                        proto::Auth {
+                            user_id: 2,
+                            access_token: "access-token-2".to_string(),
+                        },
+                    )
+                    .await
+                    .unwrap(),
+                proto::AuthResponse {
+                    credentials_valid: false,
+                }
             );
+
             assert_eq!(
-                client2.request(client2_conn_id, request3).await.unwrap(),
-                response3
+                client1
+                    .request(
+                        client1_conn_id,
+                        proto::OpenBuffer {
+                            worktree_id: 1,
+                            path: "path/one".to_string(),
+                        },
+                    )
+                    .await
+                    .unwrap(),
+                proto::OpenBufferResponse {
+                    buffer: Some(proto::Buffer {
+                        id: 101,
+                        content: "path/one content".to_string(),
+                        history: vec![],
+                        selections: vec![],
+                    }),
+                }
             );
+
             assert_eq!(
-                client1.request(client1_conn_id, request4).await.unwrap(),
-                response4
+                client2
+                    .request(
+                        client2_conn_id,
+                        proto::OpenBuffer {
+                            worktree_id: 2,
+                            path: "path/two".to_string(),
+                        },
+                    )
+                    .await
+                    .unwrap(),
+                proto::OpenBufferResponse {
+                    buffer: Some(proto::Buffer {
+                        id: 102,
+                        content: "path/two content".to_string(),
+                        history: vec![],
+                        selections: vec![],
+                    }),
+                }
             );
 
             client1.disconnect(client1_conn_id).await;
             client2.disconnect(client1_conn_id).await;
-
-            server_done_rx.recv().await.unwrap();
         });
     }
 
@@ -555,17 +677,28 @@ mod tests {
             let (client_conn, mut server_conn) = test::Channel::bidirectional();
 
             let client = Peer::new();
-            let (connection_id, handler) = client.add_connection(client_conn).await;
-            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
-                postage::barrier::channel();
+            let router = Arc::new(Router::new());
+            let (connection_id, io_handler, message_handler) =
+                client.add_connection(client_conn, router).await;
+
+            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
+            smol::spawn(async move {
+                io_handler.run().await.ok();
+                io_ended_tx.send(()).await.unwrap();
+            })
+            .detach();
+
+            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
             smol::spawn(async move {
-                handler.run().await.ok();
-                incoming_messages_ended_tx.send(()).await.unwrap();
+                message_handler.await.ok();
+                messages_ended_tx.send(()).await.unwrap();
             })
             .detach();
+
             client.disconnect(connection_id).await;
 
-            incoming_messages_ended_rx.recv().await;
+            io_ended_rx.recv().await;
+            messages_ended_rx.recv().await;
             assert!(
                 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
                     .await
@@ -581,8 +714,11 @@ mod tests {
             drop(server_conn);
 
             let client = Peer::new();
-            let (connection_id, handler) = client.add_connection(client_conn).await;
-            smol::spawn(handler.run()).detach();
+            let router = Arc::new(Router::new());
+            let (connection_id, io_handler, message_handler) =
+                client.add_connection(client_conn, router).await;
+            smol::spawn(io_handler.run()).detach();
+            smol::spawn(message_handler).detach();
 
             let err = client
                 .request(

zed/src/editor.rs πŸ”—

@@ -4015,7 +4015,8 @@ mod tests {
             let history = History::new(text.into());
             Buffer::from_history(0, history, None, lang.cloned(), cx)
         });
-        let (_, view) = cx.add_window(|cx| Editor::for_buffer(buffer, app_state.settings, cx));
+        let (_, view) =
+            cx.add_window(|cx| Editor::for_buffer(buffer, app_state.settings.clone(), cx));
         view.condition(&cx, |view, cx| !view.buffer.read(cx).is_parsing())
             .await;
 

zed/src/editor/buffer.rs πŸ”—

@@ -719,6 +719,7 @@ impl Buffer {
         mtime: SystemTime,
         cx: &mut ModelContext<Self>,
     ) {
+        eprintln!("{} did_save {:?}", self.replica_id, version);
         self.saved_mtime = mtime;
         self.saved_version = version;
         cx.emit(Event::Saved);

zed/src/file_finder.rs πŸ”—

@@ -479,8 +479,12 @@ mod tests {
 
         let app_state = cx.read(build_app_state);
         let (window_id, workspace) = cx.add_window(|cx| {
-            let mut workspace =
-                Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
+            let mut workspace = Workspace::new(
+                app_state.settings.clone(),
+                app_state.languages.clone(),
+                app_state.rpc.clone(),
+                cx,
+            );
             workspace.add_worktree(tmp_dir.path(), cx);
             workspace
         });
@@ -559,7 +563,7 @@ mod tests {
         cx.read(|cx| workspace.read(cx).worktree_scans_complete(cx))
             .await;
         let (_, finder) =
-            cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx));
+            cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx));
 
         let query = "hi".to_string();
         finder
@@ -622,7 +626,7 @@ mod tests {
         cx.read(|cx| workspace.read(cx).worktree_scans_complete(cx))
             .await;
         let (_, finder) =
-            cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx));
+            cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx));
 
         // Even though there is only one worktree, that worktree's filename
         // is included in the matching, because the worktree is a single file.
@@ -681,7 +685,7 @@ mod tests {
             .await;
 
         let (_, finder) =
-            cx.add_window(|cx| FileFinder::new(app_state.settings, workspace.clone(), cx));
+            cx.add_window(|cx| FileFinder::new(app_state.settings.clone(), workspace.clone(), cx));
 
         // Run a search that matches two files with the same relative path.
         finder

zed/src/lib.rs πŸ”—

@@ -1,3 +1,5 @@
+use zed_rpc::ForegroundRouter;
+
 pub mod assets;
 pub mod editor;
 pub mod file_finder;
@@ -14,10 +16,10 @@ mod util;
 pub mod workspace;
 pub mod worktree;
 
-#[derive(Clone)]
 pub struct AppState {
     pub settings: postage::watch::Receiver<settings::Settings>,
     pub languages: std::sync::Arc<language::LanguageRegistry>,
+    pub rpc_router: std::sync::Arc<ForegroundRouter>,
     pub rpc: rpc::Client,
 }
 

zed/src/main.rs πŸ”—

@@ -10,6 +10,7 @@ use zed::{
     workspace::{self, OpenParams},
     worktree, AppState,
 };
+use zed_rpc::ForegroundRouter;
 
 fn main() {
     init_logger();
@@ -20,20 +21,27 @@ fn main() {
     let languages = Arc::new(language::LanguageRegistry::new());
     languages.set_theme(&settings.borrow().theme);
 
-    let app_state = AppState {
+    let mut app_state = AppState {
         languages: languages.clone(),
         settings,
+        rpc_router: Arc::new(ForegroundRouter::new()),
         rpc: rpc::Client::new(languages),
     };
 
     app.run(move |cx| {
-        cx.set_menus(menus::menus(app_state.clone()));
+        worktree::init(
+            cx,
+            &app_state.rpc,
+            Arc::get_mut(&mut app_state.rpc_router).unwrap(),
+        );
         zed::init(cx);
         workspace::init(cx);
-        worktree::init(cx, app_state.rpc.clone());
         editor::init(cx);
         file_finder::init(cx);
 
+        let app_state = Arc::new(app_state);
+        cx.set_menus(menus::menus(&app_state.clone()));
+
         if stdout_is_a_pty() {
             cx.platform().activate(true);
         }

zed/src/menus.rs πŸ”—

@@ -1,8 +1,9 @@
 use crate::AppState;
 use gpui::{Menu, MenuItem};
+use std::sync::Arc;
 
 #[cfg(target_os = "macos")]
-pub fn menus(state: AppState) -> Vec<Menu<'static>> {
+pub fn menus(state: &Arc<AppState>) -> Vec<Menu<'static>> {
     vec![
         Menu {
             name: "Zed",
@@ -48,7 +49,7 @@ pub fn menus(state: AppState) -> Vec<Menu<'static>> {
                     name: "Open…",
                     keystroke: Some("cmd-o"),
                     action: "workspace:open",
-                    arg: Some(Box::new(state)),
+                    arg: Some(Box::new(state.clone())),
                 },
             ],
         },

zed/src/rpc.rs πŸ”—

@@ -1,10 +1,8 @@
 use crate::{language::LanguageRegistry, worktree::Worktree};
 use anyhow::{anyhow, Context, Result};
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use gpui::executor::Background;
 use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
 use lazy_static::lazy_static;
-use postage::prelude::Stream;
 use smol::lock::RwLock;
 use std::collections::HashMap;
 use std::time::Duration;
@@ -13,7 +11,7 @@ use surf::Url;
 pub use zed_rpc::{proto, ConnectionId, PeerId, TypedEnvelope};
 use zed_rpc::{
     proto::{EnvelopedMessage, RequestMessage},
-    Peer, Receipt,
+    ForegroundRouter, Peer, Receipt,
 };
 
 lazy_static! {
@@ -63,24 +61,30 @@ impl Client {
         }
     }
 
-    pub fn on_message<H, M>(&self, handler: H, cx: &mut gpui::MutableAppContext)
-    where
-        H: 'static + for<'a> MessageHandler<'a, M>,
+    pub fn on_message<H, M>(
+        &self,
+        router: &mut ForegroundRouter,
+        handler: H,
+        cx: &mut gpui::MutableAppContext,
+    ) where
+        H: 'static + Clone + for<'a> MessageHandler<'a, M>,
         M: proto::EnvelopedMessage,
     {
         let this = self.clone();
-        let mut messages = smol::block_on(this.peer.add_message_handler::<M>());
-        cx.spawn(|mut cx| async move {
-            while let Some(message) = messages.recv().await {
-                if let Err(err) = handler.handle(message, &this, &mut cx).await {
-                    log::error!("error handling message: {:?}", err);
-                }
-            }
-        })
-        .detach();
+        let cx = cx.to_async();
+        router.add_message_handler(move |message| {
+            let this = this.clone();
+            let mut cx = cx.clone();
+            let handler = handler.clone();
+            async move { handler.handle(message, &this, &mut cx).await }
+        });
     }
 
-    pub async fn log_in_and_connect(&self, cx: &AsyncAppContext) -> surf::Result<()> {
+    pub async fn log_in_and_connect(
+        &self,
+        router: Arc<ForegroundRouter>,
+        cx: AsyncAppContext,
+    ) -> surf::Result<()> {
         if self.state.read().await.connection_id.is_some() {
             return Ok(());
         }
@@ -96,14 +100,14 @@ impl Client {
             .await
             .context("websocket handshake")?;
             log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-            self.add_connection(stream, user_id, access_token, &cx.background())
+            self.add_connection(stream, user_id, access_token, router, cx)
                 .await?;
         } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
             let stream = smol::net::TcpStream::connect(host).await?;
             let (stream, _) =
                 async_tungstenite::client_async(format!("ws://{}/rpc", host), stream).await?;
             log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-            self.add_connection(stream, user_id, access_token, &cx.background())
+            self.add_connection(stream, user_id, access_token, router, cx)
                 .await?;
         } else {
             return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
@@ -117,7 +121,8 @@ impl Client {
         conn: Conn,
         user_id: i32,
         access_token: String,
-        executor: &Arc<Background>,
+        router: Arc<ForegroundRouter>,
+        cx: AsyncAppContext,
     ) -> surf::Result<()>
     where
         Conn: 'static
@@ -126,10 +131,12 @@ impl Client {
             + Unpin
             + Send,
     {
-        let (connection_id, handler) = self.peer.add_connection(conn).await;
-        executor
+        let (connection_id, handle_io, handle_messages) =
+            self.peer.add_connection(conn, router).await;
+        cx.foreground().spawn(handle_messages).detach();
+        cx.background()
             .spawn(async move {
-                if let Err(error) = handler.run().await {
+                if let Err(error) = handle_io.run().await {
                     log::error!("connection error: {:?}", error);
                 }
             })
@@ -263,7 +270,7 @@ impl Client {
     }
 }
 
-pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
+pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
     type Output: 'a + Future<Output = anyhow::Result<()>>;
 
     fn handle(
@@ -277,7 +284,7 @@ pub trait MessageHandler<'a, M: proto::EnvelopedMessage> {
 impl<'a, M, F, Fut> MessageHandler<'a, M> for F
 where
     M: proto::EnvelopedMessage,
-    F: Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
+    F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
     Fut: 'a + Future<Output = anyhow::Result<()>>,
 {
     type Output = Fut;

zed/src/test.rs πŸ”—

@@ -5,6 +5,7 @@ use std::{
     sync::Arc,
 };
 use tempdir::TempDir;
+use zed_rpc::ForegroundRouter;
 
 #[cfg(feature = "test-support")]
 pub use zed_rpc::test::Channel;
@@ -143,12 +144,13 @@ fn write_tree(path: &Path, tree: serde_json::Value) {
     }
 }
 
-pub fn build_app_state(cx: &AppContext) -> AppState {
+pub fn build_app_state(cx: &AppContext) -> Arc<AppState> {
     let settings = settings::channel(&cx.font_cache()).unwrap().1;
     let languages = Arc::new(LanguageRegistry::new());
-    AppState {
+    Arc::new(AppState {
         settings,
         languages: languages.clone(),
+        rpc_router: Arc::new(ForegroundRouter::new()),
         rpc: rpc::Client::new(languages),
-    }
+    })
 }

zed/src/workspace.rs πŸ”—

@@ -45,10 +45,10 @@ pub fn init(cx: &mut MutableAppContext) {
 
 pub struct OpenParams {
     pub paths: Vec<PathBuf>,
-    pub app_state: AppState,
+    pub app_state: Arc<AppState>,
 }
 
-fn open(app_state: &AppState, cx: &mut MutableAppContext) {
+fn open(app_state: &Arc<AppState>, cx: &mut MutableAppContext) {
     let app_state = app_state.clone();
     cx.prompt_for_paths(
         PathPromptOptions {
@@ -101,7 +101,7 @@ fn open_paths(params: &OpenParams, cx: &mut MutableAppContext) {
     });
 }
 
-fn open_new(app_state: &AppState, cx: &mut MutableAppContext) {
+fn open_new(app_state: &Arc<AppState>, cx: &mut MutableAppContext) {
     cx.add_window(|cx| {
         let mut view = Workspace::new(
             app_state.settings.clone(),
@@ -700,12 +700,13 @@ impl Workspace {
         };
     }
 
-    fn share_worktree(&mut self, _: &(), cx: &mut ViewContext<Self>) {
+    fn share_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
         let rpc = self.rpc.clone();
         let platform = cx.platform();
+        let router = app_state.rpc_router.clone();
 
         let task = cx.spawn(|this, mut cx| async move {
-            rpc.log_in_and_connect(&cx).await?;
+            rpc.log_in_and_connect(router, cx.clone()).await?;
 
             let share_task = this.update(&mut cx, |this, cx| {
                 let worktree = this.worktrees.iter().next()?;
@@ -732,12 +733,13 @@ impl Workspace {
         .detach();
     }
 
-    fn join_worktree(&mut self, _: &(), cx: &mut ViewContext<Self>) {
+    fn join_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
         let rpc = self.rpc.clone();
         let languages = self.languages.clone();
+        let router = app_state.rpc_router.clone();
 
         let task = cx.spawn(|this, mut cx| async move {
-            rpc.log_in_and_connect(&cx).await?;
+            rpc.log_in_and_connect(router, cx.clone()).await?;
 
             let worktree_url = cx
                 .platform()
@@ -974,8 +976,12 @@ mod tests {
         let app_state = cx.read(build_app_state);
 
         let (_, workspace) = cx.add_window(|cx| {
-            let mut workspace =
-                Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
+            let mut workspace = Workspace::new(
+                app_state.settings.clone(),
+                app_state.languages.clone(),
+                app_state.rpc.clone(),
+                cx,
+            );
             workspace.add_worktree(dir.path(), cx);
             workspace
         });
@@ -1077,8 +1083,12 @@ mod tests {
 
         let app_state = cx.read(build_app_state);
         let (_, workspace) = cx.add_window(|cx| {
-            let mut workspace =
-                Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
+            let mut workspace = Workspace::new(
+                app_state.settings.clone(),
+                app_state.languages.clone(),
+                app_state.rpc.clone(),
+                cx,
+            );
             workspace.add_worktree(dir1.path(), cx);
             workspace
         });
@@ -1146,8 +1156,12 @@ mod tests {
 
         let app_state = cx.read(build_app_state);
         let (window_id, workspace) = cx.add_window(|cx| {
-            let mut workspace =
-                Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
+            let mut workspace = Workspace::new(
+                app_state.settings.clone(),
+                app_state.languages.clone(),
+                app_state.rpc.clone(),
+                cx,
+            );
             workspace.add_worktree(dir.path(), cx);
             workspace
         });
@@ -1315,8 +1329,12 @@ mod tests {
 
         let app_state = cx.read(build_app_state);
         let (window_id, workspace) = cx.add_window(|cx| {
-            let mut workspace =
-                Workspace::new(app_state.settings, app_state.languages, app_state.rpc, cx);
+            let mut workspace = Workspace::new(
+                app_state.settings.clone(),
+                app_state.languages.clone(),
+                app_state.rpc.clone(),
+                cx,
+            );
             workspace.add_worktree(dir.path(), cx);
             workspace
         });

zed/src/worktree.rs πŸ”—

@@ -48,21 +48,21 @@ use std::{
     },
     time::{Duration, SystemTime},
 };
-use zed_rpc::{PeerId, TypedEnvelope};
+use zed_rpc::{ForegroundRouter, PeerId, TypedEnvelope};
 
 lazy_static! {
     static ref GITIGNORE: &'static OsStr = OsStr::new(".gitignore");
 }
 
-pub fn init(cx: &mut MutableAppContext, rpc: rpc::Client) {
-    rpc.on_message(remote::add_peer, cx);
-    rpc.on_message(remote::remove_peer, cx);
-    rpc.on_message(remote::update_worktree, cx);
-    rpc.on_message(remote::open_buffer, cx);
-    rpc.on_message(remote::close_buffer, cx);
-    rpc.on_message(remote::update_buffer, cx);
-    rpc.on_message(remote::buffer_saved, cx);
-    rpc.on_message(remote::save_buffer, cx);
+pub fn init(cx: &mut MutableAppContext, rpc: &rpc::Client, router: &mut ForegroundRouter) {
+    rpc.on_message(router, remote::add_peer, cx);
+    rpc.on_message(router, remote::remove_peer, cx);
+    rpc.on_message(router, remote::update_worktree, cx);
+    rpc.on_message(router, remote::open_buffer, cx);
+    rpc.on_message(router, remote::close_buffer, cx);
+    rpc.on_message(router, remote::update_buffer, cx);
+    rpc.on_message(router, remote::buffer_saved, cx);
+    rpc.on_message(router, remote::save_buffer, cx);
 }
 
 #[async_trait::async_trait]
@@ -2861,6 +2861,8 @@ mod remote {
         rpc: &rpc::Client,
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
+        eprintln!("got update buffer message {:?}", envelope.payload);
+
         let message = envelope.payload;
         rpc.state
             .read()
@@ -2875,6 +2877,8 @@ mod remote {
         rpc: &rpc::Client,
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
+        eprintln!("got save buffer message {:?}", envelope.payload);
+
         let state = rpc.state.read().await;
         let worktree = state.shared_worktree(envelope.payload.worktree_id, cx)?;
         let sender_id = envelope.original_sender_id()?;
@@ -2905,6 +2909,8 @@ mod remote {
         rpc: &rpc::Client,
         cx: &mut AsyncAppContext,
     ) -> anyhow::Result<()> {
+        eprintln!("got buffer_saved {:?}", envelope.payload);
+
         rpc.state
             .read()
             .await
@@ -2993,7 +2999,7 @@ mod tests {
         let dir = temp_tree(json!({
             "file1": "the old contents",
         }));
-        let tree = cx.add_model(|cx| Worktree::local(dir.path(), app_state.languages, cx));
+        let tree = cx.add_model(|cx| Worktree::local(dir.path(), app_state.languages.clone(), cx));
         let buffer = tree
             .update(&mut cx, |tree, cx| tree.open_buffer("file1", cx))
             .await
@@ -3016,7 +3022,8 @@ mod tests {
         }));
         let file_path = dir.path().join("file1");
 
-        let tree = cx.add_model(|cx| Worktree::local(file_path.clone(), app_state.languages, cx));
+        let tree =
+            cx.add_model(|cx| Worktree::local(file_path.clone(), app_state.languages.clone(), cx));
         cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete())
             .await;
         cx.read(|cx| assert_eq!(tree.read(cx).file_count(), 1));