Implement `proto::GetCollaborators` request

Antonio Scandurra created

Change summary

server/src/bin/seed.rs |   8 +
server/src/db.rs       |  17 ----
server/src/rpc.rs      | 161 ++++++++++++++++++++++++++++++++++---------
zrpc/proto/zed.proto   |  19 +++++
zrpc/src/proto.rs      |   3 
5 files changed, 154 insertions(+), 54 deletions(-)

Detailed changes

server/src/bin/seed.rs 🔗

@@ -27,8 +27,12 @@ async fn main() {
     let zed_users = ["nathansobo", "maxbrunsfeld", "as-cii", "iamnbutler"];
     let mut zed_user_ids = Vec::<UserId>::new();
     for zed_user in zed_users {
-        if let Some(user_id) = db.get_user(zed_user).await.expect("failed to fetch user") {
-            zed_user_ids.push(user_id);
+        if let Some(user) = db
+            .get_user_by_github_login(zed_user)
+            .await
+            .expect("failed to fetch user")
+        {
+            zed_user_ids.push(user.id);
         } else {
             zed_user_ids.push(
                 db.create_user(zed_user, true)

server/src/db.rs 🔗

@@ -84,27 +84,12 @@ impl Db {
 
     // users
 
-    #[allow(unused)] // Help rust-analyzer
-    #[cfg(any(test, feature = "seed-support"))]
-    pub async fn get_user(&self, github_login: &str) -> Result<Option<UserId>> {
-        test_support!(self, {
-            let query = "
-                SELECT id
-                FROM users
-                WHERE github_login = $1
-            ";
-            sqlx::query_scalar(query)
-                .bind(github_login)
-                .fetch_optional(&self.pool)
-                .await
-        })
-    }
-
     pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
         test_support!(self, {
             let query = "
                 INSERT INTO users (github_login, admin)
                 VALUES ($1, $2)
+                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
                 RETURNING id
             ";
             sqlx::query_scalar(query)

server/src/rpc.rs 🔗

@@ -48,8 +48,9 @@ pub struct Server {
 #[derive(Default)]
 struct ServerState {
     connections: HashMap<ConnectionId, ConnectionState>,
+    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
     pub worktrees: HashMap<u64, Worktree>,
-    visible_worktrees_by_github_login: HashMap<String, HashSet<u64>>,
+    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
     channels: HashMap<ChannelId, Channel>,
     next_worktree_id: u64,
 }
@@ -62,7 +63,7 @@ struct ConnectionState {
 
 struct Worktree {
     host_connection_id: ConnectionId,
-    collaborator_github_logins: Vec<String>,
+    collaborator_user_ids: Vec<UserId>,
     root_name: String,
     share: Option<WorktreeShare>,
 }
@@ -113,7 +114,8 @@ impl Server {
             .add_handler(Server::join_channel)
             .add_handler(Server::leave_channel)
             .add_handler(Server::send_channel_message)
-            .add_handler(Server::get_channel_messages);
+            .add_handler(Server::get_channel_messages)
+            .add_handler(Server::get_collaborators);
 
         Arc::new(server)
     }
@@ -215,7 +217,8 @@ impl Server {
 
     // Add a new connection associated with a given user.
     async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
-        self.state.write().await.connections.insert(
+        let mut state = self.state.write().await;
+        state.connections.insert(
             connection_id,
             ConnectionState {
                 user_id,
@@ -223,6 +226,11 @@ impl Server {
                 channels: Default::default(),
             },
         );
+        state
+            .connections_by_user_id
+            .entry(user_id)
+            .or_default()
+            .insert(connection_id);
     }
 
     // Remove the given connection and its association with any worktrees.
@@ -249,6 +257,15 @@ impl Server {
                     }
                 }
             }
+
+            let user_connections = state
+                .connections_by_user_id
+                .get_mut(&connection.user_id)
+                .unwrap();
+            user_connections.remove(&connection_id);
+            if user_connections.is_empty() {
+                state.connections_by_user_id.remove(&connection.user_id);
+            }
         }
         worktree_ids
     }
@@ -264,10 +281,24 @@ impl Server {
     ) -> tide::Result<()> {
         let receipt = request.receipt();
 
+        let mut collaborator_user_ids = Vec::new();
+        for github_login in request.payload.collaborator_logins {
+            match self.app_state.db.create_user(&github_login, false).await {
+                Ok(user_id) => collaborator_user_ids.push(user_id),
+                Err(err) => {
+                    let message = err.to_string();
+                    self.peer
+                        .respond_with_error(receipt, proto::Error { message })
+                        .await?;
+                    return Ok(());
+                }
+            }
+        }
+
         let mut state = self.state.write().await;
         let worktree_id = state.add_worktree(Worktree {
             host_connection_id: request.sender_id,
-            collaborator_github_logins: request.payload.collaborator_logins,
+            collaborator_user_ids,
             root_name: request.payload.root_name,
             share: None,
         });
@@ -351,12 +382,16 @@ impl Server {
         request: TypedEnvelope<proto::JoinWorktree>,
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
-        let user = self.user_for_connection(request.sender_id).await?;
+        let user_id = self
+            .state
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
 
         let response;
         let connection_ids;
         let mut state = self.state.write().await;
-        match state.join_worktree(request.sender_id, &user, worktree_id) {
+        match state.join_worktree(request.sender_id, user_id, worktree_id) {
             Ok((peer_replica_id, worktree)) => {
                 let share = worktree.share()?;
                 let peer_count = share.guest_connection_ids.len();
@@ -639,6 +674,66 @@ impl Server {
         Ok(())
     }
 
+    async fn get_collaborators(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::GetCollaborators>,
+    ) -> tide::Result<()> {
+        let mut collaborators = HashMap::new();
+        {
+            let state = self.state.read().await;
+            let user_id = state.user_id_for_connection(request.sender_id)?;
+            for worktree_id in state
+                .visible_worktrees_by_user_id
+                .get(&user_id)
+                .unwrap_or(&HashSet::new())
+            {
+                let worktree = &state.worktrees[worktree_id];
+
+                let mut participants = Vec::new();
+                for collaborator_user_id in &worktree.collaborator_user_ids {
+                    collaborators
+                        .entry(*collaborator_user_id)
+                        .or_insert_with(|| proto::Collaborator {
+                            user_id: collaborator_user_id.to_proto(),
+                            worktrees: Vec::new(),
+                            is_online: state.is_online(*collaborator_user_id),
+                        });
+
+                    if let Ok(share) = worktree.share() {
+                        let mut conn_ids = state.user_connection_ids(*collaborator_user_id);
+                        if conn_ids.any(|c| share.guest_connection_ids.contains_key(&c)) {
+                            participants.push(collaborator_user_id.to_proto());
+                        }
+                    }
+                }
+
+                let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
+                let host =
+                    collaborators
+                        .entry(host_user_id)
+                        .or_insert_with(|| proto::Collaborator {
+                            user_id: host_user_id.to_proto(),
+                            worktrees: Vec::new(),
+                            is_online: true,
+                        });
+                host.worktrees.push(proto::CollaboratorWorktree {
+                    is_shared: worktree.share().is_ok(),
+                    participants,
+                });
+            }
+        }
+
+        self.peer
+            .respond(
+                request.receipt(),
+                proto::GetCollaboratorsResponse {
+                    collaborators: collaborators.into_values().collect(),
+                },
+            )
+            .await?;
+        Ok(())
+    }
+
     async fn join_channel(
         self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,
@@ -856,24 +951,6 @@ impl Server {
         Ok(())
     }
 
-    async fn user_for_connection(&self, connection_id: ConnectionId) -> tide::Result<User> {
-        let user_id = self
-            .state
-            .read()
-            .await
-            .connections
-            .get(&connection_id)
-            .ok_or_else(|| anyhow!("no such connection"))?
-            .user_id;
-        Ok(self
-            .app_state
-            .db
-            .get_users_by_ids(user_id, Some(user_id).into_iter())
-            .await?
-            .pop()
-            .ok_or_else(|| anyhow!("no such user"))?)
-    }
-
     async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
         &self,
         worktree_id: u64,
@@ -945,11 +1022,26 @@ impl ServerState {
             .user_id)
     }
 
+    fn user_connection_ids<'a>(
+        &'a self,
+        user_id: UserId,
+    ) -> impl 'a + Iterator<Item = ConnectionId> {
+        self.connections_by_user_id
+            .get(&user_id)
+            .into_iter()
+            .flatten()
+            .copied()
+    }
+
+    fn is_online(&self, user_id: UserId) -> bool {
+        self.connections_by_user_id.contains_key(&user_id)
+    }
+
     // Add the given connection as a guest of the given worktree
     fn join_worktree(
         &mut self,
         connection_id: ConnectionId,
-        user: &User,
+        user_id: UserId,
         worktree_id: u64,
     ) -> tide::Result<(ReplicaId, &Worktree)> {
         let connection = self
@@ -960,10 +1052,7 @@ impl ServerState {
             .worktrees
             .get_mut(&worktree_id)
             .ok_or_else(|| anyhow!("no such worktree"))?;
-        if !worktree
-            .collaborator_github_logins
-            .contains(&user.github_login)
-        {
+        if !worktree.collaborator_user_ids.contains(&user_id) {
             Err(anyhow!("no such worktree"))?;
         }
 
@@ -1032,9 +1121,9 @@ impl ServerState {
 
     fn add_worktree(&mut self, worktree: Worktree) -> u64 {
         let worktree_id = self.next_worktree_id;
-        for collaborator_login in &worktree.collaborator_github_logins {
-            self.visible_worktrees_by_github_login
-                .entry(collaborator_login.clone())
+        for collaborator_user_id in &worktree.collaborator_user_ids {
+            self.visible_worktrees_by_user_id
+                .entry(*collaborator_user_id)
                 .or_default()
                 .insert(worktree_id);
         }
@@ -1055,10 +1144,10 @@ impl ServerState {
                 }
             }
         }
-        for collaborator_login in worktree.collaborator_github_logins {
+        for collaborator_user_id in worktree.collaborator_user_ids {
             if let Some(visible_worktrees) = self
-                .visible_worktrees_by_github_login
-                .get_mut(&collaborator_login)
+                .visible_worktrees_by_user_id
+                .get_mut(&collaborator_user_id)
             {
                 visible_worktrees.remove(&worktree_id);
             }

zrpc/proto/zed.proto 🔗

@@ -38,6 +38,8 @@ message Envelope {
         OpenWorktree open_worktree = 33;
         OpenWorktreeResponse open_worktree_response = 34;
         UnshareWorktree unshare_worktree = 35;
+        GetCollaborators get_collaborators = 36;
+        GetCollaboratorsResponse get_collaborators_response = 37;
     }
 }
 
@@ -184,6 +186,12 @@ message GetChannelMessagesResponse {
     bool done = 2;
 }
 
+message GetCollaborators {}
+
+message GetCollaboratorsResponse {
+    repeated Collaborator collaborators = 1;
+}
+
 // Entities
 
 message Peer {
@@ -326,3 +334,14 @@ message ChannelMessage {
     uint64 sender_id = 4;
     Nonce nonce = 5;
 }
+
+message Collaborator {
+    uint64 user_id = 1;
+    repeated CollaboratorWorktree worktrees = 2;
+    bool is_online = 3;
+}
+
+message CollaboratorWorktree {
+    bool is_shared = 1;
+    repeated uint64 participants = 2;
+}

zrpc/src/proto.rs 🔗

@@ -131,6 +131,8 @@ messages!(
     GetChannelMessagesResponse,
     GetChannels,
     GetChannelsResponse,
+    GetCollaborators,
+    GetCollaboratorsResponse,
     GetUsers,
     GetUsersResponse,
     JoinChannel,
@@ -168,6 +170,7 @@ request_messages!(
     (UnshareWorktree, Ack),
     (SendChannelMessage, SendChannelMessageResponse),
     (GetChannelMessages, GetChannelMessagesResponse),
+    (GetCollaborators, GetCollaboratorsResponse),
 );
 
 entity_messages!(