WIP - Allow RpcClient to register handlers for incoming messages

Max Brunsfeld created

Change summary

gpui/src/app.rs       |  10 ++
zed-rpc/src/proto.rs  |   5 +
zed/src/lib.rs        |  30 ++++++++
zed/src/main.rs       |   9 +
zed/src/rpc_client.rs | 150 +++++++++++++++++++++++++++++++++++++++------
zed/src/workspace.rs  |  15 +++-
6 files changed, 190 insertions(+), 29 deletions(-)

Detailed changes

gpui/src/app.rs 🔗

@@ -1340,10 +1340,18 @@ impl MutableAppContext {
         Fut: 'static + Future<Output = T>,
         T: 'static,
     {
-        let cx = AsyncAppContext(self.weak_self.as_ref().unwrap().upgrade().unwrap());
+        let cx = self.to_async();
         self.foreground.spawn(f(cx))
     }
 
+    pub fn to_async(&self) -> AsyncAppContext {
+        AsyncAppContext(self.weak_self.as_ref().unwrap().upgrade().unwrap())
+    }
+
+    pub fn to_background(&self) -> BackgroundAppContext {
+        //
+    }
+
     pub fn write_to_clipboard(&self, item: ClipboardItem) {
         self.platform.write_to_clipboard(item);
     }

zed-rpc/src/proto.rs 🔗

@@ -7,6 +7,7 @@ include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 
 pub trait EnvelopedMessage: Sized + Send + 'static {
     fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
+    fn matches_envelope(envelope: &Envelope) -> bool;
     fn from_envelope(envelope: Envelope) -> Option<Self>;
 }
 
@@ -25,6 +26,10 @@ macro_rules! message {
                 }
             }
 
