Consolidate server's rpc state into the rpc::Server struct

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

server/src/auth.rs  |   41 -
server/src/main.rs  |    4 
server/src/rpc.rs   | 1130 +++++++++++++++++++++++-----------------------
server/src/tests.rs |    3 
4 files changed, 569 insertions(+), 609 deletions(-)

Detailed changes

server/src/auth.rs 🔗

@@ -2,7 +2,7 @@ use super::{
     db::{self, UserId},
     errors::TideResultExt,
 };
-use crate::{github, rpc, AppState, Request, RequestExt as _};
+use crate::{github, AppState, Request, RequestExt as _};
 use anyhow::{anyhow, Context};
 use async_trait::async_trait;
 pub use oauth2::basic::BasicClient as Client;
@@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize};
 use std::{borrow::Cow, convert::TryFrom, sync::Arc};
 use surf::Url;
 use tide::Server;
-use zrpc::{auth as zed_auth, proto, Peer};
+use zrpc::auth as zed_auth;
 
 static CURRENT_GITHUB_USER: &'static str = "current_github_user";
 static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
@@ -100,43 +100,6 @@ impl RequestExt for Request {
     }
 }
 
-#[async_trait]
-pub trait PeerExt {
-    async fn sign_out(
-        self: &Arc<Self>,
-        connection_id: zrpc::ConnectionId,
-        state: &AppState,
-    ) -> tide::Result<()>;
-}
-
-#[async_trait]
-impl PeerExt for Peer {
-    async fn sign_out(
-        self: &Arc<Self>,
-        connection_id: zrpc::ConnectionId,
-        state: &AppState,
-    ) -> tide::Result<()> {
-        self.disconnect(connection_id).await;
-        let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
-        for worktree_id in worktree_ids {
-            let state = state.rpc.read().await;
-            if let Some(worktree) = state.worktrees.get(&worktree_id) {
-                rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
-                    self.send(
-                        conn_id,
-                        proto::RemovePeer {
-                            worktree_id,
-                            peer_id: connection_id.0,
-                        },
-                    )
-                })
-                .await?;
-            }
-        }
-        Ok(())
-    }
-}
-
 pub fn build_client(client_id: &str, client_secret: &str) -> Client {
     Client::new(
         ClientId::new(client_id.to_string()),

server/src/main.rs 🔗

@@ -14,7 +14,7 @@ mod tests;
 
 use self::errors::TideResultExt as _;
 use anyhow::{Context, Result};
-use async_std::{net::TcpListener, sync::RwLock as AsyncRwLock};
+use async_std::net::TcpListener;
 use async_trait::async_trait;
 use auth::RequestExt as _;
 use db::{Db, DbOptions};
@@ -51,7 +51,6 @@ pub struct AppState {
     auth_client: auth::Client,
     github_client: Arc<github::AppClient>,
     repo_client: github::RepoClient,
-    rpc: AsyncRwLock<rpc::State>,
     config: Config,
 }
 
@@ -76,7 +75,6 @@ impl AppState {
             auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
             github_client,
             repo_client,
-            rpc: Default::default(),
             config,
         };
         this.register_partials();

server/src/rpc.rs 🔗

@@ -1,10 +1,10 @@
 use super::{
-    auth::{self, PeerExt as _},
+    auth,
     db::{ChannelId, UserId},
     AppState,
 };
 use anyhow::anyhow;
-use async_std::task;
+use async_std::{sync::RwLock, task};
 use async_tungstenite::{
     tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
     WebSocketStream,
@@ -13,7 +13,7 @@ use futures::{future::BoxFuture, FutureExt};
 use postage::prelude::Stream as _;
 use sha1::{Digest as _, Sha1};
 use std::{
-    any::{Any, TypeId},
+    any::TypeId,
     collections::{HashMap, HashSet},
     future::Future,
     mem,
@@ -38,51 +38,90 @@ type ReplicaId = u16;
 type MessageHandler = Box<
     dyn Send
         + Sync
-        + Fn(Box<dyn AnyTypedEnvelope>, Arc<Server>) -> BoxFuture<'static, tide::Result<()>>,
+        + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
 >;
 
-#[derive(Default)]
-struct ServerBuilder {
+pub struct Server {
+    peer: Arc<Peer>,
+    state: RwLock<ServerState>,
+    app_state: Arc<AppState>,
     handlers: HashMap<TypeId, MessageHandler>,
 }
 
-impl ServerBuilder {
-    pub fn on_message<F, Fut, M>(mut self, handler: F) -> Self
+#[derive(Default)]
+struct ServerState {
+    connections: HashMap<ConnectionId, Connection>,
+    pub worktrees: HashMap<u64, Worktree>,
+    channels: HashMap<ChannelId, Channel>,
+    next_worktree_id: u64,
+}
+
+struct Connection {
+    user_id: UserId,
+    worktrees: HashSet<u64>,
+    channels: HashSet<ChannelId>,
+}
+
+struct Worktree {
+    host_connection_id: Option<ConnectionId>,
+    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
+    active_replica_ids: HashSet<ReplicaId>,
+    access_token: String,
+    root_name: String,
+    entries: HashMap<u64, proto::Entry>,
+}
+
+#[derive(Default)]
+struct Channel {
+    connection_ids: HashSet<ConnectionId>,
+}
+
+impl Server {
+    pub fn new(app_state: Arc<AppState>, peer: Arc<Peer>) -> Arc<Self> {
+        let mut server = Server {
+            peer,
+            app_state,
+            state: Default::default(),
+            handlers: Default::default(),
+        };
+
+        server
+            .add_handler(Server::share_worktree)
+            .add_handler(Server::join_worktree)
+            .add_handler(Server::update_worktree)
+            .add_handler(Server::close_worktree)
+            .add_handler(Server::open_buffer)
+            .add_handler(Server::close_buffer)
+            .add_handler(Server::update_buffer)
+            .add_handler(Server::buffer_saved)
+            .add_handler(Server::save_buffer)
+            .add_handler(Server::get_channels)
+            .add_handler(Server::get_users)
+            .add_handler(Server::join_channel)
+            .add_handler(Server::send_channel_message);
+
+        Arc::new(server)
+    }
+
+    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
     where
-        F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
+        F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
         Fut: 'static + Send + Future<Output = tide::Result<()>>,
         M: EnvelopedMessage,
     {
         let prev_handler = self.handlers.insert(
             TypeId::of::<M>(),
-            Box::new(move |envelope, server| {
+            Box::new(move |server, envelope| {
                 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
-                (handler)(envelope, server).boxed()
+                (handler)(server, *envelope).boxed()
             }),
         );
         if prev_handler.is_some() {
             panic!("registered a handler for the same message twice");
         }
-
         self
     }
 
-    pub fn build(self, rpc: &Arc<Peer>, state: &Arc<AppState>) -> Arc<Server> {
-        Arc::new(Server {
-            rpc: rpc.clone(),
-            state: state.clone(),
-            handlers: self.handlers,
-        })
-    }
-}
-
-pub struct Server {
-    rpc: Arc<Peer>,
-    state: Arc<AppState>,
-    handlers: HashMap<TypeId, MessageHandler>,
-}
-
-impl Server {
     pub fn handle_connection<Conn>(
         self: &Arc<Self>,
         connection: Conn,
@@ -99,12 +138,8 @@ impl Server {
         let this = self.clone();
         async move {
             let (connection_id, handle_io, mut incoming_rx) =
-                this.rpc.add_connection(connection).await;
-            this.state
-                .rpc
-                .write()
-                .await
-                .add_connection(connection_id, user_id);
+                this.peer.add_connection(connection).await;
+            this.add_connection(connection_id, user_id).await;
 
             let handle_io = handle_io.fuse();
             futures::pin_mut!(handle_io);
@@ -117,7 +152,7 @@ impl Server {
                             let start_time = Instant::now();
                             log::info!("RPC message received: {}", message.payload_type_name());
                             if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
-                                if let Err(err) = (handler)(message, this.clone()).await {
+                                if let Err(err) = (handler)(this.clone(), message).await {
                                     log::error!("error handling message: {:?}", err);
                                 } else {
                                     log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
@@ -139,67 +174,36 @@ impl Server {
                 }
             }
 
-            if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await {
+            if let Err(err) = this.sign_out(connection_id).await {
                 log::error!("error signing out connection {:?} - {:?}", addr, err);
             }
         }
     }
-}
-
-#[derive(Default)]
-pub struct State {
-    connections: HashMap<ConnectionId, Connection>,
-    pub worktrees: HashMap<u64, Worktree>,
-    channels: HashMap<ChannelId, Channel>,
-    next_worktree_id: u64,
-}
-
-struct Connection {
-    user_id: UserId,
-    worktrees: HashSet<u64>,
-    channels: HashSet<ChannelId>,
-}
-
-pub struct Worktree {
-    host_connection_id: Option<ConnectionId>,
-    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
-    active_replica_ids: HashSet<ReplicaId>,
-    access_token: String,
-    root_name: String,
-    entries: HashMap<u64, proto::Entry>,
-}
-
-#[derive(Default)]
-struct Channel {
-    connection_ids: HashSet<ConnectionId>,
-}
-
-impl Worktree {
-    pub fn connection_ids(&self) -> Vec<ConnectionId> {
-        self.guest_connection_ids
-            .keys()
-            .copied()
-            .chain(self.host_connection_id)
-            .collect()
-    }
-
-    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
-        Ok(self
-            .host_connection_id
-            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
-    }
-}
 
-impl Channel {
-    fn connection_ids(&self) -> Vec<ConnectionId> {
-        self.connection_ids.iter().copied().collect()
+    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
+        self.peer.disconnect(connection_id).await;
+        let worktree_ids = self.remove_connection(connection_id).await;
+        for worktree_id in worktree_ids {
+            let state = self.state.read().await;
+            if let Some(worktree) = state.worktrees.get(&worktree_id) {
+                broadcast(connection_id, worktree.connection_ids(), |conn_id| {
+                    self.peer.send(
+                        conn_id,
+                        proto::RemovePeer {
+                            worktree_id,
+                            peer_id: connection_id.0,
+                        },
+                    )
+                })
+                .await?;
+            }
+        }
+        Ok(())
     }
-}
 
-impl State {
     // Add a new connection associated with a given user.
-    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
-        self.connections.insert(
+    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
+        self.state.write().await.connections.insert(
             connection_id,
             Connection {
                 user_id,
@@ -210,16 +214,17 @@ impl State {
     }
 
     // Remove the given connection and its association with any worktrees.
-    pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
+    async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
         let mut worktree_ids = Vec::new();
-        if let Some(connection) = self.connections.remove(&connection_id) {
+        let mut state = self.state.write().await;
+        if let Some(connection) = state.connections.remove(&connection_id) {
             for channel_id in connection.channels {
-                if let Some(channel) = self.channels.get_mut(&channel_id) {
+                if let Some(channel) = state.channels.get_mut(&channel_id) {
                     channel.connection_ids.remove(&connection_id);
                 }
             }
             for worktree_id in connection.worktrees {
-                if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+                if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
                     if worktree.host_connection_id == Some(connection_id) {
                         worktree_ids.push(worktree_id);
                     } else if let Some(replica_id) =
@@ -234,6 +239,444 @@ impl State {
         worktree_ids
     }
 
+    async fn share_worktree(
+        self: Arc<Server>,
+        mut request: TypedEnvelope<proto::ShareWorktree>,
+    ) -> tide::Result<()> {
+        let mut state = self.state.write().await;
+        let worktree_id = state.next_worktree_id;
+        state.next_worktree_id += 1;
+        let access_token = random_token();
+        let worktree = request
+            .payload
+            .worktree
+            .as_mut()
+            .ok_or_else(|| anyhow!("missing worktree"))?;
+        let entries = mem::take(&mut worktree.entries)
+            .into_iter()
+            .map(|entry| (entry.id, entry))
+            .collect();
+        state.worktrees.insert(
+            worktree_id,
+            Worktree {
+                host_connection_id: Some(request.sender_id),
+                guest_connection_ids: Default::default(),
+                active_replica_ids: Default::default(),
+                access_token: access_token.clone(),
+                root_name: mem::take(&mut worktree.root_name),
+                entries,
+            },
+        );
+
+        self.peer
+            .respond(
+                request.receipt(),
+                proto::ShareWorktreeResponse {
+                    worktree_id,
+                    access_token,
+                },
+            )
+            .await?;
+        Ok(())
+    }
+
+    async fn join_worktree(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::OpenWorktree>,
+    ) -> tide::Result<()> {
+        let worktree_id = request.payload.worktree_id;
+        let access_token = &request.payload.access_token;
+
+        let mut state = self.state.write().await;
+        if let Some((peer_replica_id, worktree)) =
+            state.join_worktree(request.sender_id, worktree_id, access_token)
+        {
+            let mut peers = Vec::new();
+            if let Some(host_connection_id) = worktree.host_connection_id {
+                peers.push(proto::Peer {
+                    peer_id: host_connection_id.0,
+                    replica_id: 0,
+                });
+            }
+            for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
+                if *peer_conn_id != request.sender_id {
+                    peers.push(proto::Peer {
+                        peer_id: peer_conn_id.0,
+                        replica_id: *peer_replica_id as u32,
+                    });
+                }
+            }
+
+            broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
+                self.peer.send(
+                    conn_id,
+                    proto::AddPeer {
+                        worktree_id,
+                        peer: Some(proto::Peer {
+                            peer_id: request.sender_id.0,
+                            replica_id: peer_replica_id as u32,
+                        }),
+                    },
+                )
+            })
+            .await?;
+            self.peer
+                .respond(
+                    request.receipt(),
+                    proto::OpenWorktreeResponse {
+                        worktree_id,
+                        worktree: Some(proto::Worktree {
+                            root_name: worktree.root_name.clone(),
+                            entries: worktree.entries.values().cloned().collect(),
+                        }),
+                        replica_id: peer_replica_id as u32,
+                        peers,
+                    },
+                )
+                .await?;
+        } else {
+            self.peer
+                .respond(
+                    request.receipt(),
+                    proto::OpenWorktreeResponse {
+                        worktree_id,
+                        worktree: None,
+                        replica_id: 0,
+                        peers: Vec::new(),
+                    },
+                )
+                .await?;
+        }
+
+        Ok(())
+    }
+
+    async fn update_worktree(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::UpdateWorktree>,
+    ) -> tide::Result<()> {
+        {
+            let mut state = self.state.write().await;
+            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
+            for entry_id in &request.payload.removed_entries {
+                worktree.entries.remove(&entry_id);
+            }
+
+            for entry in &request.payload.updated_entries {
+                worktree.entries.insert(entry.id, entry.clone());
+            }
+        }
+
+        self.broadcast_in_worktree(request.payload.worktree_id, &request)
+            .await?;
+        Ok(())
+    }
+
+    async fn close_worktree(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::CloseWorktree>,
+    ) -> tide::Result<()> {
+        let connection_ids;
+        {
+            let mut state = self.state.write().await;
+            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
+            connection_ids = worktree.connection_ids();
+            if worktree.host_connection_id == Some(request.sender_id) {
+                worktree.host_connection_id = None;
+            } else if let Some(replica_id) =
+                worktree.guest_connection_ids.remove(&request.sender_id)
+            {
+                worktree.active_replica_ids.remove(&replica_id);
+            }
+        }
+
+        broadcast(request.sender_id, connection_ids, |conn_id| {
+            self.peer.send(
+                conn_id,
+                proto::RemovePeer {
+                    worktree_id: request.payload.worktree_id,
+                    peer_id: request.sender_id.0,
+                },
+            )
+        })
+        .await?;
+
+        Ok(())
+    }
+
+    async fn open_buffer(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::OpenBuffer>,
+    ) -> tide::Result<()> {
+        let receipt = request.receipt();
+        let worktree_id = request.payload.worktree_id;
+        let host_connection_id = self
+            .state
+            .read()
+            .await
+            .read_worktree(worktree_id, request.sender_id)?
+            .host_connection_id()?;
+
+        let response = self
+            .peer
+            .forward_request(request.sender_id, host_connection_id, request.payload)
+            .await?;
+        self.peer.respond(receipt, response).await?;
+        Ok(())
+    }
+
+    async fn close_buffer(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::CloseBuffer>,
+    ) -> tide::Result<()> {
+        let host_connection_id = self
+            .state
+            .read()
+            .await
+            .read_worktree(request.payload.worktree_id, request.sender_id)?
+            .host_connection_id()?;
+
+        self.peer
+            .forward_send(request.sender_id, host_connection_id, request.payload)
+            .await?;
+
+        Ok(())
+    }
+
+    async fn save_buffer(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::SaveBuffer>,
+    ) -> tide::Result<()> {
+        let host;
+        let guests;
+        {
+            let state = self.state.read().await;
+            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
+            host = worktree.host_connection_id()?;
+            guests = worktree
+                .guest_connection_ids
+                .keys()
+                .copied()
+                .collect::<Vec<_>>();
+        }
+
+        let sender = request.sender_id;
+        let receipt = request.receipt();
+        let response = self
+            .peer
+            .forward_request(sender, host, request.payload.clone())
+            .await?;
+
+        broadcast(host, guests, |conn_id| {
+            let response = response.clone();
+            let peer = &self.peer;
+            async move {
+                if conn_id == sender {
+                    peer.respond(receipt, response).await
+                } else {
+                    peer.forward_send(host, conn_id, response).await
+                }
+            }
+        })
+        .await?;
+
+        Ok(())
+    }
+
+    async fn update_buffer(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::UpdateBuffer>,
+    ) -> tide::Result<()> {
+        self.broadcast_in_worktree(request.payload.worktree_id, &request)
+            .await
+    }
+
+    async fn buffer_saved(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::BufferSaved>,
+    ) -> tide::Result<()> {
+        self.broadcast_in_worktree(request.payload.worktree_id, &request)
+            .await
+    }
+
+    async fn get_channels(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::GetChannels>,
+    ) -> tide::Result<()> {
+        let user_id = self
+            .state
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
+        let channels = self.app_state.db.get_channels_for_user(user_id).await?;
+        self.peer
+            .respond(
+                request.receipt(),
+                proto::GetChannelsResponse {
+                    channels: channels
+                        .into_iter()
+                        .map(|chan| proto::Channel {
+                            id: chan.id.to_proto(),
+                            name: chan.name,
+                        })
+                        .collect(),
+                },
+            )
+            .await?;
+        Ok(())
+    }
+
+    async fn get_users(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::GetUsers>,
+    ) -> tide::Result<()> {
+        let user_id = self
+            .state
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
+        let receipt = request.receipt();
+        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
+        let users = self
+            .app_state
+            .db
+            .get_users_by_ids(user_id, user_ids)
+            .await?
+            .into_iter()
+            .map(|user| proto::User {
+                id: user.id.to_proto(),
+                github_login: user.github_login,
+                avatar_url: String::new(),
+            })
+            .collect();
+        self.peer
+            .respond(receipt, proto::GetUsersResponse { users })
+            .await?;
+        Ok(())
+    }
+
+    async fn join_channel(
+        self: Arc<Self>,
+        request: TypedEnvelope<proto::JoinChannel>,
+    ) -> tide::Result<()> {
+        let user_id = self
+            .state
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
+        let channel_id = ChannelId::from_proto(request.payload.channel_id);
+        if !self
+            .app_state
+            .db
+            .can_user_access_channel(user_id, channel_id)
+            .await?
+        {
+            Err(anyhow!("access denied"))?;
+        }
+
+        self.state
+            .write()
+            .await
+            .join_channel(request.sender_id, channel_id);
+        let messages = self
+            .app_state
+            .db
+            .get_recent_channel_messages(channel_id, 50)
+            .await?
+            .into_iter()
+            .map(|msg| proto::ChannelMessage {
+                id: msg.id.to_proto(),
+                body: msg.body,
+                timestamp: msg.sent_at.unix_timestamp() as u64,
+                sender_id: msg.sender_id.to_proto(),
+            })
+            .collect();
+        self.peer
+            .respond(request.receipt(), proto::JoinChannelResponse { messages })
+            .await?;
+        Ok(())
+    }
+
+    async fn send_channel_message(
+        self: Arc<Self>,
+        request: TypedEnvelope<proto::SendChannelMessage>,
+    ) -> tide::Result<()> {
+        let channel_id = ChannelId::from_proto(request.payload.channel_id);
+        let user_id;
+        let connection_ids;
+        {
+            let state = self.state.read().await;
+            user_id = state.user_id_for_connection(request.sender_id)?;
+            if let Some(channel) = state.channels.get(&channel_id) {
+                connection_ids = channel.connection_ids();
+            } else {
+                return Ok(());
+            }
+        }
+
+        let timestamp = OffsetDateTime::now_utc();
+        let message_id = self
+            .app_state
+            .db
+            .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
+            .await?;
+        let message = proto::ChannelMessageSent {
+            channel_id: channel_id.to_proto(),
+            message: Some(proto::ChannelMessage {
+                sender_id: user_id.to_proto(),
+                id: message_id.to_proto(),
+                body: request.payload.body,
+                timestamp: timestamp.unix_timestamp() as u64,
+            }),
+        };
+        broadcast(request.sender_id, connection_ids, |conn_id| {
+            self.peer.send(conn_id, message.clone())
+        })
+        .await?;
+
+        Ok(())
+    }
+
+    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
+        &self,
+        worktree_id: u64,
+        message: &TypedEnvelope<T>,
+    ) -> tide::Result<()> {
+        let connection_ids = self
+            .state
+            .read()
+            .await
+            .read_worktree(worktree_id, message.sender_id)?
+            .connection_ids();
+
+        broadcast(message.sender_id, connection_ids, |conn_id| {
+            self.peer
+                .forward_send(message.sender_id, conn_id, message.payload.clone())
+        })
+        .await?;
+
+        Ok(())
+    }
+}
+
+pub async fn broadcast<F, T>(
+    sender_id: ConnectionId,
+    receiver_ids: Vec<ConnectionId>,
+    mut f: F,
+) -> anyhow::Result<()>
+where
+    F: FnMut(ConnectionId) -> T,
+    T: Future<Output = anyhow::Result<()>>,
+{
+    let futures = receiver_ids
+        .into_iter()
+        .filter(|id| *id != sender_id)
+        .map(|id| f(id));
+    futures::future::try_join_all(futures).await?;
+    Ok(())
+}
+
+impl ServerState {
     fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
         if let Some(connection) = self.connections.get_mut(&connection_id) {
             connection.channels.insert(channel_id);
@@ -245,8 +688,16 @@ impl State {
         }
     }
 
+    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
+        Ok(self
+            .connections
+            .get(&connection_id)
+            .ok_or_else(|| anyhow!("unknown connection"))?
+            .user_id)
+    }
+
     // Add the given connection as a guest of the given worktree
-    pub fn join_worktree(
+    fn join_worktree(
         &mut self,
         connection_id: ConnectionId,
         worktree_id: u64,
@@ -275,14 +726,6 @@ impl State {
         }
     }
 
-    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
-        Ok(self
-            .connections
-            .get(&connection_id)
-            .ok_or_else(|| anyhow!("unknown connection"))?
-            .user_id)
-    }
-
     fn read_worktree(
         &self,
         worktree_id: u64,
@@ -330,26 +773,30 @@ impl State {
     }
 }
 
-pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
-    ServerBuilder::default()
-        .on_message(share_worktree)
-        .on_message(join_worktree)
-        .on_message(update_worktree)
-        .on_message(close_worktree)
-        .on_message(open_buffer)
-        .on_message(close_buffer)
-        .on_message(update_buffer)
-        .on_message(buffer_saved)
-        .on_message(save_buffer)
-        .on_message(get_channels)
-        .on_message(get_users)
-        .on_message(join_channel)
-        .on_message(send_channel_message)
-        .build(rpc, state)
+impl Worktree {
+    pub fn connection_ids(&self) -> Vec<ConnectionId> {
+        self.guest_connection_ids
+            .keys()
+            .copied()
+            .chain(self.host_connection_id)
+            .collect()
+    }
+
+    fn host_connection_id(&self) -> tide::Result<ConnectionId> {
+        Ok(self
+            .host_connection_id
+            .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
+    }
+}
+
+impl Channel {
+    fn connection_ids(&self) -> Vec<ConnectionId> {
+        self.connection_ids.iter().copied().collect()
+    }
 }
 
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
-    let server = build_server(app.state(), rpc);
+    let server = Server::new(app.state().clone(), rpc.clone());
     app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
         let user_id = request.ext::<UserId>().copied();
         let server = server.clone();
@@ -392,453 +839,6 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
     });
 }
 
-async fn share_worktree(
-    mut request: Box<TypedEnvelope<proto::ShareWorktree>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let mut state = server.state.rpc.write().await;
-    let worktree_id = state.next_worktree_id;
-    state.next_worktree_id += 1;
-    let access_token = random_token();
-    let worktree = request
-        .payload
-        .worktree
-        .as_mut()
-        .ok_or_else(|| anyhow!("missing worktree"))?;
-    let entries = mem::take(&mut worktree.entries)
-        .into_iter()
-        .map(|entry| (entry.id, entry))
-        .collect();
-    state.worktrees.insert(
-        worktree_id,
-        Worktree {
-            host_connection_id: Some(request.sender_id),
-            guest_connection_ids: Default::default(),
-            active_replica_ids: Default::default(),
-            access_token: access_token.clone(),
-            root_name: mem::take(&mut worktree.root_name),
-            entries,
-        },
-    );
-
-    server
-        .rpc
-        .respond(
-            request.receipt(),
-            proto::ShareWorktreeResponse {
-                worktree_id,
-                access_token,
-            },
-        )
-        .await?;
-    Ok(())
-}
-
-async fn join_worktree(
-    request: Box<TypedEnvelope<proto::OpenWorktree>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let worktree_id = request.payload.worktree_id;
-    let access_token = &request.payload.access_token;
-
-    let mut state = server.state.rpc.write().await;
-    if let Some((peer_replica_id, worktree)) =
-        state.join_worktree(request.sender_id, worktree_id, access_token)
-    {
-        let mut peers = Vec::new();
-        if let Some(host_connection_id) = worktree.host_connection_id {
-            peers.push(proto::Peer {
-                peer_id: host_connection_id.0,
-                replica_id: 0,
-            });
-        }
-        for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
-            if *peer_conn_id != request.sender_id {
-                peers.push(proto::Peer {
-                    peer_id: peer_conn_id.0,
-                    replica_id: *peer_replica_id as u32,
-                });
-            }
-        }
-
-        broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
-            server.rpc.send(
-                conn_id,
-                proto::AddPeer {
-                    worktree_id,
-                    peer: Some(proto::Peer {
-                        peer_id: request.sender_id.0,
-                        replica_id: peer_replica_id as u32,
-                    }),
-                },
-            )
-        })
-        .await?;
-        server
-            .rpc
-            .respond(
-                request.receipt(),
-                proto::OpenWorktreeResponse {
-                    worktree_id,
-                    worktree: Some(proto::Worktree {
-                        root_name: worktree.root_name.clone(),
-                        entries: worktree.entries.values().cloned().collect(),
-                    }),
-                    replica_id: peer_replica_id as u32,
-                    peers,
-                },
-            )
-            .await?;
-    } else {
-        server
-            .rpc
-            .respond(
-                request.receipt(),
-                proto::OpenWorktreeResponse {
-                    worktree_id,
-                    worktree: None,
-                    replica_id: 0,
-                    peers: Vec::new(),
-                },
-            )
-            .await?;
-    }
-
-    Ok(())
-}
-
-async fn update_worktree(
-    request: Box<TypedEnvelope<proto::UpdateWorktree>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    {
-        let mut state = server.state.rpc.write().await;
-        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
-        for entry_id in &request.payload.removed_entries {
-            worktree.entries.remove(&entry_id);
-        }
-
-        for entry in &request.payload.updated_entries {
-            worktree.entries.insert(entry.id, entry.clone());
-        }
-    }
-
-    broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?;
-    Ok(())
-}
-
-async fn close_worktree(
-    request: Box<TypedEnvelope<proto::CloseWorktree>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let connection_ids;
-    {
-        let mut state = server.state.rpc.write().await;
-        let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
-        connection_ids = worktree.connection_ids();
-        if worktree.host_connection_id == Some(request.sender_id) {
-            worktree.host_connection_id = None;
-        } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
-            worktree.active_replica_ids.remove(&replica_id);
-        }
-    }
-
-    broadcast(request.sender_id, connection_ids, |conn_id| {
-        server.rpc.send(
-            conn_id,
-            proto::RemovePeer {
-                worktree_id: request.payload.worktree_id,
-                peer_id: request.sender_id.0,
-            },
-        )
-    })
-    .await?;
-
-    Ok(())
-}
-
-async fn open_buffer(
-    request: Box<TypedEnvelope<proto::OpenBuffer>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let receipt = request.receipt();
-    let worktree_id = request.payload.worktree_id;
-    let host_connection_id = server
-        .state
-        .rpc
-        .read()
-        .await
-        .read_worktree(worktree_id, request.sender_id)?
-        .host_connection_id()?;
-
-    let response = server
-        .rpc
-        .forward_request(request.sender_id, host_connection_id, request.payload)
-        .await?;
-    server.rpc.respond(receipt, response).await?;
-    Ok(())
-}
-
-async fn close_buffer(
-    request: Box<TypedEnvelope<proto::CloseBuffer>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let host_connection_id = server
-        .state
-        .rpc
-        .read()
-        .await
-        .read_worktree(request.payload.worktree_id, request.sender_id)?
-        .host_connection_id()?;
-
-    server
-        .rpc
-        .forward_send(request.sender_id, host_connection_id, request.payload)
-        .await?;
-
-    Ok(())
-}
-
-async fn save_buffer(
-    request: Box<TypedEnvelope<proto::SaveBuffer>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let host;
-    let guests;
-    {
-        let state = server.state.rpc.read().await;
-        let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
-        host = worktree.host_connection_id()?;
-        guests = worktree
-            .guest_connection_ids
-            .keys()
-            .copied()
-            .collect::<Vec<_>>();
-    }
-
-    let sender = request.sender_id;
-    let receipt = request.receipt();
-    let response = server
-        .rpc
-        .forward_request(sender, host, request.payload.clone())
-        .await?;
-
-    broadcast(host, guests, |conn_id| {
-        let response = response.clone();
-        let server = &server;
-        async move {
-            if conn_id == sender {
-                server.rpc.respond(receipt, response).await
-            } else {
-                server.rpc.forward_send(host, conn_id, response).await
-            }
-        }
-    })
-    .await?;
-
-    Ok(())
-}
-
-async fn update_buffer(
-    request: Box<TypedEnvelope<proto::UpdateBuffer>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
-}
-
-async fn buffer_saved(
-    request: Box<TypedEnvelope<proto::BufferSaved>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
-}
-
-async fn get_channels(
-    request: Box<TypedEnvelope<proto::GetChannels>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let user_id = server
-        .state
-        .rpc
-        .read()
-        .await
-        .user_id_for_connection(request.sender_id)?;
-    let channels = server.state.db.get_channels_for_user(user_id).await?;
-    server
-        .rpc
-        .respond(
-            request.receipt(),
-            proto::GetChannelsResponse {
-                channels: channels
-                    .into_iter()
-                    .map(|chan| proto::Channel {
-                        id: chan.id.to_proto(),
-                        name: chan.name,
-                    })
-                    .collect(),
-            },
-        )
-        .await?;
-    Ok(())
-}
-
-async fn get_users(
-    request: Box<TypedEnvelope<proto::GetUsers>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let user_id = server
-        .state
-        .rpc
-        .read()
-        .await
-        .user_id_for_connection(request.sender_id)?;
-    let receipt = request.receipt();
-    let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
-    let users = server
-        .state
-        .db
-        .get_users_by_ids(user_id, user_ids)
-        .await?
-        .into_iter()
-        .map(|user| proto::User {
-            id: user.id.to_proto(),
-            github_login: user.github_login,
-            avatar_url: String::new(),
-        })
-        .collect();
-    server
-        .rpc
-        .respond(receipt, proto::GetUsersResponse { users })
-        .await?;
-    Ok(())
-}
-
-async fn join_channel(
-    request: Box<TypedEnvelope<proto::JoinChannel>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let user_id = server
-        .state
-        .rpc
-        .read()
-        .await
-        .user_id_for_connection(request.sender_id)?;
-    let channel_id = ChannelId::from_proto(request.payload.channel_id);
-    if !server
-        .state
-        .db
-        .can_user_access_channel(user_id, channel_id)
-        .await?
-    {
-        Err(anyhow!("access denied"))?;
-    }
-
-    server
-        .state
-        .rpc
-        .write()
-        .await
-        .join_channel(request.sender_id, channel_id);
-    let messages = server
-        .state
-        .db
-        .get_recent_channel_messages(channel_id, 50)
-        .await?
-        .into_iter()
-        .map(|msg| proto::ChannelMessage {
-            id: msg.id.to_proto(),
-            body: msg.body,
-            timestamp: msg.sent_at.unix_timestamp() as u64,
-            sender_id: msg.sender_id.to_proto(),
-        })
-        .collect();
-    server
-        .rpc
-        .respond(request.receipt(), proto::JoinChannelResponse { messages })
-        .await?;
-    Ok(())
-}
-
-async fn send_channel_message(
-    request: Box<TypedEnvelope<proto::SendChannelMessage>>,
-    server: Arc<Server>,
-) -> tide::Result<()> {
-    let channel_id = ChannelId::from_proto(request.payload.channel_id);
-    let user_id;
-    let connection_ids;
-    {
-        let state = server.state.rpc.read().await;
-        user_id = state.user_id_for_connection(request.sender_id)?;
-        if let Some(channel) = state.channels.get(&channel_id) {
-            connection_ids = channel.connection_ids();
-        } else {
-            return Ok(());
-        }
-    }
-
-    let timestamp = OffsetDateTime::now_utc();
-    let message_id = server
-        .state
-        .db
-        .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
-        .await?;
-    let message = proto::ChannelMessageSent {
-        channel_id: channel_id.to_proto(),
-        message: Some(proto::ChannelMessage {
-            sender_id: user_id.to_proto(),
-            id: message_id.to_proto(),
-            body: request.payload.body,
-            timestamp: timestamp.unix_timestamp() as u64,
-        }),
-    };
-    broadcast(request.sender_id, connection_ids, |conn_id| {
-        server.rpc.send(conn_id, message.clone())
-    })
-    .await?;
-
-    Ok(())
-}
-
-async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
-    worktree_id: u64,
-    request: &TypedEnvelope<T>,
-    server: &Arc<Server>,
-) -> tide::Result<()> {
-    let connection_ids = server
-        .state
-        .rpc
-        .read()
-        .await
-        .read_worktree(worktree_id, request.sender_id)?
-        .connection_ids();
-
-    broadcast(request.sender_id, connection_ids, |conn_id| {
-        server
-            .rpc
-            .forward_send(request.sender_id, conn_id, request.payload.clone())
-    })
-    .await?;
-
-    Ok(())
-}
-
-pub async fn broadcast<F, T>(
-    sender_id: ConnectionId,
-    receiver_ids: Vec<ConnectionId>,
-    mut f: F,
-) -> anyhow::Result<()>
-where
-    F: FnMut(ConnectionId) -> T,
-    T: Future<Output = anyhow::Result<()>>,
-{
-    let futures = receiver_ids
-        .into_iter()
-        .filter(|id| *id != sender_id)
-        .map(|id| f(id));
-    futures::future::try_join_all(futures).await?;
-    Ok(())
-}
-
 fn header_contains_ignore_case<T>(
     request: &tide::Request<T>,
     header_name: HeaderName,

server/src/tests.rs 🔗

@@ -540,7 +540,7 @@ impl TestServer {
         let db_name = format!("zed-test-{}", rng.gen::<u128>());
         let app_state = Self::build_app_state(&db_name).await;
         let peer = Peer::new();
-        let server = rpc::build_server(&app_state, &peer);
+        let server = rpc::Server::new(app_state.clone(), peer.clone());
         Self {
             peer,
             app_state,
@@ -595,7 +595,6 @@ impl TestServer {
             auth_client: auth::build_client("", ""),
             repo_client: github::RepoClient::test(&github_client),
             github_client,
-            rpc: Default::default(),
             config,
         })
     }