Start moving towards using sea-query to construct queries

Antonio Scandurra created

Change summary

Cargo.lock                     |  34 +++++++++
Cargo.toml                     |   1 
crates/collab/Cargo.toml       |  15 +--
crates/collab/src/db.rs        | 134 ++++++++++++++++++++++-------------
crates/collab/src/db/schema.rs |  43 +++++++++++
crates/collab/src/db/tests.rs  |   2 
crates/collab/src/main.rs      |   2 
7 files changed, 168 insertions(+), 63 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1065,6 +1065,8 @@ dependencies = [
  "reqwest",
  "rpc",
  "scrypt",
+ "sea-query",
+ "sea-query-binder",
  "serde",
  "serde_json",
  "settings",
@@ -5121,6 +5123,38 @@ dependencies = [
  "untrusted",
 ]
 
+[[package]]
+name = "sea-query"
+version = "0.27.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a4f0fc4d8e44e1d51c739a68d336252a18bc59553778075d5e32649be6ec92ed"
+dependencies = [
+ "sea-query-derive",
+]
+
+[[package]]
+name = "sea-query-binder"
+version = "0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c2585b89c985cfacfe0ec9fc9e7bb055b776c1a2581c4e3c6185af2b8bf8865"
+dependencies = [
+ "sea-query",
+ "sqlx",
+]
+
+[[package]]
+name = "sea-query-derive"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34cdc022b4f606353fe5dc85b09713a04e433323b70163e81513b141c6ae6eb5"
+dependencies = [
+ "heck 0.3.3",
+ "proc-macro2",
+ "quote",
+ "syn",
+ "thiserror",
+]
+
 [[package]]
 name = "seahash"
 version = "4.1.0"

Cargo.toml 🔗

@@ -67,6 +67,7 @@ rand = { version = "0.8" }
 [patch.crates-io]
 tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "366210ae925d7ea0891bc7a0c738f60c77c04d7b" }
 async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" }
+sqlx = { git = "https://github.com/launchbadge/sqlx", rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" }
 
 # TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457
 cocoa = { git = "https://github.com/servo/core-foundation-rs", rev = "079665882507dd5e2ff77db3de5070c1f6c0fb85" }

crates/collab/Cargo.toml 🔗

@@ -36,9 +36,12 @@ prometheus = "0.13"
 rand = "0.8"
 reqwest = { version = "0.11", features = ["json"], optional = true }
 scrypt = "0.7"
+sea-query = { version = "0.27", features = ["derive"] }
+sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] }
 serde = { version = "1.0", features = ["derive", "rc"] }
 serde_json = "1.0"
 sha-1 = "0.9"
+sqlx = { version = "0.6", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"] }
 time = { version = "0.3", features = ["serde", "serde-well-known"] }
 tokio = { version = "1", features = ["full"] }
 tokio-tungstenite = "0.17"
@@ -49,11 +52,6 @@ tracing = "0.1.34"
 tracing-log = "0.1.3"
 tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] }
 
-[dependencies.sqlx]
-git = "https://github.com/launchbadge/sqlx"
-rev = "4b7053807c705df312bcb9b6281e184bf7534eb3"
-features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"]
-
 [dev-dependencies]
 collections = { path = "../collections", features = ["test-support"] }
 gpui = { path = "../gpui", features = ["test-support"] }
@@ -76,13 +74,10 @@ env_logger = "0.9"
 log = { version = "0.4.16", features = ["kv_unstable_serde"] }
 util = { path = "../util" }
 lazy_static = "1.4"
+sea-query-binder = { version = "0.2", features = ["sqlx-sqlite"] }
 serde_json = { version = "1.0", features = ["preserve_order"] }
+sqlx = { version = "0.6", features = ["sqlite"] }
 unindent = "0.1"
 
