Detailed changes
@@ -2,28 +2,22 @@ use super::{
db::{self, UserId},
errors::TideResultExt,
};
-use crate::{github, AppState, Request, RequestExt as _};
+use crate::{github, Request, RequestExt as _};
use anyhow::{anyhow, Context};
use async_trait::async_trait;
pub use oauth2::basic::BasicClient as Client;
-use oauth2::{
- AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
- TokenResponse as _, TokenUrl,
-};
use rand::thread_rng;
use rpc::auth as zed_auth;
use scrypt::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Scrypt,
};
-use serde::{Deserialize, Serialize};
-use std::{borrow::Cow, convert::TryFrom, sync::Arc};
-use surf::{StatusCode, Url};
-use tide::{log, Error, Server};
+use serde::Serialize;
+use std::convert::TryFrom;
+use surf::StatusCode;
+use tide::Error;
static CURRENT_GITHUB_USER: &'static str = "current_github_user";
-static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
-static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
#[derive(Serialize)]
pub struct User {
@@ -99,172 +93,6 @@ impl RequestExt for Request {
}
}
-pub fn build_client(client_id: &str, client_secret: &str) -> Client {
- Client::new(
- ClientId::new(client_id.to_string()),
- Some(oauth2::ClientSecret::new(client_secret.to_string())),
- AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
- Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
- )
-}
-
-pub fn add_routes(app: &mut Server<Arc<AppState>>) {
- app.at("/sign_in").get(get_sign_in);
- app.at("/sign_out").post(post_sign_out);
- app.at("/auth_callback").get(get_auth_callback);
- app.at("/native_app_signin").get(get_sign_in);
- app.at("/native_app_signin_succeeded")
- .get(get_app_signin_success);
-}
-
-#[derive(Debug, Deserialize)]
-struct NativeAppSignInParams {
- native_app_port: String,
- native_app_public_key: String,
- impersonate: Option<String>,
-}
-
-async fn get_sign_in(mut request: Request) -> tide::Result {
- let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
-
- request
- .session_mut()
- .insert("pkce_verifier", pkce_verifier)?;
-
- let mut redirect_url = Url::parse(&format!(
- "{}://{}/auth_callback",
- request
- .header("X-Forwarded-Proto")
- .and_then(|values| values.get(0))
- .map(|value| value.as_str())
- .unwrap_or("http"),
- request.host().unwrap()
- ))?;
-
- let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
- if let Some(query) = app_sign_in_params {
- let mut redirect_query = redirect_url.query_pairs_mut();
- redirect_query
- .clear()
- .append_pair("native_app_port", &query.native_app_port)
- .append_pair("native_app_public_key", &query.native_app_public_key);
-
- if let Some(impersonate) = &query.impersonate {
- redirect_query.append_pair("impersonate", impersonate);
- }
- }
-
- let (auth_url, csrf_token) = request
- .state()
- .auth_client
- .authorize_url(CsrfToken::new_random)
- .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
- .set_pkce_challenge(pkce_challenge)
- .url();
-
- request
- .session_mut()
- .insert("auth_csrf_token", csrf_token)?;
-
- Ok(tide::Redirect::new(auth_url).into())
-}
-
-async fn get_app_signin_success(_: Request) -> tide::Result {
- Ok(tide::Redirect::new("/").into())
-}
-
-async fn get_auth_callback(mut request: Request) -> tide::Result {
- #[derive(Debug, Deserialize)]
- struct Query {
- code: String,
- state: String,
-
- #[serde(flatten)]
- native_app_sign_in_params: Option<NativeAppSignInParams>,
- }
-
- let query: Query = request.query()?;
-
- let pkce_verifier = request
- .session()
- .get("pkce_verifier")
- .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
-
- let csrf_token = request
- .session()
- .get::<CsrfToken>("auth_csrf_token")
- .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
-
- if &query.state != csrf_token.secret() {
- return Err(anyhow!("csrf token does not match").into());
- }
-
- let github_access_token = request
- .state()
- .auth_client
- .exchange_code(AuthorizationCode::new(query.code))
- .set_pkce_verifier(pkce_verifier)
- .request_async(oauth2_surf::http_client)
- .await
- .context("failed to exchange oauth code")?
- .access_token()
- .secret()
- .clone();
-
- let user_details = request
- .state()
- .github_client
- .user(github_access_token)
- .details()
- .await
- .context("failed to fetch user")?;
-
- let user = request
- .db()
- .get_user_by_github_login(&user_details.login)
- .await?;
-
- request
- .session_mut()
- .insert(CURRENT_GITHUB_USER, user_details.clone())?;
-
- // When signing in from the native app, generate a new access token for the current user. Return
- // a redirect so that the user's browser sends this access token to the locally-running app.
- if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
- let mut user_id = user.id;
- if let Some(impersonated_login) = app_sign_in_params.impersonate {
- log::info!("attempting to impersonate user @{}", impersonated_login);
- if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() {
- if user.admin {
- user_id = request.db().create_user(&impersonated_login, false).await?;
- log::info!("impersonating user {}", user_id.0);
- } else {
- log::info!("refusing to impersonate user");
- }
- }
- }
-
- let access_token = create_access_token(request.db().as_ref(), user_id).await?;
- let encrypted_access_token = encrypt_access_token(
- &access_token,
- app_sign_in_params.native_app_public_key.clone(),
- )?;
-
- return Ok(tide::Redirect::new(&format!(
- "http://127.0.0.1:{}?user_id={}&access_token={}",
- app_sign_in_params.native_app_port, user_id.0, encrypted_access_token,
- ))
- .into());
- }
-
- Ok(tide::Redirect::new("/").into())
-}
-
-async fn post_sign_out(mut request: Request) -> tide::Result {
- request.session_mut().remove(CURRENT_GITHUB_USER);
- Ok(tide::Redirect::new("/").into())
-}
-
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result<String> {
@@ -1,47 +1,3 @@
-use crate::{AppState, LayoutData, Request, RequestExt};
-use async_trait::async_trait;
-use serde::Serialize;
-use std::sync::Arc;
-use tide::http::mime;
-
-pub struct Middleware;
-
-#[async_trait]
-impl tide::Middleware<Arc<AppState>> for Middleware {
- async fn handle(
- &self,
- mut request: Request,
- next: tide::Next<'_, Arc<AppState>>,
- ) -> tide::Result {
- let app = request.state().clone();
- let layout_data = request.layout_data().await?;
-
- let mut response = next.run(request).await;
-
- #[derive(Serialize)]
- struct ErrorData {
- #[serde(flatten)]
- layout: Arc<LayoutData>,
- status: u16,
- reason: &'static str,
- }
-
- if !response.status().is_success() {
- response.set_body(app.render_template(
- "error.hbs",
- &ErrorData {
- layout: layout_data,
- status: response.status().into(),
- reason: response.status().canonical_reason(),
- },
- )?);
- response.set_content_type(mime::HTML);
- }
-
- Ok(response)
- }
-}
-
// Allow tide Results to accept context like other Results do when
// using anyhow.
pub trait TideResultExt {
@@ -1,43 +1 @@
-use std::{future::Future, time::Instant};
-use async_std::sync::Mutex;
-
-#[derive(Default)]
-pub struct Expiring<T>(Mutex<Option<ExpiringState<T>>>);
-
-pub struct ExpiringState<T> {
- value: T,
- expires_at: Instant,
-}
-
-impl<T: Clone> Expiring<T> {
- pub async fn get_or_refresh<F, G>(&self, f: F) -> tide::Result<T>
- where
- F: FnOnce() -> G,
- G: Future<Output = tide::Result<(T, Instant)>>,
- {
- let mut state = self.0.lock().await;
-
- if let Some(state) = state.as_mut() {
- if Instant::now() >= state.expires_at {
- let (value, expires_at) = f().await?;
- state.value = value.clone();
- state.expires_at = expires_at;
- Ok(value)
- } else {
- Ok(state.value.clone())
- }
- } else {
- let (value, expires_at) = f().await?;
- *state = Some(ExpiringState {
- value: value.clone(),
- expires_at,
- });
- Ok(value)
- }
- }
-
- pub async fn clear(&self) {
- self.0.lock().await.take();
- }
-}
@@ -1,12 +1,4 @@
-use crate::expiring::Expiring;
-use anyhow::{anyhow, Context};
-use serde::{de::DeserializeOwned, Deserialize, Serialize};
-use std::{
- future::Future,
- sync::Arc,
- time::{Duration, Instant},
-};
-use surf::{http::Method, RequestBuilder, Url};
+use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, Serialize)]
pub struct Release {
@@ -23,259 +15,14 @@ pub struct Asset {
pub url: String,
}
-pub struct AppClient {
- id: usize,
- private_key: String,
- jwt_bearer_header: Expiring<String>,
-}
-
#[derive(Deserialize)]
struct Installation {
#[allow(unused)]
id: usize,
}
-impl AppClient {
- #[cfg(test)]
- pub fn test() -> Arc<Self> {
- Arc::new(Self {
- id: Default::default(),
- private_key: Default::default(),
- jwt_bearer_header: Default::default(),
- })
- }
-
- pub fn new(id: usize, private_key: String) -> Arc<Self> {
- Arc::new(Self {
- id,
- private_key,
- jwt_bearer_header: Default::default(),
- })
- }
-
- pub async fn repo(self: &Arc<Self>, nwo: String) -> tide::Result<RepoClient> {
- let installation: Installation = self
- .request(
- Method::Get,
- &format!("/repos/{}/installation", &nwo),
- |refresh| self.bearer_header(refresh),
- )
- .await?;
-
- Ok(RepoClient {
- app: self.clone(),
- nwo,
- installation_id: installation.id,
- installation_token_header: Default::default(),
- })
- }
-
- pub fn user(self: &Arc<Self>, access_token: String) -> UserClient {
- UserClient {
- app: self.clone(),
- access_token,
- }
- }
-
- async fn request<T, F, G>(
- &self,
- method: Method,
- path: &str,
- get_auth_header: F,
- ) -> tide::Result<T>
- where
- T: DeserializeOwned,
- F: Fn(bool) -> G,
- G: Future<Output = tide::Result<String>>,
- {
- let mut retried = false;
-
- loop {
- let response = RequestBuilder::new(
- method,
- Url::parse(&format!("https://api.github.com{}", path))?,
- )
- .header("Accept", "application/vnd.github.v3+json")
- .header("Authorization", get_auth_header(retried).await?)
- .recv_json()
- .await;
-
- if let Err(error) = response.as_ref() {
- if error.status() == 401 && !retried {
- retried = true;
- continue;
- }
- }
-
- return response;
- }
- }
-
- async fn bearer_header(&self, refresh: bool) -> tide::Result<String> {
- if refresh {
- self.jwt_bearer_header.clear().await;
- }
-
- self.jwt_bearer_header
- .get_or_refresh(|| async {
- use jwt_simple::{algorithms::RS256KeyPair, prelude::*};
- use std::time;
-
- let key_pair = RS256KeyPair::from_pem(&self.private_key)
- .with_context(|| format!("invalid private key {:?}", self.private_key))?;
- let mut claims = Claims::create(Duration::from_mins(10));
- claims.issued_at = Some(Clock::now_since_epoch() - Duration::from_mins(1));
- claims.issuer = Some(self.id.to_string());
- let token = key_pair.sign(claims).context("failed to sign claims")?;
- let expires_at = time::Instant::now() + time::Duration::from_secs(9 * 60);
-
- Ok((format!("Bearer {}", token), expires_at))
- })
- .await
- }
-
- async fn installation_token_header(
- &self,
- header: &Expiring<String>,
- installation_id: usize,
- refresh: bool,
- ) -> tide::Result<String> {
- if refresh {
- header.clear().await;
- }
-
- header
- .get_or_refresh(|| async {
- #[derive(Debug, Deserialize)]
- struct AccessToken {
- token: String,
- }
-
- let access_token: AccessToken = self
- .request(
- Method::Post,
- &format!("/app/installations/{}/access_tokens", installation_id),
- |refresh| self.bearer_header(refresh),
- )
- .await?;
-
- let header = format!("Token {}", access_token.token);
- let expires_at = Instant::now() + Duration::from_secs(60 * 30);
-
- Ok((header, expires_at))
- })
- .await
- }
-}
-
-pub struct RepoClient {
- app: Arc<AppClient>,
- nwo: String,
- installation_id: usize,
- installation_token_header: Expiring<String>,
-}
-
-impl RepoClient {
- #[cfg(test)]
- pub fn test(app_client: &Arc<AppClient>) -> Self {
- Self {
- app: app_client.clone(),
- nwo: String::new(),
- installation_id: 0,
- installation_token_header: Default::default(),
- }
- }
-
- pub async fn releases(&self) -> tide::Result<Vec<Release>> {
- self.get(&format!("/repos/{}/releases?per_page=100", self.nwo))
- .await
- }
-
- pub async fn release_asset(&self, tag: &str, name: &str) -> tide::Result<surf::Body> {
- let release: Release = self
- .get(&format!("/repos/{}/releases/tags/{}", self.nwo, tag))
- .await?;
-
- let asset = release
- .assets
- .iter()
- .find(|asset| asset.name == name)
- .ok_or_else(|| anyhow!("no asset found with name {}", name))?;
-
- let request = surf::get(&asset.url)
- .header("Accept", "application/octet-stream'")
- .header(
- "Authorization",
- self.installation_token_header(false).await?,
- );
-
- let client = surf::client();
- let mut response = client.send(request).await?;
-
- // Avoid using `surf::middleware::Redirect` because that type forwards
- // the original request headers to the redirect URI. In this case, the
- // redirect will be to S3, which forbids us from supplying an
- // `Authorization` header.
- if response.status().is_redirection() {
- if let Some(url) = response.header("location") {
- let request = surf::get(url.as_str()).header("Accept", "application/octet-stream");
- response = client.send(request).await?;
- }
- }
-
- if !response.status().is_success() {
- Err(anyhow!("failed to fetch release asset {} {}", tag, name))?;
- }
-
- Ok(response.take_body())
- }
-
- async fn get<T: DeserializeOwned>(&self, path: &str) -> tide::Result<T> {
- self.request::<T>(Method::Get, path).await
- }
-
- async fn request<T: DeserializeOwned>(&self, method: Method, path: &str) -> tide::Result<T> {
- Ok(self
- .app
- .request(method, path, |refresh| {
- self.installation_token_header(refresh)
- })
- .await?)
- }
-
- async fn installation_token_header(&self, refresh: bool) -> tide::Result<String> {
- self.app
- .installation_token_header(
- &self.installation_token_header,
- self.installation_id,
- refresh,
- )
- .await
- }
-}
-
-pub struct UserClient {
- app: Arc<AppClient>,
- access_token: String,
-}
-
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct User {
pub login: String,
pub avatar_url: String,
}
-
-impl UserClient {
- pub async fn details(&self) -> tide::Result<User> {
- Ok(self
- .app
- .request(Method::Get, "/user", |_| async {
- Ok(self.access_token_header())
- })
- .await?)
- }
-
- fn access_token_header(&self) -> String {
- format!("Token {}", self.access_token)
- }
-}
@@ -8,17 +8,14 @@ mod expiring;
mod github;
mod rpc;
-use self::errors::TideResultExt as _;
use ::rpc::Peer;
-use anyhow::Result;
use async_std::net::TcpListener;
use async_trait::async_trait;
-use auth::RequestExt as _;
use db::{Db, PostgresDb};
-use handlebars::{Handlebars, TemplateRenderError};
+use handlebars::Handlebars;
use parking_lot::RwLock;
use rust_embed::RustEmbed;
-use serde::{Deserialize, Serialize};
+use serde::Deserialize;
use std::sync::Arc;
use surf::http::cookies::SameSite;
use tide::sessions::SessionMiddleware;
@@ -45,28 +42,16 @@ pub struct Config {
pub struct AppState {
db: Arc<dyn Db>,
handlebars: RwLock<Handlebars<'static>>,
- auth_client: auth::Client,
- github_client: Arc<github::AppClient>,
- repo_client: github::RepoClient,
config: Config,
}
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
let db = PostgresDb::new(&config.database_url, 5).await?;
- let github_client =
- github::AppClient::new(config.github_app_id, config.github_private_key.clone());
- let repo_client = github_client
- .repo("zed-industries/zed".into())
- .await
- .context("failed to initialize github client")?;
let this = Self {
db: Arc::new(db),
handlebars: Default::default(),
- auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
- github_client,
- repo_client,
config,
};
this.register_partials();
@@ -87,49 +72,20 @@ impl AppState {
}
}
}
-
- fn render_template(
- &self,
- path: &'static str,
- data: &impl Serialize,
- ) -> Result<String, TemplateRenderError> {
- #[cfg(debug_assertions)]
- self.register_partials();
-
- self.handlebars.read().render_template(
- std::str::from_utf8(&Templates::get(path).unwrap().data).unwrap(),
- data,
- )
- }
}
#[async_trait]
trait RequestExt {
- async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
fn db(&self) -> &Arc<dyn Db>;
}
#[async_trait]
impl RequestExt for Request {
- async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>> {
- if self.ext::<Arc<LayoutData>>().is_none() {
- self.set_ext(Arc::new(LayoutData {
- current_user: self.current_user().await?,
- }));
- }
- Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
- }
-
fn db(&self) -> &Arc<dyn Db> {
&self.state().db
}
}
-#[derive(Serialize)]
-struct LayoutData {
- current_user: Option<auth::User>,
-}
-
#[async_std::main]
async fn main() -> tide::Result<()> {
if std::env::var("LOG_JSON").is_ok() {
@@ -173,9 +129,7 @@ pub async fn run_server(
)
.with_same_site_policy(SameSite::Lax), // Required obtain our session in /auth_callback
);
- web.with(errors::Middleware);
api::add_routes(&mut web);
- auth::add_routes(&mut web);
let mut assets = tide::new();
assets.with(CompressMiddleware::new());
@@ -1180,9 +1180,8 @@ fn header_contains_ignore_case<T>(
mod tests {
use super::*;
use crate::{
- auth,
db::{tests::TestDb, UserId},
- github, AppState, Config,
+ AppState, Config,
};
use ::rpc::Peer;
use client::{
@@ -5731,13 +5730,9 @@ mod tests {
let mut config = Config::default();
config.session_secret = "a".repeat(32);
config.database_url = test_db.url.clone();
- let github_client = github::AppClient::test();
Arc::new(AppState {
db: test_db.db().clone(),
handlebars: Default::default(),
- auth_client: auth::build_client("", ""),
- repo_client: github::RepoClient::test(&github_client),
- github_client,
config,
})
}