Merge pull request #1777 from zed-industries/impersonate-via-secret-token

Max Brunsfeld created

Impersonate via secret token

Change summary

crates/client/src/client.rs             | 117 +++++++++++++++++++++------
crates/collab/.env.toml                 |   2 
crates/collab/k8s/manifest.template.yml |  32 -------
crates/collab/src/api.rs                |   6 
crates/collab/src/auth.rs               |  25 +++--
crates/collab/src/integration_tests.rs  |   3 
crates/collab/src/main.rs               |  14 +--
crates/collab/src/rpc.rs                |   9 +
script/zed-with-local-servers           |   2 
9 files changed, 123 insertions(+), 87 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -13,11 +13,13 @@ use async_tungstenite::tungstenite::{
     http::{Request, StatusCode},
 };
 use db::Db;
-use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
+use futures::{future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryStreamExt};
 use gpui::{
-    actions, serde_json::Value, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle,
-    AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle,
-    MutableAppContext, Task, View, ViewContext, ViewHandle,
+    actions,
+    serde_json::{self, Value},
+    AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
+    AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext,
+    ViewHandle,
 };
 use http::HttpClient;
 use lazy_static::lazy_static;
@@ -25,6 +27,7 @@ use parking_lot::RwLock;
 use postage::watch;
 use rand::prelude::*;
 use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
