1use super::{
2 db::{self, UserId},
3 errors::TideResultExt,
4};
5use crate::{github, AppState, Request, RequestExt as _};
6use anyhow::{anyhow, Context};
7use async_trait::async_trait;
8pub use oauth2::basic::BasicClient as Client;
9use oauth2::{
10 AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
11 TokenResponse as _, TokenUrl,
12};
13use rand::thread_rng;
14use scrypt::{
15 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
16 Scrypt,
17};
18use serde::{Deserialize, Serialize};
19use std::{borrow::Cow, convert::TryFrom, sync::Arc};
20use surf::{StatusCode, Url};
21use tide::{log, Server};
22use zrpc::auth as zed_auth;
23
24static CURRENT_GITHUB_USER: &'static str = "current_github_user";
25static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
26static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
27
28#[derive(Serialize)]
29pub struct User {
30 pub github_login: String,
31 pub avatar_url: String,
32 pub is_insider: bool,
33 pub is_admin: bool,
34}
35
36pub struct VerifyToken;
37
38#[async_trait]
39impl tide::Middleware<Arc<AppState>> for VerifyToken {
40 async fn handle(
41 &self,
42 mut request: Request,
43 next: tide::Next<'_, Arc<AppState>>,
44 ) -> tide::Result {
45 let mut auth_header = request
46 .header("Authorization")
47 .ok_or_else(|| anyhow!("no authorization header"))?
48 .last()
49 .as_str()
50 .split_whitespace();
51
52 let user_id = UserId(
53 auth_header
54 .next()
55 .ok_or_else(|| anyhow!("missing user id in authorization header"))?
56 .parse()?,
57 );
58 let access_token = auth_header
59 .next()
60 .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
61
62 let state = request.state().clone();
63
64 let mut credentials_valid = false;
65 for password_hash in state.db.get_access_token_hashes(user_id).await? {
66 if verify_access_token(&access_token, &password_hash)? {
67 credentials_valid = true;
68 break;
69 }
70 }
71
72 if credentials_valid {
73 request.set_ext(user_id);
74 Ok(next.run(request).await)
75 } else {
76 let mut response = tide::Response::new(StatusCode::Unauthorized);
77 response.set_body("invalid credentials");
78 Ok(response)
79 }
80 }
81}
82
83#[async_trait]
84pub trait RequestExt {
85 async fn current_user(&self) -> tide::Result<Option<User>>;
86}
87
88#[async_trait]
89impl RequestExt for Request {
90 async fn current_user(&self) -> tide::Result<Option<User>> {
91 if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
92 let user = self.db().get_user_by_github_login(&details.login).await?;
93 Ok(Some(User {
94 github_login: details.login,
95 avatar_url: details.avatar_url,
96 is_insider: user.is_some(),
97 is_admin: user.map_or(false, |user| user.admin),
98 }))
99 } else {
100 Ok(None)
101 }
102 }
103}
104
105pub fn build_client(client_id: &str, client_secret: &str) -> Client {
106 Client::new(
107 ClientId::new(client_id.to_string()),
108 Some(oauth2::ClientSecret::new(client_secret.to_string())),
109 AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
110 Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
111 )
112}
113
114pub fn add_routes(app: &mut Server<Arc<AppState>>) {
115 app.at("/sign_in").get(get_sign_in);
116 app.at("/sign_out").post(post_sign_out);
117 app.at("/auth_callback").get(get_auth_callback);
118}
119
120#[derive(Debug, Deserialize)]
121struct NativeAppSignInParams {
122 native_app_port: String,
123 native_app_public_key: String,
124 impersonate: Option<String>,
125}
126
127async fn get_sign_in(mut request: Request) -> tide::Result {
128 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
129
130 request
131 .session_mut()
132 .insert("pkce_verifier", pkce_verifier)?;
133
134 let mut redirect_url = Url::parse(&format!(
135 "{}://{}/auth_callback",
136 request
137 .header("X-Forwarded-Proto")
138 .and_then(|values| values.get(0))
139 .map(|value| value.as_str())
140 .unwrap_or("http"),
141 request.host().unwrap()
142 ))?;
143
144 let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
145 if let Some(query) = app_sign_in_params {
146 let mut redirect_query = redirect_url.query_pairs_mut();
147 redirect_query
148 .clear()
149 .append_pair("native_app_port", &query.native_app_port)
150 .append_pair("native_app_public_key", &query.native_app_public_key);
151
152 if let Some(impersonate) = &query.impersonate {
153 redirect_query.append_pair("impersonate", impersonate);
154 }
155 }
156
157 let (auth_url, csrf_token) = request
158 .state()
159 .auth_client
160 .authorize_url(CsrfToken::new_random)
161 .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
162 .set_pkce_challenge(pkce_challenge)
163 .url();
164
165 request
166 .session_mut()
167 .insert("auth_csrf_token", csrf_token)?;
168
169 Ok(tide::Redirect::new(auth_url).into())
170}
171
172async fn get_auth_callback(mut request: Request) -> tide::Result {
173 #[derive(Debug, Deserialize)]
174 struct Query {
175 code: String,
176 state: String,
177
178 #[serde(flatten)]
179 native_app_sign_in_params: Option<NativeAppSignInParams>,
180 }
181
182 let query: Query = request.query()?;
183
184 let pkce_verifier = request
185 .session()
186 .get("pkce_verifier")
187 .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
188
189 let csrf_token = request
190 .session()
191 .get::<CsrfToken>("auth_csrf_token")
192 .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
193
194 if &query.state != csrf_token.secret() {
195 return Err(anyhow!("csrf token does not match").into());
196 }
197
198 let github_access_token = request
199 .state()
200 .auth_client
201 .exchange_code(AuthorizationCode::new(query.code))
202 .set_pkce_verifier(pkce_verifier)
203 .request_async(oauth2_surf::http_client)
204 .await
205 .context("failed to exchange oauth code")?
206 .access_token()
207 .secret()
208 .clone();
209
210 let user_details = request
211 .state()
212 .github_client
213 .user(github_access_token)
214 .details()
215 .await
216 .context("failed to fetch user")?;
217
218 let user = request
219 .db()
220 .get_user_by_github_login(&user_details.login)
221 .await?;
222
223 request
224 .session_mut()
225 .insert(CURRENT_GITHUB_USER, user_details.clone())?;
226
227 // When signing in from the native app, generate a new access token for the current user. Return
228 // a redirect so that the user's browser sends this access token to the locally-running app.
229 if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
230 let mut user_id = user.id;
231 if let Some(impersonated_login) = app_sign_in_params.impersonate {
232 log::info!("attempting to impersonate user @{}", impersonated_login);
233 if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() {
234 if user.admin {
235 user_id = request.db().create_user(&impersonated_login, false).await?;
236 log::info!("impersonating user {}", user_id.0);
237 } else {
238 log::info!("refusing to impersonate user");
239 }
240 }
241 }
242
243 let access_token = create_access_token(request.db(), user_id).await?;
244 let native_app_public_key =
245 zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
246 .context("failed to parse app public key")?;
247 let encrypted_access_token = native_app_public_key
248 .encrypt_string(&access_token)
249 .context("failed to encrypt access token with public key")?;
250
251 return Ok(tide::Redirect::new(&format!(
252 "http://127.0.0.1:{}?user_id={}&access_token={}",
253 app_sign_in_params.native_app_port, user_id.0, encrypted_access_token,
254 ))
255 .into());
256 }
257
258 Ok(tide::Redirect::new("/").into())
259}
260
261async fn post_sign_out(mut request: Request) -> tide::Result {
262 request.session_mut().remove(CURRENT_GITHUB_USER);
263 Ok(tide::Redirect::new("/").into())
264}
265
266pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
267 let access_token = zed_auth::random_token();
268 let access_token_hash =
269 hash_access_token(&access_token).context("failed to hash access token")?;
270 db.create_access_token_hash(user_id, access_token_hash)
271 .await?;
272 Ok(access_token)
273}
274
275fn hash_access_token(token: &str) -> tide::Result<String> {
276 // Avoid slow hashing in debug mode.
277 let params = if cfg!(debug_assertions) {
278 scrypt::Params::new(1, 1, 1).unwrap()
279 } else {
280 scrypt::Params::recommended()
281 };
282
283 Ok(Scrypt
284 .hash_password(
285 token.as_bytes(),
286 None,
287 params,
288 &SaltString::generate(thread_rng()),
289 )?
290 .to_string())
291}
292
293pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
294 let hash = PasswordHash::new(hash)?;
295 Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
296}