1use super::{
2 db::{self, UserId},
3 errors::TideResultExt,
4};
5use crate::{github, AppState, Request, RequestExt as _};
6use anyhow::{anyhow, Context};
7use async_trait::async_trait;
8pub use oauth2::basic::BasicClient as Client;
9use oauth2::{
10 AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
11 TokenResponse as _, TokenUrl,
12};
13use rand::thread_rng;
14use scrypt::{
15 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
16 Scrypt,
17};
18use serde::{Deserialize, Serialize};
19use std::{borrow::Cow, convert::TryFrom, sync::Arc};
20use surf::{StatusCode, Url};
21use tide::{log, Error, Server};
22use zrpc::auth as zed_auth;
23
24static CURRENT_GITHUB_USER: &'static str = "current_github_user";
25static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
26static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
27
28#[derive(Serialize)]
29pub struct User {
30 pub github_login: String,
31 pub avatar_url: String,
32 pub is_insider: bool,
33 pub is_admin: bool,
34}
35
36pub async fn process_auth_header(request: &Request) -> tide::Result<UserId> {
37 let mut auth_header = request
38 .header("Authorization")
39 .ok_or_else(|| {
40 Error::new(
41 StatusCode::BadRequest,
42 anyhow!("missing authorization header"),
43 )
44 })?
45 .last()
46 .as_str()
47 .split_whitespace();
48 let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
49 Error::new(
50 StatusCode::BadRequest,
51 anyhow!("missing user id in authorization header"),
52 )
53 })?);
54 let access_token = auth_header.next().ok_or_else(|| {
55 Error::new(
56 StatusCode::BadRequest,
57 anyhow!("missing access token in authorization header"),
58 )
59 })?;
60
61 let state = request.state().clone();
62 let mut credentials_valid = false;
63 for password_hash in state.db.get_access_token_hashes(user_id).await? {
64 if verify_access_token(&access_token, &password_hash)? {
65 credentials_valid = true;
66 break;
67 }
68 }
69
70 if !credentials_valid {
71 Err(Error::new(
72 StatusCode::Unauthorized,
73 anyhow!("invalid credentials"),
74 ))?;
75 }
76
77 Ok(user_id)
78}
79
80#[async_trait]
81pub trait RequestExt {
82 async fn current_user(&self) -> tide::Result<Option<User>>;
83}
84
85#[async_trait]
86impl RequestExt for Request {
87 async fn current_user(&self) -> tide::Result<Option<User>> {
88 if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
89 let user = self.db().get_user_by_github_login(&details.login).await?;
90 Ok(Some(User {
91 github_login: details.login,
92 avatar_url: details.avatar_url,
93 is_insider: user.is_some(),
94 is_admin: user.map_or(false, |user| user.admin),
95 }))
96 } else {
97 Ok(None)
98 }
99 }
100}
101
102pub fn build_client(client_id: &str, client_secret: &str) -> Client {
103 Client::new(
104 ClientId::new(client_id.to_string()),
105 Some(oauth2::ClientSecret::new(client_secret.to_string())),
106 AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
107 Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
108 )
109}
110
111pub fn add_routes(app: &mut Server<Arc<AppState>>) {
112 app.at("/sign_in").get(get_sign_in);
113 app.at("/sign_out").post(post_sign_out);
114 app.at("/auth_callback").get(get_auth_callback);
115}
116
117#[derive(Debug, Deserialize)]
118struct NativeAppSignInParams {
119 native_app_port: String,
120 native_app_public_key: String,
121 impersonate: Option<String>,
122}
123
124async fn get_sign_in(mut request: Request) -> tide::Result {
125 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
126
127 request
128 .session_mut()
129 .insert("pkce_verifier", pkce_verifier)?;
130
131 let mut redirect_url = Url::parse(&format!(
132 "{}://{}/auth_callback",
133 request
134 .header("X-Forwarded-Proto")
135 .and_then(|values| values.get(0))
136 .map(|value| value.as_str())
137 .unwrap_or("http"),
138 request.host().unwrap()
139 ))?;
140
141 let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
142 if let Some(query) = app_sign_in_params {
143 let mut redirect_query = redirect_url.query_pairs_mut();
144 redirect_query
145 .clear()
146 .append_pair("native_app_port", &query.native_app_port)
147 .append_pair("native_app_public_key", &query.native_app_public_key);
148
149 if let Some(impersonate) = &query.impersonate {
150 redirect_query.append_pair("impersonate", impersonate);
151 }
152 }
153
154 let (auth_url, csrf_token) = request
155 .state()
156 .auth_client
157 .authorize_url(CsrfToken::new_random)
158 .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
159 .set_pkce_challenge(pkce_challenge)
160 .url();
161
162 request
163 .session_mut()
164 .insert("auth_csrf_token", csrf_token)?;
165
166 Ok(tide::Redirect::new(auth_url).into())
167}
168
169async fn get_auth_callback(mut request: Request) -> tide::Result {
170 #[derive(Debug, Deserialize)]
171 struct Query {
172 code: String,
173 state: String,
174
175 #[serde(flatten)]
176 native_app_sign_in_params: Option<NativeAppSignInParams>,
177 }
178
179 let query: Query = request.query()?;
180
181 let pkce_verifier = request
182 .session()
183 .get("pkce_verifier")
184 .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
185
186 let csrf_token = request
187 .session()
188 .get::<CsrfToken>("auth_csrf_token")
189 .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
190
191 if &query.state != csrf_token.secret() {
192 return Err(anyhow!("csrf token does not match").into());
193 }
194
195 let github_access_token = request
196 .state()
197 .auth_client
198 .exchange_code(AuthorizationCode::new(query.code))
199 .set_pkce_verifier(pkce_verifier)
200 .request_async(oauth2_surf::http_client)
201 .await
202 .context("failed to exchange oauth code")?
203 .access_token()
204 .secret()
205 .clone();
206
207 let user_details = request
208 .state()
209 .github_client
210 .user(github_access_token)
211 .details()
212 .await
213 .context("failed to fetch user")?;
214
215 let user = request
216 .db()
217 .get_user_by_github_login(&user_details.login)
218 .await?;
219
220 request
221 .session_mut()
222 .insert(CURRENT_GITHUB_USER, user_details.clone())?;
223
224 // When signing in from the native app, generate a new access token for the current user. Return
225 // a redirect so that the user's browser sends this access token to the locally-running app.
226 if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
227 let mut user_id = user.id;
228 if let Some(impersonated_login) = app_sign_in_params.impersonate {
229 log::info!("attempting to impersonate user @{}", impersonated_login);
230 if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() {
231 if user.admin {
232 user_id = request.db().create_user(&impersonated_login, false).await?;
233 log::info!("impersonating user {}", user_id.0);
234 } else {
235 log::info!("refusing to impersonate user");
236 }
237 }
238 }
239
240 let access_token = create_access_token(request.db(), user_id).await?;
241 let native_app_public_key =
242 zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
243 .context("failed to parse app public key")?;
244 let encrypted_access_token = native_app_public_key
245 .encrypt_string(&access_token)
246 .context("failed to encrypt access token with public key")?;
247
248 return Ok(tide::Redirect::new(&format!(
249 "http://127.0.0.1:{}?user_id={}&access_token={}",
250 app_sign_in_params.native_app_port, user_id.0, encrypted_access_token,
251 ))
252 .into());
253 }
254
255 Ok(tide::Redirect::new("/").into())
256}
257
258async fn post_sign_out(mut request: Request) -> tide::Result {
259 request.session_mut().remove(CURRENT_GITHUB_USER);
260 Ok(tide::Redirect::new("/").into())
261}
262
263const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
264
265pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
266 let access_token = zed_auth::random_token();
267 let access_token_hash =
268 hash_access_token(&access_token).context("failed to hash access token")?;
269 db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
270 .await?;
271 Ok(access_token)
272}
273
274fn hash_access_token(token: &str) -> tide::Result<String> {
275 // Avoid slow hashing in debug mode.
276 let params = if cfg!(debug_assertions) {
277 scrypt::Params::new(1, 1, 1).unwrap()
278 } else {
279 scrypt::Params::recommended()
280 };
281
282 Ok(Scrypt
283 .hash_password(
284 token.as_bytes(),
285 None,
286 params,
287 &SaltString::generate(thread_rng()),
288 )?
289 .to_string())
290}
291
292pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
293 let hash = PasswordHash::new(hash)?;
294 Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
295}