auth.rs

  1use super::{
  2    db::{self, UserId},
  3    errors::TideResultExt,
  4};
  5use crate::{github, rpc, AppState, Request, RequestExt as _};
  6use anyhow::{anyhow, Context};
  7use async_trait::async_trait;
  8pub use oauth2::basic::BasicClient as Client;
  9use oauth2::{
 10    AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
 11    TokenResponse as _, TokenUrl,
 12};
 13use rand::thread_rng;
 14use scrypt::{
 15    password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
 16    Scrypt,
 17};
 18use serde::{Deserialize, Serialize};
 19use std::{borrow::Cow, convert::TryFrom, sync::Arc};
 20use surf::Url;
 21use tide::Server;
 22use zrpc::{auth as zed_auth, proto, Peer};
 23
 24static CURRENT_GITHUB_USER: &'static str = "current_github_user";
 25static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
 26static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
 27
 28#[derive(Serialize)]
 29pub struct User {
 30    pub github_login: String,
 31    pub avatar_url: String,
 32    pub is_insider: bool,
 33    pub is_admin: bool,
 34}
 35
 36pub struct VerifyToken;
 37
 38#[async_trait]
 39impl tide::Middleware<Arc<AppState>> for VerifyToken {
 40    async fn handle(
 41        &self,
 42        mut request: Request,
 43        next: tide::Next<'_, Arc<AppState>>,
 44    ) -> tide::Result {
 45        let mut auth_header = request
 46            .header("Authorization")
 47            .ok_or_else(|| anyhow!("no authorization header"))?
 48            .last()
 49            .as_str()
 50            .split_whitespace();
 51
 52        let user_id = UserId(
 53            auth_header
 54                .next()
 55                .ok_or_else(|| anyhow!("missing user id in authorization header"))?
 56                .parse()?,
 57        );
 58        let access_token = auth_header
 59            .next()
 60            .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
 61
 62        let state = request.state().clone();
 63
 64        let mut credentials_valid = false;
 65        for password_hash in state.db.get_access_token_hashes(user_id).await? {
 66            if verify_access_token(&access_token, &password_hash)? {
 67                credentials_valid = true;
 68                break;
 69            }
 70        }
 71
 72        if credentials_valid {
 73            request.set_ext(user_id);
 74            Ok(next.run(request).await)
 75        } else {
 76            Err(anyhow!("invalid credentials").into())
 77        }
 78    }
 79}
 80
 81#[async_trait]
 82pub trait RequestExt {
 83    async fn current_user(&self) -> tide::Result<Option<User>>;
 84}
 85
 86#[async_trait]
 87impl RequestExt for Request {
 88    async fn current_user(&self) -> tide::Result<Option<User>> {
 89        if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
 90            let user = self.db().get_user_by_github_login(&details.login).await?;
 91            Ok(Some(User {
 92                github_login: details.login,
 93                avatar_url: details.avatar_url,
 94                is_insider: user.is_some(),
 95                is_admin: user.map_or(false, |user| user.admin),
 96            }))
 97        } else {
 98            Ok(None)
 99        }
100    }
101}
102
103#[async_trait]
104pub trait PeerExt {
105    async fn sign_out(
106        self: &Arc<Self>,
107        connection_id: zrpc::ConnectionId,
108        state: &AppState,
109    ) -> tide::Result<()>;
110}
111
112#[async_trait]
113impl PeerExt for Peer {
114    async fn sign_out(
115        self: &Arc<Self>,
116        connection_id: zrpc::ConnectionId,
117        state: &AppState,
118    ) -> tide::Result<()> {
119        self.disconnect(connection_id).await;
120        let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
121        for worktree_id in worktree_ids {
122            let state = state.rpc.read().await;
123            if let Some(worktree) = state.worktrees.get(&worktree_id) {
124                rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
125                    self.send(
126                        conn_id,
127                        proto::RemovePeer {
128                            worktree_id,
129                            peer_id: connection_id.0,
130                        },
131                    )
132                })
133                .await?;
134            }
135        }
136        Ok(())
137    }
138}
139
140#[async_trait]
141impl PeerExt for zrpc::peer2::Peer {
142    async fn sign_out(
143        self: &Arc<Self>,
144        connection_id: zrpc::ConnectionId,
145        state: &AppState,
146    ) -> tide::Result<()> {
147        self.disconnect(connection_id).await;
148        let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
149        for worktree_id in worktree_ids {
150            let state = state.rpc.read().await;
151            if let Some(worktree) = state.worktrees.get(&worktree_id) {
152                rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
153                    self.send(
154                        conn_id,
155                        proto::RemovePeer {
156                            worktree_id,
157                            peer_id: connection_id.0,
158                        },
159                    )
160                })
161                .await?;
162            }
163        }
164        Ok(())
165    }
166}
167
168pub fn build_client(client_id: &str, client_secret: &str) -> Client {
169    Client::new(
170        ClientId::new(client_id.to_string()),
171        Some(oauth2::ClientSecret::new(client_secret.to_string())),
172        AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
173        Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
174    )
175}
176
177pub fn add_routes(app: &mut Server<Arc<AppState>>) {
178    app.at("/sign_in").get(get_sign_in);
179    app.at("/sign_out").post(post_sign_out);
180    app.at("/auth_callback").get(get_auth_callback);
181}
182
183#[derive(Debug, Deserialize)]
184struct NativeAppSignInParams {
185    native_app_port: String,
186    native_app_public_key: String,
187}
188
189async fn get_sign_in(mut request: Request) -> tide::Result {
190    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
191
192    request
193        .session_mut()
194        .insert("pkce_verifier", pkce_verifier)?;
195
196    let mut redirect_url = Url::parse(&format!(
197        "{}://{}/auth_callback",
198        request
199            .header("X-Forwarded-Proto")
200            .and_then(|values| values.get(0))
201            .map(|value| value.as_str())
202            .unwrap_or("http"),
203        request.host().unwrap()
204    ))?;
205
206    let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
207    if let Some(query) = app_sign_in_params {
208        redirect_url
209            .query_pairs_mut()
210            .clear()
211            .append_pair("native_app_port", &query.native_app_port)
212            .append_pair("native_app_public_key", &query.native_app_public_key);
213    }
214
215    let (auth_url, csrf_token) = request
216        .state()
217        .auth_client
218        .authorize_url(CsrfToken::new_random)
219        .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
220        .set_pkce_challenge(pkce_challenge)
221        .url();
222
223    request
224        .session_mut()
225        .insert("auth_csrf_token", csrf_token)?;
226
227    Ok(tide::Redirect::new(auth_url).into())
228}
229
230async fn get_auth_callback(mut request: Request) -> tide::Result {
231    #[derive(Debug, Deserialize)]
232    struct Query {
233        code: String,
234        state: String,
235
236        #[serde(flatten)]
237        native_app_sign_in_params: Option<NativeAppSignInParams>,
238    }
239
240    let query: Query = request.query()?;
241
242    let pkce_verifier = request
243        .session()
244        .get("pkce_verifier")
245        .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
246
247    let csrf_token = request
248        .session()
249        .get::<CsrfToken>("auth_csrf_token")
250        .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
251
252    if &query.state != csrf_token.secret() {
253        return Err(anyhow!("csrf token does not match").into());
254    }
255
256    let github_access_token = request
257        .state()
258        .auth_client
259        .exchange_code(AuthorizationCode::new(query.code))
260        .set_pkce_verifier(pkce_verifier)
261        .request_async(oauth2_surf::http_client)
262        .await
263        .context("failed to exchange oauth code")?
264        .access_token()
265        .secret()
266        .clone();
267
268    let user_details = request
269        .state()
270        .github_client
271        .user(github_access_token)
272        .details()
273        .await
274        .context("failed to fetch user")?;
275
276    let user = request
277        .db()
278        .get_user_by_github_login(&user_details.login)
279        .await?;
280
281    request
282        .session_mut()
283        .insert(CURRENT_GITHUB_USER, user_details.clone())?;
284
285    // When signing in from the native app, generate a new access token for the current user. Return
286    // a redirect so that the user's browser sends this access token to the locally-running app.
287    if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
288        let access_token = create_access_token(request.db(), user.id).await?;
289        let native_app_public_key =
290            zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
291                .context("failed to parse app public key")?;
292        let encrypted_access_token = native_app_public_key
293            .encrypt_string(&access_token)
294            .context("failed to encrypt access token with public key")?;
295
296        return Ok(tide::Redirect::new(&format!(
297            "http://127.0.0.1:{}?user_id={}&access_token={}",
298            app_sign_in_params.native_app_port, user.id.0, encrypted_access_token,
299        ))
300        .into());
301    }
302
303    Ok(tide::Redirect::new("/").into())
304}
305
306async fn post_sign_out(mut request: Request) -> tide::Result {
307    request.session_mut().remove(CURRENT_GITHUB_USER);
308    Ok(tide::Redirect::new("/").into())
309}
310
311pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
312    let access_token = zed_auth::random_token();
313    let access_token_hash =
314        hash_access_token(&access_token).context("failed to hash access token")?;
315    db.create_access_token_hash(user_id, access_token_hash)
316        .await?;
317    Ok(access_token)
318}
319
320fn hash_access_token(token: &str) -> tide::Result<String> {
321    // Avoid slow hashing in debug mode.
322    let params = if cfg!(debug_assertions) {
323        scrypt::Params::new(1, 1, 1).unwrap()
324    } else {
325        scrypt::Params::recommended()
326    };
327
328    Ok(Scrypt
329        .hash_password(
330            token.as_bytes(),
331            None,
332            params,
333            &SaltString::generate(thread_rng()),
334        )?
335        .to_string())
336}
337
338pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
339    let hash = PasswordHash::new(hash)?;
340    Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
341}