Start work on RpcClient

Max Brunsfeld created

Change summary

docs/collaboration.md |  6 ++
zed-rpc/src/proto.rs  | 75 +++++++++++++++++++++++++++++-----
zed/src/lib.rs        | 58 +++++++++++++++++++-------
zed/src/rpc_client.rs | 97 +++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 207 insertions(+), 29 deletions(-)

Detailed changes

docs/collaboration.md 🔗

@@ -59,6 +59,12 @@ Any resource you can subscribe to is considered a *channel*, and all of its proc
 
 The client will interact with the server via a `api::Client` object. Model objects with remote behavior will interact directly with this client to communicate with the server. For example, `Worktree` will be changed to an enum type with `Local` and `Remote` variants. The local variant will have an optional `client` in order to stream local changes to the server when sharing. The remote variant will always have a client and implement all worktree operations in terms of it.
 
+```rs
+let mut client = Client::new(conn, cx.background_executor());
+let stream = client.subscribe(from_client::Variant::Auth(from_client::));
+client.close();
+```
+
 ```rs
 enum Worktree {
     Local {

zed-rpc/src/proto.rs 🔗

@@ -5,24 +5,72 @@ use std::io;
 
 include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 
-use from_client as client;
-use from_server as server;
+/// A message that the client can send to the server.
+pub trait ClientMessage: Sized {
+    fn to_variant(self) -> from_client::Variant;
+    fn from_variant(variant: from_client::Variant) -> Option<Self>;
+}
+
+/// A message that the server can send to the client.
+pub trait ServerMessage: Sized {
+    fn to_variant(self) -> from_server::Variant;
+    fn from_variant(variant: from_server::Variant) -> Option<Self>;
+}
+
+/// A message that the client can send to the server, where the server must respond with a single
+/// message of a certain type.
+pub trait RequestMessage: ClientMessage {
+    type Response: ServerMessage;
+}
+
+/// A message that the client can send to the server, where the server must respond with a series of
+/// messages of a certain type.
+pub trait SubscribeMessage: ClientMessage {
+    type Event: ServerMessage;
+}
+
+/// A message that the client can send to the server, where the server will not respond.
+pub trait SendMessage: ClientMessage {}
+
+macro_rules! directed_message {
+    ($name:ident, $direction_trait:ident, $direction_module:ident) => {
+        impl $direction_trait for $direction_module::$name {
+            fn to_variant(self) -> $direction_module::Variant {
+                $direction_module::Variant::$name(self)
+            }
 
-pub trait Request {
-    type Response;
+            fn from_variant(variant: $direction_module::Variant) -> Option<Self> {
+                if let $direction_module::Variant::$name(msg) = variant {
+                    Some(msg)
+                } else {
+                    None
+                }
+            }
+        }
+    };
 }
 
-macro_rules! request_response {
-    ($req:path, $resp:path) => {
-        impl Request for $req {
-            type Response = $resp;
+macro_rules! request_message {
+    ($req:ident, $resp:ident) => {
+        directed_message!($req, ClientMessage, from_client);
+        directed_message!($resp, ServerMessage, from_server);
+        impl RequestMessage for from_client::$req {
+            type Response = from_server::$resp;
         }
     };
 }
 
-request_response!(client::Auth, server::AuthResponse);
-request_response!(client::NewWorktree, server::NewWorktreeResponse);
-request_response!(client::ShareWorktree, server::ShareWorktreeResponse);
+macro_rules! send_message {
+    ($msg:ident) => {
+        directed_message!($msg, ClientMessage, from_client);
+        impl SendMessage for from_client::$msg {}
+    };
+}
+
+request_message!(Auth, AuthResponse);
+request_message!(NewWorktree, NewWorktreeResponse);
+request_message!(ShareWorktree, ShareWorktreeResponse);
+send_message!(UploadFile);
 
 /// A stream of protobuf messages.
 pub struct MessageStream<T> {
@@ -37,6 +85,10 @@ impl<T> MessageStream<T> {
             buffer: Default::default(),
         }
     }
+
+    pub fn inner_mut(&mut self) -> &mut T {
+        &mut self.byte_stream
+    }
 }
 
 impl<T> MessageStream<T>
@@ -59,7 +111,6 @@ where
     pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
         // Ensure the buffer is large enough to hold the maximum delimiter length
         const MAX_DELIMITER_LEN: usize = 10;
-        self.buffer.clear();
         self.buffer.resize(MAX_DELIMITER_LEN, 0);
 
         // Read until a complete length delimiter can be decoded.

zed/src/lib.rs 🔗

@@ -1,9 +1,11 @@
 use anyhow::{anyhow, Context, Result};
 use gpui::{AsyncAppContext, MutableAppContext, Task};
-use std::{convert::TryFrom, time::Duration};
+use rpc_client::RpcClient;
+use std::{convert::TryFrom, net::Shutdown, time::Duration};
 use tiny_http::{Header, Response, Server};
 use url::Url;
 use util::SurfResultExt;
+use zed_rpc::{proto, rest::CreateWorktreeResponse};
 
 pub mod assets;
 pub mod editor;
@@ -11,6 +13,7 @@ pub mod file_finder;
 pub mod language;
 pub mod menus;
 mod operation_queue;
+mod rpc_client;
 pub mod settings;
 mod sum_tree;
 #[cfg(test)]
@@ -33,7 +36,9 @@ pub fn init(cx: &mut MutableAppContext) {
 
 fn share_worktree(_: &(), cx: &mut MutableAppContext) {
     let zed_url = std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev".to_string());
-    cx.spawn::<_, _, surf::Result<()>>(|cx| async move {
+    let executor = cx.background_executor().clone();
+
+    let task = cx.spawn::<_, _, surf::Result<()>>(|cx| async move {
         let (user_id, access_token) = login(zed_url.clone(), &cx).await?;
 
         let mut response = surf::post(format!("{}/api/worktrees", &zed_url))
@@ -44,28 +49,47 @@ fn share_worktree(_: &(), cx: &mut MutableAppContext) {
             .await
             .context("")?;
 
-        let body = response
-            .body_json::<zed_rpc::rest::CreateWorktreeResponse>()
-            .await?;
+        let CreateWorktreeResponse {
+            worktree_id,
+            rpc_address,
+        } = response.body_json().await?;
+
+        eprintln!("got worktree response: {:?} {:?}", worktree_id, rpc_address);
 
         // TODO - If the `ZED_SERVER_URL` uses https, then wrap this stream in
         // a TLS stream using `native-tls`.
-        let stream = smol::net::TcpStream::connect(body.rpc_address).await?;
-
-        let mut message_stream = zed_rpc::proto::MessageStream::new(stream);
-        message_stream
-            .write_message(&zed_rpc::proto::FromClient {
-                id: 0,
-                variant: Some(zed_rpc::proto::from_client::Variant::Auth(
-                    zed_rpc::proto::from_client::Auth {
-                        user_id: user_id.parse::<i32>()?,
-                        access_token,
-                    },
-                )),
+        let stream = smol::net::TcpStream::connect(rpc_address).await?;
+
+        let mut rpc_client = RpcClient::new(stream, executor, |stream| {
+            stream.shutdown(Shutdown::Read).ok();
+        });
+
+        let auth_response = rpc_client
+            .request(proto::from_client::Auth {
+                user_id: user_id.parse::<i32>()?,
+                access_token,
+            })
+            .await?;
+        if !auth_response.credentials_valid {
+            Err(anyhow!("failed to authenticate with RPC server"))?;
+        }
+
+        let share_response = rpc_client
+            .request(proto::from_client::ShareWorktree {
+                worktree_id: worktree_id as u64,
+                files: Vec::new(),
             })
             .await?;
 
+        log::info!("sharing worktree {:?}", share_response);
+
         Ok(())
+    });
+
+    cx.spawn(|_| async move {
+        if let Err(e) = task.await {
+            log::error!("sharing failed: {}", e);
+        }
     })
     .detach();
 }

zed/src/rpc_client.rs 🔗

@@ -0,0 +1,97 @@
+use anyhow::{anyhow, Result};
+use gpui::executor::Background;
+use parking_lot::Mutex;
+use postage::{
+    oneshot,
+    prelude::{Sink, Stream},
+};
+use smol::prelude::{AsyncRead, AsyncWrite};
+use std::{collections::HashMap, sync::Arc};
+use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage};
+
+pub struct RpcClient<Conn, ShutdownFn>
+where
+    ShutdownFn: FnMut(&mut Conn),
+{
+    stream: MessageStream<Conn>,
+    response_channels: Arc<Mutex<HashMap<i32, oneshot::Sender<proto::from_server::Variant>>>>,
+    next_message_id: i32,
+    shutdown_fn: ShutdownFn,
+}
+
+impl<Conn, ShutdownFn> RpcClient<Conn, ShutdownFn>
+where
+    Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
+    ShutdownFn: FnMut(&mut Conn),
+{
+    pub fn new(conn: Conn, executor: Arc<Background>, shutdown_fn: ShutdownFn) -> Self {
+        let response_channels = Arc::new(Mutex::new(HashMap::new()));
+
+        let result = Self {
+            next_message_id: 0,
+            stream: MessageStream::new(conn.clone()),
+            response_channels: response_channels.clone(),
+            shutdown_fn,
+        };
+
+        executor
+            .spawn::<Result<()>, _>(async move {
+                let mut stream = MessageStream::new(conn);
+                loop {
+                    let message = stream.read_message::<proto::FromServer>().await?;
+                    if let Some(variant) = message.variant {
+                        if let Some(request_id) = message.request_id {
+                            let tx = response_channels.lock().remove(&request_id);
+                            if let Some(mut tx) = tx {
+                                tx.send(variant).await?;
+                            } else {
+                                log::warn!(
+                                    "received RPC response to unknown request id {}",
+                                    request_id
+                                );
+                            }
+                        }
+                    } else {
+                        log::warn!("received RPC message with no content");
+                    }
+                }
+            })
+            .detach();
+
+        result
+    }
+
+    pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
+        let message_id = self.next_message_id;
+        self.next_message_id += 1;
+
+        let (tx, mut rx) = oneshot::channel();
+        self.response_channels.lock().insert(message_id, tx);
+
+        self.stream
+            .write_message(&proto::FromClient {
+                id: message_id,
+                variant: Some(req.to_variant()),
+            })
+            .await?;
+        let response = rx
+            .recv()
+            .await
+            .expect("response channel was unexpectedly dropped");
+        T::Response::from_variant(response)
+            .ok_or_else(|| anyhow!("received response of the wrong t"))
+    }
+
+    pub async fn send<T: SendMessage>(_: T) -> Result<()> {
+        todo!()
+    }
+}
+
+impl<Conn, ShutdownFn> Drop for RpcClient<Conn, ShutdownFn>
+where
+    ShutdownFn: FnMut(&mut Conn),
+{
+    fn drop(&mut self) {
+        (self.shutdown_fn)(self.stream.inner_mut())
+    }
+}