Move `Store::update_diagnostic_summary` to `Db`

Antonio Scandurra created

Change summary

crates/collab/src/db.rs        | 115 +++++++++++++++++++++++++++--------
crates/collab/src/rpc.rs       |  22 ++----
crates/collab/src/rpc/store.rs |  25 -------
3 files changed, 97 insertions(+), 65 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -1724,6 +1724,81 @@ where
         .await
     }
 
+    pub async fn update_diagnostic_summary(
+        &self,
+        update: &proto::UpdateDiagnosticSummary,
+        connection_id: ConnectionId,
+    ) -> Result<Vec<ConnectionId>> {
+        self.transact(|mut tx| async {
+            let project_id = ProjectId::from_proto(update.project_id);
+            let worktree_id = WorktreeId::from_proto(update.worktree_id);
+            let summary = update
+                .summary
+                .as_ref()
+                .ok_or_else(|| anyhow!("invalid summary"))?;
+
+            // Ensure the update comes from the host.
+            sqlx::query(
+                "
+                SELECT 1
+                FROM projects
+                WHERE id = $1 AND host_connection_id = $2
+                ",
+            )
+            .bind(project_id)
+            .bind(connection_id.0 as i32)
+            .fetch_one(&mut tx)
+            .await?;
+
+            // Update summary.
+            sqlx::query(
+                "
+                INSERT INTO worktree_diagnostic_summaries (
+                    project_id,
+                    worktree_id,
+                    path,
+                    language_server_id,
+                    error_count,
+                    warning_count
+                )
+                VALUES ($1, $2, $3, $4, $5, $6)
+                ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
+                    language_server_id = excluded.language_server_id,
+                    error_count = excluded.error_count, 
+                    warning_count = excluded.warning_count
+                ",
+            )
+            .bind(project_id)
+            .bind(worktree_id)
+            .bind(&summary.path)
+            .bind(summary.language_server_id as i64)
+            .bind(summary.error_count as i32)
+            .bind(summary.warning_count as i32)
+            .execute(&mut tx)
+            .await?;
+
+            let connection_ids = sqlx::query_scalar::<_, i32>(
+                "
+                SELECT connection_id
+                FROM project_collaborators
+                WHERE project_id = $1 AND connection_id != $2
+                ",
+            )
+            .bind(project_id)
+            .bind(connection_id.0 as i32)
+            .fetch_all(&mut tx)
+            .await?;
+
+            tx.commit().await?;
+
+            Ok(connection_ids
+                .into_iter()
+                .map(|connection_id| ConnectionId(connection_id as u32))
+                .collect())
+        })
+        .await
+    }
+
     pub async fn join_project(
         &self,
         project_id: ProjectId,
@@ -1830,25 +1905,17 @@ where
                 })
                 .collect::<BTreeMap<_, _>>();
 
-            let mut params = "?,".repeat(worktrees.len());
-            if !worktrees.is_empty() {
-                params.pop();
-            }
-
             // Populate worktree entries.
             {
-                let query = format!(
+                let mut entries = sqlx::query_as::<_, WorktreeEntry>(
                     "
-                        SELECT *
-                        FROM worktree_entries
-                        WHERE project_id = ? AND worktree_id IN ({params})
+                    SELECT *
+                    FROM worktree_entries
+                    WHERE project_id = $1
                     ",
-                );
-                let mut entries = sqlx::query_as::<_, WorktreeEntry>(&query).bind(project_id);
-                for worktree_id in worktrees.keys() {
-                    entries = entries.bind(*worktree_id);
-                }
-                let mut entries = entries.fetch(&mut tx);
+                )
+                .bind(project_id)
+                .fetch(&mut tx);
                 while let Some(entry) = entries.next().await {
                     let entry = entry?;
                     if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
@@ -1870,19 +1937,15 @@ where
 
             // Populate worktree diagnostic summaries.
             {
-                let query = format!(
+                let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
                     "
-                        SELECT *
-                        FROM worktree_diagnostic_summaries
-                        WHERE project_id = $1 AND worktree_id IN ({params})
+                    SELECT *
+                    FROM worktree_diagnostic_summaries
+                    WHERE project_id = $1
                     ",
-                );
-                let mut summaries =
-                    sqlx::query_as::<_, WorktreeDiagnosticSummary>(&query).bind(project_id);
-                for worktree_id in worktrees.keys() {
-                    summaries = summaries.bind(*worktree_id);
-                }
-                let mut summaries = summaries.fetch(&mut tx);
+                )
+                .bind(project_id)
+                .fetch(&mut tx);
                 while let Some(summary) = summaries.next().await {
                     let summary = summary?;
                     if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {

crates/collab/src/rpc.rs 🔗

@@ -1103,7 +1103,7 @@ impl Server {
         request: Message<proto::UpdateWorktree>,
         response: Response<proto::UpdateWorktree>,
     ) -> Result<()> {
-        let connection_ids = self
+        let guest_connection_ids = self
             .app_state
             .db
             .update_worktree(&request.payload, request.sender_connection_id)
@@ -1111,7 +1111,7 @@ impl Server {
 
         broadcast(
             request.sender_connection_id,
-            connection_ids,
+            guest_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,
@@ -1128,21 +1128,15 @@ impl Server {
         self: Arc<Server>,
         request: Message<proto::UpdateDiagnosticSummary>,
     ) -> Result<()> {
-        let summary = request
-            .payload
-            .summary
-            .clone()
-            .ok_or_else(|| anyhow!("invalid summary"))?;
-        let receiver_ids = self.store().await.update_diagnostic_summary(
-            ProjectId::from_proto(request.payload.project_id),
-            request.payload.worktree_id,
-            request.sender_connection_id,
-            summary,
-        )?;
+        let guest_connection_ids = self
+            .app_state
+            .db
+            .update_diagnostic_summary(&request.payload, request.sender_connection_id)
+            .await?;
 
         broadcast(
             request.sender_connection_id,
-            receiver_ids,
+            guest_connection_ids,
             |connection_id| {
                 self.peer.forward_send(
                     request.sender_connection_id,

crates/collab/src/rpc/store.rs 🔗

@@ -251,31 +251,6 @@ impl Store {
         }
     }
 
-    pub fn update_diagnostic_summary(
-        &mut self,
-        project_id: ProjectId,
-        worktree_id: u64,
-        connection_id: ConnectionId,
-        summary: proto::DiagnosticSummary,
-    ) -> Result<Vec<ConnectionId>> {
-        let project = self
-            .projects
-            .get_mut(&project_id)
-            .ok_or_else(|| anyhow!("no such project"))?;
-        if project.host_connection_id == connection_id {
-            let worktree = project
-                .worktrees
-                .get_mut(&worktree_id)
-                .ok_or_else(|| anyhow!("no such worktree"))?;
-            worktree
-                .diagnostic_summaries
-                .insert(summary.path.clone().into(), summary);
-            return Ok(project.connection_ids());
-        }
-
-        Err(anyhow!("no such worktree"))?
-    }
-
     pub fn start_language_server(
         &mut self,
         project_id: ProjectId,