Detailed changes
@@ -1,7 +1,6 @@
-use crate::{auth::RequestExt as _, AppState, DbPool, LayoutData, Request, RequestExt as _};
+use crate::{auth::RequestExt as _, db, AppState, LayoutData, Request, RequestExt as _};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
-use sqlx::{Executor, FromRow};
use std::sync::Arc;
use surf::http::mime;
@@ -41,23 +40,8 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
struct AdminData {
#[serde(flatten)]
layout: Arc<LayoutData>,
- users: Vec<User>,
- signups: Vec<Signup>,
-}
-
-#[derive(Debug, FromRow, Serialize)]
-pub struct User {
- pub id: i32,
- pub github_login: String,
- pub admin: bool,
-}
-
-#[derive(Debug, FromRow, Serialize)]
-pub struct Signup {
- pub id: i32,
- pub github_login: String,
- pub email_address: String,
- pub about: String,
+ users: Vec<db::User>,
+ signups: Vec<db::Signup>,
}
async fn get_admin_page(mut request: Request) -> tide::Result {
@@ -65,12 +49,8 @@ async fn get_admin_page(mut request: Request) -> tide::Result {
let data = AdminData {
layout: request.layout_data().await?,
- users: sqlx::query_as("SELECT * FROM users ORDER BY github_login ASC")
- .fetch_all(request.db())
- .await?,
- signups: sqlx::query_as("SELECT * FROM signups ORDER BY id DESC")
- .fetch_all(request.db())
- .await?,
+ users: request.db().get_all_users().await?,
+ signups: request.db().get_all_signups().await?,
};
Ok(tide::Response::builder(200)
@@ -96,7 +76,7 @@ async fn post_user(mut request: Request) -> tide::Result {
.unwrap_or(&form.github_login);
if !github_login.is_empty() {
- create_user(request.db(), github_login, form.admin).await?;
+ request.db().create_user(github_login, form.admin).await?;
}
Ok(tide::Redirect::new("/admin").into())
@@ -116,11 +96,7 @@ async fn put_user(mut request: Request) -> tide::Result {
request
.db()
- .execute(
- sqlx::query("UPDATE users SET admin = $1 WHERE id = $2;")
- .bind(body.admin)
- .bind(user_id),
- )
+ .set_user_is_admin(db::UserId(user_id), body.admin)
.await?;
Ok(tide::Response::builder(200).build())
@@ -128,33 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result {
async fn delete_user(request: Request) -> tide::Result {
request.require_admin().await?;
-
- let user_id = request.param("id")?.parse::<i32>()?;
- request
- .db()
- .execute(sqlx::query("DELETE FROM users WHERE id = $1;").bind(user_id))
- .await?;
-
+ let user_id = db::UserId(request.param("id")?.parse::<i32>()?);
+ request.db().delete_user(user_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
-pub async fn create_user(db: &DbPool, github_login: &str, admin: bool) -> tide::Result<i32> {
- let id: i32 =
- sqlx::query_scalar("INSERT INTO users (github_login, admin) VALUES ($1, $2) RETURNING id;")
- .bind(github_login)
- .bind(admin)
- .fetch_one(db)
- .await?;
- Ok(id)
-}
-
async fn delete_signup(request: Request) -> tide::Result {
request.require_admin().await?;
- let signup_id = request.param("id")?.parse::<i32>()?;
- request
- .db()
- .execute(sqlx::query("DELETE FROM signups WHERE id = $1;").bind(signup_id))
- .await?;
-
+ let signup_id = db::SignupId(request.param("id")?.parse::<i32>()?);
+ request.db().delete_signup(signup_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
@@ -1,7 +1,9 @@
-use super::errors::TideResultExt;
-use crate::{github, rpc, AppState, DbPool, Request, RequestExt as _};
+use super::{
+ db::{self, UserId},
+ errors::TideResultExt,
+};
+use crate::{github, rpc, AppState, Request, RequestExt as _};
use anyhow::{anyhow, Context};
-use async_std::stream::StreamExt;
use async_trait::async_trait;
pub use oauth2::basic::BasicClient as Client;
use oauth2::{
@@ -14,7 +16,6 @@ use scrypt::{
Scrypt,
};
use serde::{Deserialize, Serialize};
-use sqlx::FromRow;
use std::{borrow::Cow, convert::TryFrom, sync::Arc};
use surf::Url;
use tide::Server;
@@ -34,9 +35,6 @@ pub struct User {
pub struct VerifyToken;
-#[derive(Clone, Copy)]
-pub struct UserId(pub i32);
-
#[async_trait]
impl tide::Middleware<Arc<AppState>> for VerifyToken {
async fn handle(
@@ -51,33 +49,28 @@ impl tide::Middleware<Arc<AppState>> for VerifyToken {
.as_str()
.split_whitespace();
- let user_id: i32 = auth_header
- .next()
- .ok_or_else(|| anyhow!("missing user id in authorization header"))?
- .parse()?;
+ let user_id = UserId(
+ auth_header
+ .next()
+ .ok_or_else(|| anyhow!("missing user id in authorization header"))?
+ .parse()?,
+ );
let access_token = auth_header
.next()
.ok_or_else(|| anyhow!("missing access token in authorization header"))?;
let state = request.state().clone();
- let mut password_hashes =
- sqlx::query_scalar::<_, String>("SELECT hash FROM access_tokens WHERE user_id = $1")
- .bind(&user_id)
- .fetch_many(&state.db);
-
let mut credentials_valid = false;
- while let Some(password_hash) = password_hashes.next().await {
- if let either::Either::Right(password_hash) = password_hash? {
- if verify_access_token(&access_token, &password_hash)? {
- credentials_valid = true;
- break;
- }
+ for password_hash in state.db.get_access_token_hashes(user_id).await? {
+ if verify_access_token(&access_token, &password_hash)? {
+ credentials_valid = true;
+ break;
}
}
if credentials_valid {
- request.set_ext(UserId(user_id));
+ request.set_ext(user_id);
Ok(next.run(request).await)
} else {
Err(anyhow!("invalid credentials").into())
@@ -94,25 +87,12 @@ pub trait RequestExt {
impl RequestExt for Request {
async fn current_user(&self) -> tide::Result<Option<User>> {
if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
- #[derive(FromRow)]
- struct UserRow {
- admin: bool,
- }
-
- let user_row: Option<UserRow> =
- sqlx::query_as("SELECT admin FROM users WHERE github_login = $1")
- .bind(&details.login)
- .fetch_optional(self.db())
- .await?;
-
- let is_insider = user_row.is_some();
- let is_admin = user_row.map_or(false, |row| row.admin);
-
+ let user = self.db().get_user_by_github_login(&details.login).await?;
Ok(Some(User {
github_login: details.login,
avatar_url: details.avatar_url,
- is_insider,
- is_admin,
+ is_insider: user.is_some(),
+ is_admin: user.map_or(false, |user| user.admin),
}))
} else {
Ok(None)
@@ -265,9 +245,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
.await
.context("failed to fetch user")?;
- let user_id: Option<i32> = sqlx::query_scalar("SELECT id from users where github_login = $1")
- .bind(&user_details.login)
- .fetch_optional(request.db())
+ let user = request
+ .db()
+ .get_user_by_github_login(&user_details.login)
.await?;
request
@@ -276,8 +256,8 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
// When signing in from the native app, generate a new access token for the current user. Return
// a redirect so that the user's browser sends this access token to the locally-running app.
- if let Some((user_id, app_sign_in_params)) = user_id.zip(query.native_app_sign_in_params) {
- let access_token = create_access_token(request.db(), user_id).await?;
+ if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
+ let access_token = create_access_token(request.db(), user.id()).await?;
let native_app_public_key =
zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
.context("failed to parse app public key")?;
@@ -287,7 +267,9 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
return Ok(tide::Redirect::new(&format!(
"http://127.0.0.1:{}?user_id={}&access_token={}",
- app_sign_in_params.native_app_port, user_id, encrypted_access_token,
+ app_sign_in_params.native_app_port,
+ user.id().0,
+ encrypted_access_token,
))
.into());
}
@@ -300,14 +282,11 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
Ok(tide::Redirect::new("/").into())
}
-pub async fn create_access_token(db: &DbPool, user_id: i32) -> tide::Result<String> {
+pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
let access_token = zed_auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
- sqlx::query("INSERT INTO access_tokens (user_id, hash) values ($1, $2)")
- .bind(user_id)
- .bind(access_token_hash)
- .fetch_optional(db)
+ db.create_access_token_hash(user_id, access_token_hash)
.await?;
Ok(access_token)
}
@@ -0,0 +1,276 @@
+use serde::Serialize;
+use sqlx::{FromRow, Result};
+
+pub use async_sqlx_session::PostgresSessionStore as SessionStore;
+pub use sqlx::postgres::PgPoolOptions as DbOptions;
+
+pub struct Db(pub sqlx::PgPool);
+
+#[derive(Debug, FromRow, Serialize)]
+pub struct User {
+ id: i32,
+ pub github_login: String,
+ pub admin: bool,
+}
+
+#[derive(Debug, FromRow, Serialize)]
+pub struct Signup {
+ id: i32,
+ pub github_login: String,
+ pub email_address: String,
+ pub about: String,
+}
+
+#[derive(Debug, FromRow)]
+pub struct ChannelMessage {
+ id: i32,
+ sender_id: i32,
+ body: String,
+ sent_at: i64,
+}
+
+#[derive(Clone, Copy)]
+pub struct UserId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct OrgId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct ChannelId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct SignupId(pub i32);
+
+#[derive(Clone, Copy)]
+pub struct MessageId(pub i32);
+
+impl Db {
+ // signups
+
+ pub async fn create_signup(
+ &self,
+ github_login: &str,
+ email_address: &str,
+ about: &str,
+ ) -> Result<SignupId> {
+ let query = "
+ INSERT INTO signups (github_login, email_address, about)
+ VALUES ($1, $2, $3)
+ RETURNING id
+ ";
+ sqlx::query_scalar(query)
+ .bind(github_login)
+ .bind(email_address)
+ .bind(about)
+ .fetch_one(&self.0)
+ .await
+ .map(SignupId)
+ }
+
+ pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+ let query = "SELECT * FROM users ORDER BY github_login ASC";
+ sqlx::query_as(query).fetch_all(&self.0).await
+ }
+
+ pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
+ let query = "DELETE FROM signups WHERE id = $1";
+ sqlx::query(query)
+ .bind(id.0)
+ .execute(&self.0)
+ .await
+ .map(drop)
+ }
+
+ // users
+
+ pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+ let query = "
+ INSERT INTO users (github_login, admin)
+ VALUES ($1, $2)
+ RETURNING id
+ ";
+ sqlx::query_scalar(query)
+ .bind(github_login)
+ .bind(admin)
+ .fetch_one(&self.0)
+ .await
+ .map(UserId)
+ }
+
+ pub async fn get_all_users(&self) -> Result<Vec<User>> {
+ let query = "SELECT * FROM users ORDER BY github_login ASC";
+ sqlx::query_as(query).fetch_all(&self.0).await
+ }
+
+ pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+ let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
+ sqlx::query_as(query)
+ .bind(github_login)
+ .fetch_optional(&self.0)
+ .await
+ }
+
+ pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+ let query = "UPDATE users SET admin = $1 WHERE id = $2";
+ sqlx::query(query)
+ .bind(is_admin)
+ .bind(id.0)
+ .execute(&self.0)
+ .await
+ .map(drop)
+ }
+
+ pub async fn delete_user(&self, id: UserId) -> Result<()> {
+ let query = "DELETE FROM users WHERE id = $1;";
+ sqlx::query(query)
+ .bind(id.0)
+ .execute(&self.0)
+ .await
+ .map(drop)
+ }
+
+ // access tokens
+
+ pub async fn create_access_token_hash(
+ &self,
+ user_id: UserId,
+ access_token_hash: String,
+ ) -> Result<()> {
+ let query = "
+ INSERT INTO access_tokens (user_id, hash)
+ VALUES ($1, $2)
+ ";
+ sqlx::query(query)
+ .bind(user_id.0 as i32)
+ .bind(access_token_hash)
+ .execute(&self.0)
+ .await
+ .map(drop)
+ }
+
+ pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+ let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
+ sqlx::query_scalar::<_, String>(query)
+ .bind(user_id.0 as i32)
+ .fetch_all(&self.0)
+ .await
+ }
+
+ // orgs
+
+ pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+ let query = "
+ INSERT INTO orgs (name, slug)
+ VALUES ($1, $2)
+ RETURNING id
+ ";
+ sqlx::query_scalar(query)
+ .bind(name)
+ .bind(slug)
+ .fetch_one(&self.0)
+ .await
+ .map(OrgId)
+ }
+
+ pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
+ let query = "
+ INSERT INTO org_memberships (org_id, user_id)
+ VALUES ($1, $2)
+ ";
+ sqlx::query(query)
+ .bind(org_id.0)
+ .bind(user_id.0)
+ .execute(&self.0)
+ .await
+ .map(drop)
+ }
+
+ // channels
+
+ pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+ let query = "
+ INSERT INTO channels (owner_id, owner_is_user, name)
+ VALUES ($1, false, $2)
+ RETURNING id
+ ";
+ sqlx::query_scalar(query)
+ .bind(org_id.0)
+ .bind(name)
+ .fetch_one(&self.0)
+ .await
+ .map(ChannelId)
+ }
+
+ pub async fn add_channel_member(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()> {
+ let query = "
+ INSERT INTO channel_memberships (channel_id, user_id, admin)
+ VALUES ($1, $2, $3)
+ ";
+ sqlx::query(query)
+ .bind(channel_id.0)
+ .bind(user_id.0)
+ .bind(is_admin)
+ .execute(&self.0)
+ .await
+ .map(drop)
+ }
+
+ // messages
+
+ pub async fn create_channel_message(
+ &self,
+ channel_id: ChannelId,
+ sender_id: UserId,
+ body: &str,
+ ) -> Result<MessageId> {
+ let query = "
+ INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
+ VALUES ($1, $2, $3, NOW()::timestamp)
+ RETURNING id
+ ";
+ sqlx::query_scalar(query)
+ .bind(channel_id.0)
+ .bind(sender_id.0)
+ .bind(body)
+ .fetch_one(&self.0)
+ .await
+ .map(MessageId)
+ }
+
+ pub async fn get_recent_channel_messages(
+ &self,
+ channel_id: ChannelId,
+ count: usize,
+ ) -> Result<Vec<ChannelMessage>> {
+ let query = "
+ SELECT id, sender_id, body, sent_at
+ FROM channel_messages
+ WHERE channel_id = $1
+ LIMIT $2
+ ";
+ sqlx::query_as(query)
+ .bind(channel_id.0)
+ .bind(count as i64)
+ .fetch_all(&self.0)
+ .await
+ }
+}
+
+impl std::ops::Deref for Db {
+ type Target = sqlx::PgPool;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl User {
+ pub fn id(&self) -> UserId {
+ UserId(self.id)
+ }
+}
@@ -3,7 +3,6 @@ use crate::{
};
use comrak::ComrakOptions;
use serde::{Deserialize, Serialize};
-use sqlx::Executor as _;
use std::sync::Arc;
use tide::{http::mime, log, Server};
@@ -76,14 +75,7 @@ async fn post_signup(mut request: Request) -> tide::Result {
// Save signup in the database
request
.db()
- .execute(
- sqlx::query(
- "INSERT INTO signups (github_login, email_address, about) VALUES ($1, $2, $3);",
- )
- .bind(&form.github_login)
- .bind(&form.email_address)
- .bind(&form.about),
- )
+ .create_signup(&form.github_login, &form.email_address, &form.about)
.await?;
let layout_data = request.layout_data().await?;
@@ -1,6 +1,7 @@
mod admin;
mod assets;
mod auth;
+mod db;
mod env;
mod errors;
mod expiring;
@@ -13,15 +14,14 @@ mod tests;
use self::errors::TideResultExt as _;
use anyhow::{Context, Result};
-use async_sqlx_session::PostgresSessionStore;
use async_std::{net::TcpListener, sync::RwLock as AsyncRwLock};
use async_trait::async_trait;
use auth::RequestExt as _;
+use db::{Db, DbOptions};
use handlebars::{Handlebars, TemplateRenderError};
use parking_lot::RwLock;
use rust_embed::RustEmbed;
use serde::{Deserialize, Serialize};
-use sqlx::postgres::{PgPool, PgPoolOptions};
use std::sync::Arc;
use surf::http::cookies::SameSite;
use tide::{log, sessions::SessionMiddleware};
@@ -29,7 +29,6 @@ use tide_compress::CompressMiddleware;
use zrpc::Peer;
type Request = tide::Request<Arc<AppState>>;
-type DbPool = PgPool;
#[derive(RustEmbed)]
#[folder = "templates"]
@@ -47,7 +46,7 @@ pub struct Config {
}
pub struct AppState {
- db: sqlx::PgPool,
+ db: Db,
handlebars: RwLock<Handlebars<'static>>,
auth_client: auth::Client,
github_client: Arc<github::AppClient>,
@@ -58,11 +57,11 @@ pub struct AppState {
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
- let db = PgPoolOptions::new()
+ let db = Db(DbOptions::new()
.max_connections(5)
.connect(&config.database_url)
.await
- .context("failed to connect to postgres database")?;
+ .context("failed to connect to postgres database")?);
let github_client =
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
@@ -117,7 +116,7 @@ impl AppState {
#[async_trait]
trait RequestExt {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
- fn db(&self) -> &DbPool;
+ fn db(&self) -> &Db;
}
#[async_trait]
@@ -131,7 +130,7 @@ impl RequestExt for Request {
Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
}
- fn db(&self) -> &DbPool {
+ fn db(&self) -> &Db {
&self.state().db
}
}
@@ -173,7 +172,7 @@ pub async fn run_server(
web.with(CompressMiddleware::new());
web.with(
SessionMiddleware::new(
- PostgresSessionStore::new_with_table_name(&state.config.database_url, "sessions")
+ db::SessionStore::new_with_table_name(&state.config.database_url, "sessions")
.await
.unwrap(),
state.config.session_secret.as_bytes(),
@@ -1,6 +1,8 @@
-use crate::auth::{self, UserId};
-
-use super::{auth::PeerExt as _, AppState};
+use super::{
+ auth::{self, PeerExt as _},
+ db::UserId,
+ AppState,
+};
use anyhow::anyhow;
use async_std::task;
use async_tungstenite::{
@@ -37,7 +39,7 @@ pub struct State {
}
struct ConnectionState {
- _user_id: i32,
+ _user_id: UserId,
worktrees: HashSet<u64>,
}
@@ -68,7 +70,7 @@ impl WorktreeState {
impl State {
// Add a new connection associated with a given user.
- pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: i32) {
+ pub fn add_connection(&mut self, connection_id: ConnectionId, _user_id: UserId) {
self.connections.insert(
connection_id,
ConnectionState {
@@ -291,7 +293,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let upgrade_receiver = http_res.recv_upgrade().await;
let addr = request.remote().unwrap_or("unknown").to_string();
let state = request.state().clone();
- let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?.0;
+ let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
@@ -310,7 +312,7 @@ pub async fn handle_connection<Conn>(
state: Arc<AppState>,
addr: String,
stream: Conn,
- user_id: i32,
+ user_id: UserId,
) where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
@@ -1,5 +1,5 @@
use crate::{
- admin, auth, github,
+ auth, db, github,
rpc::{self, add_rpc_routes},
AppState, Config,
};
@@ -9,7 +9,6 @@ use rand::prelude::*;
use serde_json::json;
use sqlx::{
migrate::{MigrateDatabase, Migrator},
- postgres::PgPoolOptions,
Executor as _, Postgres,
};
use std::{path::Path, sync::Arc};
@@ -499,9 +498,7 @@ impl TestServer {
}
async fn create_client(&mut self, cx: &mut TestAppContext, name: &str) -> Client {
- let user_id = admin::create_user(&self.app_state.db, name, false)
- .await
- .unwrap();
+ let user_id = self.app_state.db.create_user(name, false).await.unwrap();
let lang_registry = Arc::new(LanguageRegistry::new());
let client = Client::new(lang_registry.clone());
let mut client_router = ForegroundRouter::new();
@@ -532,18 +529,20 @@ impl TestServer {
config.database_url = format!("postgres://postgres@localhost/{}", db_name);
Self::create_db(&config.database_url).await;
- let db = PgPoolOptions::new()
- .max_connections(5)
- .connect(&config.database_url)
- .await
- .expect("failed to connect to postgres database");
+ let db = db::Db(
+ db::DbOptions::new()
+ .max_connections(5)
+ .connect(&config.database_url)
+ .await
+ .expect("failed to connect to postgres database"),
+ );
let migrator = Migrator::new(Path::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations"
)))
.await
.unwrap();
- migrator.run(&db).await.unwrap();
+ migrator.run(&db.0).await.unwrap();
let github_client = github::AppClient::test();
Arc::new(AppState {