Flatten protobuf message namespace

Max Brunsfeld and Nathan Sobo created

* Remove `FromClient`/`FromServer` distinction.
* Remove `subscribe` concept - clients will need to handle
  unprompted messages from the server.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed-rpc/proto/zed.proto | 105 +++++++++++--------
zed-rpc/src/proto.rs    | 116 +++++++--------------
zed/src/rpc_client.rs   | 231 +++++++++---------------------------------
zed/src/workspace.rs    |   4 
zed/src/worktree.rs     |  14 -
5 files changed, 158 insertions(+), 312 deletions(-)

Detailed changes

zed-rpc/proto/zed.proto 🔗

@@ -1,65 +1,82 @@
 syntax = "proto3";
 package zed.messages;
 
-message FromClient {
-    int32 id = 1;
-
-    oneof variant {
-        Auth auth = 2;
-        NewWorktree new_worktree = 3;
-        ShareWorktree share_worktree = 4;
-        UploadFile upload_file = 5;
-        SubscribeToPathRequests subscribe_to_path_requests = 6;
+message Envelope {
+    uint32 id = 1;
+    optional uint32 responding_to = 2;
+    oneof payload {
+        Auth auth = 3;
+        AuthResponse auth_response = 4;
+        ShareWorktree share_worktree = 5;
+        ShareWorktreeResponse share_worktree_response = 6;
+        OpenWorktree open_worktree = 7;
+        OpenWorktreeResponse open_worktree_response = 8;
+        OpenBuffer open_buffer = 9;
+        OpenBufferResponse open_buffer_response = 10;
     }
+}
 
-    message Auth {
-        int32 user_id = 1;
-        string access_token = 2;
-    }
+message Auth {
+    uint64 user_id = 1;
+    string access_token = 2;
+}
 
-    message NewWorktree {}
+message AuthResponse {
+    bool credentials_valid = 1;
+}
 
-    message ShareWorktree {
-        uint64 worktree_id = 1;
-        repeated PathAndDigest files = 2;
-    }
+message ShareWorktree {
+    Worktree worktree = 1;
+}
 
-    message PathAndDigest {
-        bytes path = 1;
-        bytes digest = 2;
-    }
+message ShareWorktreeResponse {
+    uint64 worktree_id = 1;
+    string access_token = 2;
+}
 
-    message UploadFile {
-        bytes path = 1;
-        bytes content = 2;
-    }
+message OpenWorktree {
+    uint64 worktree_id = 1;
+    string access_token = 2;
+}
 
-    message SubscribeToPathRequests {}
+message OpenWorktreeResponse {
+    Worktree worktree = 1;
 }
 
-message FromServer {
-    optional int32 request_id = 1;
+message OpenBuffer {
+    uint64 worktree_id = 1;
+    bytes path = 2;
+}
 
-    oneof variant {
-        AuthResponse auth_response = 2;
-        NewWorktreeResponse new_worktree_response = 3;
-        ShareWorktreeResponse share_worktree_response = 4;
-        PathRequest path_request = 5;
-    }
+message OpenBufferResponse {
+    Buffer buffer = 1;
+}
 
-    message AuthResponse {
-        bool credentials_valid = 1;
-    }
+message Worktree {
+    repeated bytes paths = 1;
+}
 
-    message NewWorktreeResponse {
-        uint64 worktree_id = 1;
+message Buffer {
+    uint64 id = 1;
+    bytes path = 2;
+    bytes content = 3;
+    repeated Operation history = 4;
+}
+
+message Operation {
+    oneof variant {
+        Edit edit = 1;
     }
 
-    message ShareWorktreeResponse {
-        repeated int32 needed_file_indices = 1;
+    message Edit {
+        uint32 replica_id = 1;
+        uint32 local_timestamp = 2;
+        uint32 lamport_timestamp = 3;
+        repeated VectorClockEntry version = 4;
     }
 
-    message PathRequest {
-        bytes path = 1;
+    message VectorClockEntry {
+        uint32 replica_id = 1;
+        uint32 timestamp = 2;
     }
 }

zed-rpc/src/proto.rs 🔗

@@ -5,42 +5,28 @@ use std::{convert::TryInto, io};
 
 include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
 
-/// A message that the client can send to the server.
-pub trait ClientMessage: Sized {
-    fn to_variant(self) -> from_client::Variant;
-    fn from_variant(variant: from_client::Variant) -> Option<Self>;
+pub trait EnvelopedMessage: Sized {
+    fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
+    fn from_envelope(envelope: Envelope) -> Option<Self>;
 }
 
-/// A message that the server can send to the client.
-pub trait ServerMessage: Sized {
-    fn to_variant(self) -> from_server::Variant;
-    fn from_variant(variant: from_server::Variant) -> Option<Self>;
+pub trait RequestMessage: EnvelopedMessage {
+    type Response: EnvelopedMessage;
 }
 
-/// A message that the client can send to the server, where the server must respond with a single
-/// message of a certain type.
-pub trait RequestMessage: ClientMessage {
-    type Response: ServerMessage;
-}
-
-/// A message that the client can send to the server, where the server must respond with a series of
-/// messages of a certain type.
-pub trait SubscribeMessage: ClientMessage {
-    type Event: ServerMessage;
-}
-
-/// A message that the client can send to the server, where the server will not respond.
-pub trait SendMessage: ClientMessage {}
-
-macro_rules! directed_message {
-    ($name:ident, $direction_trait:ident, $direction_module:ident) => {
-        impl $direction_trait for $direction_module::$name {
-            fn to_variant(self) -> $direction_module::Variant {
-                $direction_module::Variant::$name(self)
+macro_rules! message {
+    ($name:ident) => {
+        impl EnvelopedMessage for $name {
+            fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
+                Envelope {
+                    id,
+                    responding_to,
+                    payload: Some(envelope::Payload::$name(self)),
+                }
             }
 
-            fn from_variant(variant: $direction_module::Variant) -> Option<Self> {
-                if let $direction_module::Variant::$name(msg) = variant {
+            fn from_envelope(envelope: Envelope) -> Option<Self> {
+                if let Some(envelope::Payload::$name(msg)) = envelope.payload {
                     Some(msg)
                 } else {
                     None
@@ -52,36 +38,18 @@ macro_rules! directed_message {
 
 macro_rules! request_message {
     ($req:ident, $resp:ident) => {
-        directed_message!($req, ClientMessage, from_client);
-        directed_message!($resp, ServerMessage, from_server);
-        impl RequestMessage for from_client::$req {
-            type Response = from_server::$resp;
-        }
-    };
-}
-
-macro_rules! send_message {
-    ($msg:ident) => {
-        directed_message!($msg, ClientMessage, from_client);
-        impl SendMessage for from_client::$msg {}
-    };
-}
-
-macro_rules! subscribe_message {
-    ($subscription:ident, $event:ident) => {
-        directed_message!($subscription, ClientMessage, from_client);
-        directed_message!($event, ServerMessage, from_server);
-        impl SubscribeMessage for from_client::$subscription {
-            type Event = from_server::$event;
+        message!($req);
+        message!($resp);
+        impl RequestMessage for $req {
+            type Response = $resp;
         }
     };
 }
 
 request_message!(Auth, AuthResponse);
-request_message!(NewWorktree, NewWorktreeResponse);
 request_message!(ShareWorktree, ShareWorktreeResponse);
-send_message!(UploadFile);
-subscribe_message!(SubscribeToPathRequests, PathRequest);
+request_message!(OpenWorktree, OpenWorktreeResponse);
+request_message!(OpenBuffer, OpenBufferResponse);
 
 /// A stream of protobuf messages.
 pub struct MessageStream<T> {
@@ -107,7 +75,7 @@ where
     T: AsyncWrite + Unpin,
 {
     /// Write a given protobuf message to the stream.
-    pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
+    pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
         let message_len: u32 = message
             .encoded_len()
             .try_into()
@@ -124,13 +92,13 @@ where
     T: AsyncRead + Unpin,
 {
     /// Read a protobuf message of the given type from the stream.
-    pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
+    pub async fn read_message(&mut self) -> futures_io::Result<Envelope> {
         let mut delimiter_buf = [0; 4];
         self.byte_stream.read_exact(&mut delimiter_buf).await?;
         let message_len = u32::from_be_bytes(delimiter_buf) as usize;
         self.buffer.resize(message_len, 0);
         self.byte_stream.read_exact(&mut self.buffer).await?;
-        Ok(M::decode(self.buffer.as_slice())?)
+        Ok(Envelope::decode(self.buffer.as_slice())?)
     }
 }
 
@@ -151,30 +119,24 @@ mod tests {
                 chunk_size: 3,
             };
 
-            let message1 = FromClient {
-                id: 3,
-                variant: Some(from_client::Variant::Auth(from_client::Auth {
-                    user_id: 5,
-                    access_token: "the-access-token".into(),
-                })),
-            };
-            let message2 = FromClient {
-                id: 4,
-                variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
-                    path: Vec::new(),
-                    content: format!(
-                        "a {}long error message that requires a two-byte length delimiter",
-                        "very ".repeat(60)
-                    )
-                    .into(),
-                })),
-            };
+            let message1 = Auth {
+                user_id: 5,
+                access_token: "the-access-token".into(),
+            }
+            .into_envelope(3, None);
+
+            let message2 = ShareWorktree {
+                worktree: Some(Worktree {
+                    paths: vec![b"ok".to_vec()],
+                }),
+            }
+            .into_envelope(5, None);
 
             let mut message_stream = MessageStream::new(byte_stream);
             message_stream.write_message(&message1).await.unwrap();
             message_stream.write_message(&message2).await.unwrap();
-            let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
-            let decoded_message2 = message_stream.read_message::<FromClient>().await.unwrap();
+            let decoded_message1 = message_stream.read_message().await.unwrap();
+            let decoded_message2 = message_stream.read_message().await.unwrap();
             assert_eq!(decoded_message1, message1);
             assert_eq!(decoded_message2, message2);
         });

zed/src/rpc_client.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use futures::future::Either;
 use gpui::executor::Background;
 use postage::{
-    barrier, mpsc,
+    barrier, oneshot,
     prelude::{Sink, Stream},
 };
 use smol::{
@@ -14,18 +14,16 @@ use std::{
     collections::HashMap,
     future::Future,
     sync::{
-        atomic::{self, AtomicI32},
+        atomic::{self, AtomicU32},
         Arc,
     },
 };
-use zed_rpc::proto::{
-    self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
-};
+use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
 
 pub struct RpcClient {
-    response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
+    response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
     outgoing: Mutex<MessageStream<BoxedWriter>>,
-    next_message_id: AtomicI32,
+    next_message_id: AtomicU32,
     _drop_tx: barrier::Sender,
 }
 
@@ -50,16 +48,14 @@ impl RpcClient {
             response_channels,
             outgoing: Mutex::new(MessageStream::new(Box::pin(conn_tx))),
             _drop_tx,
-            next_message_id: AtomicI32::new(0),
+            next_message_id: AtomicU32::new(0),
         })
     }
 
     async fn handle_incoming<Conn>(
         conn: ReadHalf<Conn>,
         mut drop_rx: barrier::Receiver,
-        response_channels: Arc<
-            Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>,
-        >,
+        response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
     ) where
         Conn: AsyncRead + Unpin,
     {
@@ -68,36 +64,27 @@ impl RpcClient {
 
         let mut stream = MessageStream::new(conn);
         loop {
-            let read_message = stream.read_message::<proto::FromServer>();
+            let read_message = stream.read_message();
             smol::pin!(read_message);
 
             match futures::future::select(read_message, &mut dropped).await {
                 Either::Left((Ok(incoming), _)) => {
-                    if let Some(variant) = incoming.variant {
-                        if let Some(request_id) = incoming.request_id {
-                            let channel = response_channels.lock().await.remove(&request_id);
-                            if let Some((mut tx, oneshot)) = channel {
-                                if tx.send(variant).await.is_ok() {
-                                    if !oneshot {
-                                        response_channels
-                                            .lock()
-                                            .await
-                                            .insert(request_id, (tx, false));
-                                    }
-                                }
-                            } else {
-                                log::warn!(
-                                    "received RPC response to unknown request id {}",
-                                    request_id
-                                );
-                            }
+                    if let Some(responding_to) = incoming.responding_to {
+                        let channel = response_channels.lock().await.remove(&responding_to);
+                        if let Some(mut tx) = channel {
+                            tx.send(incoming).await.ok();
+                        } else {
+                            log::warn!(
+                                "received RPC response to unknown request {}",
+                                responding_to
+                            );
                         }
                     } else {
-                        log::warn!("received RPC message with no content");
+                        // unprompted message from server
                     }
                 }
                 Either::Left((Err(error), _)) => {
-                    log::warn!("invalid incoming RPC message {:?}", error);
+                    log::warn!("received invalid RPC message {:?}", error);
                 }
                 Either::Right(_) => break,
             }
@@ -111,67 +98,35 @@ impl RpcClient {
         let this = self.clone();
         async move {
             let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-            let (tx, mut rx) = mpsc::channel(1);
-            this.response_channels
-                .lock()
-                .await
-                .insert(message_id, (tx, true));
+            let (tx, mut rx) = oneshot::channel();
+            this.response_channels.lock().await.insert(message_id, tx);
             this.outgoing
                 .lock()
                 .await
-                .write_message(&proto::FromClient {
-                    id: message_id,
-                    variant: Some(req.to_variant()),
-                })
+                .write_message(&req.into_envelope(message_id, None))
                 .await?;
             let response = rx
                 .recv()
                 .await
                 .expect("response channel was unexpectedly dropped");
-            T::Response::from_variant(response)
+            T::Response::from_envelope(response)
                 .ok_or_else(|| anyhow!("received response of the wrong t"))
         }
     }
 
-    pub fn send<T: SendMessage>(self: &Arc<Self>, message: T) -> impl Future<Output = Result<()>> {
-        let this = self.clone();
-        async move {
-            let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-            this.outgoing
-                .lock()
-                .await
-                .write_message(&proto::FromClient {
-                    id: message_id,
-                    variant: Some(message.to_variant()),
-                })
-                .await?;
-            Ok(())
-        }
-    }
-
-    pub fn subscribe<T: SubscribeMessage>(
+    pub fn send<T: EnvelopedMessage>(
         self: &Arc<Self>,
-        subscription: T,
-    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Event>>>> {
+        message: T,
+    ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
             let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-            let (tx, rx) = mpsc::channel(256);
-            this.response_channels
-                .lock()
-                .await
-                .insert(message_id, (tx, false));
             this.outgoing
                 .lock()
                 .await
-                .write_message(&proto::FromClient {
-                    id: message_id,
-                    variant: Some(subscription.to_variant()),
-                })
+                .write_message(&message.into_envelope(message_id, None))
                 .await?;
-            Ok(rx.map(|event| {
-                T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
-            }))
+            Ok(())
         }
     }
 }
@@ -199,133 +154,49 @@ mod tests {
         let mut server_stream = MessageStream::new(server_conn);
         let client = RpcClient::new(client_conn, executor.clone());
 
-        let client_req = client.request(proto::from_client::Auth {
+        let client_req = client.request(proto::Auth {
             user_id: 42,
             access_token: "token".to_string(),
         });
         smol::pin!(client_req);
-        let server_req = send_recv(
-            &mut client_req,
-            server_stream.read_message::<proto::FromClient>(),
-        )
-        .await
-        .unwrap();
+        let server_req = send_recv(&mut client_req, server_stream.read_message())
+            .await
+            .unwrap();
         assert_eq!(
-            server_req.variant,
-            Some(proto::from_client::Variant::Auth(
-                proto::from_client::Auth {
-                    user_id: 42,
-                    access_token: "token".to_string()
-                }
-            ))
+            server_req.payload,
+            Some(proto::envelope::Payload::Auth(proto::Auth {
+                user_id: 42,
+                access_token: "token".to_string()
+            }))
         );
 
         // Respond to another request to ensure requests are properly matched up.
         server_stream
-            .write_message(&proto::FromServer {
-                request_id: Some(999),
-                variant: Some(proto::from_server::Variant::AuthResponse(
-                    proto::from_server::AuthResponse {
-                        credentials_valid: false,
-                    },
-                )),
-            })
+            .write_message(
+                &proto::AuthResponse {
+                    credentials_valid: false,
+                }
+                .into_envelope(1000, Some(999)),
+            )
             .await
             .unwrap();
         server_stream
-            .write_message(&proto::FromServer {
-                request_id: Some(server_req.id),
-                variant: Some(proto::from_server::Variant::AuthResponse(
-                    proto::from_server::AuthResponse {
-                        credentials_valid: true,
-                    },
-                )),
-            })
+            .write_message(
+                &proto::AuthResponse {
+                    credentials_valid: true,
+                }
+                .into_envelope(1001, Some(server_req.id)),
+            )
             .await
             .unwrap();
         assert_eq!(
             client_req.await.unwrap(),
-            proto::from_server::AuthResponse {
+            proto::AuthResponse {
                 credentials_valid: true
             }
         );
     }
 
-    #[gpui::test]
-    async fn test_subscribe(cx: gpui::TestAppContext) {
-        let executor = cx.read(|app| app.background_executor().clone());
-        let socket_dir_path = TempDir::new("subscribe").unwrap();
-        let socket_path = socket_dir_path.path().join(".sock");
-        let listener = UnixListener::bind(&socket_path).unwrap();
-        let client_conn = UnixStream::connect(&socket_path).await.unwrap();
-        let (server_conn, _) = listener.accept().await.unwrap();
-
-        let mut server_stream = MessageStream::new(server_conn);
-        let client = RpcClient::new(client_conn, executor.clone());
-
-        let mut events = client
-            .subscribe(proto::from_client::SubscribeToPathRequests {})
-            .await
-            .unwrap();
-
-        let subscription = server_stream
-            .read_message::<proto::FromClient>()
-            .await
-            .unwrap();
-        assert_eq!(
-            subscription.variant,
-            Some(proto::from_client::Variant::SubscribeToPathRequests(
-                proto::from_client::SubscribeToPathRequests {}
-            ))
-        );
-        server_stream
-            .write_message(&proto::FromServer {
-                request_id: Some(subscription.id),
-                variant: Some(proto::from_server::Variant::PathRequest(
-                    proto::from_server::PathRequest {
-                        path: b"path-1".to_vec(),
-                    },
-                )),
-            })
-            .await
-            .unwrap();
-        server_stream
-            .write_message(&proto::FromServer {
-                request_id: Some(99999),
-                variant: Some(proto::from_server::Variant::PathRequest(
-                    proto::from_server::PathRequest {
-                        path: b"path-2".to_vec(),
-                    },
-                )),
-            })
-            .await
-            .unwrap();
-        server_stream
-            .write_message(&proto::FromServer {
-                request_id: Some(subscription.id),
-                variant: Some(proto::from_server::Variant::PathRequest(
-                    proto::from_server::PathRequest {
-                        path: b"path-3".to_vec(),
-                    },
-                )),
-            })
-            .await
-            .unwrap();
-
-        assert_eq!(
-            events.recv().await.unwrap().unwrap(),
-            proto::from_server::PathRequest {
-                path: b"path-1".to_vec()
-            }
-        );
-        assert_eq!(
-            events.recv().await.unwrap().unwrap(),
-            proto::from_server::PathRequest {
-                path: b"path-3".to_vec()
-            }
-        );
-    }
-
     #[gpui::test]
     async fn test_drop_client(cx: gpui::TestAppContext) {
         let executor = cx.read(|app| app.background_executor().clone());
@@ -362,7 +233,7 @@ mod tests {
 
         let client = RpcClient::new(client_conn, executor.clone());
         let err = client
-            .request(proto::from_client::Auth {
+            .request(proto::Auth {
                 user_id: 42,
                 access_token: "token".to_string(),
             })

zed/src/workspace.rs 🔗

@@ -673,8 +673,8 @@ impl Workspace {
             let rpc_client = RpcClient::new(stream, executor);
 
             let auth_response = rpc_client
-                .request(proto::from_client::Auth {
-                    user_id: user_id.parse::<i32>()?,
+                .request(proto::Auth {
+                    user_id: user_id.parse::<u64>()?,
                     access_token,
                 })
                 .await?;

zed/src/worktree.rs 🔗

@@ -32,7 +32,7 @@ use std::{
     sync::{Arc, Weak},
     time::{Duration, SystemTime, UNIX_EPOCH},
 };
-use zed_rpc::proto::{self, from_client::PathAndDigest};
+use zed_rpc::proto;
 
 use self::{char_bag::CharBag, ignore::IgnoreStack};
 
@@ -234,23 +234,19 @@ impl Worktree {
         self.rpc_client = Some(client.clone());
         let snapshot = self.snapshot();
         cx.spawn(|_this, cx| async move {
-            let files = cx
+            let paths = cx
                 .background_executor()
                 .spawn(async move {
                     snapshot
                         .paths()
-                        .map(|path| PathAndDigest {
-                            path: path.as_os_str().as_bytes().to_vec(),
-                            digest: Default::default(),
-                        })
+                        .map(|path| path.as_os_str().as_bytes().to_vec())
                         .collect()
                 })
                 .await;
 
             let share_response = client
-                .request(proto::from_client::ShareWorktree {
-                    worktree_id: 0,
-                    files,
+                .request(proto::ShareWorktree {
+                    worktree: Some(proto::Worktree { paths }),
                 })
                 .await?;