-[dev-dependencies.sqlx]
-git = "https://github.com/launchbadge/sqlx"
-rev = "4b7053807c705df312bcb9b6281e184bf7534eb3"
-features = ["sqlite"]
-
 [features]
 seed-support = ["clap", "lipsum", "reqwest"]

crates/collab/src/db.rs 🔗

@@ -1,3 +1,7 @@
+mod schema;
+#[cfg(test)]
+mod tests;
+
 use crate::{Error, Result};
 use anyhow::anyhow;
 use axum::http::StatusCode;
@@ -5,6 +9,8 @@ use collections::{BTreeMap, HashMap, HashSet};
 use dashmap::DashMap;
 use futures::{future::BoxFuture, FutureExt, StreamExt};
 use rpc::{proto, ConnectionId};
+use sea_query::{Expr, Query};
+use sea_query_binder::SqlxBinder;
 use serde::{Deserialize, Serialize};
 use sqlx::{
     migrate::{Migrate as _, Migration, MigrationSource},
@@ -89,6 +95,23 @@ impl BeginTransaction for Db<sqlx::Sqlite> {
     }
 }
 
+pub trait BuildQuery {
+    fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues);
+}
+
+impl BuildQuery for Db<sqlx::Postgres> {
+    fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
+        query.build_sqlx(sea_query::PostgresQueryBuilder)
+    }
+}
+
+#[cfg(test)]
+impl BuildQuery for Db<sqlx::Sqlite> {
+    fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
+        query.build_sqlx(sea_query::SqliteQueryBuilder)
+    }
+}
+
 pub trait RowsAffected {
     fn rows_affected(&self) -> u64;
 }
@@ -595,10 +618,11 @@ impl Db<sqlx::Postgres> {
 
 impl<D> Db<D>
 where
-    Self: BeginTransaction<Database = D>,
+    Self: BeginTransaction<Database = D> + BuildQuery,
     D: sqlx::Database + sqlx::migrate::MigrateDatabase,
     D::Connection: sqlx::migrate::Migrate,
     for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
+    for<'a> sea_query_binder::SqlxValues: sqlx::IntoArguments<'a, D>,
     for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
     for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
     D::QueryResult: RowsAffected,
@@ -1537,63 +1561,66 @@ where
         worktrees: &[proto::WorktreeMetadata],
     ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
         self.transact(|mut tx| async move {
-            let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
-                "
-                SELECT room_id, user_id
-                FROM room_participants
-                WHERE answering_connection_id = $1
-                ",
-            )
-            .bind(connection_id.0 as i32)
-            .fetch_one(&mut tx)
-            .await?;
+            let (sql, values) = self.build_query(
+                Query::select()
+                    .columns([
+                        schema::room_participant::Definition::RoomId,
+                        schema::room_participant::Definition::UserId,
+                    ])
+                    .from(schema::room_participant::Definition::Table)
+                    .and_where(
+                        Expr::col(schema::room_participant::Definition::AnsweringConnectionId)
+                            .eq(connection_id.0),
+                    ),
+            );
+            let (room_id, user_id) = sqlx::query_as_with::<_, (RoomId, UserId), _>(&sql, values)
+                .fetch_one(&mut tx)
+                .await?;
             if room_id != expected_room_id {
                 return Err(anyhow!("shared project on unexpected room"))?;
             }
 
-            let project_id: ProjectId = sqlx::query_scalar(
-                "
-                INSERT INTO projects (room_id, host_user_id, host_connection_id)
-                VALUES ($1, $2, $3)
-                RETURNING id
-                ",
-            )
-            .bind(room_id)
-            .bind(user_id)
-            .bind(connection_id.0 as i32)
-            .fetch_one(&mut tx)
-            .await?;
+            let (sql, values) = self.build_query(
+                Query::insert()
+                    .into_table(schema::project::Definition::Table)
+                    .columns([
+                        schema::project::Definition::RoomId,
+                        schema::project::Definition::HostUserId,
+                        schema::project::Definition::HostConnectionId,
+                    ])
+                    .values_panic([room_id.into(), user_id.into(), connection_id.0.into()])
+                    .returning_col(schema::project::Definition::Id),
+            );
+            let project_id: ProjectId = sqlx::query_scalar_with(&sql, values)
+                .fetch_one(&mut tx)
+                .await?;
 
             if !worktrees.is_empty() {
-                let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
-                params.pop();
-                let query = format!(
-                    "
-                    INSERT INTO worktrees (
-                        project_id,
-                        id,
-                        root_name,
-                        abs_path,
-                        visible,
-                        scan_id,
-                        is_complete
-                    )
-                    VALUES {params}
-                    "
-                );
-
-                let mut query = sqlx::query(&query);
+                let mut query = Query::insert()
+                    .into_table(schema::worktree::Definition::Table)
+                    .columns([
+                        schema::worktree::Definition::ProjectId,
+                        schema::worktree::Definition::Id,
+                        schema::worktree::Definition::RootName,
+                        schema::worktree::Definition::AbsPath,
+                        schema::worktree::Definition::Visible,
+                        schema::worktree::Definition::ScanId,
+                        schema::worktree::Definition::IsComplete,
+                    ])
+                    .to_owned();
                 for worktree in worktrees {
-                    query = query
-                        .bind(project_id)
-                        .bind(worktree.id as i32)
-                        .bind(&worktree.root_name)
-                        .bind(&worktree.abs_path)
-                        .bind(worktree.visible)
-                        .bind(0)
-                        .bind(false);
+                    query.values_panic([
+                        project_id.into(),
+                        worktree.id.into(),
+                        worktree.root_name.clone().into(),
+                        worktree.abs_path.clone().into(),
+                        worktree.visible.into(),
+                        0.into(),
+                        false.into(),
+                    ]);
                 }
-                query.execute(&mut tx).await?;
+                let (sql, values) = self.build_query(&query);
+                sqlx::query_with(&sql, values).execute(&mut tx).await?;
             }
 
             sqlx::query(
@@ -2648,6 +2675,12 @@ macro_rules! id_type {
                 self.0.fmt(f)
             }
         }
+
+        impl From<$name> for sea_query::Value {
+            fn from(value: $name) -> Self {
+                sea_query::Value::Int(Some(value.0))
+            }
+        }
     };
 }
 
