Fix retrieving recent channel messages

Antonio Scandurra created

Change summary

server/src/db.rs  | 126 +++++++++++++++++++++++++++++-------------------
server/src/rpc.rs |  19 +++----
2 files changed, 84 insertions(+), 61 deletions(-)

Detailed changes

server/src/db.rs 🔗

@@ -21,8 +21,9 @@ macro_rules! test_support {
     }};
 }
 
+#[derive(Clone)]
 pub struct Db {
-    db: sqlx::PgPool,
+    pool: sqlx::PgPool,
     test_mode: bool,
 }
 
@@ -57,13 +58,13 @@ pub struct ChannelMessage {
 
 impl Db {
     pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
-        let db = DbOptions::new()
+        let pool = DbOptions::new()
             .max_connections(max_connections)
             .connect(url)
             .await
             .context("failed to connect to postgres database")?;
         Ok(Self {
-            db,
+            pool,
             test_mode: false,
         })
     }
@@ -86,7 +87,7 @@ impl Db {
                 .bind(github_login)
                 .bind(email_address)
                 .bind(about)
-                .fetch_one(&self.db)
+                .fetch_one(&self.pool)
                 .await
                 .map(SignupId)
         })
@@ -95,7 +96,7 @@ impl Db {
     pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
         test_support!(self, {
             let query = "SELECT * FROM users ORDER BY github_login ASC";
-            sqlx::query_as(query).fetch_all(&self.db).await
+            sqlx::query_as(query).fetch_all(&self.pool).await
         })
     }
 
@@ -104,7 +105,7 @@ impl Db {
             let query = "DELETE FROM signups WHERE id = $1";
             sqlx::query(query)
                 .bind(id.0)
-                .execute(&self.db)
+                .execute(&self.pool)
                 .await
                 .map(drop)
         })
@@ -122,7 +123,7 @@ impl Db {
             sqlx::query_scalar(query)
                 .bind(github_login)
                 .bind(admin)
-                .fetch_one(&self.db)
+                .fetch_one(&self.pool)
                 .await
                 .map(UserId)
         })
@@ -131,7 +132,7 @@ impl Db {
     pub async fn get_all_users(&self) -> Result<Vec<User>> {
         test_support!(self, {
             let query = "SELECT * FROM users ORDER BY github_login ASC";
-            sqlx::query_as(query).fetch_all(&self.db).await
+            sqlx::query_as(query).fetch_all(&self.pool).await
         })
     }
 
@@ -159,7 +160,7 @@ impl Db {
             sqlx::query_as(query)
                 .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
                 .bind(requester_id)
-                .fetch_all(&self.db)
+                .fetch_all(&self.pool)
                 .await
         })
     }
@@ -169,7 +170,7 @@ impl Db {
             let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
             sqlx::query_as(query)
                 .bind(github_login)
-                .fetch_optional(&self.db)
+                .fetch_optional(&self.pool)
                 .await
         })
     }
@@ -180,7 +181,7 @@ impl Db {
             sqlx::query(query)
                 .bind(is_admin)
                 .bind(id.0)
-                .execute(&self.db)
+                .execute(&self.pool)
                 .await
                 .map(drop)
         })
@@ -191,7 +192,7 @@ impl Db {
             let query = "DELETE FROM users WHERE id = $1;";
             sqlx::query(query)
                 .bind(id.0)
-                .execute(&self.db)
+                .execute(&self.pool)
                 .await
                 .map(drop)
         })
@@ -212,7 +213,7 @@ impl Db {
             sqlx::query(query)
                 .bind(user_id.0)
                 .bind(access_token_hash)
-                .execute(&self.db)
+                .execute(&self.pool)
                 .await
                 .map(drop)
         })
@@ -223,7 +224,7 @@ impl Db {
             let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
             sqlx::query_scalar(query)
                 .bind(user_id.0)
-                .fetch_all(&self.db)
+                .fetch_all(&self.pool)
                 .await
         })
     }
@@ -241,7 +242,7 @@ impl Db {
             sqlx::query_scalar(query)
                 .bind(name)
                 .bind(slug)
-                .fetch_one(&self.db)
+                .fetch_one(&self.pool)
                 .await
                 .map(OrgId)
         })
@@ -263,7 +264,7 @@ impl Db {
                 .bind(org_id.0)
                 .bind(user_id.0)
                 .bind(is_admin)
-                .execute(&self.db)
+                .execute(&self.pool)
                 .await
                 .map(drop)
         })
@@ -282,7 +283,7 @@ impl Db {
             sqlx::query_scalar(query)
                 .bind(org_id.0)
                 .bind(name)
-                .fetch_one(&self.db)
+                .fetch_one(&self.pool)
                 .await
                 .map(ChannelId)
         })
@@ -301,7 +302,7 @@ impl Db {
             ";
             sqlx::query_as(query)
                 .bind(user_id.0)
-                .fetch_all(&self.db)
+                .fetch_all(&self.pool)
                 .await
         })
     }
@@ -321,7 +322,7 @@ impl Db {
             sqlx::query_scalar::<_, i32>(query)
                 .bind(user_id.0)
                 .bind(channel_id.0)
-                .fetch_optional(&self.db)
+                .fetch_optional(&self.pool)
                 .await
                 .map(|e| e.is_some())
         })
@@ -343,7 +344,7 @@ impl Db {
                 .bind(channel_id.0)
                 .bind(user_id.0)
                 .bind(is_admin)
-                .execute(&self.db)
+                .execute(&self.pool)
                 .await
                 .map(drop)
         })
