From 83e4e269896e84c43f64e523f02973a4c41f9673 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 19 Oct 2022 13:27:14 -0700 Subject: [PATCH] Allow setting ZED_SERVER_URL to URL of a collab server --- crates/client/src/client.rs | 68 ++++++++++++++++++++----------------- crates/collab/src/auth.rs | 11 +++--- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 9cfccba37f9963beaeea18e9db1a6e4f012cc7ff..64075472cdc7b39090045cd7af233f6eea39a91f 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -926,29 +926,34 @@ impl Client { } async fn get_rpc_url(http: Arc) -> Result { - let rpc_response = http - .get( - &(format!("{}/rpc", *ZED_SERVER_URL)), - Default::default(), - false, - ) - .await?; - if !rpc_response.status().is_redirection() { + let url = format!("{}/rpc", *ZED_SERVER_URL); + let response = http.get(&url, Default::default(), false).await?; + + // Normally, ZED_SERVER_URL is set to the URL of zed.dev website. + // The website's /rpc endpoint redirects to a collab server's /rpc endpoint, + // which requires authorization via an HTTP header. + // + // For testing purposes, ZED_SERVER_URL can also set to the direct URL of + // of a collab server. In that case, a request to the /rpc endpoint will + // return an 'unauthorized' response. + let collab_url = if response.status().is_redirection() { + response + .headers() + .get("Location") + .ok_or_else(|| anyhow!("missing location header in /rpc response"))? + .to_str() + .map_err(EstablishConnectionError::other)? + .to_string() + } else if response.status() == StatusCode::UNAUTHORIZED { + url + } else { Err(anyhow!( "unexpected /rpc response status {}", - rpc_response.status() + response.status() ))? - } - - let rpc_url = rpc_response - .headers() - .get("Location") - .ok_or_else(|| anyhow!("missing location header in /rpc response"))? - .to_str() - .map_err(EstablishConnectionError::other)? - .to_string(); + }; - Url::parse(&rpc_url).context("invalid rpc url") + Url::parse(&collab_url).context("invalid rpc url") } fn establish_websocket_connection( @@ -1105,6 +1110,18 @@ impl Client { login: String, mut api_token: String, ) -> Result { + #[derive(Deserialize)] + struct AuthenticatedUserResponse { + user: User, + } + + #[derive(Deserialize)] + struct User { + id: u64, + } + + // Use the collab server's admin API to retrieve the id + // of the impersonated user. let mut url = Self::get_rpc_url(http.clone()).await?; url.set_path("/user"); url.set_query(Some(&format!("github_login={login}"))); @@ -1115,7 +1132,6 @@ impl Client { let mut response = http.send(request).await?; let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - if !response.status().is_success() { Err(anyhow!( "admin user request failed {} - {}", @@ -1123,19 +1139,9 @@ impl Client { body, ))?; } - - #[derive(Deserialize)] - struct AuthenticatedUserResponse { - user: User, - } - - #[derive(Deserialize)] - struct User { - id: u64, - } - let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + // Use the admin API token to authenticate as the impersonated user. api_token.insert_str(0, "ADMIN_TOKEN:"); Ok(Credentials { user_id: response.user.id, diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index e9e2855f1c2bb707d9c70eaffce810c4a0d49c81..9081fe1f1e793bab5e7825941ce198f8c0a14a67 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -1,7 +1,7 @@ -use std::sync::Arc; - -use super::db::{self, UserId}; -use crate::{AppState, Error, Result}; +use crate::{ + db::{self, UserId}, + AppState, Error, Result, +}; use anyhow::{anyhow, Context}; use axum::{ http::{self, Request, StatusCode}, @@ -13,6 +13,7 @@ use scrypt::{ password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Scrypt, }; +use std::sync::Arc; pub async fn validate_header(mut req: Request, next: Next) -> impl IntoResponse { let mut auth_header = req @@ -21,7 +22,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into .and_then(|header| header.to_str().ok()) .ok_or_else(|| { Error::Http( - StatusCode::BAD_REQUEST, + StatusCode::UNAUTHORIZED, "missing authorization header".to_string(), ) })?