+            fn matches_envelope(envelope: &Envelope) -> bool {
+                matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
+            }
+
             fn from_envelope(envelope: Envelope) -> Option<Self> {
                 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
                     Some(msg)

zed/src/lib.rs 🔗

@@ -1,10 +1,16 @@
+use futures::Future;
+use gpui::MutableAppContext;
+use rpc_client::RpcClient;
+use std::sync::Arc;
+use zed_rpc::proto::RequestMessage;
+
 pub mod assets;
 pub mod editor;
 pub mod file_finder;
 pub mod language;
 pub mod menus;
 mod operation_queue;
-mod rpc_client;
+pub mod rpc_client;
 pub mod settings;
 mod sum_tree;
 #[cfg(test)]
@@ -18,6 +24,28 @@ mod worktree;
 pub struct AppState {
     pub settings: postage::watch::Receiver<settings::Settings>,
     pub language_registry: std::sync::Arc<language::LanguageRegistry>,
+    pub rpc_client: Arc<RpcClient>,
+}
+
+impl AppState {
+    pub async fn on_rpc_request<Req, F, Fut>(
+        &self,
+        cx: &mut MutableAppContext,
+        handler: F,
+    ) where
+        Req: RequestMessage,
+        F: 'static + Send + Sync + Fn(Req, &AppState, &mut MutableAppContext) -> Fut,
+        Fut: 'static + Send + Sync + Future<Output = Req::Response>,
+    {
+        let app_state = self.clone();
+        let cx = cx.to_background();
+        app_state
+            .rpc_client
+            .on_request(move |req| cx.update(|cx| async move {
+                handler(req, &app_state, cx)
+            })
+            .await
+    }
 }
 
 pub fn init(cx: &mut gpui::MutableAppContext) {

zed/src/main.rs 🔗

@@ -6,7 +6,9 @@ use log::LevelFilter;
 use simplelog::SimpleLogger;
 use std::{fs, path::PathBuf, sync::Arc};
 use zed::{
-    self, assets, editor, file_finder, language, menus, settings,
+    self, assets, editor, file_finder, language, menus,
+    rpc_client::RpcClient,
+    settings,
     workspace::{self, OpenParams},
     AppState,
 };
@@ -19,16 +21,17 @@ fn main() {
     let (_, settings) = settings::channel(&app.font_cache()).unwrap();
     let language_registry = Arc::new(language::LanguageRegistry::new());
     language_registry.set_theme(&settings.borrow().theme);
+
     let app_state = AppState {
         language_registry,
         settings,
+        rpc_client: Arc::new(RpcClient::new()),
     };
 
     app.run(move |cx| {
         cx.set_menus(menus::menus(app_state.clone()));
-
         zed::init(cx);
-        workspace::init(cx);
+        workspace::init(cx, &app_state);
         editor::init(cx);
         file_finder::init(cx);
 

zed/src/rpc_client.rs 🔗

@@ -1,5 +1,5 @@
 use anyhow::{anyhow, Result};
-use futures::future::Either;
+use futures::future::{BoxFuture, Either, FutureExt};
 use postage::{
     barrier, oneshot,
     prelude::{Sink, Stream},
@@ -10,7 +10,8 @@ use smol::{
     prelude::{AsyncRead, AsyncWrite},
 };
 use std::{
-    collections::HashMap,
+    any::TypeId,
+    collections::{HashMap, HashSet},
     future::Future,
     sync::{
         atomic::{self, AtomicU32},
@@ -29,20 +30,93 @@ struct RpcConnection {
     _close_barrier: barrier::Sender,
 }
 
+type RequestHandler = Box<
+    dyn Send
+        + Sync
+        + Fn(&mut Option<proto::Envelope>, &AtomicU32) -> Option<BoxFuture<'static, proto::Envelope>>,
+>;
+type MessageHandler =
+    Box<dyn Send + Sync + Fn(&mut Option<proto::Envelope>) -> Option<BoxFuture<'static, ()>>>;
+
 pub struct RpcClient {
-    connections: Arc<RwLock<HashMap<ConnectionId, Arc<RpcConnection>>>>,
+    connections: RwLock<HashMap<ConnectionId, Arc<RpcConnection>>>,
+    request_handlers: RwLock<Vec<RequestHandler>>,
+    message_handlers: RwLock<Vec<MessageHandler>>,
+    handler_types: RwLock<HashSet<TypeId>>,
     next_connection_id: AtomicU32,
 }
 
 impl RpcClient {
-    pub fn new() -> Arc<Self> {
-        Arc::new(Self {
+    pub fn new() -> Self {
+        Self {
+            request_handlers: Default::default(),
+            message_handlers: Default::default(),
+            handler_types: Default::default(),
             connections: Default::default(),
             next_connection_id: Default::default(),
-        })
+        }
     }
 
-    pub async fn add_connection<Conn>(&self, conn: Conn) -> (ConnectionId, impl Future<Output = ()>)
+    pub async fn on_request<Req, F, Fut>(&self, handler: F)
+    where
+        Req: RequestMessage,
+        F: 'static + Send + Sync + Fn(Req) -> Fut,
+        Fut: 'static + Send + Sync + Future<Output = Req::Response>,
+    {
+        if !self.handler_types.write().await.insert(TypeId::of::<Req>()) {
+            panic!("duplicate request handler type");
+        }
+
+        self.request_handlers
+            .write()
+            .await
+            .push(Box::new(move |envelope, next_message_id| {
+                if envelope.as_ref().map_or(false, Req::matches_envelope) {
+                    let envelope = Option::take(envelope).unwrap();
+                    let message_id = next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+                    let responding_to = envelope.id;
+                    let request = Req::from_envelope(envelope).unwrap();
+                    Some(
+                        handler(request)
+                            .map(move |response| {
+                                response.into_envelope(message_id, Some(responding_to))
+                            })
+                            .boxed(),
+                    )
+                } else {
+                    None
+                }
+            }));
+    }
+
+    pub async fn on_message<M, F, Fut>(&self, handler: F)
+    where
+        M: EnvelopedMessage,
+        F: 'static + Send + Sync + Fn(M) -> Fut,
+        Fut: 'static + Send + Sync + Future<Output = ()>,
+    {
+        if !self.handler_types.write().await.insert(TypeId::of::<M>()) {
+            panic!("duplicate request handler type");
+        }
+
+        self.message_handlers
+            .write()
+            .await
+            .push(Box::new(move |envelope| {
+                if envelope.as_ref().map_or(false, M::matches_envelope) {
+                    let envelope = Option::take(envelope).unwrap();
+                    let request = M::from_envelope(envelope).unwrap();
+                    Some(handler(request).boxed())
+                } else {
+                    None
+                }
+            }));
+    }
+
+    pub async fn add_connection<Conn>(
+        self: &Arc<Self>,
+        conn: Conn,
+    ) -> (ConnectionId, impl Future<Output = ()>)
     where
         Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
     {
@@ -52,7 +126,6 @@ impl RpcClient {
         );
         let (close_tx, mut close_rx) = barrier::channel();
         let (conn_rx, conn_tx) = smol::io::split(conn);
-        let connections = self.connections.clone();
         let connection = Arc::new(RpcConnection {
             writer: Mutex::new(MessageStream::new(Box::pin(conn_tx))),
             response_channels: Default::default(),
@@ -60,11 +133,12 @@ impl RpcClient {
             _close_barrier: close_tx,
         });
 
-        connections
+        self.connections
             .write()
             .await
             .insert(connection_id, connection.clone());
 
+        let this = self.clone();
         let handler_future = async move {
             let closed = close_rx.recv();
             smol::pin!(closed);
@@ -91,11 +165,45 @@ impl RpcClient {
                                 );
                             }
                         } else {
-                            // unprompted message from server
+                            let mut handled = false;
+                            let mut envelope = Some(incoming);
+                            for handler in this.request_handlers.iter() {
+                                if let Some(future) =
+                                    handler(&mut envelope, &connection.next_message_id)
+                                {
+                                    let response = future.await;
+                                    if let Err(error) = connection
+                                        .writer
+                                        .lock()
+                                        .await
+                                        .write_message(&response)
+                                        .await
+                                    {
+                                        log::warn!("failed to write response: {}", error);
+                                        return;
+                                    }
+                                    handled = true;
+                                    break;
+                                }
+                            }
+
+                            if !handled {
+                                for handler in this.message_handlers.iter() {
+                                    if let Some(future) = handler(&mut envelope) {
+                                        future.await;
+                                        handled = true;
+                                        break;
+                                    }
+                                }
+                            }
+
+                            if !handled {
+                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
+                            }
                         }
                     }
                     Either::Left((Err(error), _)) => {
-                        log::warn!("received invalid RPC message {:?}", error);
+                        log::warn!("received invalid RPC message: {}", error);
                     }
                     Either::Right(_) => break,
                 }
@@ -110,14 +218,15 @@ impl RpcClient {
     }
 
     pub fn request<T: RequestMessage>(
-        &self,
+        self: &Arc<Self>,
         connection_id: ConnectionId,
         req: T,
     ) -> impl Future<Output = Result<T::Response>> {
-        let connections = self.connections.clone();
+        let this = self.clone();
         let (tx, mut rx) = oneshot::channel();
         async move {
-            let connection = connections
+            let connection = this
+                .connections
                 .read()
                 .await
                 .get(&connection_id)
@@ -147,13 +256,14 @@ impl RpcClient {
     }
 
     pub fn send<T: EnvelopedMessage>(
-        &self,
+        self: &Arc<Self>,
         connection_id: ConnectionId,
         message: T,
     ) -> impl Future<Output = Result<()>> {
-        let connections = self.connections.clone();
+        let this = self.clone();
         async move {
-            let connection = connections
+            let connection = this
+                .connections
                 .read()
                 .await
                 .get(&connection_id)
@@ -194,7 +304,7 @@ mod tests {
         let (server_conn, _) = listener.accept().await.unwrap();
 
         let mut server_stream = MessageStream::new(server_conn);
-        let client = RpcClient::new();
+        let client = Arc::new(RpcClient::new());
         let (connection_id, handler) = client.add_connection(client_conn).await;
         executor.spawn(handler).detach();
 
@@ -253,7 +363,7 @@ mod tests {
         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
         let (mut server_conn, _) = listener.accept().await.unwrap();
 
-        let client = RpcClient::new();
+        let client = Arc::new(RpcClient::new());
         let (connection_id, handler) = client.add_connection(client_conn).await;
         executor.spawn(handler).detach();
         client.disconnect(connection_id).await;
@@ -280,7 +390,7 @@ mod tests {
         let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
         client_conn.close().await.unwrap();
 
-        let client = RpcClient::new();
+        let client = Arc::new(RpcClient::new());
         let (connection_id, handler) = client.add_connection(client_conn).await;
         executor.spawn(handler).detach();
         let err = client

zed/src/workspace.rs 🔗

@@ -14,8 +14,8 @@ use crate::{
 use anyhow::{anyhow, Context as _};
 use gpui::{
     color::rgbu, elements::*, json::to_string_pretty, keymap::Binding, AnyViewHandle, AppContext,
-    ClipboardItem, Entity, ModelHandle, MutableAppContext, PathPromptOptions, PromptLevel, Task,
-    View, ViewContext, ViewHandle, WeakModelHandle,
+    AsyncAppContext, ClipboardItem, Entity, ModelHandle, MutableAppContext, PathPromptOptions,
+    PromptLevel, Task, View, ViewContext, ViewHandle, WeakModelHandle,
 };
 use log::error;
 pub use pane::*;
@@ -33,7 +33,7 @@ use std::{
 use surf::Url;
 use zed_rpc::{proto, rest::CreateWorktreeResponse};
 
-pub fn init(cx: &mut MutableAppContext) {
+pub fn init(cx: &mut MutableAppContext, rpc_client: &mut RpcClient) {
     cx.add_global_action("workspace:open", open);
     cx.add_global_action("workspace:open_paths", open_paths);
     cx.add_action("workspace:save", Workspace::save_active_item);
@@ -45,6 +45,9 @@ pub fn init(cx: &mut MutableAppContext) {
         Binding::new("cmd-alt-i", "workspace:debug_elements", None),
     ]);
     pane::init(cx);
+
+    let cx = cx.to_async();
+    rpc_client.on_request(move |req| handle_open_buffer(req, cx));
 }
 
 pub struct OpenParams {
@@ -105,6 +108,10 @@ fn open_paths(params: &OpenParams, cx: &mut MutableAppContext) {
     });
 }
 
+fn handle_open_buffer(request: zed_rpc::proto::OpenBuffer, cx: AsyncAppContext) {
+    //
+}
+
 pub trait Item: Entity + Sized {
     type View: ItemView;
 
@@ -670,7 +677,7 @@ impl Workspace {
             // a TLS stream using `native-tls`.
             let stream = smol::net::TcpStream::connect(rpc_address).await?;
 
-            let rpc_client = RpcClient::new();
+            let rpc_client = Arc::new(RpcClient::new());
             let (connection_id, handler) = rpc_client.add_connection(stream).await;
             executor.spawn(handler).detach();