@@ -369,7 +370,7 @@ impl Db {
                 .bind(sender_id.0)
                 .bind(body)
                 .bind(timestamp)
-                .fetch_one(&self.db)
+                .fetch_one(&self.pool)
                 .await
                 .map(MessageId)
         })
@@ -382,36 +383,23 @@ impl Db {
     ) -> Result<Vec<ChannelMessage>> {
         test_support!(self, {
             let query = r#"
-                SELECT
-                    id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
-                FROM
-                    channel_messages
-                WHERE
-                    channel_id = $1
-                LIMIT $2
+                SELECT * FROM (
+                    SELECT
+                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
+                    FROM
+                        channel_messages
+                    WHERE
+                        channel_id = $1
+                    ORDER BY id DESC
+                    LIMIT $2
+                ) as recent_messages
+                ORDER BY id ASC
             "#;
             sqlx::query_as(query)
                 .bind(channel_id.0)
                 .bind(count as i64)
-                .fetch_all(&self.db)
-                .await
-        })
-    }
-
-    #[cfg(test)]
-    pub async fn close(&self, db_name: &str) {
-        test_support!(self, {
-            let query = "
-                SELECT pg_terminate_backend(pg_stat_activity.pid)
-                FROM pg_stat_activity
-                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
-            ";
-            sqlx::query(query)
-                .bind(db_name)
-                .execute(&self.db)
+                .fetch_all(&self.pool)
                 .await
-                .unwrap();
-            self.db.close().await;
         })
     }
 }
@@ -454,12 +442,13 @@ pub mod tests {
     use std::path::Path;
 
     pub struct TestDb {
+        pub db: Db,
         pub name: String,
         pub url: String,
     }
 
     impl TestDb {
-        pub fn new() -> (Self, Db) {
+        pub fn new() -> Self {
             // Enable tests to run in parallel by serializing the creation of each test database.
             lazy_static::lazy_static! {
                 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
@@ -479,17 +468,54 @@ pub mod tests {
                 let mut db = Db::new(&url, 5).await.unwrap();
                 db.test_mode = true;
                 let migrator = Migrator::new(migrations_path).await.unwrap();
-                migrator.run(&db.db).await.unwrap();
+                migrator.run(&db.pool).await.unwrap();
                 db
             });
 
-            (Self { name, url }, db)
+            Self { db, name, url }
+        }
+
+        pub fn db(&self) -> &Db {
+            &self.db
         }
     }
 
     impl Drop for TestDb {
         fn drop(&mut self) {
-            block_on(Postgres::drop_database(&self.url)).unwrap();
+            block_on(async {
+                let query = "
+                    SELECT pg_terminate_backend(pg_stat_activity.pid)
+                    FROM pg_stat_activity
+                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
+                ";
+                sqlx::query(query)
+                    .bind(&self.name)
+                    .execute(&self.db.pool)
+                    .await
+                    .unwrap();
+                self.db.pool.close().await;
+                Postgres::drop_database(&self.url).await.unwrap();
+            });
         }
     }
+
+    #[gpui::test]
+    async fn test_recent_channel_messages() {
+        let test_db = TestDb::new();
+        let db = test_db.db();
+        let user = db.create_user("user", false).await.unwrap();
+        let org = db.create_org("org", "org").await.unwrap();
+        let channel = db.create_org_channel(org, "channel").await.unwrap();
+        for i in 0..10 {
+            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
+                .await
+                .unwrap();
+        }
+
+        let messages = db.get_recent_channel_messages(channel, 5).await.unwrap();
+        assert_eq!(
+            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+            ["5", "6", "7", "8", "9"]
+        );
+    }
 }

server/src/rpc.rs 🔗

@@ -919,7 +919,7 @@ mod tests {
     use super::*;
     use crate::{
         auth,
-        db::{tests::TestDb, Db, UserId},
+        db::{tests::TestDb, UserId},
         github, AppState, Config,
     };
     use async_std::{sync::RwLockReadGuard, task};
@@ -1529,14 +1529,14 @@ mod tests {
         peer: Arc<Peer>,
         app_state: Arc<AppState>,
         server: Arc<Server>,
-        test_db: TestDb,
         notifications: mpsc::Receiver<()>,
+        _test_db: TestDb,
     }
 
     impl TestServer {
         async fn start() -> Self {
-            let (test_db, db) = TestDb::new();
-            let app_state = Self::build_app_state(&test_db, db).await;
+            let test_db = TestDb::new();
+            let app_state = Self::build_app_state(&test_db).await;
             let peer = Peer::new();
             let notifications = mpsc::channel(128);
             let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
@@ -1544,8 +1544,8 @@ mod tests {
                 peer,
                 app_state,
                 server,
-                test_db,
                 notifications: notifications.1,
+                _test_db: test_db,
             }
         }
 
@@ -1570,13 +1570,13 @@ mod tests {
             (user_id, client)
         }
 
-        async fn build_app_state(test_db: &TestDb, db: Db) -> Arc<AppState> {
+        async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
             let mut config = Config::default();
             config.session_secret = "a".repeat(32);
             config.database_url = test_db.url.clone();
             let github_client = github::AppClient::test();
             Arc::new(AppState {
-                db,
+                db: test_db.db().clone(),
                 handlebars: Default::default(),
                 auth_client: auth::build_client("", ""),
                 repo_client: github::RepoClient::test(&github_client),
@@ -1605,10 +1605,7 @@ mod tests {
 
     impl Drop for TestServer {
         fn drop(&mut self) {
-            task::block_on(async {
-                self.peer.reset().await;
-                self.app_state.db.close(&self.test_db.name).await;
-            });
+            task::block_on(self.peer.reset());
         }
     }