@@ -2692,6 +2725,7 @@ id_type!(WorktreeId);
 #[derive(Clone, Debug, Default, FromRow, PartialEq)]
 struct WorktreeRow {
     pub id: WorktreeId,
+    pub project_id: ProjectId,
     pub abs_path: String,
     pub root_name: String,
     pub visible: bool,

crates/collab/src/db/schema.rs 🔗

@@ -0,0 +1,43 @@
+pub mod project {
+    use sea_query::Iden;
+
+    #[derive(Iden)]
+    pub enum Definition {
+        #[iden = "projects"]
+        Table,
+        Id,
+        RoomId,
+        HostUserId,
+        HostConnectionId,
+    }
+}
+
+pub mod worktree {
+    use sea_query::Iden;
+
+    #[derive(Iden)]
+    pub enum Definition {
+        #[iden = "worktrees"]
+        Table,
+        Id,
+        ProjectId,
+        AbsPath,
+        RootName,
+        Visible,
+        ScanId,
+        IsComplete,
+    }
+}
+
+pub mod room_participant {
+    use sea_query::Iden;
+
+    #[derive(Iden)]
+    pub enum Definition {
+        #[iden = "room_participants"]
+        Table,
+        RoomId,
+        UserId,
+        AnsweringConnectionId,
+    }
+}

crates/collab/src/db_tests.rs → crates/collab/src/db/tests.rs 🔗

@@ -1,4 +1,4 @@
-use super::db::*;
+use super::*;
 use gpui::executor::{Background, Deterministic};
 use std::sync::Arc;