+use serde::Deserialize;
 use std::{
     any::TypeId,
     collections::HashMap,
@@ -50,6 +53,9 @@ lazy_static! {
     pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
         .ok()
         .and_then(|s| if s.is_empty() { None } else { Some(s) });
+    pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
+        .ok()
+        .and_then(|s| if s.is_empty() { None } else { Some(s) });
 }
 
 pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
@@ -919,6 +925,37 @@ impl Client {
         self.establish_websocket_connection(credentials, cx)
     }
 
+    async fn get_rpc_url(http: Arc<dyn HttpClient>) -> Result<Url> {
+        let url = format!("{}/rpc", *ZED_SERVER_URL);
+        let response = http.get(&url, Default::default(), false).await?;
+
+        // Normally, ZED_SERVER_URL is set to the URL of zed.dev website.
+        // The website's /rpc endpoint redirects to a collab server's /rpc endpoint,
+        // which requires authorization via an HTTP header.
+        //
+        // For testing purposes, ZED_SERVER_URL can also set to the direct URL of
+        // of a collab server. In that case, a request to the /rpc endpoint will
+        // return an 'unauthorized' response.
+        let collab_url = if response.status().is_redirection() {
+            response
+                .headers()
+                .get("Location")
+                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
+                .to_str()
+                .map_err(EstablishConnectionError::other)?
+                .to_string()
+        } else if response.status() == StatusCode::UNAUTHORIZED {
+            url
+        } else {
+            Err(anyhow!(
+                "unexpected /rpc response status {}",
+                response.status()
+            ))?
+        };
+
+        Url::parse(&collab_url).context("invalid rpc url")
+    }
+
     fn establish_websocket_connection(
         self: &Arc<Self>,
         credentials: &Credentials,
@@ -933,28 +970,7 @@ impl Client {
 
         let http = self.http.clone();
         cx.background().spawn(async move {
-            let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
-            let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
-            if rpc_response.status().is_redirection() {
-                rpc_url = rpc_response
-                    .headers()
-                    .get("Location")
-                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
-                    .to_str()
-                    .map_err(EstablishConnectionError::other)?
-                    .to_string();
-            }
-            // Until we switch the zed.dev domain to point to the new Next.js app, there
-            // will be no redirect required, and the app will connect directly to
-            // wss://zed.dev/rpc.
-            else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
-                Err(anyhow!(
-                    "unexpected /rpc response status {}",
-                    rpc_response.status()
-                ))?
-            }
-
-            let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
+            let mut rpc_url = Self::get_rpc_url(http).await?;
             let rpc_host = rpc_url
                 .host_str()
                 .zip(rpc_url.port_or_known_default())
@@ -997,6 +1013,7 @@ impl Client {
         let platform = cx.platform();
         let executor = cx.background();
         let telemetry = self.telemetry.clone();
+        let http = self.http.clone();
         executor.clone().spawn(async move {
             // Generate a pair of asymmetric encryption keys. The public key will be used by the
             // zed server to encrypt the user's access token, so that it can'be intercepted by
@@ -1006,6 +1023,10 @@ impl Client {
             let public_key_string =
                 String::try_from(public_key).expect("failed to serialize public key for auth");
 
+            if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
+                return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
+            }
+
             // Start an HTTP server to receive the redirect from Zed's sign-in page.
             let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
             let port = server.server_addr().port();
@@ -1084,6 +1105,50 @@ impl Client {
         })
     }
 
+    async fn authenticate_as_admin(
+        http: Arc<dyn HttpClient>,
+        login: String,
+        mut api_token: String,
+    ) -> Result<Credentials> {
+        #[derive(Deserialize)]
+        struct AuthenticatedUserResponse {
+            user: User,
+        }
+
+        #[derive(Deserialize)]
+        struct User {
+            id: u64,
+        }
+
+        // Use the collab server's admin API to retrieve the id
+        // of the impersonated user.
+        let mut url = Self::get_rpc_url(http.clone()).await?;
+        url.set_path("/user");
+        url.set_query(Some(&format!("github_login={login}")));
+        let request = Request::get(url.as_str())
+            .header("Authorization", format!("token {api_token}"))
+            .body("".into())?;
+
+        let mut response = http.send(request).await?;
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+        if !response.status().is_success() {
+            Err(anyhow!(
+                "admin user request failed {} - {}",
+                response.status().as_u16(),
+                body,
+            ))?;
+        }
+        let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
+
+        // Use the admin API token to authenticate as the impersonated user.
+        api_token.insert_str(0, "ADMIN_TOKEN:");
+        Ok(Credentials {
+            user_id: response.user.id,
+            access_token: api_token,
+        })
+    }
+
     pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
         let conn_id = self.connection_id()?;
         self.peer.disconnect(conn_id);

crates/collab/.env.toml 🔗

@@ -3,7 +3,5 @@ HTTP_PORT = 8080
 API_TOKEN = "secret"
 INVITE_LINK_PREFIX = "http://localhost:3000/invites/"
 
-# HONEYCOMB_API_KEY=
-# HONEYCOMB_DATASET=
 # RUST_LOG=info
 # LOG_JSON=true

crates/collab/k8s/manifest.template.yml 🔗

@@ -65,31 +65,6 @@ spec:
                 secretKeyRef:
                   name: database
                   key: url
-            - name: SESSION_SECRET
-              valueFrom:
-                secretKeyRef:
-                  name: session
-                  key: secret
-            - name: GITHUB_APP_ID
-              valueFrom:
-                secretKeyRef:
-                  name: github
-                  key: appId
-            - name: GITHUB_CLIENT_ID
-              valueFrom:
-                secretKeyRef:
-                  name: github
-                  key: clientId
-            - name: GITHUB_CLIENT_SECRET
-              valueFrom:
-                secretKeyRef:
-                  name: github
-                  key: clientSecret
-            - name: GITHUB_PRIVATE_KEY
-              valueFrom:
-                secretKeyRef:
-                  name: github
-                  key: privateKey
             - name: API_TOKEN
               valueFrom:
                 secretKeyRef:
@@ -101,13 +76,6 @@ spec:
               value: ${RUST_LOG}
             - name: LOG_JSON
               value: "true"
-            - name: HONEYCOMB_DATASET
-              value: "collab"
-            - name: HONEYCOMB_API_KEY
-              valueFrom:
-                secretKeyRef:
-                  name: honeycomb
-                  key: apiKey
           securityContext:
             capabilities:
               # FIXME - Switch to the more restrictive `PERFMON` capability.

crates/collab/src/api.rs 🔗

@@ -76,7 +76,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
 
     let state = req.extensions().get::<Arc<AppState>>().unwrap();
 
-    if token != state.api_token {
+    if token != state.config.api_token {
         Err(Error::Http(
             StatusCode::UNAUTHORIZED,
             "invalid authorization token".to_string(),
@@ -88,7 +88,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
 
 #[derive(Debug, Deserialize)]
 struct AuthenticatedUserParams {
-    github_user_id: i32,
+    github_user_id: Option<i32>,
     github_login: String,
 }
 
@@ -104,7 +104,7 @@ async fn get_authenticated_user(
 ) -> Result<Json<AuthenticatedUserResponse>> {
     let user = app
         .db
-        .get_user_by_github_account(&params.github_login, Some(params.github_user_id))
+        .get_user_by_github_account(&params.github_login, params.github_user_id)
         .await?
         .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "user not found".into()))?;
     let metrics_id = app.db.get_user_metrics_id(user.id).await?;

crates/collab/src/auth.rs 🔗

@@ -1,7 +1,7 @@
-use std::sync::Arc;
-
-use super::db::{self, UserId};
-use crate::{AppState, Error, Result};
+use crate::{
+    db::{self, UserId},
+    AppState, Error, Result,
+};
 use anyhow::{anyhow, Context};
 use axum::{
     http::{self, Request, StatusCode},
@@ -13,6 +13,7 @@ use scrypt::{
     password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
     Scrypt,
 };
+use std::sync::Arc;
 
 pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
     let mut auth_header = req
@@ -21,7 +22,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
         .and_then(|header| header.to_str().ok())
         .ok_or_else(|| {
             Error::Http(
-                StatusCode::BAD_REQUEST,
+                StatusCode::UNAUTHORIZED,
                 "missing authorization header".to_string(),
             )
         })?
@@ -41,12 +42,18 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
         )
     })?;
 
-    let state = req.extensions().get::<Arc<AppState>>().unwrap();
     let mut credentials_valid = false;
-    for password_hash in state.db.get_access_token_hashes(user_id).await? {
-        if verify_access_token(access_token, &password_hash)? {
+    let state = req.extensions().get::<Arc<AppState>>().unwrap();
+    if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
+        if state.config.api_token == admin_token {
             credentials_valid = true;
-            break;
+        }
+    } else {
+        for password_hash in state.db.get_access_token_hashes(user_id).await? {
+            if verify_access_token(access_token, &password_hash)? {
+                credentials_valid = true;
+                break;
+            }
         }
     }
 

crates/collab/src/integration_tests.rs 🔗

@@ -6357,8 +6357,7 @@ impl TestServer {
     async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
         Arc::new(AppState {
             db: test_db.db().clone(),
-            api_token: Default::default(),
-            invite_link_prefix: Default::default(),
+            config: Default::default(),
         })
     }
 

crates/collab/src/main.rs 🔗

@@ -28,25 +28,21 @@ pub struct Config {
     pub database_url: String,
     pub api_token: String,
     pub invite_link_prefix: String,
-    pub honeycomb_api_key: Option<String>,
-    pub honeycomb_dataset: Option<String>,
     pub rust_log: Option<String>,
     pub log_json: Option<bool>,
 }
 
 pub struct AppState {
     db: Arc<dyn Db>,
-    api_token: String,
-    invite_link_prefix: String,
+    config: Config,
 }
 
 impl AppState {
-    async fn new(config: &Config) -> Result<Arc<Self>> {
+    async fn new(config: Config) -> Result<Arc<Self>> {
         let db = PostgresDb::new(&config.database_url, 5).await?;
         let this = Self {
             db: Arc::new(db),
-            api_token: config.api_token.clone(),
-            invite_link_prefix: config.invite_link_prefix.clone(),
+            config,
         };
         Ok(Arc::new(this))
     }
@@ -63,9 +59,9 @@ async fn main() -> Result<()> {
 
     let config = envy::from_env::<Config>().expect("error loading config");
     init_tracing(&config);
-    let state = AppState::new(&config).await?;
+    let state = AppState::new(config).await?;
 
-    let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
+    let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
         .expect("failed to bind TCP listener");
     let rpc_server = rpc::Server::new(state.clone(), None);
 

crates/collab/src/rpc.rs 🔗

@@ -397,7 +397,7 @@ impl Server {
 
                 if let Some((code, count)) = invite_code {
                     this.peer.send(connection_id, proto::UpdateInviteInfo {
-                        url: format!("{}{}", this.app_state.invite_link_prefix, code),
+                        url: format!("{}{}", this.app_state.config.invite_link_prefix, code),
                         count,
                     })?;
                 }
@@ -561,7 +561,7 @@ impl Server {
                     self.peer.send(
                         connection_id,
                         proto::UpdateInviteInfo {
-                            url: format!("{}{}", self.app_state.invite_link_prefix, &code),
+                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
                             count: user.invite_count as u32,
                         },
                     )?;
@@ -579,7 +579,10 @@ impl Server {
                     self.peer.send(
                         connection_id,
                         proto::UpdateInviteInfo {
-                            url: format!("{}{}", self.app_state.invite_link_prefix, invite_code),
+                            url: format!(
+                                "{}{}",
+                                self.app_state.config.invite_link_prefix, invite_code
+                            ),
                             count: user.invite_count as u32,
                         },
                     )?;

script/zed-with-local-servers 🔗

@@ -1 +1 @@
-ZED_SERVER_URL=http://localhost:3000 cargo run $@
+ZED_ADMIN_API_TOKEN=secret ZED_SERVER_URL=http://localhost:3000 cargo run $@