auth.rs

  1use super::{
  2    db::{self, UserId},
  3    errors::TideResultExt,
  4};
  5use crate::{github, AppState, Request, RequestExt as _};
  6use anyhow::{anyhow, Context};
  7use async_trait::async_trait;
  8pub use oauth2::basic::BasicClient as Client;
  9use oauth2::{
 10    AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
 11    TokenResponse as _, TokenUrl,
 12};
 13use rand::thread_rng;
 14use scrypt::{
 15    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
 16    Scrypt,
 17};
 18use serde::{Deserialize, Serialize};
 19use std::{borrow::Cow, convert::TryFrom, sync::Arc};
 20use surf::{StatusCode, Url};
 21use tide::{log, Error, Server};
 22use rpc::auth as zed_auth;
 23
 24static CURRENT_GITHUB_USER: &'static str = "current_github_user";
 25static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
 26static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
 27
 28#[derive(Serialize)]
 29pub struct User {
 30    pub github_login: String,
 31    pub avatar_url: String,
 32    pub is_insider: bool,
 33    pub is_admin: bool,
 34}
 35
 36pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
 37    let mut auth_header = request
 38        .header("Authorization")
 39        .ok_or_else(|| {
 40            Error::new(
 41                StatusCode::BadRequest,
 42                anyhow!("missing authorization header"),
 43            )
 44        })?
 45        .last()
 46        .as_str()
 47        .split_whitespace();
 48    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
 49        Error::new(
 50            StatusCode::BadRequest,
 51            anyhow!("missing user id in authorization header"),
 52        )
 53    })?);
 54    let access_token = auth_header.next().ok_or_else(|| {
 55        Error::new(
 56            StatusCode::BadRequest,
 57            anyhow!("missing access token in authorization header"),
 58        )
 59    })?;
 60
 61    let state = request.state().clone();
 62    let mut credentials_valid = false;
 63    for password_hash in state.db.get_access_token_hashes(user_id).await? {
 64        if verify_access_token(&access_token, &password_hash)? {
 65            credentials_valid = true;
 66            break;
 67        }
 68    }
 69
 70    if !credentials_valid {
 71        Err(Error::new(
 72            StatusCode::Unauthorized,
 73            anyhow!("invalid credentials"),
 74        ))?;
 75    }
 76
 77    Ok(user_id)
 78}
 79
 80#[async_trait]
 81pub trait RequestExt {
 82    async fn current_user(&self) -> tide::Result<Option<User>>;
 83}
 84
 85#[async_trait]
 86impl RequestExt for Request {
 87    async fn current_user(&self) -> tide::Result<Option<User>> {
 88        if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
 89            let user = self.db().get_user_by_github_login(&details.login).await?;
 90            Ok(Some(User {
 91                github_login: details.login,
 92                avatar_url: details.avatar_url,
 93                is_insider: user.is_some(),
 94                is_admin: user.map_or(false, |user| user.admin),
 95            }))
 96        } else {
 97            Ok(None)
 98        }
 99    }
