WIP: add more trait bounds

Antonio Scandurra created

Change summary

crates/collab/src/db.rs | 119 ++++++++++++++++++++++++------------------
1 file changed, 67 insertions(+), 52 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -7,18 +7,11 @@ use serde::{Deserialize, Serialize};
 use sqlx::{
     migrate::{Migrate as _, Migration, MigrationSource},
     types::Uuid,
-    Encode, FromRow, QueryBuilder,
+    FromRow, QueryBuilder,
 };
 use std::{cmp, ops::Range, path::Path, time::Duration};
 use time::{OffsetDateTime, PrimitiveDateTime};
 
-#[cfg(any(test, debug_assertions))]
-pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
-    Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
-
-#[cfg(not(any(test, debug_assertions)))]
-pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
-
 pub struct Db<D: sqlx::Database> {
     pool: sqlx::Pool<D>,
 }
@@ -37,11 +30,27 @@ macro_rules! test_support {
     }};
 }
 
+trait RowsAffected {
+    fn rows_affected(&self) -> u64;
+}
+
+impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
+    fn rows_affected(&self) -> u64 {
+        self.rows_affected()
+    }
+}
+
+impl RowsAffected for sqlx::postgres::PgQueryResult {
+    fn rows_affected(&self) -> u64 {
+        self.rows_affected()
+    }
+}
+
 impl Db<sqlx::Sqlite> {
     #[cfg(test)]
-    pub async fn sqlite(url: &str) -> Result<Self> {
+    pub async fn sqlite(url: &str, max_connections: u32) -> Result<Self> {
         let pool = sqlx::sqlite::SqlitePoolOptions::new()
-            .max_connections(1)
+            .max_connections(max_connections)
             .connect(url)
             .await?;
         Ok(Self { pool })
@@ -49,9 +58,9 @@ impl Db<sqlx::Sqlite> {
 }
 
 impl Db<sqlx::Postgres> {
-    pub async fn postgres(url: &str, max_connection: u32) -> Result<Self> {
+    pub async fn postgres(url: &str, max_connections: u32) -> Result<Self> {
         let pool = sqlx::postgres::PgPoolOptions::new()
-            .max_connections(1)
+            .max_connections(max_connections)
             .connect(url)
             .await?;
         Ok(Self { pool })
@@ -62,19 +71,34 @@ impl<D> Db<D>
 where
     D: sqlx::Database + sqlx::migrate::MigrateDatabase,
     for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
-    D: for<'r> sqlx::database::HasValueRef<'r>,
-    D: for<'r> sqlx::database::HasArguments<'r>,
-    for<'a> &'a mut D::Connection: sqlx::Executor<'a>,
+    D: for<'a> sqlx::database::HasValueRef<'a>,
+    D: for<'a> sqlx::database::HasArguments<'a>,
+    D::Connection: sqlx::migrate::Migrate,
+    for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
+    for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
+    D::QueryResult: RowsAffected,
     String: sqlx::Type<D>,
     i32: sqlx::Type<D>,
+    i64: sqlx::Type<D>,
     bool: sqlx::Type<D>,
     str: sqlx::Type<D>,
-    for<'a> str: sqlx::Encode<'a, D>,
-    for<'a> &'a str: sqlx::Encode<'a, D>,
-    for<'a> String: sqlx::Encode<'a, D>,
-    for<'a> i32: sqlx::Encode<'a, D>,
-    for<'a> bool: sqlx::Encode<'a, D>,
-    for<'a> Option<String>: sqlx::Encode<'a, D>,
+    Uuid: sqlx::Type<D>,
+    sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
+    OffsetDateTime: sqlx::Type<D>,
+    PrimitiveDateTime: sqlx::Type<D>,
+    usize: sqlx::ColumnIndex<D::Row>,
+    for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
+    for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
+    for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
 {
     pub async fn migrate(
         &self,
@@ -1813,39 +1837,38 @@ mod test {
     use std::sync::Arc;
 
     pub struct TestDb {
-        pub db: Option<Arc<Db>>,
+        pub db: Option<Arc<Db<Sqlite>>>,
         pub url: String,
     }
 
     impl TestDb {
         #[allow(clippy::await_holding_lock)]
         pub async fn real() -> Self {
-            eprintln!("creating database...");
-            let start = std::time::Instant::now();
-            let mut rng = StdRng::from_entropy();
-            let url = format!("/tmp/zed-test-{}", rng.gen::<u128>());
-            Sqlite::create_database(&url).await.unwrap();
-            let db = Db::new(&url, 5).await.unwrap();
-            db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
-                .await
-                .unwrap();
-
-            eprintln!("created database: {:?}", start.elapsed());
-            Self {
-                db: Some(Arc::new(db)),
-                url,
-            }
+            todo!()
+            // eprintln!("creating database...");
+            // let start = std::time::Instant::now();
+            // let mut rng = StdRng::from_entropy();
+            // let url = format!("/tmp/zed-test-{}", rng.gen::<u128>());
+            // Sqlite::create_database(&url).await.unwrap();
+            // let db = Db::new(&url, 5).await.unwrap();
+            // db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
+            //     .await
+            //     .unwrap();
+
+            // eprintln!("created database: {:?}", start.elapsed());
+            // Self {
+            //     db: Some(Arc::new(db)),
+            //     url,
+            // }
         }
 
         pub async fn fake(background: Arc<Background>) -> Self {
             let start = std::time::Instant::now();
             let mut rng = StdRng::from_entropy();
-            let url = format!("/tmp/zed-test-{}", rng.gen::<u128>());
-            Sqlite::create_database(&url).await.unwrap();
-            let db = Db::new(&url, 5).await.unwrap();
-            db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
-                .await
-                .unwrap();
+            let url = format!("file:db-{}?mode=memory&cache=shared", rng.gen::<u128>());
+            let db = Db::sqlite(&url, 5).await.unwrap();
+            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
+            db.migrate(Path::new(migrations_path), false).await.unwrap();
 
             Self {
                 db: Some(Arc::new(db)),
@@ -1853,16 +1876,8 @@ mod test {
             }
         }
 
-        pub fn db(&self) -> &Arc<Db> {
+        pub fn db(&self) -> &Arc<Db<Sqlite>> {
             self.db.as_ref().unwrap()
         }
     }
-
-    impl Drop for TestDb {
-        fn drop(&mut self) {
-            if let Some(db) = self.db.take() {
-                std::fs::remove_file(&self.url).ok();
-            }
-        }
-    }
 }