@@ -30,6 +30,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources, SettingsStore};
use std::fmt;
+use std::pin::Pin;
use std::{
any::TypeId,
convert::TryFrom,
@@ -65,6 +66,13 @@ impl fmt::Display for DevServerToken {
lazy_static! {
static ref ZED_SERVER_URL: Option<String> = std::env::var("ZED_SERVER_URL").ok();
static ref ZED_RPC_URL: Option<String> = std::env::var("ZED_RPC_URL").ok();
+ /// An environment variable whose presence indicates that the development auth
+ /// provider should be used.
+ ///
+ /// Only works in development. Setting this environment variable in other release
+ /// channels is a no-op.
+ pub static ref ZED_DEVELOPMENT_AUTH: bool =
+ std::env::var("ZED_DEVELOPMENT_AUTH").map_or(false, |value| !value.is_empty());
pub static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
.ok()
.and_then(|s| if s.is_empty() { None } else { Some(s) });
@@ -161,6 +169,7 @@ pub struct Client {
peer: Arc<Peer>,
http: Arc<HttpClientWithUrl>,
telemetry: Arc<Telemetry>,
+ credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static>,
state: RwLock<ClientState>,
#[allow(clippy::type_complexity)]
@@ -298,6 +307,32 @@ impl Credentials {
}
}
+/// A provider for [`Credentials`].
+///
+/// Used to abstract over reading and writing credentials to some form of
+/// persistence (like the system keychain).
+trait CredentialsProvider {
+ /// Reads the credentials from the provider.
+ fn read_credentials<'a>(
+ &'a self,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>>;
+
+ /// Writes the credentials to the provider.
+ fn write_credentials<'a>(
+ &'a self,
+ user_id: u64,
+ access_token: String,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
+
+ /// Deletes the credentials from the provider.
+ fn delete_credentials<'a>(
+ &'a self,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>>;
+}
+
impl Default for ClientState {
fn default() -> Self {
Self {
@@ -443,11 +478,27 @@ impl Client {
http: Arc<HttpClientWithUrl>,
cx: &mut AppContext,
) -> Arc<Self> {
+ let use_zed_development_auth = match ReleaseChannel::try_global(cx) {
+ Some(ReleaseChannel::Dev) => *ZED_DEVELOPMENT_AUTH,
+ Some(ReleaseChannel::Nightly | ReleaseChannel::Preview | ReleaseChannel::Stable)
+ | None => false,
+ };
+
+ let credentials_provider: Arc<dyn CredentialsProvider + Send + Sync + 'static> =
+ if use_zed_development_auth {
+ Arc::new(DevelopmentCredentialsProvider {
+ path: util::paths::CONFIG_DIR.join("development_auth"),
+ })
+ } else {
+ Arc::new(KeychainCredentialsProvider)
+ };
+
Arc::new(Self {
id: AtomicU64::new(0),
peer: Peer::new(0),
telemetry: Telemetry::new(clock, http.clone(), cx),
http,
+ credentials_provider,
state: Default::default(),
#[cfg(any(test, feature = "test-support"))]
@@ -763,8 +814,11 @@ impl Client {
}
}
- pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
- read_credentials_from_keychain(cx).await.is_some()
+ pub async fn has_credentials(&self, cx: &AsyncAppContext) -> bool {
+ self.credentials_provider
+ .read_credentials(cx)
+ .await
+ .is_some()
}
pub fn set_dev_server_token(&self, token: DevServerToken) -> &Self {
@@ -775,7 +829,7 @@ impl Client {
#[async_recursion(?Send)]
pub async fn authenticate_and_connect(
self: &Arc<Self>,
- try_keychain: bool,
+ try_provider: bool,
cx: &AsyncAppContext,
) -> anyhow::Result<()> {
let was_disconnected = match *self.status().borrow() {
@@ -796,12 +850,13 @@ impl Client {
self.set_status(Status::Reauthenticating, cx)
}
- let mut read_from_keychain = false;
+ let mut read_from_provider = false;
let mut credentials = self.state.read().credentials.clone();
- if credentials.is_none() && try_keychain {
- credentials = read_credentials_from_keychain(cx).await;
- read_from_keychain = credentials.is_some();
+ if credentials.is_none() && try_provider {
+ credentials = self.credentials_provider.read_credentials(cx).await;
+ read_from_provider = credentials.is_some();
}
+
if credentials.is_none() {
let mut status_rx = self.status();
let _ = status_rx.next().await;
@@ -838,9 +893,9 @@ impl Client {
match connection {
Ok(conn) => {
self.state.write().credentials = Some(credentials.clone());
- if !read_from_keychain && IMPERSONATE_LOGIN.is_none() {
+ if !read_from_provider && IMPERSONATE_LOGIN.is_none() {
if let Credentials::User{user_id, access_token} = credentials {
- write_credentials_to_keychain(user_id, access_token, cx).await.log_err();
+ self.credentials_provider.write_credentials(user_id, access_token, cx).await.log_err();
}
}
@@ -854,8 +909,8 @@ impl Client {
}
Err(EstablishConnectionError::Unauthorized) => {
self.state.write().credentials.take();
- if read_from_keychain {
- delete_credentials_from_keychain(cx).await.log_err();
+ if read_from_provider {
+ self.credentials_provider.delete_credentials(cx).await.log_err();
self.set_status(Status::SignedOut, cx);
self.authenticate_and_connect(false, cx).await
} else {
@@ -1264,8 +1319,11 @@ impl Client {
self.state.write().credentials = None;
self.disconnect(&cx);
- if self.has_keychain_credentials(cx).await {
- delete_credentials_from_keychain(cx).await.log_err();
+ if self.has_credentials(cx).await {
+ self.credentials_provider
+ .delete_credentials(cx)
+ .await
+ .log_err();
}
}
@@ -1465,41 +1523,128 @@ impl Client {
}
}
-async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
- if IMPERSONATE_LOGIN.is_some() {
- return None;
+#[derive(Serialize, Deserialize)]
+struct DevelopmentCredentials {
+ user_id: u64,
+ access_token: String,
+}
+
+/// A credentials provider that stores credentials in a local file.
+///
+/// This MUST only be used in development, as this is not a secure way of storing
+/// credentials on user machines.
+///
+/// Its existence is purely to work around the annoyance of having to constantly
+/// re-allow access to the system keychain when developing Zed.
+struct DevelopmentCredentialsProvider {
+ path: PathBuf,
+}
+
+impl CredentialsProvider for DevelopmentCredentialsProvider {
+ fn read_credentials<'a>(
+ &'a self,
+ _cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
+ async move {
+ if IMPERSONATE_LOGIN.is_some() {
+ return None;
+ }
+
+ let json = std::fs::read(&self.path).log_err()?;
+
+ let credentials: DevelopmentCredentials = serde_json::from_slice(&json).log_err()?;
+
+ Some(Credentials::User {
+ user_id: credentials.user_id,
+ access_token: credentials.access_token,
+ })
+ }
+ .boxed_local()
}
- let (user_id, access_token) = cx
- .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
- .log_err()?
- .await
- .log_err()??;
+ fn write_credentials<'a>(
+ &'a self,
+ user_id: u64,
+ access_token: String,
+ _cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move {
+ let json = serde_json::to_string(&DevelopmentCredentials {
+ user_id,
+ access_token,
+ })?;
- Some(Credentials::User {
- user_id: user_id.parse().ok()?,
- access_token: String::from_utf8(access_token).ok()?,
- })
-}
+ std::fs::write(&self.path, json)?;
-async fn write_credentials_to_keychain(
- user_id: u64,
- access_token: String,
- cx: &AsyncAppContext,
-) -> Result<()> {
- cx.update(move |cx| {
- cx.write_credentials(
- &ClientSettings::get_global(cx).server_url,
- &user_id.to_string(),
- access_token.as_bytes(),
- )
- })?
- .await
+ Ok(())
+ }
+ .boxed_local()
+ }
+
+ fn delete_credentials<'a>(
+ &'a self,
+ _cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move { Ok(std::fs::remove_file(&self.path)?) }.boxed_local()
+ }
}
-async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> {
- cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
- .await
+/// A credentials provider that stores credentials in the system keychain.
+struct KeychainCredentialsProvider;
+
+impl CredentialsProvider for KeychainCredentialsProvider {
+ fn read_credentials<'a>(
+ &'a self,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Option<Credentials>> + 'a>> {
+ async move {
+ if IMPERSONATE_LOGIN.is_some() {
+ return None;
+ }
+
+ let (user_id, access_token) = cx
+ .update(|cx| cx.read_credentials(&ClientSettings::get_global(cx).server_url))
+ .log_err()?
+ .await
+ .log_err()??;
+
+ Some(Credentials::User {
+ user_id: user_id.parse().ok()?,
+ access_token: String::from_utf8(access_token).ok()?,
+ })
+ }
+ .boxed_local()
+ }
+
+ fn write_credentials<'a>(
+ &'a self,
+ user_id: u64,
+ access_token: String,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move {
+ cx.update(move |cx| {
+ cx.write_credentials(
+ &ClientSettings::get_global(cx).server_url,
+ &user_id.to_string(),
+ access_token.as_bytes(),
+ )
+ })?
+ .await
+ }
+ .boxed_local()
+ }
+
+ fn delete_credentials<'a>(
+ &'a self,
+ cx: &'a AsyncAppContext,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ async move {
+ cx.update(move |cx| cx.delete_credentials(&ClientSettings::get_global(cx).server_url))?
+ .await
+ }
+ .boxed_local()
+ }
}
/// prefix for the zed:// url scheme