100}
101
102pub fn build_client(client_id: &str, client_secret: &str) -> Client {
103    Client::new(
104        ClientId::new(client_id.to_string()),
105        Some(oauth2::ClientSecret::new(client_secret.to_string())),
106        AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
107        Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
108    )
109}
110
111pub fn add_routes(app: &mut Server<Arc<AppState>>) {
112    app.at("/sign_in").get(get_sign_in);
113    app.at("/sign_out").post(post_sign_out);
114    app.at("/auth_callback").get(get_auth_callback);
115}
116
117#[derive(Debug, Deserialize)]
118struct NativeAppSignInParams {
119    native_app_port: String,
120    native_app_public_key: String,
121    impersonate: Option<String>,
122}
123
124async fn get_sign_in(mut request: Request) -> tide::Result {
125    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
126
127    request
128        .session_mut()
129        .insert("pkce_verifier", pkce_verifier)?;
130
131    let mut redirect_url = Url::parse(&format!(
132        "{}://{}/auth_callback",
133        request
134            .header("X-Forwarded-Proto")
135            .and_then(|values| values.get(0))
136            .map(|value| value.as_str())
137            .unwrap_or("http"),
138        request.host().unwrap()
139    ))?;
140
141    let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
142    if let Some(query) = app_sign_in_params {
143        let mut redirect_query = redirect_url.query_pairs_mut();
144        redirect_query
145            .clear()
146            .append_pair("native_app_port", &query.native_app_port)
147            .append_pair("native_app_public_key", &query.native_app_public_key);
148
149        if let Some(impersonate) = &query.impersonate {
150            redirect_query.append_pair("impersonate", impersonate);
151        }
152    }
153
154    let (auth_url, csrf_token) = request
155        .state()
156        .auth_client
157        .authorize_url(CsrfToken::new_random)
158        .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
159        .set_pkce_challenge(pkce_challenge)
160        .url();
161
162    request
163        .session_mut()
164        .insert("auth_csrf_token", csrf_token)?;
165
166    Ok(tide::Redirect::new(auth_url).into())
167}
168
169async fn get_auth_callback(mut request: Request) -> tide::Result {
170    #[derive(Debug, Deserialize)]
171    struct Query {
172        code: String,
173        state: String,
174
175        #[serde(flatten)]
176        native_app_sign_in_params: Option<NativeAppSignInParams>,
177    }
178
179    let query: Query = request.query()?;
180
181    let pkce_verifier = request
182        .session()
183        .get("pkce_verifier")
184        .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
185
186    let csrf_token = request
187        .session()
188        .get::<CsrfToken>("auth_csrf_token")
189        .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
190
191    if &query.state != csrf_token.secret() {
192        return Err(anyhow!("csrf token does not match").into());
193    }
194
195    let github_access_token = request
196        .state()
197        .auth_client
198        .exchange_code(AuthorizationCode::new(query.code))
199        .set_pkce_verifier(pkce_verifier)
200        .request_async(oauth2_surf::http_client)
201        .await
202        .context("failed to exchange oauth code")?
203        .access_token()
204        .secret()
205        .clone();
206
207    let user_details = request
208        .state()
209        .github_client
210        .user(github_access_token)
211        .details()
212        .await
213        .context("failed to fetch user")?;
214
215    let user = request
216        .db()
217        .get_user_by_github_login(&user_details.login)
218        .await?;
219
220    request
221        .session_mut()
222        .insert(CURRENT_GITHUB_USER, user_details.clone())?;
223
224    // When signing in from the native app, generate a new access token for the current user. Return
225    // a redirect so that the user's browser sends this access token to the locally-running app.
226    if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
227        let mut user_id = user.id;
228        if let Some(impersonated_login) = app_sign_in_params.impersonate {
229            log::info!("attempting to impersonate user @{}", impersonated_login);
230            if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() {
231                if user.admin {
232                    user_id = request.db().create_user(&impersonated_login, false).await?;
233                    log::info!("impersonating user {}", user_id.0);
234                } else {
235                    log::info!("refusing to impersonate user");
236                }
237            }
238        }
239
240        let access_token = create_access_token(request.db(), user_id).await?;
241        let native_app_public_key =
242            zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
243                .context("failed to parse app public key")?;
244        let encrypted_access_token = native_app_public_key
245            .encrypt_string(&access_token)
246            .context("failed to encrypt access token with public key")?;
247
248        return Ok(tide::Redirect::new(&format!(
249            "http://127.0.0.1:{}?user_id={}&access_token={}",
250            app_sign_in_params.native_app_port, user_id.0, encrypted_access_token,
251        ))
252        .into());
253    }
254
255    Ok(tide::Redirect::new("/").into())
256}
257
258async fn post_sign_out(mut request: Request) -> tide::Result {
259    request.session_mut().remove(CURRENT_GITHUB_USER);
260    Ok(tide::Redirect::new("/").into())
261}
262
263const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
264
265pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
266    let access_token = zed_auth::random_token();
267    let access_token_hash =
268        hash_access_token(&access_token).context("failed to hash access token")?;
269    db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
270        .await?;
271    Ok(access_token)
272}
273
274fn hash_access_token(token: &str) -> tide::Result<String> {
275    // Avoid slow hashing in debug mode.
276    let params = if cfg!(debug_assertions) {
277        scrypt::Params::new(1, 1, 1).unwrap()
278    } else {
279        scrypt::Params::recommended()
280    };
281
282    Ok(Scrypt
283        .hash_password(
284            token.as_bytes(),
285            None,
286            params,
287            &SaltString::generate(thread_rng()),
288        )?
289        .to_string())
290}
291
292pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
293    let hash = PasswordHash::new(hash)?;
294    Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
295}