1use super::{
2 db::{self, UserId},
3 errors::TideResultExt,
4};
5use crate::{github, rpc, AppState, Request, RequestExt as _};
6use anyhow::{anyhow, Context};
7use async_trait::async_trait;
8pub use oauth2::basic::BasicClient as Client;
9use oauth2::{
10 AuthUrl, AuthorizationCode, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl,
11 TokenResponse as _, TokenUrl,
12};
13use rand::thread_rng;
14use scrypt::{
15 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
16 Scrypt,
17};
18use serde::{Deserialize, Serialize};
19use std::{borrow::Cow, convert::TryFrom, sync::Arc};
20use surf::Url;
21use tide::Server;
22use zrpc::{auth as zed_auth, proto, Peer};
23
24static CURRENT_GITHUB_USER: &'static str = "current_github_user";
25static GITHUB_AUTH_URL: &'static str = "https://github.com/login/oauth/authorize";
26static GITHUB_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
27
28#[derive(Serialize)]
29pub struct User {
30 pub github_login: String,
31 pub avatar_url: String,
32 pub is_insider: bool,
33 pub is_admin: bool,
34}
35
36pub struct VerifyToken;
37
38#[async_trait]
39impl tide::Middleware<Arc<AppState>> for VerifyToken {
40 async fn handle(
41 &self,
42 mut request: Request,
43 next: tide::Next<'_, Arc<AppState>>,
44 ) -> tide::Result {
45 let mut auth_header = request
46 .header("Authorization")
47 .ok_or_else(|| anyhow!("no authorization header"))?
48 .last()
49 .as_str()
50 .split_whitespace();
51
52 let user_id = UserId(
53 auth_header
54 .next()
55 .ok_or_else(|| anyhow!("missing user id in authorization header"))?
56 .parse()?,
57 );
58 let access_token = auth_header
59 .next()
60 .ok_or_else(|| anyhow!("missing access token in authorization header"))?;
61
62 let state = request.state().clone();
63
64 let mut credentials_valid = false;
65 for password_hash in state.db.get_access_token_hashes(user_id).await? {
66 if verify_access_token(&access_token, &password_hash)? {
67 credentials_valid = true;
68 break;
69 }
70 }
71
72 if credentials_valid {
73 request.set_ext(user_id);
74 Ok(next.run(request).await)
75 } else {
76 Err(anyhow!("invalid credentials").into())
77 }
78 }
79}
80
81#[async_trait]
82pub trait RequestExt {
83 async fn current_user(&self) -> tide::Result<Option<User>>;
84}
85
86#[async_trait]
87impl RequestExt for Request {
88 async fn current_user(&self) -> tide::Result<Option<User>> {
89 if let Some(details) = self.session().get::<github::User>(CURRENT_GITHUB_USER) {
90 let user = self.db().get_user_by_github_login(&details.login).await?;
91 Ok(Some(User {
92 github_login: details.login,
93 avatar_url: details.avatar_url,
94 is_insider: user.is_some(),
95 is_admin: user.map_or(false, |user| user.admin),
96 }))
97 } else {
98 Ok(None)
99 }
100 }
101}
102
103#[async_trait]
104pub trait PeerExt {
105 async fn sign_out(
106 self: &Arc<Self>,
107 connection_id: zrpc::ConnectionId,
108 state: &AppState,
109 ) -> tide::Result<()>;
110}
111
112#[async_trait]
113impl PeerExt for Peer {
114 async fn sign_out(
115 self: &Arc<Self>,
116 connection_id: zrpc::ConnectionId,
117 state: &AppState,
118 ) -> tide::Result<()> {
119 self.disconnect(connection_id).await;
120 let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
121 for worktree_id in worktree_ids {
122 let state = state.rpc.read().await;
123 if let Some(worktree) = state.worktrees.get(&worktree_id) {
124 rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
125 self.send(
126 conn_id,
127 proto::RemovePeer {
128 worktree_id,
129 peer_id: connection_id.0,
130 },
131 )
132 })
133 .await?;
134 }
135 }
136 Ok(())
137 }
138}
139
140#[async_trait]
141impl PeerExt for zrpc::peer2::Peer {
142 async fn sign_out(
143 self: &Arc<Self>,
144 connection_id: zrpc::ConnectionId,
145 state: &AppState,
146 ) -> tide::Result<()> {
147 self.disconnect(connection_id).await;
148 let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
149 for worktree_id in worktree_ids {
150 let state = state.rpc.read().await;
151 if let Some(worktree) = state.worktrees.get(&worktree_id) {
152 rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
153 self.send(
154 conn_id,
155 proto::RemovePeer {
156 worktree_id,
157 peer_id: connection_id.0,
158 },
159 )
160 })
161 .await?;
162 }
163 }
164 Ok(())
165 }
166}
167
168pub fn build_client(client_id: &str, client_secret: &str) -> Client {
169 Client::new(
170 ClientId::new(client_id.to_string()),
171 Some(oauth2::ClientSecret::new(client_secret.to_string())),
172 AuthUrl::new(GITHUB_AUTH_URL.into()).unwrap(),
173 Some(TokenUrl::new(GITHUB_TOKEN_URL.into()).unwrap()),
174 )
175}
176
177pub fn add_routes(app: &mut Server<Arc<AppState>>) {
178 app.at("/sign_in").get(get_sign_in);
179 app.at("/sign_out").post(post_sign_out);
180 app.at("/auth_callback").get(get_auth_callback);
181}
182
183#[derive(Debug, Deserialize)]
184struct NativeAppSignInParams {
185 native_app_port: String,
186 native_app_public_key: String,
187}
188
189async fn get_sign_in(mut request: Request) -> tide::Result {
190 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
191
192 request
193 .session_mut()
194 .insert("pkce_verifier", pkce_verifier)?;
195
196 let mut redirect_url = Url::parse(&format!(
197 "{}://{}/auth_callback",
198 request
199 .header("X-Forwarded-Proto")
200 .and_then(|values| values.get(0))
201 .map(|value| value.as_str())
202 .unwrap_or("http"),
203 request.host().unwrap()
204 ))?;
205
206 let app_sign_in_params: Option<NativeAppSignInParams> = request.query().ok();
207 if let Some(query) = app_sign_in_params {
208 redirect_url
209 .query_pairs_mut()
210 .clear()
211 .append_pair("native_app_port", &query.native_app_port)
212 .append_pair("native_app_public_key", &query.native_app_public_key);
213 }
214
215 let (auth_url, csrf_token) = request
216 .state()
217 .auth_client
218 .authorize_url(CsrfToken::new_random)
219 .set_redirect_uri(Cow::Owned(RedirectUrl::from_url(redirect_url)))
220 .set_pkce_challenge(pkce_challenge)
221 .url();
222
223 request
224 .session_mut()
225 .insert("auth_csrf_token", csrf_token)?;
226
227 Ok(tide::Redirect::new(auth_url).into())
228}
229
230async fn get_auth_callback(mut request: Request) -> tide::Result {
231 #[derive(Debug, Deserialize)]
232 struct Query {
233 code: String,
234 state: String,
235
236 #[serde(flatten)]
237 native_app_sign_in_params: Option<NativeAppSignInParams>,
238 }
239
240 let query: Query = request.query()?;
241
242 let pkce_verifier = request
243 .session()
244 .get("pkce_verifier")
245 .ok_or_else(|| anyhow!("could not retrieve pkce_verifier from session"))?;
246
247 let csrf_token = request
248 .session()
249 .get::<CsrfToken>("auth_csrf_token")
250 .ok_or_else(|| anyhow!("could not retrieve auth_csrf_token from session"))?;
251
252 if &query.state != csrf_token.secret() {
253 return Err(anyhow!("csrf token does not match").into());
254 }
255
256 let github_access_token = request
257 .state()
258 .auth_client
259 .exchange_code(AuthorizationCode::new(query.code))
260 .set_pkce_verifier(pkce_verifier)
261 .request_async(oauth2_surf::http_client)
262 .await
263 .context("failed to exchange oauth code")?
264 .access_token()
265 .secret()
266 .clone();
267
268 let user_details = request
269 .state()
270 .github_client
271 .user(github_access_token)
272 .details()
273 .await
274 .context("failed to fetch user")?;
275
276 let user = request
277 .db()
278 .get_user_by_github_login(&user_details.login)
279 .await?;
280
281 request
282 .session_mut()
283 .insert(CURRENT_GITHUB_USER, user_details.clone())?;
284
285 // When signing in from the native app, generate a new access token for the current user. Return
286 // a redirect so that the user's browser sends this access token to the locally-running app.
287 if let Some((user, app_sign_in_params)) = user.zip(query.native_app_sign_in_params) {
288 let access_token = create_access_token(request.db(), user.id).await?;
289 let native_app_public_key =
290 zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
291 .context("failed to parse app public key")?;
292 let encrypted_access_token = native_app_public_key
293 .encrypt_string(&access_token)
294 .context("failed to encrypt access token with public key")?;
295
296 return Ok(tide::Redirect::new(&format!(
297 "http://127.0.0.1:{}?user_id={}&access_token={}",
298 app_sign_in_params.native_app_port, user.id.0, encrypted_access_token,
299 ))
300 .into());
301 }
302
303 Ok(tide::Redirect::new("/").into())
304}
305
306async fn post_sign_out(mut request: Request) -> tide::Result {
307 request.session_mut().remove(CURRENT_GITHUB_USER);
308 Ok(tide::Redirect::new("/").into())
309}
310
311pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
312 let access_token = zed_auth::random_token();
313 let access_token_hash =
314 hash_access_token(&access_token).context("failed to hash access token")?;
315 db.create_access_token_hash(user_id, access_token_hash)
316 .await?;
317 Ok(access_token)
318}
319
320fn hash_access_token(token: &str) -> tide::Result<String> {
321 // Avoid slow hashing in debug mode.
322 let params = if cfg!(debug_assertions) {
323 scrypt::Params::new(1, 1, 1).unwrap()
324 } else {
325 scrypt::Params::recommended()
326 };
327
328 Ok(Scrypt
329 .hash_password(
330 token.as_bytes(),
331 None,
332 params,
333 &SaltString::generate(thread_rng()),
334 )?
335 .to_string())
336}
337
338pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
339 let hash = PasswordHash::new(hash)?;
340 Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
341}