Start reworking `join_project` to use the database

Antonio Scandurra created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql |   3 
crates/collab/src/db.rs                                        | 152 +++
crates/collab/src/rpc.rs                                       |  43 
3 files changed, 164 insertions(+), 34 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -62,6 +62,9 @@ CREATE TABLE "worktrees" (
     "id" INTEGER NOT NULL,
     "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
     "root_name" VARCHAR NOT NULL,
+    "visible" BOOL NOT NULL,
+    "scan_id" INTEGER NOT NULL,
+    "is_complete" BOOL NOT NULL,
     PRIMARY KEY(project_id, id)
 );
 CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id");

crates/collab/src/db.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{Error, Result};
 use anyhow::anyhow;
 use axum::http::StatusCode;
-use collections::HashMap;
+use collections::{BTreeMap, HashMap, HashSet};
 use futures::{future::BoxFuture, FutureExt, StreamExt};
 use rpc::{proto, ConnectionId};
 use serde::{Deserialize, Serialize};
@@ -10,7 +10,11 @@ use sqlx::{
     types::Uuid,
     FromRow,
 };
-use std::{future::Future, path::Path, time::Duration};
+use std::{
+    future::Future,
+    path::{Path, PathBuf},
+    time::Duration,
+};
 use time::{OffsetDateTime, PrimitiveDateTime};
 
 #[cfg(test)]
