Enable authentication via the NextJS site

Nathan Sobo created

Change summary

crates/client/src/client.rs | 41 ++++++++++++++++---------------
crates/server/src/api.rs    | 50 +++++++++++++++++++++++++++++++++++---
crates/server/src/auth.rs   | 19 ++++++++++----
3 files changed, 79 insertions(+), 31 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -35,8 +35,10 @@ pub use rpc::*;
 pub use user::*;
 
 lazy_static! {
-    static ref ZED_SERVER_URL: String =
-        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
+    static ref COLLAB_URL: String =
+        std::env::var("ZED_COLLAB_URL").unwrap_or("https://collab.zed.dev:443".to_string());
+    static ref SITE_URL: String =
+        std::env::var("ZED_SITE_URL").unwrap_or("https://zed.dev".to_string());
     static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
         .ok()
         .and_then(|s| if s.is_empty() { None } else { Some(s) });
@@ -403,7 +405,7 @@ impl Client {
 
         match self.establish_connection(&credentials, cx).await {
             Ok(conn) => {
-                log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+                log::info!("connected to rpc address {}", *COLLAB_URL);
                 self.state.write().credentials = Some(credentials.clone());
                 if !used_keychain && IMPERSONATE_LOGIN.is_none() {
                     write_credentials_to_keychain(&credentials, cx).log_err();
@@ -414,7 +416,7 @@ impl Client {
             Err(EstablishConnectionError::Unauthorized) => {
                 self.state.write().credentials.take();
                 if used_keychain {
-                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
+                    cx.platform().delete_credentials(&COLLAB_URL).log_err();
                     self.set_status(Status::SignedOut, cx);
                     self.authenticate_and_connect(cx).await
                 } else {
@@ -522,19 +524,19 @@ impl Client {
             )
             .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
         cx.background().spawn(async move {
-            if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
+            if let Some(host) = COLLAB_URL.strip_prefix("https://") {
                 let stream = smol::net::TcpStream::connect(host).await?;
                 let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
                 let (stream, _) =
                     async_tungstenite::async_tls::client_async_tls(request, stream).await?;
                 Ok(Connection::new(stream))
-            } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
+            } else if let Some(host) = COLLAB_URL.strip_prefix("http://") {
                 let stream = smol::net::TcpStream::connect(host).await?;
                 let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
                 let (stream, _) = async_tungstenite::client_async(request, stream).await?;
                 Ok(Connection::new(stream))
             } else {
-                Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?
+                Err(anyhow!("invalid server url: {}", *COLLAB_URL))?
             }
         })
     }
@@ -561,8 +563,8 @@ impl Client {
             // Open the Zed sign-in page in the user's browser, with query parameters that indicate
             // that the user is signing in from a Zed app running on the same device.
             let mut url = format!(
-                "{}/sign_in?native_app_port={}&native_app_public_key={}",
-                *ZED_SERVER_URL, port, public_key_string
+                "{}/native_app_signin?native_app_port={}&native_app_public_key={}",
+                *SITE_URL, port, public_key_string
             );
 
             if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
@@ -592,9 +594,15 @@ impl Client {
                                 user_id = Some(value.to_string());
                             }
                         }
+
+                        let post_auth_url = format!("{}/native_app_signin_succeeded", *SITE_URL);
                         req.respond(
-                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
-                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
+                            tiny_http::Response::empty(302).with_header(
+                                tiny_http::Header::from_bytes(
+                                    &b"Location"[..],
+                                    post_auth_url.as_bytes(),
+                                )
+                                .unwrap(),
                             ),
                         )
                         .context("failed to respond to login http request")?;
@@ -660,7 +668,7 @@ fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
 
     let (user_id, access_token) = cx
         .platform()
-        .read_credentials(&ZED_SERVER_URL)
+        .read_credentials(&COLLAB_URL)
         .log_err()
         .flatten()?;
     Some(Credentials {
@@ -671,7 +679,7 @@ fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
 
 fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
     cx.platform().write_credentials(
-        &ZED_SERVER_URL,
+        &COLLAB_URL,
         &credentials.user_id.to_string(),
         credentials.access_token.as_bytes(),
     )
@@ -694,13 +702,6 @@ pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
     Some((id, access_token.to_string()))
 }
 
-const LOGIN_RESPONSE: &'static str = "
-<!DOCTYPE html>
-<html>
-<script>window.close();</script>
-</html>
-";
-
 #[cfg(test)]
 mod tests {
     use super::*;

crates/server/src/api.rs 🔗

@@ -1,7 +1,9 @@
 use crate::{auth, AppState, Request, RequestExt as _};
 use async_trait::async_trait;
+use serde::Deserialize;
 use serde_json::json;
 use std::sync::Arc;
+use surf::StatusCode;
 
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>) {
     app.at("/users/:github_login").get(get_user);
@@ -18,7 +20,7 @@ async fn get_user(request: Request) -> tide::Result {
         .await?
         .ok_or_else(|| surf::Error::from_str(404, "user not found"))?;
 
-    Ok(tide::Response::builder(200)
+    Ok(tide::Response::builder(StatusCode::Ok)
         .body(tide::Body::from_json(&user)?)
         .build())
 }
@@ -30,11 +32,49 @@ async fn create_access_token(request: Request) -> tide::Result {
         .db()
         .get_user_by_github_login(request.param("github_login")?)
         .await?
-        .ok_or_else(|| surf::Error::from_str(404, "user not found"))?;
-    let token = auth::create_access_token(request.db(), user.id).await?;
+        .ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?;
+    let access_token = auth::create_access_token(request.db(), user.id).await?;
+
+    #[derive(Deserialize)]
+    struct QueryParams {
+        public_key: String,
+        impersonate: Option<String>,
+    }
+
+    let query_params: QueryParams = request.query().map_err(|_| {
+        surf::Error::from_str(StatusCode::UnprocessableEntity, "invalid query params")
+    })?;
+
+    let encrypted_access_token =
+        auth::encrypt_access_token(&access_token, query_params.public_key.clone())?;
+
+    let mut user_id = user.id;
+    if let Some(impersonate) = query_params.impersonate {
+        if user.admin {
+            if let Some(impersonated_user) =
+                request.db().get_user_by_github_login(&impersonate).await?
+            {
+                user_id = impersonated_user.id;
+            } else {
+                return Ok(tide::Response::builder(StatusCode::UnprocessableEntity)
+                    .body(format!(
+                        "Can't impersonate non-existent user {}",
+                        impersonate
+                    ))
+                    .build());
+            }
+        } else {
+            return Ok(tide::Response::builder(StatusCode::Unauthorized)
+                .body(format!(
+                    "Can't impersonate user {} because the real user isn't an admin",
+                    impersonate
+                ))
+                .build());
+        }
+    }
 
-    Ok(tide::Response::builder(200)
-        .body(json!({"user_id": user.id, "access_token": token}))
+    Ok(tide::Response::builder(StatusCode::Ok)
+        .body(json!({"user_id": user_id, "encrypted_access_token": encrypted_access_token}))
         .build())
 }
 

crates/server/src/auth.rs 🔗

@@ -238,12 +238,10 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
         }
 
         let access_token = create_access_token(request.db(), user_id).await?;
-        let native_app_public_key =
-            zed_auth::PublicKey::try_from(app_sign_in_params.native_app_public_key.clone())
-                .context("failed to parse app public key")?;
-        let encrypted_access_token = native_app_public_key
-            .encrypt_string(&access_token)
-            .context("failed to encrypt access token with public key")?;
+        let encrypted_access_token = encrypt_access_token(
+            &access_token,
+            app_sign_in_params.native_app_public_key.clone(),
+        )?;
 
         return Ok(tide::Redirect::new(&format!(
             "http://127.0.0.1:{}?user_id={}&access_token={}",
@@ -289,6 +287,15 @@ fn hash_access_token(token: &str) -> tide::Result<String> {
         .to_string())
 }
 
+pub fn encrypt_access_token(access_token: &str, public_key: String) -> tide::Result<String> {
+    let native_app_public_key =
+        zed_auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
+    let encrypted_access_token = native_app_public_key
+        .encrypt_string(&access_token)
+        .context("failed to encrypt access token with public key")?;
+    Ok(encrypted_access_token)
+}
+
 pub fn verify_access_token(token: &str, hash: &str) -> tide::Result<bool> {
     let hash = PasswordHash::new(hash)?;
     Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())