Store the impersonator id on access tokens created via ZED_IMPERSONATE (#4108)

Max Brunsfeld created

* Use the impersonator id to prevent these tokens from counting against
the impersonated user when limiting the users' total of access tokens.
* When connecting using an access token with an impersonator add the
impersonator as a field to the tracing span that wraps the task for that
connection.
* Disallow impersonating users via the admin API token in production,
because when using the admin API token, we aren't able to identify the
impersonator.

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql                |   1 
crates/collab/migrations/20240117150300_add_impersonator_to_access_tokens.sql |   1 
crates/collab/src/api.rs                                                      |   9 
crates/collab/src/auth.rs                                                     |  99 
crates/collab/src/db/queries/access_tokens.rs                                 |   2 
crates/collab/src/db/tables/access_token.rs                                   |   1 
crates/collab/src/db/tests/db_tests.rs                                        | 104 
crates/collab/src/rpc.rs                                                      |  20 
crates/collab/src/tests/test_server.rs                                        |   1 
docs/src/developing_zed__building_zed.md                                      |   2 
script/lib/squawk.toml                                                        |   4 
script/squawk                                                                 |   5 
12 files changed, 202 insertions(+), 47 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -19,6 +19,7 @@ CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id");
 CREATE TABLE "access_tokens" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "user_id" INTEGER REFERENCES users (id),
+    "impersonated_user_id" INTEGER REFERENCES users (id),
     "hash" VARCHAR(128)
 );
 CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id");

crates/collab/src/api.rs 🔗

