1use super::errors::TideResultExt;
2use crate::{github, rpc, AppState, DbPool, Request, RequestExt as _};
3use anyhow::{anyhow, Context};
4use async_std::stream::StreamExt;
5use async_trait::async_trait;
6pub use oauth2::basic::BasicClient as Client;
7use oauth2::{
8 AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
9 TokenResponse as _, TokenUrl,
10};
11use rand::thread_rng;
12use scrypt::{
13 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
14 Scrypt,
15};
16use serde::{Deserialize, Serialize};
17use sqlx::FromRow;
18use std::{borrow::Cow, convert::TryFrom, sync::Arc};
19use surf::Url;
20use tide::Server;
21use zrpc::{auth as zed_auth, proto, Peer};
22
23static CURRENT_GITHUB_USER: &'static str = "current_github_user";
24static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
25static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
26
27#[derive(Serialize)]
28pub struct User {
29 pub github_login: String,
30 pub avatar_url: String,
31 pub is_insider: bool,
32 pub is_admin: bool,
33}
34
35pub struct VerifyToken;
36
37#[derive(Clone, Copy)]
38pub struct UserId(pub i32);
39
40#[async_trait]
41impl tide::Middleware<Arc<AppState>> for VerifyToken {
42 async fn handle(
43 &self,
44 mut request: Request,
45 next: tide::Next<'_, Arc<AppState>>,
46 ) -> tide::Result {
47 let mut auth_header = request
48 .header("Authorization")
49 .ok_or_else(|| anyhow!("no authorization header"))?
50 .last()
51 .as_str()
52 .split_whitespace();
53
54 let user_id: i32 = auth_header
55 .next()
56 .ok_or_else(|| anyhow!("missing user id in authorization header"))?
57 .parse()?;
58 let access_token = auth_header
59 .next()
60 .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
61
62 let state = request.state().clone();
63
64 let mut password_hashes =
65 sqlx::query_scalar::<_, String>("SELECT hash FROM access_tokens WHERE user_id = $1")
66 .bind(&user_id)
67 .fetch_many(&state.db);
68
69 let mut credentials_valid = false;
70 while let Some(password_hash) = password_hashes.next().await {
71 if let either::Either::Right(password_hash) = password_hash? {
72 if verify_access_token(&access_token, &password_hash)? {
73 credentials_valid = true;
74 break;
75 }
76 }
77 }
78
79 if credentials_valid {
80 request.set_ext(UserId(user_id));
81 Ok(next.run(request).await)
82 } else {
83 Err(anyhow!("invalid credentials").into())
84 }
85 }
86}
87
88#[async_trait]
89pub trait RequestExt {
90 async fn current_user(&self) -> tide::Result<Option<User>>;
91}
92
93#[async_trait]
94impl RequestExt for Request {
95 async fn current_user(&self) -> tide::Result<Option<User>> {
96 if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
97 #[derive(FromRow)]
98 struct UserRow {
99 admin: bool,
100 }
101
102 let user_row: Option<UserRow> =
103 sqlx::query_as("SELECT admin FROM users WHERE github_login = $1")
104 .bind(&details.login)
105 .fetch_optional(self.db())
106 .await?;
107
108 let is_insider = user_row.is_some();
109 let is_admin = user_row.map_or(false, |row| row.admin);
110
111 Ok(Some(User {
112 github_login: details.login,
113 avatar_url: details.avatar_url,
114 is_insider,
115 is_admin,
116 }))
117 } else {
118 Ok(None)
119 }
120 }
121}
122
123#[async_trait]
124pub trait PeerExt {
125 async fn sign_out(
126 self: &Arc<Self>,
127 connection_id: zrpc::ConnectionId,
128 state: &AppState,
129 ) -> tide::Result<()>;
130}
131
132#[async_trait]
133impl PeerExt for Peer {
134 async fn sign_out(
135 self: &Arc<Self>,
136 connection_id: zrpc::ConnectionId,
137 state: &AppState,
138 ) -> tide::Result<()> {
139 self.disconnect(connection_id).await;
140 let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
141 for worktree_id in worktree_ids {
142 let state = state.rpc.read().await;
143 if let Some(worktree) = state.worktrees.get(&worktree_id) {
144 rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
145 self.send(
146 conn_id,
147 proto::RemovePeer {
148 worktree_id,
149 peer_id: connection_id.0,
150 },
151 )
152 })
153 .await?;
154 }
155 }
156 Ok(())
157 }
158}
159
160pub fn build_client(client_id: &str, client_secret: &str) -> Client {
161 Client::new(
162 ClientId::new(client_id.to_string()),
163 Some(oauth2::ClientSecret::new(client_secret.to_string())),
164 AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
165 Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
166 )
167}
168
169pub fn add_routes(app: &mut Server<Arc<AppState>>) {
170 app.at("/sign_in").get(get_sign_in);
171 app.at("/sign_out").post(post_sign_out);
172 app.at("/auth_callback").get(get_auth_callback);
173}
174
175#[derive(Debug, Deserialize)]
176struct NativeAppSignInParams {
177 native_app_port: String,
178 native_app_public_key: String,
179}
180
181async fn get_sign_in(mut request: Request) -> tide::Result {
182 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
183
184 request
185 .session_mut()
186 .insert("pkce_verifier", pkce_verifier)?;
187
188 let mut redirect_url = Url::parse(&format!(
189 "{}://{}/auth_callback",
190 request
191 .header("X-Forwarded-Proto")
192 .and_then(|values| values.get(0))
193 .map(|value| value.as_str())
194 .unwrap_or("http"),
195 request.host().unwrap()
196 ))?;
197
198 let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
199 if let Some(query) = app_sign_in_params {
200 redirect_url
201 .query_pairs_mut()
202 .clear()
203 .append_pair("native_app_port", &query.native_app_port)
204 .append_pair("native_app_public_key", &query.native_app_public_key);
205 }
206
207 let (auth_url, csrf_token) = request
208 .state()
209 .auth_client
210 .authorize_url(CsrfToken::new_random)
211 .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
212 .set_pkce_challenge(pkce_challenge)
213 .url();
214
215 request
216 .session_mut()
217 .insert("auth_csrf_token", csrf_token)?;
218
219 Ok(tide::Redirect::new(auth_url).into())
220}
221
222async fn get_auth_callback(mut request: Request) -> tide::Result {
223 #[derive(Debug, Deserialize)]
224 struct Query {
225 code: String,
226 state: String,
227
228 #[serde(flatten)]
229 native_app_sign_in_params: Option<NativeAppSignInParams>,
230 }
231
232 let query: Query = request.query()?;
233
234 let pkce_verifier = request
235 .session()
236 .get("pkce_verifier")
237 .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
238
239 let csrf_token = request
240 .session()
241 .get::<CsrfToken>("auth_csrf_token")
242 .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
243
244 if &query.state != csrf_token.secret() {
245 return Err(anyhow!("csrf token does not match").into());
246 }
247
248 let github_access_token = request
249 .state()
250 .auth_client
251 .exchange_code(AuthorizationCode::new(query.code))
252 .set_pkce_verifier(pkce_verifier)
253 .request_async(oauth2_surf::http_client)
254 .await
255 .context("failed to exchange oauth code")?
256 .access_token()
257 .secret()
258 .clone();
259
260 let user_details = request
261 .state()
262 .github_client
263 .user(github_access_token)
264 .details()
265 .await
266 .context("failed to fetch user")?;
267
268 let user_id: Option<i32> = sqlx::query_scalar("SELECT id from users where github_login = $1")
269 .bind(&user_details.login)
270 .fetch_optional(request.db())
271 .await?;
272
273 request
274 .session_mut()
275 .insert(CURRENT_GITHUB_USER, user_details.clone())?;
276
277 // When signing in from the native app, generate a new access token for the current user. Return
278 // a redirect so that the user's browser sends this access token to the locally-running app.
279 if let Some((user_id, app_sign_in_params)) = user_id.zip(query.native_app_sign_in_params) {
280 let access_token = create_access_token(request.db(), user_id).await?;
281 let native_app_public_key =
282 zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
283 .context("failed to parse app public key")?;
284 let encrypted_access_token = native_app_public_key
285 .encrypt_string(&access_token)
286 .context("failed to encrypt access token with public key")?;
287
288 return Ok(tide::Redirect::new(&format!(
289 "http://127.0.0.1:{}?user_id={}&access_token={}",
290 app_sign_in_params.native_app_port, user_id, encrypted_access_token,
291 ))
292 .into());
293 }
294
295 Ok(tide::Redirect::new("/").into())
296}
297
298async fn post_sign_out(mut request: Request) -> tide::Result {
299 request.session_mut().remove(CURRENT_GITHUB_USER);
300 Ok(tide::Redirect::new("/").into())
301}
302
303pub async fn create_access_token(db: &DbPool, user_id: i32) -> tide::Result<String> {
304 let access_token = zed_auth::random_token();
305 let access_token_hash =
306 hash_access_token(&access_token).context("failed to hash access token")?;
307 sqlx::query("INSERT INTO access_tokens (user_id, hash) values ($1, $2)")
308 .bind(user_id)
309 .bind(access_token_hash)
310 .fetch_optional(db)
311 .await?;
312 Ok(access_token)
313}
314
315fn hash_access_token(token: &str) -> tide::Result<String> {
316 // Avoid slow hashing in debug mode.
317 let params = if cfg!(debug_assertions) {
318 scrypt::Params::new(1, 1, 1).unwrap()
319 } else {
320 scrypt::Params::recommended()
321 };
322
323 Ok(Scrypt
324 .hash_password(
325 token.as_bytes(),
326 None,
327 params,
328 &SaltString::generate(thread_rng()),
329 )?
330 .to_string())
331}
332
333pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
334 let hash = PasswordHash::new(hash)?;
335 Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
336}