Add server methods for creating chat domain objects

Max Brunsfeld created

Also, consolidate all sql into a `db` module

Change summary

server/src/admin.rs |  65 +---------
server/src/auth.rs  |  79 ++++--------
server/src/db.rs    | 276 +++++++++++++++++++++++++++++++++++++++++++++++
server/src/home.rs  |  10 -
server/src/main.rs  |  17 +-
server/src/rpc.rs   |  16 +-
server/src/tests.rs |  21 +-
7 files changed, 344 insertions(+), 140 deletions(-)

Detailed changes

server/src/admin.rs 🔗

@@ -1,7 +1,6 @@
-use crate::{auth::RequestExt as _, AppState, DbPool, LayoutData, Request, RequestExt as _};
+use crate::{auth::RequestExt as _, db, AppState, LayoutData, Request, RequestExt as _};
 use async_trait::async_trait;
 use serde::{Deserialize, Serialize};
-use sqlx::{Executor, FromRow};
 use std::sync::Arc;
 use surf::http::mime;
 
@@ -41,23 +40,8 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
 struct AdminData {
     #[serde(flatten)]
     layout: Arc<LayoutData>,
-    users: Vec<User>,
-    signups: Vec<Signup>,
-}
-
-#[derive(Debug, FromRow, Serialize)]
-pub struct User {
-    pub id: i32,
-    pub github_login: String,
-    pub admin: bool,
-}
-
-#[derive(Debug, FromRow, Serialize)]
-pub struct Signup {
-    pub id: i32,
-    pub github_login: String,
-    pub email_address: String,
-    pub about: String,
+    users: Vec<db::User>,
+    signups: Vec<db::Signup>,
 }
 
 async fn get_admin_page(mut request: Request) -> tide::Result {
@@ -65,12 +49,8 @@ async fn get_admin_page(mut request: Request) -> tide::Result {
 
     let data = AdminData {
         layout: request.layout_data().await?,
-        users: sqlx::query_as("SELECT * FROM users ORDER BY github_login ASC")
-            .fetch_all(request.db())
-            .await?,
-        signups: sqlx::query_as("SELECT * FROM signups ORDER BY id DESC")
-            .fetch_all(request.db())
-            .await?,
+        users: request.db().get_all_users().await?,
+        signups: request.db().get_all_signups().await?,
     };
 
     Ok(tide::Response::builder(200)
@@ -96,7 +76,7 @@ async fn post_user(mut request: Request) -> tide::Result {
         .unwrap_or(&form.github_login);
 
     if !github_login.is_empty() {
-        create_user(request.db(), github_login, form.admin).await?;
+        request.db().create_user(github_login, form.admin).await?;
     }
 
     Ok(tide::Redirect::new("/admin").into())
@@ -116,11 +96,7 @@ async fn put_user(mut request: Request) -> tide::Result {
 
     request
         .db()
-        .execute(
-            sqlx::query("UPDATE users SET admin = $1 WHERE id = $2;")
-                .bind(body.admin)
-                .bind(user_id),
-        )
+        .set_user_is_admin(db::UserId(user_id), body.admin)
         .await?;
 
     Ok(tide::Response::builder(200).build())
@@ -128,33 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result {
 
 async fn delete_user(request: Request) -> tide::Result {
     request.require_admin().await?;
-
-    let user_id = request.param("id")?.parse::<i32>()?;
-    request
-        .db()
-        .execute(sqlx::query("DELETE FROM users WHERE id = $1;").bind(user_id))
-        .await?;
-
+    let user_id = db::UserId(request.param("id")?.parse::<i32>()?);
+    request.db().delete_user(user_id).await?;
     Ok(tide::Redirect::new("/admin").into())
 }
 
-pub async fn create_user(db: &DbPool, github_login: &str, admin: bool) -> tide::Result<i32> {
-    let id: i32 =
-        sqlx::query_scalar("INSERT INTO users (github_login, admin) VALUES ($1, $2) RETURNING id;")
-            .bind(github_login)
-            .bind(admin)
-            .fetch_one(db)
-            .await?;
-    Ok(id)
-}
-
 async fn delete_signup(request: Request) -> tide::Result {
     request.require_admin().await?;
-    let signup_id = request.param("id")?.parse::<i32>()?;
-    request
-        .db()
-        .execute(sqlx::query("DELETE FROM signups WHERE id = $1;").bind(signup_id))
-        .await?;
-
+    let signup_id = db::SignupId(request.param("id")?.parse::<i32>()?);
+    request.db().delete_signup(signup_id).await?;
     Ok(tide::Redirect::new("/admin").into())
 }

server/src/auth.rs 🔗

@@ -1,7 +1,9 @@
-use super::errors::TideResultExt;
-use crate::{github, rpc, AppState, DbPool, Request, RequestExt as _};
+use super::{
+    db::{self, UserId},
+    errors::TideResultExt,
+};
+use crate::{github, rpc, AppState, Request, RequestExt as _};
 use anyhow::{anyhow, Context};
-use async_std::stream::StreamExt;
 use async_trait::async_trait;
 pub use oauth2::basic::BasicClient as Client;
 use oauth2::{
@@ -14,7 +16,6 @@ use scrypt::{
     Scrypt,
 };
 use serde::{Deserialize, Serialize};
-use sqlx::FromRow;
 use std::{borrow::Cow, convert::TryFrom, sync::Arc};
 use surf::Url;
 use tide::Server;
@@ -34,9 +35,6 @@ pub struct User {
 
 pub struct VerifyToken;
 
-#[derive(Clone, Copy)]
-pub struct UserId(pub i32);
-
 #[async_trait]
 impl tide::Middleware<Arc<AppState>> for VerifyToken {
     async fn handle(
@@ -51,33 +49,28 @@ impl tide::Middleware<Arc<AppState>> for VerifyToken {
             .as_str()
             .split_whitespace();
 
-        let user_id: i32 = auth_header
-            .next()
-            .ok_or_else(|| anyhow!("missing user id in authorization header"))?
-            .parse()?;
+        let user_id = UserId(
+            auth_header
+                .next()
+                .ok_or_else(|| anyhow!("missing user id in authorization header"))?
+                .parse()?,
+        );
         let access_token = auth_header
             .next()
             .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
 
         let state = request.state().clone();
 
-        let mut password_hashes =
-            sqlx::query_scalar::<_, String>("SELECT hash FROM access_tokens WHERE user_id = $1")
-                .bind(&user_id)
-                .fetch_many(&state.db);
-
         let mut credentials_valid = false;
-        while let Some(password_hash) = password_hashes.next().await {
-            if let either::Either::Right(password_hash) = password_hash? {
-                if verify_access_token(&access_token, &password_hash)? {
-                    credentials_valid = true;
-                    break;
-                }
+        for password_hash in state.db.get_access_token_hashes(user_id).await? {
+            if verify_access_token(&access_token, &password_hash)? {
+                credentials_valid = true;
+                break;
             }
         }
 
         if credentials_valid {
-            request.set_ext(UserId(user_id));
+            request.set_ext(user_id);
             Ok(next.run(request).await)
         } else {
             Err(anyhow!("invalid credentials").into())
@@ -94,25 +87,12 @@ pub trait RequestExt {
 impl RequestExt for Request {
     async fn current_user(&self) -> tide::Result<Option<User>> {
         if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
-            #[derive(FromRow)]
-            struct UserRow {
-                admin: bool,
-            }
-
-            let user_row: Option<UserRow> =
-                sqlx::query_as("SELECT admin FROM users WHERE github_login = $1")
-                    .bind(&details.login)
-                    .fetch_optional(self.db())
-                    .await?;
-
-            let is_insider = user_row.is_some();
-            let is_admin = user_row.map_or(false, |row| row.admin);
-
+            let user = self.db().get_user_by_github_login(&details.login).await?;
             Ok(Some(User {
                 github_login: details.login,
                 avatar_url: details.avatar_url,
-                is_insider,
-                is_admin,
+                is_insider: user.is_some(),
+                is_admin: user.map_or(false, |user| user.admin),
             }))
         } else {
             Ok(None)
@@ -265,9 +245,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
         .await
         .context("failed to fetch user")?;
 
-    let user_id: Option<i32> = sqlx::query_scalar("SELECT id from users where github_login = $1")
-        .bind(&user_details.login)
-        .fetch_optional(request.db())
+    let user = request
+        .db()
+        .get_user_by_github_login(&user_details.login)
         .await?;
 
     request
@@ -276,8 +256,8 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
 
     // When signing in from the native app, generate a new access token for the current user. Return
     // a redirect so that the user's browser sends this access token to the locally-running app.
-    if let Some((user_id, app_sign_in_params)) = user_id.zip(query.native_app_sign_in_params) {
-        let access_token = create_access_token(request.db(), user_id).await?;
+    if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
+        let access_token = create_access_token(request.db(), user.id()).await?;
         let native_app_public_key =
             zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
                 .context("failed to parse app public key")?;
@@ -287,7 +267,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
 
         return Ok(tide::Redirect::new(&format!(
             "http://127.0.0.1:{}?user_id={}&access_token={}",
-            app_sign_in_params.native_app_port, user_id, encrypted_access_token,
+            app_sign_in_params.native_app_port,
+            user.id().0,
+            encrypted_access_token,
         ))
         .into());
     }
@@ -300,14 +282,11 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
     Ok(tide::Redirect::new("/").into())
 }
 
-pub async fn create_access_token(db: &DbPool, user_id: i32) -> tide::Result<String> {
+pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
     let access_token = zed_auth::random_token();
     let access_token_hash =
         hash_access_token(&access_token).context("failed to hash access token")?;
-    sqlx::query("INSERT INTO access_tokens (user_id, hash) values ($1, $2)")
-        .bind(user_id)
-        .bind(access_token_hash)
-        .fetch_optional(db)
+    db.create_access_token_hash(user_id, access_token_hash)
         .await?;
     Ok(access_token)
 }

server/src/db.rs 🔗

@@ -0,0 +1,276 @@
+use serde::Serialize;
+use sqlx::{FromRow, Result};
+
+pub use async_sqlx_session::PostgresSessionStore as SessionStore;
+pub use sqlx::postgres::PgPoolOptions as DbOptions;
+
+pub struct Db(pub sqlx::PgPool);
+
+#[derive(Debug, FromRow, Serialize)]
+pub struct User {
+    id: i32,
+    pub github_login: String,
+    pub admin: bool,
+}
+
+#[derive(Debug, FromRow, Serialize)]
+pub struct Signup {
+    id: i32,
+    pub github_login: String,
+    pub email_address: String,
+    pub about: String,
+}
+
+#[derive(Debug, FromRow)]
+pub struct ChannelMessage {
+    id: i32,
+    sender_id: i32,
+    body: String,
+    sent_at: i64,
+}
+
+#[derive(Clone, Copy)]
+pub struct UserId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct OrgId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct ChannelId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct SignupId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct MessageId(pub i32);
+
+impl Db {
+    // signups
+
+    pub async fn create_signup(
+        &self,
+        github_login: &str,
+        email_address: &str,
+        about: &str,
+    ) -> Result<SignupId> {
+        let query = "
+            INSERT INTO signups (github_login, email_address, about)
+            VALUES ($1, $2, $3)
+            RETURNING id
+        ";
+        sqlx::query_scalar(query)
+            .bind(github_login)
+            .bind(email_address)
+            .bind(about)
+            .fetch_one(&self.0)
+            .await
+            .map(SignupId)
+    }
+
+    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+        let query = "SELECT * FROM users ORDER BY github_login ASC";
+        sqlx::query_as(query).fetch_all(&self.0).await
+    }
+
+    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
+        let query = "DELETE FROM signups WHERE id = $1";
+        sqlx::query(query)
+            .bind(id.0)
+            .execute(&self.0)
+            .await
+            .map(drop)
+    }
+
+    // users
+
+    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+        let query = "
+            INSERT INTO users (github_login, admin)
+            VALUES ($1, $2)
+            RETURNING id
+        ";
+        sqlx::query_scalar(query)
+            .bind(github_login)
+            .bind(admin)
+            .fetch_one(&self.0)
+            .await
+            .map(UserId)
+    }
+
+    pub async fn get_all_users(&self) -> Result<Vec<User>> {
+        let query = "SELECT * FROM users ORDER BY github_login ASC";
+        sqlx::query_as(query).fetch_all(&self.0).await
+    }
+
+    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
+        sqlx::query_as(query)
+            .bind(github_login)
+            .fetch_optional(&self.0)
+            .await
+    }
+
+    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+        let query = "UPDATE users SET admin = $1 WHERE id = $2";
+        sqlx::query(query)
+            .bind(is_admin)
+            .bind(id.0)
+            .execute(&self.0)
+            .await
+            .map(drop)
+    }
+
+    pub async fn delete_user(&self, id: UserId) -> Result<()> {
+        let query = "DELETE FROM users WHERE id = $1;";
+        sqlx::query(query)
+            .bind(id.0)
+            .execute(&self.0)
+            .await
+            .map(drop)
+    }
+
+    // access tokens
+
+    pub async fn create_access_token_hash(
+        &self,
+        user_id: UserId,
+        access_token_hash: String,
+    ) -> Result<()> {
+        let query = "
+            INSERT INTO access_tokens (user_id, hash)
+            VALUES ($1, $2)
+        ";
+        sqlx::query(query)
+            .bind(user_id.0 as i32)
+            .bind(access_token_hash)
+            .execute(&self.0)
+            .await
+            .map(drop)
+    }
+
+    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+        let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
+        sqlx::query_scalar::<_, String>(query)
+            .bind(user_id.0 as i32)
+            .fetch_all(&self.0)
+            .await
+    }
+
+    // orgs
+
+    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+        let query = "
+            INSERT INTO orgs (name, slug)
+            VALUES ($1, $2)
+            RETURNING id
+        ";
+        sqlx::query_scalar(query)
+            .bind(name)
+            .bind(slug)
+            .fetch_one(&self.0)
+            .await
+            .map(OrgId)
+    }
+
+    pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
+        let query = "
+            INSERT INTO org_memberships (org_id, user_id)
+            VALUES ($1, $2)
+        ";
+        sqlx::query(query)
+            .bind(org_id.0)
+            .bind(user_id.0)
+            .execute(&self.0)
+            .await
+            .map(drop)
+    }
+
+    // channels
+
+    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+        let query = "
+            INSERT INTO channels (owner_id, owner_is_user, name)
+            VALUES ($1, false, $2)
+            RETURNING id
+        ";
+        sqlx::query_scalar(query)
+            .bind(org_id.0)
+            .bind(name)
+            .fetch_one(&self.0)
+            .await
+            .map(ChannelId)
+    }
+
+    pub async fn add_channel_member(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        is_admin: bool,
+    ) -> Result<()> {
+        let query = "
+            INSERT INTO channel_memberships (channel_id, user_id, admin)
+            VALUES ($1, $2, $3)
+        ";
+        sqlx::query(query)
+            .bind(channel_id.0)
+            .bind(user_id.0)
+            .bind(is_admin)
+            .execute(&self.0)
+            .await
+            .map(drop)
+    }
+
+    // messages
+
+    pub async fn create_channel_message(
+        &self,
+        channel_id: ChannelId,
+        sender_id: UserId,
+        body: &str,
+    ) -> Result<MessageId> {
+        let query = "
+            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
+            VALUES ($1, $2, $3, NOW()::timestamp)
+            RETURNING id
+        ";
+        sqlx::query_scalar(query)
+            .bind(channel_id.0)
+            .bind(sender_id.0)
+            .bind(body)
+            .fetch_one(&self.0)
+            .await
+            .map(MessageId)
+    }
+
+    pub async fn get_recent_channel_messages(
+        &self,
+        channel_id: ChannelId,
+        count: usize,
+    ) -> Result<Vec<ChannelMessage>> {
+        let query = "
+            SELECT id, sender_id, body, sent_at
+            FROM channel_messages
+            WHERE channel_id = $1
+            LIMIT $2
+        ";
+        sqlx::query_as(query)
+            .bind(channel_id.0)
+            .bind(count as i64)
+            .fetch_all(&self.0)
+            .await
+    }
+}
+
+impl std::ops::Deref for Db {
+    type Target = sqlx::PgPool;
+
+    fn deref(&self) -> &Self::Target {
+        &self.0
+    }
+}
+
+impl User {
+    pub fn id(&self) -> UserId {
+        UserId(self.id)
+    }
+}

server/src/home.rs 🔗

@@ -3,7 +3,6 @@ use crate::{
 };
 use comrak::ComrakOptions;
 use serde::{Deserialize, Serialize};
-use sqlx::Executor as _;
 use std::sync::Arc;
 use tide::{http::mime, log, Server};
 
@@ -76,14 +75,7 @@ async fn post_signup(mut request: Request) -> tide::Result {
     // Save signup in the database
     request
         .db()
-        .execute(
-            sqlx::query(
-                "INSERT INTO signups (github_login, email_address, about) VALUES ($1, $2, $3);",
-            )
-            .bind(&form.github_login)
-            .bind(&form.email_address)
-            .bind(&form.about),
-        )
+        .create_signup(&form.github_login, &form.email_address, &form.about)
         .await?;
 
     let layout_data = request.layout_data().await?;

server/src/main.rs 🔗

@@ -1,6 +1,7 @@
 mod admin;
 mod assets;
 mod auth;
+mod db;
 mod env;
 mod errors;
 mod expiring;
@@ -13,15 +14,14 @@ mod tests;
 
 use self::errors::TideResultExt as _;
 use anyhow::{Context, Result};
-use async_sqlx_session::PostgresSessionStore;
 use async_std::{net::TcpListener, sync::RwLock as AsyncRwLock};
 use async_trait::async_trait;
 use auth::RequestExt as _;
+use db::{Db, DbOptions};
 use handlebars::{Handlebars, TemplateRenderError};
 use parking_lot::RwLock;
 use rust_embed::RustEmbed;
 use serde::{Deserialize, Serialize};
-use sqlx::postgres::{PgPool, PgPoolOptions};
 use std::sync::Arc;
 use surf::http::cookies::SameSite;
 use tide::{log, sessions::SessionMiddleware};
@@ -29,7 +29,6 @@ use tide_compress::CompressMiddleware;
 use zrpc::Peer;
 
 type Request = tide::Request<Arc<AppState>>;
-type DbPool = PgPool;
 
 #[derive(RustEmbed)]
 #[folder = "templates"]
@@ -47,7 +46,7 @@ pub struct Config {
 }
 
 pub struct AppState {
-    db: sqlx::PgPool,
+    db: Db,
     handlebars: RwLock<Handlebars<'static>>,
     auth_client: auth::Client,
     github_client: Arc<github::AppClient>,
@@ -58,11 +57,11 @@ pub struct AppState {
 
 impl AppState {
     async fn new(config: Config) -> tide::Result<Arc<Self>> {
-        let db = PgPoolOptions::new()
+        let db = Db(DbOptions::new()
             .max_connections(5)
             .connect(&config.database_url)
             .await
-            .context("failed to connect to postgres database")?;
+            .context("failed to connect to postgres database")?);
 
         let github_client =
             github::AppClient::new(config.github_app_id, config.github_private_key.clone());
@@ -117,7 +116,7 @@ impl AppState {
 #[async_trait]
 trait RequestExt {
     async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
-    fn db(&self) -> &DbPool;
+    fn db(&self) -> &Db;
 }
 
 #[async_trait]
@@ -131,7 +130,7 @@ impl RequestExt for Request {
         Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
     }
 
-    fn db(&self) -> &DbPool {
+    fn db(&self) -> &Db {
         &self.state().db
     }
 }
@@ -173,7 +172,7 @@ pub async fn run_server(
     web.with(CompressMiddleware::new());
     web.with(
         SessionMiddleware::new(
-            PostgresSessionStore::new_with_table_name(&state.config.database_url, "sessions")
+            db::SessionStore::new_with_table_name(&state.config.database_url, "sessions")
                 .await
                 .unwrap(),
             state.config.session_secret.as_bytes(),

server/src/rpc.rs 🔗

@@ -1,6 +1,8 @@
-use crate::auth::{self, UserId};
-
-use super::{auth::PeerExt as _, AppState};
+use super::{
+    auth::{self, PeerExt as _},
+    db::UserId,
+    AppState,
+};
 use anyhow::anyhow;
 use async_std::task;
 use async_tungstenite::{
@@ -37,7 +39,7 @@ pub struct State {
 }
 
 struct ConnectionState {
-    _user_id: i32,
+    _user_id: UserId,
     worktrees: HashSet<u64>,
 }
 
@@ -68,7 +70,7 @@ impl WorktreeState {
 
 impl State {
     // Add a new connection associated with a given user.
-    pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) {
+    pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) {
         self.connections.insert(
             connection_id,
             ConnectionState {
@@ -291,7 +293,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
             let upgrade_receiver = http_res.recv_upgrade().await;
             let addr = request.remote().unwrap_or("unknown").to_string();
             let state = request.state().clone();
-            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0;
+            let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
             task::spawn(async move {
                 if let Some(stream) = upgrade_receiver.await {
                     let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
@@ -310,7 +312,7 @@ pub async fn handle_connection<Conn>(
     state: Arc<AppState>,
     addr: String,
     stream: Conn,
-    user_id: i32,
+    user_id: UserId,
 ) where
     Conn: 'static
         + futures::Sink<WebSocketMessage, Error = WebSocketError>

server/src/tests.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    admin, auth, github,
+    auth, db, github,
     rpc::{self, add_rpc_routes},
     AppState, Config,
 };
@@ -9,7 +9,6 @@ use rand::prelude::*;
 use serde_json::json;
 use sqlx::{
     migrate::{MigrateDatabase, Migrator},
-    postgres::PgPoolOptions,
     Executor as _, Postgres,
 };
 use std::{path::Path, sync::Arc};
@@ -499,9 +498,7 @@ impl TestServer {
     }
 
     async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> Client {
-        let user_id = admin::create_user(&self.app_state.db, name, false)
-            .await
-            .unwrap();
+        let user_id = self.app_state.db.create_user(name, false).await.unwrap();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let client = Client::new(lang_registry.clone());
         let mut client_router = ForegroundRouter::new();
@@ -532,18 +529,20 @@ impl TestServer {
         config.database_url = format!("postgres://postgres@localhost/{}", db_name);
 
         Self::create_db(&config.database_url).await;
-        let db = PgPoolOptions::new()
-            .max_connections(5)
-            .connect(&config.database_url)
-            .await
-            .expect("failed to connect to postgres database");
+        let db = db::Db(
+            db::DbOptions::new()
+                .max_connections(5)
+                .connect(&config.database_url)
+                .await
+                .expect("failed to connect to postgres database"),
+        );
         let migrator = Migrator::new(Path::new(concat!(
             env!("CARGO_MANIFEST_DIR"),
             "/migrations"
         )))
         .await
         .unwrap();
-        migrator.run(&db).await.unwrap();
+        migrator.run(&db.0).await.unwrap();
 
         let github_client = github::AppClient::test();
         Arc::new(AppState {