@@ -1404,13 +1408,26 @@ where
 
     pub async fn share_project(
         &self,
-        room_id: RoomId,
-        user_id: UserId,
+        expected_room_id: RoomId,
         connection_id: ConnectionId,
         worktrees: &[proto::WorktreeMetadata],
     ) -> Result<(ProjectId, proto::Room)> {
         self.transact(|mut tx| async move {
-            let project_id = sqlx::query_scalar(
+            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
+                "
+                SELECT room_id, user_id
+                FROM room_participants
+                WHERE answering_connection_id = $1
+                ",
+            )
+            .bind(connection_id.0 as i32)
+            .fetch_one(&mut tx)
+            .await?;
+            if room_id != expected_room_id {
+                return Err(anyhow!("shared project on unexpected room"))?;
+            }
+
+            let project_id: ProjectId = sqlx::query_scalar(
                 "
                 INSERT INTO projects (room_id, host_user_id, host_connection_id)
                 VALUES ($1, $2, $3)
@@ -1421,8 +1438,7 @@ where
             .bind(user_id)
             .bind(connection_id.0 as i32)
             .fetch_one(&mut tx)
-            .await
-            .map(ProjectId)?;
+            .await?;
 
             for worktree in worktrees {
                 sqlx::query(
@@ -1536,6 +1552,111 @@ where
         .await
     }
 
+    pub async fn join_project(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<(Project, i32)> {
+        self.transact(|mut tx| async move {
+            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
+                "
+                SELECT room_id, user_id
+                FROM room_participants
+                WHERE answering_connection_id = $1
+                ",
+            )
+            .bind(connection_id.0 as i32)
+            .fetch_one(&mut tx)
+            .await?;
+
+            // Ensure project id was shared on this room.
+            sqlx::query(
+                "
+                SELECT 1
+                FROM projects
+                WHERE project_id = $1 AND room_id = $2
+                ",
+            )
+            .bind(project_id)
+            .bind(room_id)
+            .fetch_one(&mut tx)
+            .await?;
+
+            let replica_ids = sqlx::query_scalar::<_, i32>(
+                "
+                SELECT replica_id
+                FROM project_collaborators
+                WHERE project_id = $1
+                ",
+            )
+            .bind(project_id)
+            .fetch_all(&mut tx)
+            .await?;
+            let replica_ids = HashSet::from_iter(replica_ids);
+            let mut replica_id = 1;
+            while replica_ids.contains(&replica_id) {
+                replica_id += 1;
+            }
+
+            sqlx::query(
+                "
+                INSERT INTO project_collaborators (
+                    project_id,
+                    connection_id,
+                    user_id,
+                    replica_id,
+                    is_host
+                )
+                VALUES ($1, $2, $3, $4, $5)
+                ",
+            )
+            .bind(project_id)
+            .bind(connection_id.0 as i32)
+            .bind(user_id)
+            .bind(replica_id)
+            .bind(false)
+            .execute(&mut tx)
+            .await?;
+
+            tx.commit().await?;
+            todo!()
+        })
+        .await
+        // sqlx::query(
+        //     "
+        //     SELECT replica_id
+        //     FROM project_collaborators
+        //     WHERE project_id = $
+        //     ",
+        // )
+        // .bind(project_id)
+        // .bind(connection_id.0 as i32)
+        // .bind(user_id)
+        // .bind(0)
+        // .bind(true)
+        // .execute(&mut tx)
+        // .await?;
+        // sqlx::query(
+        //     "
+        //     INSERT INTO project_collaborators (
+        //         project_id,
+        //         connection_id,
+        //         user_id,
+        //         replica_id,
+        //         is_host
+        //     )
+        //     VALUES ($1, $2, $3, $4, $5)
+        //     ",
+        // )
+        // .bind(project_id)
+        // .bind(connection_id.0 as i32)
+        // .bind(user_id)
+        // .bind(0)
+        // .bind(true)
+        // .execute(&mut tx)
+        // .await?;
+    }
+
     pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
         todo!()
         // test_support!(self, {
@@ -1967,11 +2088,11 @@ pub struct Room {
 }
 
 id_type!(ProjectId);
-#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 pub struct Project {
     pub id: ProjectId,
-    pub host_user_id: UserId,
-    pub unregistered: bool,
+    pub collaborators: Vec<ProjectCollaborator>,
+    pub worktrees: BTreeMap<u64, Worktree>,
+    pub language_servers: Vec<proto::LanguageServer>,
 }
 
 #[derive(Clone, Debug, Default, FromRow, PartialEq)]
@@ -1983,6 +2104,17 @@ pub struct ProjectCollaborator {
     pub is_host: bool,
 }
 
+#[derive(Default)]
+pub struct Worktree {
+    pub abs_path: PathBuf,
+    pub root_name: String,
+    pub visible: bool,
+    pub entries: BTreeMap<u64, proto::Entry>,
+    pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
+    pub scan_id: u64,
+    pub is_complete: bool,
+}
+
 pub struct LeftProject {
     pub id: ProjectId,
     pub host_user_id: UserId,

crates/collab/src/rpc.rs 🔗

@@ -862,7 +862,6 @@ impl Server {
             .db
             .share_project(
                 RoomId::from_proto(request.payload.room_id),
-                request.sender_user_id,
                 request.sender_connection_id,
                 &request.payload.worktrees,
             )
@@ -942,15 +941,21 @@ impl Server {
 
         tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project");
 
-        let mut store = self.store().await;
-        let (project, replica_id) = store.join_project(request.sender_connection_id, project_id)?;
-        let peer_count = project.guests.len();
-        let mut collaborators = Vec::with_capacity(peer_count);
-        collaborators.push(proto::Collaborator {
-            peer_id: project.host_connection_id.0,
-            replica_id: 0,
-            user_id: project.host.user_id.to_proto(),
-        });
+        let (project, replica_id) = self
+            .app_state
+            .db
+            .join_project(project_id, request.sender_connection_id)
+            .await?;
+
+        let collaborators = project
+            .collaborators
+            .iter()
+            .map(|collaborator| proto::Collaborator {
+                peer_id: collaborator.connection_id as u32,
+                replica_id: collaborator.replica_id as u32,
+                user_id: collaborator.user_id.to_proto(),
+            })
+            .collect::<Vec<_>>();
         let worktrees = project
             .worktrees
             .iter()
@@ -962,22 +967,12 @@ impl Server {
             })
             .collect::<Vec<_>>();
 
-        // Add all guests other than the requesting user's own connections as collaborators
-        for (guest_conn_id, guest) in &project.guests {
-            if request.sender_connection_id != *guest_conn_id {
-                collaborators.push(proto::Collaborator {
-                    peer_id: guest_conn_id.0,
-                    replica_id: guest.replica_id as u32,
-                    user_id: guest.user_id.to_proto(),
-                });
-            }
-        }
-
-        for conn_id in project.connection_ids() {
-            if conn_id != request.sender_connection_id {
+        for collaborator in &project.collaborators {
+            let connection_id = ConnectionId(collaborator.connection_id as u32);
+            if connection_id != request.sender_connection_id {
                 self.peer
                     .send(
-                        conn_id,
+                        connection_id,
                         proto::AddProjectCollaborator {
                             project_id: project_id.to_proto(),
                             collaborator: Some(proto::Collaborator {