diff --git a/server/src/admin.rs b/server/src/admin.rs index 3f379ff56f9a1e2f2c5d34d41b20a24dfa683c7a..d6e3f8161589e6b2420cc9a7d54bd212ec8d76b1 100644 --- a/server/src/admin.rs +++ b/server/src/admin.rs @@ -1,7 +1,6 @@ -use crate::{auth::RequestExt as _, AppState, DbPool, LayoutData, Request, RequestExt as _}; +use crate::{auth::RequestExt as _, db, AppState, LayoutData, Request, RequestExt as _}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use sqlx::{Executor, FromRow}; use std::sync::Arc; use surf::http::mime; @@ -41,23 +40,8 @@ pub fn add_routes(app: &mut tide::Server>) { struct AdminData { #[serde(flatten)] layout: Arc, - users: Vec, - signups: Vec, -} - -#[derive(Debug, FromRow, Serialize)] -pub struct User { - pub id: i32, - pub github_login: String, - pub admin: bool, -} - -#[derive(Debug, FromRow, Serialize)] -pub struct Signup { - pub id: i32, - pub github_login: String, - pub email_address: String, - pub about: String, + users: Vec, + signups: Vec, } async fn get_admin_page(mut request: Request) -> tide::Result { @@ -65,12 +49,8 @@ async fn get_admin_page(mut request: Request) -> tide::Result { let data = AdminData { layout: request.layout_data().await?, - users: sqlx::query_as("SELECT * FROM users ORDER BY github_login ASC") - .fetch_all(request.db()) - .await?, - signups: sqlx::query_as("SELECT * FROM signups ORDER BY id DESC") - .fetch_all(request.db()) - .await?, + users: request.db().get_all_users().await?, + signups: request.db().get_all_signups().await?, }; Ok(tide::Response::builder(200) @@ -96,7 +76,7 @@ async fn post_user(mut request: Request) -> tide::Result { .unwrap_or(&form.github_login); if !github_login.is_empty() { - create_user(request.db(), github_login, form.admin).await?; + request.db().create_user(github_login, form.admin).await?; } Ok(tide::Redirect::new("/admin").into()) @@ -116,11 +96,7 @@ async fn put_user(mut request: Request) -> tide::Result { request .db() - .execute( - sqlx::query("UPDATE users SET admin = $1 WHERE id = $2;") - .bind(body.admin) - .bind(user_id), - ) + .set_user_is_admin(db::UserId(user_id), body.admin) .await?; Ok(tide::Response::builder(200).build()) @@ -128,33 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result { async fn delete_user(request: Request) -> tide::Result { request.require_admin().await?; - - let user_id = request.param("id")?.parse::()?; - request - .db() - .execute(sqlx::query("DELETE FROM users WHERE id = $1;").bind(user_id)) - .await?; - + let user_id = db::UserId(request.param("id")?.parse::()?); + request.db().delete_user(user_id).await?; Ok(tide::Redirect::new("/admin").into()) } -pub async fn create_user(db: &DbPool, github_login: &str, admin: bool) -> tide::Result { - let id: i32 = - sqlx::query_scalar("INSERT INTO users (github_login, admin) VALUES ($1, $2) RETURNING id;") - .bind(github_login) - .bind(admin) - .fetch_one(db) - .await?; - Ok(id) -} - async fn delete_signup(request: Request) -> tide::Result { request.require_admin().await?; - let signup_id = request.param("id")?.parse::()?; - request - .db() - .execute(sqlx::query("DELETE FROM signups WHERE id = $1;").bind(signup_id)) - .await?; - + let signup_id = db::SignupId(request.param("id")?.parse::()?); + request.db().delete_signup(signup_id).await?; Ok(tide::Redirect::new("/admin").into()) } diff --git a/server/src/auth.rs b/server/src/auth.rs index 4a7107e550c2adc4838a703308554c93ea21464d..9dde8212ff2a848257dfafc2512443b4a0218c33 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -1,7 +1,9 @@ -use super::errors::TideResultExt; -use crate::{github, rpc, AppState, DbPool, Request, RequestExt as _}; +use super::{ + db::{self, UserId}, + errors::TideResultExt, +}; +use crate::{github, rpc, AppState, Request, RequestExt as _}; use anyhow::{anyhow, Context}; -use async_std::stream::StreamExt; use async_trait::async_trait; pub use oauth2::basic::BasicClient as Client; use oauth2::{ @@ -14,7 +16,6 @@ use scrypt::{ Scrypt, }; use serde::{Deserialize, Serialize}; -use sqlx::FromRow; use std::{borrow::Cow, convert::TryFrom, sync::Arc}; use surf::Url; use tide::Server; @@ -34,9 +35,6 @@ pub struct User { pub struct VerifyToken; -#[derive(Clone, Copy)] -pub struct UserId(pub i32); - #[async_trait] impl tide::Middleware> for VerifyToken { async fn handle( @@ -51,33 +49,28 @@ impl tide::Middleware> for VerifyToken { .as_str() .split_whitespace(); - let user_id: i32 = auth_header - .next() - .ok_or_else(|| anyhow!("missing user id in authorization header"))? - .parse()?; + let user_id = UserId( + auth_header + .next() + .ok_or_else(|| anyhow!("missing user id in authorization header"))? + .parse()?, + ); let access_token = auth_header .next() .ok_or_else(|| anyhow!("missing access token in authorization header"))?; let state = request.state().clone(); - let mut password_hashes = - sqlx::query_scalar::<_, String>("SELECT hash FROM access_tokens WHERE user_id = $1") - .bind(&user_id) - .fetch_many(&state.db); - let mut credentials_valid = false; - while let Some(password_hash) = password_hashes.next().await { - if let either::Either::Right(password_hash) = password_hash? { - if verify_access_token(&access_token, &password_hash)? { - credentials_valid = true; - break; - } + for password_hash in state.db.get_access_token_hashes(user_id).await? { + if verify_access_token(&access_token, &password_hash)? { + credentials_valid = true; + break; } } if credentials_valid { - request.set_ext(UserId(user_id)); + request.set_ext(user_id); Ok(next.run(request).await) } else { Err(anyhow!("invalid credentials").into()) @@ -94,25 +87,12 @@ pub trait RequestExt { impl RequestExt for Request { async fn current_user(&self) -> tide::Result> { if let Some(details) = self.session().get::(CURRENT_GITHUB_USER) { - #[derive(FromRow)] - struct UserRow { - admin: bool, - } - - let user_row: Option = - sqlx::query_as("SELECT admin FROM users WHERE github_login = $1") - .bind(&details.login) - .fetch_optional(self.db()) - .await?; - - let is_insider = user_row.is_some(); - let is_admin = user_row.map_or(false, |row| row.admin); - + let user = self.db().get_user_by_github_login(&details.login).await?; Ok(Some(User { github_login: details.login, avatar_url: details.avatar_url, - is_insider, - is_admin, + is_insider: user.is_some(), + is_admin: user.map_or(false, |user| user.admin), })) } else { Ok(None) @@ -265,9 +245,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { .await .context("failed to fetch user")?; - let user_id: Option = sqlx::query_scalar("SELECT id from users where github_login = $1") - .bind(&user_details.login) - .fetch_optional(request.db()) + let user = request + .db() + .get_user_by_github_login(&user_details.login) .await?; request @@ -276,8 +256,8 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { // When signing in from the native app, generate a new access token for the current user. Return // a redirect so that the user's browser sends this access token to the locally-running app. - if let Some((user_id, app_sign_in_params)) = user_id.zip(query.native_app_sign_in_params) { - let access_token = create_access_token(request.db(), user_id).await?; + if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) { + let access_token = create_access_token(request.db(), user.id()).await?; let native_app_public_key = zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone()) .context("failed to parse app public key")?; @@ -287,7 +267,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { return Ok(tide::Redirect::new(&format!( "http://127.0.0.1:{}?user_id={}&access_token={}", - app_sign_in_params.native_app_port, user_id, encrypted_access_token, + app_sign_in_params.native_app_port, + user.id().0, + encrypted_access_token, )) .into()); } @@ -300,14 +282,11 @@ async fn post_sign_out(mut request: Request) -> tide::Result { Ok(tide::Redirect::new("/").into()) } -pub async fn create_access_token(db: &DbPool, user_id: i32) -> tide::Result { +pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result { let access_token = zed_auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; - sqlx::query("INSERT INTO access_tokens (user_id, hash) values ($1, $2)") - .bind(user_id) - .bind(access_token_hash) - .fetch_optional(db) + db.create_access_token_hash(user_id, access_token_hash) .await?; Ok(access_token) } diff --git a/server/src/db.rs b/server/src/db.rs new file mode 100644 index 0000000000000000000000000000000000000000..9b610187012f781b384a6b7af50c4a54b790396e --- /dev/null +++ b/server/src/db.rs @@ -0,0 +1,276 @@ +use serde::Serialize; +use sqlx::{FromRow, Result}; + +pub use async_sqlx_session::PostgresSessionStore as SessionStore; +pub use sqlx::postgres::PgPoolOptions as DbOptions; + +pub struct Db(pub sqlx::PgPool); + +#[derive(Debug, FromRow, Serialize)] +pub struct User { + id: i32, + pub github_login: String, + pub admin: bool, +} + +#[derive(Debug, FromRow, Serialize)] +pub struct Signup { + id: i32, + pub github_login: String, + pub email_address: String, + pub about: String, +} + +#[derive(Debug, FromRow)] +pub struct ChannelMessage { + id: i32, + sender_id: i32, + body: String, + sent_at: i64, +} + +#[derive(Clone, Copy)] +pub struct UserId(pub i32); + +#[derive(Clone, Copy)] +pub struct OrgId(pub i32); + +#[derive(Clone, Copy)] +pub struct ChannelId(pub i32); + +#[derive(Clone, Copy)] +pub struct SignupId(pub i32); + +#[derive(Clone, Copy)] +pub struct MessageId(pub i32); + +impl Db { + // signups + + pub async fn create_signup( + &self, + github_login: &str, + email_address: &str, + about: &str, + ) -> Result { + let query = " + INSERT INTO signups (github_login, email_address, about) + VALUES ($1, $2, $3) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(github_login) + .bind(email_address) + .bind(about) + .fetch_one(&self.0) + .await + .map(SignupId) + } + + pub async fn get_all_signups(&self) -> Result> { + let query = "SELECT * FROM users ORDER BY github_login ASC"; + sqlx::query_as(query).fetch_all(&self.0).await + } + + pub async fn delete_signup(&self, id: SignupId) -> Result<()> { + let query = "DELETE FROM signups WHERE id = $1"; + sqlx::query(query) + .bind(id.0) + .execute(&self.0) + .await + .map(drop) + } + + // users + + pub async fn create_user(&self, github_login: &str, admin: bool) -> Result { + let query = " + INSERT INTO users (github_login, admin) + VALUES ($1, $2) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(github_login) + .bind(admin) + .fetch_one(&self.0) + .await + .map(UserId) + } + + pub async fn get_all_users(&self) -> Result> { + let query = "SELECT * FROM users ORDER BY github_login ASC"; + sqlx::query_as(query).fetch_all(&self.0).await + } + + pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; + sqlx::query_as(query) + .bind(github_login) + .fetch_optional(&self.0) + .await + } + + pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + let query = "UPDATE users SET admin = $1 WHERE id = $2"; + sqlx::query(query) + .bind(is_admin) + .bind(id.0) + .execute(&self.0) + .await + .map(drop) + } + + pub async fn delete_user(&self, id: UserId) -> Result<()> { + let query = "DELETE FROM users WHERE id = $1;"; + sqlx::query(query) + .bind(id.0) + .execute(&self.0) + .await + .map(drop) + } + + // access tokens + + pub async fn create_access_token_hash( + &self, + user_id: UserId, + access_token_hash: String, + ) -> Result<()> { + let query = " + INSERT INTO access_tokens (user_id, hash) + VALUES ($1, $2) + "; + sqlx::query(query) + .bind(user_id.0 as i32) + .bind(access_token_hash) + .execute(&self.0) + .await + .map(drop) + } + + pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { + let query = "SELECT hash FROM access_tokens WHERE user_id = $1"; + sqlx::query_scalar::<_, String>(query) + .bind(user_id.0 as i32) + .fetch_all(&self.0) + .await + } + + // orgs + + pub async fn create_org(&self, name: &str, slug: &str) -> Result { + let query = " + INSERT INTO orgs (name, slug) + VALUES ($1, $2) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(name) + .bind(slug) + .fetch_one(&self.0) + .await + .map(OrgId) + } + + pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> { + let query = " + INSERT INTO org_memberships (org_id, user_id) + VALUES ($1, $2) + "; + sqlx::query(query) + .bind(org_id.0) + .bind(user_id.0) + .execute(&self.0) + .await + .map(drop) + } + + // channels + + pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { + let query = " + INSERT INTO channels (owner_id, owner_is_user, name) + VALUES ($1, false, $2) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(org_id.0) + .bind(name) + .fetch_one(&self.0) + .await + .map(ChannelId) + } + + pub async fn add_channel_member( + &self, + channel_id: ChannelId, + user_id: UserId, + is_admin: bool, + ) -> Result<()> { + let query = " + INSERT INTO channel_memberships (channel_id, user_id, admin) + VALUES ($1, $2, $3) + "; + sqlx::query(query) + .bind(channel_id.0) + .bind(user_id.0) + .bind(is_admin) + .execute(&self.0) + .await + .map(drop) + } + + // messages + + pub async fn create_channel_message( + &self, + channel_id: ChannelId, + sender_id: UserId, + body: &str, + ) -> Result { + let query = " + INSERT INTO channel_messages (channel_id, sender_id, body, sent_at) + VALUES ($1, $2, $3, NOW()::timestamp) + RETURNING id + "; + sqlx::query_scalar(query) + .bind(channel_id.0) + .bind(sender_id.0) + .bind(body) + .fetch_one(&self.0) + .await + .map(MessageId) + } + + pub async fn get_recent_channel_messages( + &self, + channel_id: ChannelId, + count: usize, + ) -> Result> { + let query = " + SELECT id, sender_id, body, sent_at + FROM channel_messages + WHERE channel_id = $1 + LIMIT $2 + "; + sqlx::query_as(query) + .bind(channel_id.0) + .bind(count as i64) + .fetch_all(&self.0) + .await + } +} + +impl std::ops::Deref for Db { + type Target = sqlx::PgPool; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl User { + pub fn id(&self) -> UserId { + UserId(self.id) + } +} diff --git a/server/src/home.rs b/server/src/home.rs index b4b8c24bf607302db7c15b109bd784bf38e667e2..25adde3a0f2fed3d7b11c107e23dcbfdf1d67bac 100644 --- a/server/src/home.rs +++ b/server/src/home.rs @@ -3,7 +3,6 @@ use crate::{ }; use comrak::ComrakOptions; use serde::{Deserialize, Serialize}; -use sqlx::Executor as _; use std::sync::Arc; use tide::{http::mime, log, Server}; @@ -76,14 +75,7 @@ async fn post_signup(mut request: Request) -> tide::Result { // Save signup in the database request .db() - .execute( - sqlx::query( - "INSERT INTO signups (github_login, email_address, about) VALUES ($1, $2, $3);", - ) - .bind(&form.github_login) - .bind(&form.email_address) - .bind(&form.about), - ) + .create_signup(&form.github_login, &form.email_address, &form.about) .await?; let layout_data = request.layout_data().await?; diff --git a/server/src/main.rs b/server/src/main.rs index ebd52b0a8bd0fd1e42beeaa505245852f545dd6f..ec153bea8fd535d513be81aab3f0b0bfc862d3b8 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,6 +1,7 @@ mod admin; mod assets; mod auth; +mod db; mod env; mod errors; mod expiring; @@ -13,15 +14,14 @@ mod tests; use self::errors::TideResultExt as _; use anyhow::{Context, Result}; -use async_sqlx_session::PostgresSessionStore; use async_std::{net::TcpListener, sync::RwLock as AsyncRwLock}; use async_trait::async_trait; use auth::RequestExt as _; +use db::{Db, DbOptions}; use handlebars::{Handlebars, TemplateRenderError}; use parking_lot::RwLock; use rust_embed::RustEmbed; use serde::{Deserialize, Serialize}; -use sqlx::postgres::{PgPool, PgPoolOptions}; use std::sync::Arc; use surf::http::cookies::SameSite; use tide::{log, sessions::SessionMiddleware}; @@ -29,7 +29,6 @@ use tide_compress::CompressMiddleware; use zrpc::Peer; type Request = tide::Request>; -type DbPool = PgPool; #[derive(RustEmbed)] #[folder = "templates"] @@ -47,7 +46,7 @@ pub struct Config { } pub struct AppState { - db: sqlx::PgPool, + db: Db, handlebars: RwLock>, auth_client: auth::Client, github_client: Arc, @@ -58,11 +57,11 @@ pub struct AppState { impl AppState { async fn new(config: Config) -> tide::Result> { - let db = PgPoolOptions::new() + let db = Db(DbOptions::new() .max_connections(5) .connect(&config.database_url) .await - .context("failed to connect to postgres database")?; + .context("failed to connect to postgres database")?); let github_client = github::AppClient::new(config.github_app_id, config.github_private_key.clone()); @@ -117,7 +116,7 @@ impl AppState { #[async_trait] trait RequestExt { async fn layout_data(&mut self) -> tide::Result>; - fn db(&self) -> &DbPool; + fn db(&self) -> &Db; } #[async_trait] @@ -131,7 +130,7 @@ impl RequestExt for Request { Ok(self.ext::>().unwrap().clone()) } - fn db(&self) -> &DbPool { + fn db(&self) -> &Db { &self.state().db } } @@ -173,7 +172,7 @@ pub async fn run_server( web.with(CompressMiddleware::new()); web.with( SessionMiddleware::new( - PostgresSessionStore::new_with_table_name(&state.config.database_url, "sessions") + db::SessionStore::new_with_table_name(&state.config.database_url, "sessions") .await .unwrap(), state.config.session_secret.as_bytes(), diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 3c189833b252e2354e43683e63f4e2fba6db54c6..cd229a3c670f681a7c435635353af63bc2606305 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1,6 +1,8 @@ -use crate::auth::{self, UserId}; - -use super::{auth::PeerExt as _, AppState}; +use super::{ + auth::{self, PeerExt as _}, + db::UserId, + AppState, +}; use anyhow::anyhow; use async_std::task; use async_tungstenite::{ @@ -37,7 +39,7 @@ pub struct State { } struct ConnectionState { - _user_id: i32, + _user_id: UserId, worktrees: HashSet, } @@ -68,7 +70,7 @@ impl WorktreeState { impl State { // Add a new connection associated with a given user. - pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) { + pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) { self.connections.insert( connection_id, ConnectionState { @@ -291,7 +293,7 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let upgrade_receiver = http_res.recv_upgrade().await; let addr = request.remote().unwrap_or("unknown").to_string(); let state = request.state().clone(); - let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0; + let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?; task::spawn(async move { if let Some(stream) = upgrade_receiver.await { let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await; @@ -310,7 +312,7 @@ pub async fn handle_connection( state: Arc, addr: String, stream: Conn, - user_id: i32, + user_id: UserId, ) where Conn: 'static + futures::Sink diff --git a/server/src/tests.rs b/server/src/tests.rs index 66d904746772c5c4d0813ac85ab33b108c074207..d0257e9f4184027561fc15b78b4bc02b2998c541 100644 --- a/server/src/tests.rs +++ b/server/src/tests.rs @@ -1,5 +1,5 @@ use crate::{ - admin, auth, github, + auth, db, github, rpc::{self, add_rpc_routes}, AppState, Config, }; @@ -9,7 +9,6 @@ use rand::prelude::*; use serde_json::json; use sqlx::{ migrate::{MigrateDatabase, Migrator}, - postgres::PgPoolOptions, Executor as _, Postgres, }; use std::{path::Path, sync::Arc}; @@ -499,9 +498,7 @@ impl TestServer { } async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> Client { - let user_id = admin::create_user(&self.app_state.db, name, false) - .await - .unwrap(); + let user_id = self.app_state.db.create_user(name, false).await.unwrap(); let lang_registry = Arc::new(LanguageRegistry::new()); let client = Client::new(lang_registry.clone()); let mut client_router = ForegroundRouter::new(); @@ -532,18 +529,20 @@ impl TestServer { config.database_url = format!("postgres://postgres@localhost/{}", db_name); Self::create_db(&config.database_url).await; - let db = PgPoolOptions::new() - .max_connections(5) - .connect(&config.database_url) - .await - .expect("failed to connect to postgres database"); + let db = db::Db( + db::DbOptions::new() + .max_connections(5) + .connect(&config.database_url) + .await + .expect("failed to connect to postgres database"), + ); let migrator = Migrator::new(Path::new(concat!( env!("CARGO_MANIFEST_DIR"), "/migrations" ))) .await .unwrap(); - migrator.run(&db).await.unwrap(); + migrator.run(&db.0).await.unwrap(); let github_client = github::AppClient::test(); Arc::new(AppState {