auth.rs

  1use super::errors::TideResultExt;
  2use crate::{github, rpc, AppState, DbPool, Request, RequestExt as _};
  3use anyhow::{anyhow, Context};
  4use async_std::stream::StreamExt;
  5use async_trait::async_trait;
  6pub use oauth2::basic::BasicClient as Client;
  7use oauth2::{
  8    AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
  9    TokenResponse as _, TokenUrl,
 10};
 11use rand::thread_rng;
 12use scrypt::{
 13    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
 14    Scrypt,
 15};
 16use serde::{Deserialize, Serialize};
 17use sqlx::FromRow;
 18use std::{borrow::Cow, convert::TryFrom, sync::Arc};
 19use surf::Url;
 20use tide::Server;
 21use zed_rpc::{auth as zed_auth, proto, Peer};
 22
 23static CURRENT_GITHUB_USER: &'static str = "current_github_user";
 24static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
 25static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
 26
 27#[derive(Serialize)]
 28pub struct User {
 29    pub github_login: String,
 30    pub avatar_url: String,
 31    pub is_insider: bool,
 32    pub is_admin: bool,
 33}
 34
 35pub struct VerifyToken;
 36
 37#[derive(Clone, Copy)]
 38pub struct UserId(pub i32);
 39
 40#[async_trait]
 41impl tide::Middleware<Arc<AppState>> for VerifyToken {
 42    async fn handle(
 43        &self,
 44        mut request: Request,
 45        next: tide::Next<'_, Arc<AppState>>,
 46    ) -> tide::Result {
 47        let mut auth_header = request
 48            .header("Authorization")
 49            .ok_or_else(|| anyhow!("no authorization header"))?
 50            .last()
 51            .as_str()
 52            .split_whitespace();
 53
 54        let user_id: i32 = auth_header
 55            .next()
 56            .ok_or_else(|| anyhow!("missing user id in authorization header"))?
 57            .parse()?;
 58        let access_token = auth_header
 59            .next()
 60            .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
 61
 62        let state = request.state().clone();
 63
 64        let mut password_hashes =
 65            sqlx::query_scalar::<_, String>("SELECT hash FROM access_tokens WHERE user_id = $1")
 66                .bind(&user_id)
 67                .fetch_many(&state.db);
 68
 69        let mut credentials_valid = false;
 70        while let Some(password_hash) = password_hashes.next().await {
 71            if let either::Either::Right(password_hash) = password_hash? {
 72                if verify_access_token(&access_token, &password_hash)? {
 73                    credentials_valid = true;
 74                    break;
 75                }
 76            }
 77        }
 78
 79        if credentials_valid {
 80            request.set_ext(UserId(user_id));
 81            Ok(next.run(request).await)
 82        } else {
 83            Err(anyhow!("invalid credentials").into())
 84        }
 85    }
 86}
 87
 88#[async_trait]
 89pub trait RequestExt {
 90    async fn current_user(&self) -> tide::Result<Option<User>>;
 91}
 92
 93#[async_trait]
 94impl RequestExt for Request {
 95    async fn current_user(&self) -> tide::Result<Option<User>> {
 96        if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
 97            #[derive(FromRow)]
 98            struct UserRow {
 99                admin: bool,
100            }
101
102            let user_row: Option<UserRow> =
103                sqlx::query_as("SELECT admin FROM users WHERE github_login = $1")
104                    .bind(&details.login)
105                    .fetch_optional(self.db())
106                    .await?;
107
108            let is_insider = user_row.is_some();
109            let is_admin = user_row.map_or(false, |row| row.admin);
110
111            Ok(Some(User {
112                github_login: details.login,
113                avatar_url: details.avatar_url,
114                is_insider,
115                is_admin,
116            }))
117        } else {
118            Ok(None)
119        }
120    }
121}
122
123#[async_trait]
124pub trait PeerExt {
125    async fn sign_out(
126        self: &Arc<Self>,
127        connection_id: zed_rpc::ConnectionId,
128        state: &AppState,
129    ) -> tide::Result<()>;
130}
131
132#[async_trait]
133impl PeerExt for Peer {
134    async fn sign_out(
135        self: &Arc<Self>,
136        connection_id: zed_rpc::ConnectionId,
137        state: &AppState,
138    ) -> tide::Result<()> {
139        self.disconnect(connection_id).await;
140        let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
141        for worktree_id in worktree_ids {
142            let state = state.rpc.read().await;
143            if let Some(worktree) = state.worktrees.get(&worktree_id) {
144                rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
145                    self.send(
146                        conn_id,
147                        proto::RemovePeer {
148                            worktree_id,
149                            peer_id: connection_id.0,
150                        },
151                    )
152                })
153                .await?;
154            }
155        }
156        Ok(())
157    }
158}
159
160pub fn build_client(client_id: &str, client_secret: &str) -> Client {
161    Client::new(
162        ClientId::new(client_id.to_string()),
163        Some(oauth2::ClientSecret::new(client_secret.to_string())),
164        AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
165        Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
166    )
167}
168
169pub fn add_routes(app: &mut Server<Arc<AppState>>) {
170    app.at("/sign_in").get(get_sign_in);
171    app.at("/sign_out").post(post_sign_out);
172    app.at("/auth_callback").get(get_auth_callback);
173}
174
175#[derive(Debug, Deserialize)]
176struct NativeAppSignInParams {
177    native_app_port: String,
178    native_app_public_key: String,
179}
180
181async fn get_sign_in(mut request: Request) -> tide::Result {
182    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
183
184    request
185        .session_mut()
186        .insert("pkce_verifier", pkce_verifier)?;
187
188    let mut redirect_url = Url::parse(&format!(
189        "{}://{}/auth_callback",
190        request
191            .header("X-Forwarded-Proto")
192            .and_then(|values| values.get(0))
193            .map(|value| value.as_str())
194            .unwrap_or("http"),
195        request.host().unwrap()
196    ))?;
197
198    let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
199    if let Some(query) = app_sign_in_params {
200        redirect_url
201            .query_pairs_mut()
202            .clear()
203            .append_pair("native_app_port", &query.native_app_port)
204            .append_pair("native_app_public_key", &query.native_app_public_key);
205    }
206
207    let (auth_url, csrf_token) = request
208        .state()
209        .auth_client
210        .authorize_url(CsrfToken::new_random)
211        .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
212        .set_pkce_challenge(pkce_challenge)
213        .url();
214
215    request
216        .session_mut()
217        .insert("auth_csrf_token", csrf_token)?;
218
219    Ok(tide::Redirect::new(auth_url).into())
220}
221
222async fn get_auth_callback(mut request: Request) -> tide::Result {
223    #[derive(Debug, Deserialize)]
224    struct Query {
225        code: String,
226        state: String,
227
228        #[serde(flatten)]
229        native_app_sign_in_params: Option<NativeAppSignInParams>,
230    }
231
232    let query: Query = request.query()?;
233
234    let pkce_verifier = request
235        .session()
236        .get("pkce_verifier")
237        .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
238
239    let csrf_token = request
240        .session()
241        .get::<CsrfToken>("auth_csrf_token")
242        .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
243
244    if &query.state != csrf_token.secret() {
245        return Err(anyhow!("csrf token does not match").into());
246    }
247
248    let github_access_token = request
249        .state()
250        .auth_client
251        .exchange_code(AuthorizationCode::new(query.code))
252        .set_pkce_verifier(pkce_verifier)
253        .request_async(oauth2_surf::http_client)
254        .await
255        .context("failed to exchange oauth code")?
256        .access_token()
257        .secret()
258        .clone();
259
260    let user_details = request
261        .state()
262        .github_client
263        .user(github_access_token)
264        .details()
265        .await
266        .context("failed to fetch user")?;
267
268    let user_id: Option<i32> = sqlx::query_scalar("SELECT id from users where github_login = $1")
269        .bind(&user_details.login)
270        .fetch_optional(request.db())
271        .await?;
272
273    request
274        .session_mut()
275        .insert(CURRENT_GITHUB_USER, user_details.clone())?;
276
277    // When signing in from the native app, generate a new access token for the current user. Return
278    // a redirect so that the user's browser sends this access token to the locally-running app.
279    if let Some((user_id, app_sign_in_params)) = user_id.zip(query.native_app_sign_in_params) {
280        let access_token = create_access_token(request.db(), user_id).await?;
281        let native_app_public_key =
282            zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
283                .context("failed to parse app public key")?;
284        let encrypted_access_token = native_app_public_key
285            .encrypt_string(&access_token)
286            .context("failed to encrypt access token with public key")?;
287
288        return Ok(tide::Redirect::new(&format!(
289            "http://127.0.0.1:{}?user_id={}&access_token={}",
290            app_sign_in_params.native_app_port, user_id, encrypted_access_token,
291        ))
292        .into());
293    }
294
295    Ok(tide::Redirect::new("/").into())
296}
297
298async fn post_sign_out(mut request: Request) -> tide::Result {
299    request.session_mut().remove(CURRENT_GITHUB_USER);
300    Ok(tide::Redirect::new("/").into())
301}
302
303pub async fn create_access_token(db: &DbPool, user_id: i32) -> tide::Result<String> {
304    let access_token = zed_auth::random_token();
305    let access_token_hash =
306        hash_access_token(&access_token).context("failed to hash access token")?;
307    sqlx::query("INSERT INTO access_tokens (user_id, hash) values ($1, $2)")
308        .bind(user_id)
309        .bind(access_token_hash)
310        .fetch_optional(db)
311        .await?;
312    Ok(access_token)
313}
314
315fn hash_access_token(token: &str) -> tide::Result<String> {
316    // Avoid slow hashing in debug mode.
317    let params = if cfg!(debug_assertions) {
318        scrypt::Params::new(1, 1, 1).unwrap()
319    } else {
320        scrypt::Params::recommended()
321    };
322
323    Ok(Scrypt
324        .hash_password(
325            token.as_bytes(),
326            None,
327            params,
328            &SaltString::generate(thread_rng()),
329        )?
330        .to_string())
331}
332
333pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
334    let hash = PasswordHash::new(hash)?;
335    Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
336}