@@ -156,11 +156,11 @@ async fn create_access_token(
         .await?
         .ok_or_else(|| anyhow!("user not found"))?;
 
-    let mut user_id = user.id;
+    let mut impersonated_user_id = None;
     if let Some(impersonate) = params.impersonate {
         if user.admin {
             if let Some(impersonated_user) = app.db.get_user_by_github_login(&impersonate).await? {
-                user_id = impersonated_user.id;
+                impersonated_user_id = Some(impersonated_user.id);
             } else {
                 return Err(Error::Http(
                     StatusCode::UNPROCESSABLE_ENTITY,
@@ -175,12 +175,13 @@ async fn create_access_token(
         }
     }
 
-    let access_token = auth::create_access_token(app.db.as_ref(), user_id).await?;
+    let access_token =
+        auth::create_access_token(app.db.as_ref(), user_id, impersonated_user_id).await?;
     let encrypted_access_token =
         auth::encrypt_access_token(&access_token, params.public_key.clone())?;
 
     Ok(Json(CreateAccessTokenResponse {
-        user_id,
+        user_id: impersonated_user_id.unwrap_or(user_id),
         encrypted_access_token,
     }))
 }

crates/collab/src/auth.rs 🔗

@@ -27,6 +27,9 @@ lazy_static! {
     .unwrap();
 }
 
+#[derive(Clone, Debug, Default, PartialEq, Eq)]
+pub struct Impersonator(pub Option<db::User>);
+
 /// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
 /// and one for the access tokens that we issue.
 pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
@@ -57,28 +60,50 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
     })?;
 
     let state = req.extensions().get::<Arc<AppState>>().unwrap();
-    let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
-        state.config.api_token == admin_token
+
+    // In development, allow impersonation using the admin API token.
+    // Don't allow this in production because we can't tell who is doing
+    // the impersonating.
+    let validate_result = if let (Some(admin_token), true) = (
+        access_token.strip_prefix("ADMIN_TOKEN:"),
+        state.config.is_development(),
+    ) {
+        Ok(VerifyAccessTokenResult {
+            is_valid: state.config.api_token == admin_token,
+            impersonator_id: None,
+        })
     } else {
-        verify_access_token(&access_token, user_id, &state.db)
-            .await
-            .unwrap_or(false)
+        verify_access_token(&access_token, user_id, &state.db).await
     };
 
-    if credentials_valid {
-        let user = state
-            .db
-            .get_user_by_id(user_id)
-            .await?
-            .ok_or_else(|| anyhow!("user {} not found", user_id))?;
-        req.extensions_mut().insert(user);
-        Ok::<_, Error>(next.run(req).await)
-    } else {
-        Err(Error::Http(
-            StatusCode::UNAUTHORIZED,
-            "invalid credentials".to_string(),
-        ))
+    if let Ok(validate_result) = validate_result {
+        if validate_result.is_valid {
+            let user = state
+                .db
+                .get_user_by_id(user_id)
+                .await?
+                .ok_or_else(|| anyhow!("user {} not found", user_id))?;
+
+            let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
+                let impersonator = state
+                    .db
+                    .get_user_by_id(impersonator_id)
+                    .await?
+                    .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
+                Some(impersonator)
+            } else {
+                None
+            };
+            req.extensions_mut().insert(user);
+            req.extensions_mut().insert(Impersonator(impersonator));
+            return Ok::<_, Error>(next.run(req).await);
+        }
     }
+
+    Err(Error::Http(
+        StatusCode::UNAUTHORIZED,
+        "invalid credentials".to_string(),
+    ))
 }
 
 const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
@@ -92,13 +117,22 @@ struct AccessTokenJson {
 
 /// Creates a new access token to identify the given user. before returning it, you should
 /// encrypt it with the user's public key.
-pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
+pub async fn create_access_token(
+    db: &db::Database,
+    user_id: UserId,
+    impersonated_user_id: Option<UserId>,
+) -> Result<String> {
     const VERSION: usize = 1;
     let access_token = rpc::auth::random_token();
     let access_token_hash =
         hash_access_token(&access_token).context("failed to hash access token")?;
     let id = db
-        .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
+        .create_access_token(
+            user_id,
+            impersonated_user_id,
+            &access_token_hash,
+            MAX_ACCESS_TOKENS_TO_STORE,
+        )
         .await?;
     Ok(serde_json::to_string(&AccessTokenJson {
         version: VERSION,
@@ -137,12 +171,22 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<St
     Ok(encrypted_access_token)
 }
 
-/// verify access token returns true if the given token is valid for the given user.
-pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> {
+pub struct VerifyAccessTokenResult {
+    pub is_valid: bool,
+    pub impersonator_id: Option<UserId>,
+}
+
+/// Checks that the given access token is valid for the given user.
+pub async fn verify_access_token(
+    token: &str,
+    user_id: UserId,
+    db: &Arc<Database>,
+) -> Result<VerifyAccessTokenResult> {
     let token: AccessTokenJson = serde_json::from_str(&token)?;
 
     let db_token = db.get_access_token(token.id).await?;
-    if db_token.user_id != user_id {
+    let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
+    if token_user_id != user_id {
         return Err(anyhow!("no such access token"))?;
     }
 
@@ -154,5 +198,12 @@ pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database
     let duration = t0.elapsed();
     log::info!("hashed access token in {:?}", duration);
     METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
-    Ok(is_valid)
+    Ok(VerifyAccessTokenResult {
+        is_valid,
+        impersonator_id: if db_token.impersonated_user_id.is_some() {
+            Some(db_token.user_id)
+        } else {
+            None
+        },
+    })
 }

crates/collab/src/db/queries/access_tokens.rs 🔗

@@ -6,6 +6,7 @@ impl Database {
     pub async fn create_access_token(
         &self,
         user_id: UserId,
+        impersonated_user_id: Option<UserId>,
         access_token_hash: &str,
         max_access_token_count: usize,
     ) -> Result<AccessTokenId> {
@@ -14,6 +15,7 @@ impl Database {
 
             let token = access_token::ActiveModel {
                 user_id: ActiveValue::set(user_id),
+                impersonated_user_id: ActiveValue::set(impersonated_user_id),
                 hash: ActiveValue::set(access_token_hash.into()),
                 ..Default::default()
             }

crates/collab/src/db/tests/db_tests.rs 🔗

@@ -146,7 +146,7 @@ test_both_dbs!(
 );
 
 async fn test_create_access_tokens(db: &Arc<Database>) {
-    let user = db
+    let user_1 = db
         .create_user(
             "u1@example.com",
             false,
@@ -158,14 +158,27 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
         .await
         .unwrap()
         .user_id;
+    let user_2 = db
+        .create_user(
+            "u2@example.com",
+            false,
+            NewUserParams {
+                github_login: "u2".into(),
+                github_user_id: 2,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
 
-    let token_1 = db.create_access_token(user, "h1", 2).await.unwrap();
-    let token_2 = db.create_access_token(user, "h2", 2).await.unwrap();
+    let token_1 = db.create_access_token(user_1, None, "h1", 2).await.unwrap();
+    let token_2 = db.create_access_token(user_1, None, "h2", 2).await.unwrap();
     assert_eq!(
         db.get_access_token(token_1).await.unwrap(),
         access_token::Model {
             id: token_1,
-            user_id: user,
+            user_id: user_1,
+            impersonated_user_id: None,
             hash: "h1".into(),
         }
     );
@@ -173,17 +186,19 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
         db.get_access_token(token_2).await.unwrap(),
         access_token::Model {
             id: token_2,
-            user_id: user,
+            user_id: user_1,
+            impersonated_user_id: None,
             hash: "h2".into()
         }
     );
 
-    let token_3 = db.create_access_token(user, "h3", 2).await.unwrap();
+    let token_3 = db.create_access_token(user_1, None, "h3", 2).await.unwrap();
     assert_eq!(
         db.get_access_token(token_3).await.unwrap(),
         access_token::Model {
             id: token_3,
-            user_id: user,
+            user_id: user_1,
+            impersonated_user_id: None,
             hash: "h3".into()
         }
     );
@@ -191,18 +206,20 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
         db.get_access_token(token_2).await.unwrap(),
         access_token::Model {
             id: token_2,
-            user_id: user,
+            user_id: user_1,
+            impersonated_user_id: None,
             hash: "h2".into()
         }
     );
     assert!(db.get_access_token(token_1).await.is_err());
 
-    let token_4 = db.create_access_token(user, "h4", 2).await.unwrap();
+    let token_4 = db.create_access_token(user_1, None, "h4", 2).await.unwrap();
     assert_eq!(
         db.get_access_token(token_4).await.unwrap(),
         access_token::Model {
             id: token_4,
-            user_id: user,
+            user_id: user_1,
+            impersonated_user_id: None,
             hash: "h4".into()
         }
     );
@@ -210,12 +227,77 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
         db.get_access_token(token_3).await.unwrap(),
         access_token::Model {
             id: token_3,
-            user_id: user,
+            user_id: user_1,
+            impersonated_user_id: None,
             hash: "h3".into()
         }
     );
     assert!(db.get_access_token(token_2).await.is_err());
     assert!(db.get_access_token(token_1).await.is_err());
+
+    // An access token for user 2 impersonating user 1 does not
+    // count against user 1's access token limit (of 2).
+    let token_5 = db
+        .create_access_token(user_2, Some(user_1), "h5", 2)
+        .await
+        .unwrap();
+    assert_eq!(
+        db.get_access_token(token_5).await.unwrap(),
+        access_token::Model {
+            id: token_5,
+            user_id: user_2,
+            impersonated_user_id: Some(user_1),
+            hash: "h5".into()
+        }
+    );
+    assert_eq!(
+        db.get_access_token(token_3).await.unwrap(),
+        access_token::Model {
+            id: token_3,
+            user_id: user_1,
+            impersonated_user_id: None,
+            hash: "h3".into()
+        }
+    );
+
+    // Only a limited number (2) of access tokens are stored for user 2
+    // impersonating other users.
+    let token_6 = db
+        .create_access_token(user_2, Some(user_1), "h6", 2)
+        .await
+        .unwrap();
+    let token_7 = db
+        .create_access_token(user_2, Some(user_1), "h7", 2)
+        .await
+        .unwrap();
+    assert_eq!(
+        db.get_access_token(token_6).await.unwrap(),
+        access_token::Model {
+            id: token_6,
+            user_id: user_2,
+            impersonated_user_id: Some(user_1),
+            hash: "h6".into()
+        }
+    );
+    assert_eq!(
+        db.get_access_token(token_7).await.unwrap(),
+        access_token::Model {
+            id: token_7,
+            user_id: user_2,
+            impersonated_user_id: Some(user_1),
+            hash: "h7".into()
+        }
+    );
+    assert!(db.get_access_token(token_5).await.is_err());
+    assert_eq!(
+        db.get_access_token(token_3).await.unwrap(),
+        access_token::Model {
+            id: token_3,
+            user_id: user_1,
+            impersonated_user_id: None,
+            hash: "h3".into()
+        }
+    );
 }
 
 test_both_dbs!(

crates/collab/src/rpc.rs 🔗

@@ -1,7 +1,7 @@
 mod connection_pool;
 
 use crate::{
-    auth,
+    auth::{self, Impersonator},
     db::{
         self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
         CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
@@ -65,7 +65,7 @@ use std::{
 use time::OffsetDateTime;
 use tokio::sync::{watch, Semaphore};
 use tower::ServiceBuilder;
-use tracing::{info_span, instrument, Instrument};
+use tracing::{field, info_span, instrument, Instrument};
 
 pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
 pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
@@ -561,13 +561,17 @@ impl Server {
         connection: Connection,
         address: String,
         user: User,
+        impersonator: Option<User>,
         mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
         executor: Executor,
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         let user_id = user.id;
         let login = user.github_login;
-        let span = info_span!("handle connection", %user_id, %login, %address);
+        let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
+        if let Some(impersonator) = impersonator {
+            span.record("impersonator", &impersonator.github_login);
+        }
         let mut teardown = self.teardown.subscribe();
         async move {
             let (connection_id, handle_io, mut incoming_rx) = this
@@ -839,6 +843,7 @@ pub async fn handle_websocket_request(
     ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
     Extension(server): Extension<Arc<Server>>,
     Extension(user): Extension<User>,
+    Extension(impersonator): Extension<Impersonator>,
     ws: WebSocketUpgrade,
 ) -> axum::response::Response {
     if protocol_version != rpc::PROTOCOL_VERSION {
@@ -858,7 +863,14 @@ pub async fn handle_websocket_request(
         let connection = Connection::new(Box::pin(socket));
         async move {
             server
-                .handle_connection(connection, socket_address, user, None, Executor::Production)
+                .handle_connection(
+                    connection,
+                    socket_address,
+                    user,
+                    impersonator.0,
+                    None,
+                    Executor::Production,
+                )
                 .await
                 .log_err();
         }

docs/src/developing_zed__building_zed.md 🔗

@@ -14,7 +14,7 @@
 - Ensure that the Xcode command line tools are using your newly installed copy of Xcode:
 
     ```
-    sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer.
+    sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer
     ```
 
 * Install the Rust wasm toolchain:

script/squawk 🔗

@@ -8,13 +8,12 @@ set -e
 
 if [ -z "$GITHUB_BASE_REF" ]; then
   echo 'Not a pull request, skipping squawk modified migrations linting'
-  return 0
+  exit
 fi
 
 SQUAWK_VERSION=0.26.0
 SQUAWK_BIN="./target/squawk-$SQUAWK_VERSION"
-SQUAWK_ARGS="--assume-in-transaction"
-
+SQUAWK_ARGS="--assume-in-transaction --config script/lib/squawk.toml"
 
 if  [ ! -f "$SQUAWK_BIN" ]; then
   curl -L -o "$SQUAWK_BIN" "https://github.com/sbdchd/squawk/releases/download/v$SQUAWK_VERSION/squawk-darwin-x86_64"