@@ -13,11 +13,13 @@ use async_tungstenite::tungstenite::{
http::{Request, StatusCode},
};
use db::Db;
-use futures::{future::LocalBoxFuture, FutureExt, SinkExt, StreamExt, TryStreamExt};
+use futures::{future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryStreamExt};
use gpui::{
- actions, serde_json::Value, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle,
- AnyWeakViewHandle, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle,
- MutableAppContext, Task, View, ViewContext, ViewHandle,
+ actions,
+ serde_json::{self, Value},
+ AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
+ AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext,
+ ViewHandle,
};
use http::HttpClient;
use lazy_static::lazy_static;
@@ -25,6 +27,7 @@ use parking_lot::RwLock;
use postage::watch;
use rand::prelude::*;
use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
+use serde::Deserialize;
use std::{
any::TypeId,
collections::HashMap,
@@ -50,6 +53,9 @@ lazy_static! {
pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
.ok()
.and_then(|s| if s.is_empty() { None } else { Some(s) });
+ pub static ref ADMIN_API_TOKEN: Option<String> = std::env::var("ZED_ADMIN_API_TOKEN")
+ .ok()
+ .and_then(|s| if s.is_empty() { None } else { Some(s) });
}
pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894";
@@ -919,6 +925,32 @@ impl Client {
self.establish_websocket_connection(credentials, cx)
}
+ async fn get_rpc_url(http: Arc<dyn HttpClient>) -> Result<Url> {
+ let rpc_response = http
+ .get(
+ &(format!("{}/rpc", *ZED_SERVER_URL)),
+ Default::default(),
+ false,
+ )
+ .await?;
+ if !rpc_response.status().is_redirection() {
+ Err(anyhow!(
+ "unexpected /rpc response status {}",
+ rpc_response.status()
+ ))?
+ }
+
+ let rpc_url = rpc_response
+ .headers()
+ .get("Location")
+ .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
+ .to_str()
+ .map_err(EstablishConnectionError::other)?
+ .to_string();
+
+ Url::parse(&rpc_url).context("invalid rpc url")
+ }
+
fn establish_websocket_connection(
self: &Arc<Self>,
credentials: &Credentials,
@@ -933,28 +965,7 @@ impl Client {
let http = self.http.clone();
cx.background().spawn(async move {
- let mut rpc_url = format!("{}/rpc", *ZED_SERVER_URL);
- let rpc_response = http.get(&rpc_url, Default::default(), false).await?;
- if rpc_response.status().is_redirection() {
- rpc_url = rpc_response
- .headers()
- .get("Location")
- .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
- .to_str()
- .map_err(EstablishConnectionError::other)?
- .to_string();
- }
- // Until we switch the zed.dev domain to point to the new Next.js app, there
- // will be no redirect required, and the app will connect directly to
- // wss://zed.dev/rpc.
- else if rpc_response.status() != StatusCode::UPGRADE_REQUIRED {
- Err(anyhow!(
- "unexpected /rpc response status {}",
- rpc_response.status()
- ))?
- }
-
- let mut rpc_url = Url::parse(&rpc_url).context("invalid rpc url")?;
+ let mut rpc_url = Self::get_rpc_url(http).await?;
let rpc_host = rpc_url
.host_str()
.zip(rpc_url.port_or_known_default())
@@ -997,6 +1008,7 @@ impl Client {
let platform = cx.platform();
let executor = cx.background();
let telemetry = self.telemetry.clone();
+ let http = self.http.clone();
executor.clone().spawn(async move {
// Generate a pair of asymmetric encryption keys. The public key will be used by the
// zed server to encrypt the user's access token, so that it can'be intercepted by
@@ -1006,6 +1018,10 @@ impl Client {
let public_key_string =
String::try_from(public_key).expect("failed to serialize public key for auth");
+ if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) {
+ return Self::authenticate_as_admin(http, login.clone(), token.clone()).await;
+ }
+
// Start an HTTP server to receive the redirect from Zed's sign-in page.
let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
let port = server.server_addr().port();
@@ -1084,6 +1100,49 @@ impl Client {
})
}
+ async fn authenticate_as_admin(
+ http: Arc<dyn HttpClient>,
+ login: String,
+ mut api_token: String,
+ ) -> Result<Credentials> {
+ let mut url = Self::get_rpc_url(http.clone()).await?;
+ url.set_path("/user");
+ url.set_query(Some(&format!("github_login={login}")));
+ let request = Request::get(url.as_str())
+ .header("Authorization", format!("token {api_token}"))
+ .body("".into())?;
+
+ let mut response = http.send(request).await?;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ if !response.status().is_success() {
+ Err(anyhow!(
+ "admin user request failed {} - {}",
+ response.status().as_u16(),
+ body,
+ ))?;
+ }
+
+ #[derive(Deserialize)]
+ struct AuthenticatedUserResponse {
+ user: User,
+ }
+
+ #[derive(Deserialize)]
+ struct User {
+ id: u64,
+ }
+
+ let response: AuthenticatedUserResponse = serde_json::from_str(&body)?;
+
+ api_token.insert_str(0, "ADMIN_TOKEN:");
+ Ok(Credentials {
+ user_id: response.user.id,
+ access_token: api_token,
+ })
+ }
+
pub fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
let conn_id = self.connection_id()?;
self.peer.disconnect(conn_id);
@@ -88,7 +88,7 @@ pub async fn validate_api_token<B>(req: Request<B>, next: Next<B>) -> impl IntoR
#[derive(Debug, Deserialize)]
struct AuthenticatedUserParams {
- github_user_id: i32,
+ github_user_id: Option<i32>,
github_login: String,
}
@@ -104,7 +104,7 @@ async fn get_authenticated_user(
) -> Result<Json<AuthenticatedUserResponse>> {
let user = app
.db
- .get_user_by_github_account(¶ms.github_login, Some(params.github_user_id))
+ .get_user_by_github_account(¶ms.github_login, params.github_user_id)
.await?
.ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "user not found".into()))?;
let metrics_id = app.db.get_user_metrics_id(user.id).await?;