Detailed changes
@@ -9212,6 +9212,7 @@ dependencies = [
"vercel",
"workspace-hack",
"x_ai",
+ "zed_env_vars",
]
[[package]]
@@ -20677,6 +20678,7 @@ dependencies = [
name = "zed_env_vars"
version = "0.1.0"
dependencies = [
+ "gpui",
"workspace-hack",
]
@@ -44,8 +44,12 @@ impl AgentServer for Gemini {
cx.spawn(async move |cx| {
let mut extra_env = HashMap::default();
- if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() {
- extra_env.insert("GEMINI_API_KEY".into(), api_key.key);
+ if let Some(api_key) = cx
+ .update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
+ .await
+ .ok()
+ {
+ extra_env.insert("GEMINI_API_KEY".into(), api_key);
}
let (mut command, root_dir, login) = store
.update(cx, |store, cx| {
@@ -61,6 +61,7 @@ util.workspace = true
vercel = { workspace = true, features = ["schemars"] }
workspace-hack.workspace = true
x_ai = { workspace = true, features = ["schemars"] }
+zed_env_vars.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }
@@ -0,0 +1,295 @@
+use anyhow::{Result, anyhow};
+use credentials_provider::CredentialsProvider;
+use futures::{FutureExt, future};
+use gpui::{AsyncApp, Context, SharedString, Task};
+use language_model::AuthenticateError;
+use std::{
+ fmt::{Display, Formatter},
+ sync::Arc,
+};
+use util::ResultExt as _;
+use zed_env_vars::EnvVar;
+
+/// Manages a single API key for a language model provider. API keys either come from environment
+/// variables or the system keychain.
+///
+/// Keys from the system keychain are associated with a provider URL, and this ensures that they are
+/// only used with that URL.
+pub struct ApiKeyState {
+ url: SharedString,
+ load_status: LoadStatus,
+ load_task: Option<future::Shared<Task<()>>>,
+}
+
+#[derive(Debug, Clone)]
+pub enum LoadStatus {
+ NotPresent,
+ Error(String),
+ Loaded(ApiKey),
+}
+
+#[derive(Debug, Clone)]
+pub struct ApiKey {
+ source: ApiKeySource,
+ key: Arc<str>,
+}
+
+impl ApiKeyState {
+ pub fn new(url: SharedString) -> Self {
+ Self {
+ url,
+ load_status: LoadStatus::NotPresent,
+ load_task: None,
+ }
+ }
+
+ pub fn has_key(&self) -> bool {
+ matches!(self.load_status, LoadStatus::Loaded { .. })
+ }
+
+ pub fn is_from_env_var(&self) -> bool {
+ match &self.load_status {
+ LoadStatus::Loaded(ApiKey {
+ source: ApiKeySource::EnvVar { .. },
+ ..
+ }) => true,
+ _ => false,
+ }
+ }
+
+ /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if
+ /// there is no key or for URL mismatches, and the mismatch case is logged.
+ ///
+ /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been
+ /// called with this URL.
+ pub fn key(&self, url: &str) -> Option<Arc<str>> {
+ let api_key = match &self.load_status {
+ LoadStatus::Loaded(api_key) => api_key,
+ _ => return None,
+ };
+ if url == self.url.as_str() {
+ Some(api_key.key.clone())
+ } else if let ApiKeySource::EnvVar(var_name) = &api_key.source {
+ log::warn!(
+ "{} is now being used with URL {}, when initially it was used with URL {}",
+ var_name,
+ url,
+ self.url
+ );
+ Some(api_key.key.clone())
+ } else {
+ // bug case because load_if_needed should be called whenever the url may have changed
+ log::error!(
+ "bug: Attempted to use API key associated with URL {} instead with URL {}",
+ self.url,
+ url
+ );
+ None
+ }
+ }
+
+ /// Set or delete the API key in the system keychain.
+ pub fn store<Ent: 'static>(
+ &mut self,
+ url: SharedString,
+ key: Option<String>,
+ get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+ cx: &Context<Ent>,
+ ) -> Task<Result<()>> {
+ if self.is_from_env_var() {
+ return Task::ready(Err(anyhow!(
+ "bug: attempted to store API key in system keychain when API key is from env var",
+ )));
+ }
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ cx.spawn(async move |ent, cx| {
+ if let Some(key) = &key {
+ credentials_provider
+ .write_credentials(&url, "Bearer", key.as_bytes(), cx)
+ .await
+ .log_err();
+ } else {
+ credentials_provider
+ .delete_credentials(&url, cx)
+ .await
+ .log_err();
+ }
+ ent.update(cx, |ent, cx| {
+ let this = get_this(ent);
+ this.url = url;
+ this.load_status = match &key {
+ Some(key) => LoadStatus::Loaded(ApiKey {
+ source: ApiKeySource::SystemKeychain,
+ key: key.as_str().into(),
+ }),
+ None => LoadStatus::NotPresent,
+ };
+ cx.notify();
+ })
+ })
+ }
+
+ /// Reloads the API key if the current API key is associated with a different URL.
+ ///
+ /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs
+ /// interchangeably - URL change should correspond to some user initiated change.
+ pub fn handle_url_change<Ent: 'static>(
+ &mut self,
+ url: SharedString,
+ env_var: &EnvVar,
+ get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+ cx: &mut Context<Ent>,
+ ) {
+ if url != self.url {
+ if !self.is_from_env_var() {
+ // loading will continue even though this result task is dropped
+ let _task = self.load_if_needed(url, env_var, get_this, cx);
+ }
+ }
+ }
+
+ /// If needed, loads the API key associated with the given URL from the system keychain. When a
+ /// non-empty environment variable is provided, it will be used instead. If called when an API
+ /// key was already loaded for a different URL, that key will be cleared before loading.
+ ///
+ /// Dropping the returned Task does not cancel key loading.
+ pub fn load_if_needed<Ent: 'static>(
+ &mut self,
+ url: SharedString,
+ env_var: &EnvVar,
+ get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static,
+ cx: &mut Context<Ent>,
+ ) -> Task<Result<(), AuthenticateError>> {
+ if let LoadStatus::Loaded { .. } = &self.load_status
+ && self.url == url
+ {
+ return Task::ready(Ok(()));
+ }
+
+ if let Some(key) = &env_var.value
+ && !key.is_empty()
+ {
+ let api_key = ApiKey::from_env(env_var.name.clone(), key);
+ self.url = url;
+ self.load_status = LoadStatus::Loaded(api_key);
+ self.load_task = None;
+ cx.notify();
+ return Task::ready(Ok(()));
+ }
+
+ let task = if let Some(load_task) = &self.load_task {
+ load_task.clone()
+ } else {
+ let load_task = Self::load(url.clone(), get_this.clone(), cx).shared();
+ self.url = url;
+ self.load_status = LoadStatus::NotPresent;
+ self.load_task = Some(load_task.clone());
+ cx.notify();
+ load_task
+ };
+
+ cx.spawn(async move |ent, cx| {
+ task.await;
+ ent.update(cx, |ent, _cx| {
+ get_this(ent).load_status.clone().into_authenticate_result()
+ })
+ .ok();
+ Ok(())
+ })
+ }
+
+ fn load<Ent: 'static>(
+ url: SharedString,
+ get_this: impl Fn(&mut Ent) -> &mut Self + 'static,
+ cx: &Context<Ent>,
+ ) -> Task<()> {
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ cx.spawn({
+ async move |ent, cx| {
+ let load_status =
+ ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx)
+ .await;
+ ent.update(cx, |ent, cx| {
+ let this = get_this(ent);
+ this.url = url;
+ this.load_status = load_status;
+ this.load_task = None;
+ cx.notify();
+ })
+ .ok();
+ }
+ })
+ }
+}
+
+impl ApiKey {
+ pub fn key(&self) -> &str {
+ &self.key
+ }
+
+ pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
+ Self {
+ source: ApiKeySource::EnvVar(env_var_name),
+ key: key.into(),
+ }
+ }
+
+ pub async fn load_from_system_keychain(
+ url: &str,
+ credentials_provider: &dyn CredentialsProvider,
+ cx: &AsyncApp,
+ ) -> Result<Self, AuthenticateError> {
+ Self::load_from_system_keychain_impl(url, credentials_provider, cx)
+ .await
+ .into_authenticate_result()
+ }
+
+ async fn load_from_system_keychain_impl(
+ url: &str,
+ credentials_provider: &dyn CredentialsProvider,
+ cx: &AsyncApp,
+ ) -> LoadStatus {
+ if url.is_empty() {
+ return LoadStatus::NotPresent;
+ }
+ let read_result = credentials_provider.read_credentials(&url, cx).await;
+ let api_key = match read_result {
+ Ok(Some((_, api_key))) => api_key,
+ Ok(None) => return LoadStatus::NotPresent,
+ Err(err) => return LoadStatus::Error(err.to_string()),
+ };
+ let key = match str::from_utf8(&api_key) {
+ Ok(key) => key,
+ Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")),
+ };
+ LoadStatus::Loaded(Self {
+ source: ApiKeySource::SystemKeychain,
+ key: key.into(),
+ })
+ }
+}
+
+impl LoadStatus {
+ fn into_authenticate_result(self) -> Result<ApiKey, AuthenticateError> {
+ match self {
+ LoadStatus::Loaded(api_key) => Ok(api_key),
+ LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound),
+ LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+enum ApiKeySource {
+ EnvVar(SharedString),
+ SystemKeychain,
+}
+
+impl Display for ApiKeySource {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var),
+ ApiKeySource::SystemKeychain => write!(f, "system keychain"),
+ }
+ }
+}
@@ -7,6 +7,7 @@ use gpui::{App, Context, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
+mod api_key;
pub mod provider;
mod settings;
pub mod ui;
@@ -1,18 +1,15 @@
-use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
+use crate::{AllLanguageModelSettings, api_key::ApiKeyState};
use anthropic::{
AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
ToolResultPart, Usage,
};
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::Result;
use collections::{BTreeMap, HashMap};
-use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{
- AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
-};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, FontStyle, Task, TextStyle, WhiteSpace};
use http_client::HttpClient;
use language_model::{
AuthenticateError, ConfigurationViewTargetAgent, LanguageModel,
@@ -27,11 +24,12 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
@@ -97,91 +95,52 @@ pub struct AnthropicLanguageModelProvider {
state: gpui::Entity<State>,
}
-const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
+const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
- _subscription: Subscription,
+ api_key_state: ApiKeyState,
}
impl State {
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .anthropic
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .ok();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
- }
-
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .anthropic
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await
- .ok();
-
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
- }
-
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ self.api_key_state.has_key()
}
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- let key = AnthropicLanguageModelProvider::api_key(cx);
-
- cx.spawn(async move |this, cx| {
- let key = key.await?;
-
- this.update(cx, |this, cx| {
- this.api_key = Some(key.key);
- this.api_key_from_env = key.from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
-}
-pub struct ApiKey {
- pub key: String,
- pub from_env: bool,
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
+ }
}
impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
});
Self { http_client, state }
@@ -197,30 +156,16 @@ impl AnthropicLanguageModelProvider {
})
}
- pub fn api_key(cx: &mut App) -> Task<Result<ApiKey, AuthenticateError>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .anthropic
- .api_url
- .clone();
-
- if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) {
- Task::ready(Ok(ApiKey {
- key,
- from_env: true,
- }))
+ fn settings(cx: &App) -> &AnthropicSettings {
+ &AllLanguageModelSettings::get_global(cx).anthropic
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ anthropic::ANTHROPIC_API_URL.into()
} else {
- cx.spawn(async move |cx| {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
-
- Ok(ApiKey {
- key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- from_env: false,
- })
- })
+ SharedString::new(api_url.as_str())
}
}
}
@@ -327,7 +272,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -417,11 +363,16 @@ impl AnthropicModel {
> {
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(&api_url);
+ (api_key, api_url)
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err.into())).boxed();
+ }
};
let beta_headers = self.model.beta_headers();
@@ -483,7 +434,10 @@ impl LanguageModel for AnthropicModel {
}
fn api_key(&self, cx: &App) -> Option<String> {
- self.state.read(cx).api_key.clone()
+ self.state.read_with(cx, |state, cx| {
+ let api_url = AnthropicLanguageModelProvider::api_url(cx);
+ state.api_key_state.key(&api_url).map(|key| key.to_string())
+ })
}
fn max_token_count(&self) -> u64 {
@@ -987,12 +941,10 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -1001,11 +953,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -1040,7 +992,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -1079,7 +1031,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
- format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."),
+ format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small)
.color(Color::Muted),
@@ -1099,7 +1051,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
} else {
"API key configured.".to_string()
})),
@@ -1112,7 +1064,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable.")))
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
@@ -1,12 +1,11 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
-use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
- AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle,
- WhiteSpace,
+ AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
+ Window,
};
use http_client::HttpClient;
use language_model::{
@@ -21,16 +20,19 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use theme::ThemeSettings;
use ui::{Icon, IconName, List, prelude::*};
use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
-use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
-const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
+
+const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default)]
struct RawToolCall {
@@ -59,95 +61,48 @@ pub struct DeepSeekLanguageModelProvider {
}
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
- _subscription: Subscription,
+ api_key_state: ApiKeyState,
}
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
- }
-
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .deepseek
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
+ self.api_key_state.has_key()
}
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .deepseek
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await?;
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = SharedString::new(DeepSeekLanguageModelProvider::api_url(cx));
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .deepseek
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
-
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = SharedString::new(DeepSeekLanguageModelProvider::api_url(cx));
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
}
}
impl DeepSeekLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = SharedString::new(Self::api_url(cx));
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(SharedString::new(Self::api_url(cx))),
+ }
});
Self { http_client, state }
@@ -160,7 +115,15 @@ impl DeepSeekLanguageModelProvider {
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
- }) as Arc<dyn LanguageModel>
+ })
+ }
+
+ fn settings(cx: &App) -> &DeepSeekSettings {
+ &AllLanguageModelSettings::get_global(cx).deepseek
+ }
+
+ fn api_url(cx: &App) -> &str {
+ &Self::settings(cx).api_url
}
}
@@ -199,11 +162,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
models.insert("deepseek-chat", deepseek::Model::Chat);
models.insert("deepseek-reasoner", deepseek::Model::Reasoner);
- for available_model in AllLanguageModelSettings::get_global(cx)
- .deepseek
- .available_models
- .iter()
- {
+ for available_model in &Self::settings(cx).available_models {
models.insert(
&available_model.name,
deepseek::Model::Custom {
@@ -240,7 +199,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -259,15 +219,25 @@ impl DeepSeekLanguageModel {
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<deepseek::StreamResponse>>>> {
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).deepseek;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = DeepSeekLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(api_url);
+ (api_key, api_url.to_string())
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err)).boxed();
+ }
};
let future = self.request_limiter.stream(async move {
- let api_key = api_key.context("Missing DeepSeek API Key")?;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
let request =
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@@ -610,7 +580,7 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self.api_key_editor.read(cx).text(cx);
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
@@ -618,12 +588,10 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn(async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -631,10 +599,12 @@ impl ConfigurationView {
.update(cx, |editor, cx| editor.set_text("", window, cx));
let state = self.state.clone();
- cx.spawn(async move |_, cx| state.update(cx, |state, cx| state.reset_api_key(cx))?.await)
- .detach_and_log_err(cx);
-
- cx.notify();
+ cx.spawn(async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
+ })
+ .detach_and_log_err(cx);
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -672,7 +642,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -706,8 +676,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(format!(
- "Or set the {} environment variable.",
- DEEPSEEK_API_KEY_VAR
+ "Or set the {API_KEY_ENV_VAR_NAME} environment variable."
))
.size(LabelSize::Small)
.color(Color::Muted),
@@ -727,7 +696,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {}", DEEPSEEK_API_KEY_VAR)
+ format!("API key set in {API_KEY_ENV_VAR_NAME}")
} else {
"API key configured".to_string()
})),
@@ -1,4 +1,4 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Context as _, Result};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
@@ -8,7 +8,8 @@ use google_ai::{
ThinkingConfig, UsageMetadata,
};
use gpui::{
- AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
+ AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
+ Window,
};
use http_client::HttpClient;
use language_model::{
@@ -26,18 +27,18 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::sync::{
- Arc,
+ Arc, LazyLock,
atomic::{self, AtomicU64},
};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
+use zed_env_vars::EnvVar;
-use crate::AllLanguageModelSettings;
+use crate::api_key::ApiKey;
use crate::ui::InstructionListItem;
-
-use super::anthropic::ApiKey;
+use crate::{AllLanguageModelSettings, api_key::ApiKeyState};
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
@@ -91,101 +92,56 @@ pub struct GoogleLanguageModelProvider {
}
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
- _subscription: Subscription,
+ api_key_state: ApiKeyState,
}
-const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY";
-const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
+const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
+const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
+
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
+ // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
+ EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
+});
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ self.api_key_state.has_key()
}
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .google
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .google
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await?;
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
- }
-
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .google
- .api_url
- .clone();
-
- cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
- (api_key, true)
- } else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
-
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
}
}
impl GoogleLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
});
Self { http_client, state }
@@ -201,30 +157,28 @@ impl GoogleLanguageModelProvider {
})
}
- pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
+ pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
+ if let Some(key) = API_KEY_ENV_VAR.value.clone() {
+ return Task::ready(Ok(key));
+ }
let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .google
- .api_url
- .clone();
-
- if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) {
- Task::ready(Ok(ApiKey {
- key,
- from_env: true,
- }))
- } else {
- cx.spawn(async move |cx| {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
+ let api_url = Self::api_url(cx).to_string();
+ cx.spawn(async move |cx| {
+ Ok(
+ ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
.await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
+ .key()
+ .to_string(),
+ )
+ })
+ }
- Ok(ApiKey {
- key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- from_env: false,
- })
- })
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &AllLanguageModelSettings::get_global(cx).google.api_url;
+ if api_url.is_empty() {
+ google_ai::API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
}
}
}
@@ -317,7 +271,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -340,11 +295,16 @@ impl GoogleLanguageModel {
> {
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).google;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(&api_url);
+ (api_key, api_url)
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err)).boxed();
+ }
};
async move {
@@ -418,13 +378,16 @@ impl LanguageModel for GoogleLanguageModel {
let model_id = self.model.request_id().to_string();
let request = into_google(request, model_id, self.model.mode());
let http_client = self.http_client.clone();
- let api_key = self.state.read(cx).api_key.clone();
-
- let settings = &AllLanguageModelSettings::get_global(cx).google;
- let api_url = settings.api_url.clone();
+ let api_url = GoogleLanguageModelProvider::api_url(cx);
+ let api_key = self.state.read(cx).api_key_state.key(&api_url);
async move {
- let api_key = api_key.context("Missing Google API key")?;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ }
+ .into());
+ };
let response = google_ai::count_tokens(
http_client.as_ref(),
&api_url,
@@ -852,7 +815,7 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self.api_key_editor.read(cx).text(cx);
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
@@ -860,12 +823,10 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -874,11 +835,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -913,7 +874,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -950,7 +911,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
- format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."),
+ format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
@@ -969,7 +930,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {GEMINI_API_KEY_VAR} environment variable.")
+ format!("API key set in {GEMINI_API_KEY_VAR_NAME} environment variable.")
} else {
"API key configured.".to_string()
})),
@@ -982,7 +943,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset.")))
+ this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
@@ -1,10 +1,10 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use collections::BTreeMap;
-use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
- AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
+ AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
+ Window,
};
use http_client::HttpClient;
use language_model::{
@@ -21,17 +21,21 @@ use settings::{Settings, SettingsStore};
use std::collections::HashMap;
use std::pin::Pin;
use std::str::FromStr;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
-use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
+const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
pub api_url: String,
@@ -56,96 +60,48 @@ pub struct MistralLanguageModelProvider {
}
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
- _subscription: Subscription,
+ api_key_state: ApiKeyState,
}
-const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY";
-
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
- }
-
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .mistral
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
+ self.api_key_state.has_key()
}
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .mistral
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await?;
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = MistralLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .mistral
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = MistralLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
}
}
impl MistralLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
});
Self { http_client, state }
@@ -160,6 +116,19 @@ impl MistralLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
+
+ fn settings(cx: &App) -> &MistralSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).mistral
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ mistral::MISTRAL_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
}
impl LanguageModelProviderState for MistralLanguageModelProvider {
@@ -202,10 +171,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
}
// Override with available models from settings
- for model in &AllLanguageModelSettings::get_global(cx)
- .mistral
- .available_models
- {
+ for model in &Self::settings(cx).available_models {
models.insert(
model.name.clone(),
mistral::Model::Custom {
@@ -254,7 +220,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -276,15 +243,25 @@ impl MistralLanguageModel {
Result<futures::stream::BoxStream<'static, Result<mistral::StreamResponse>>>,
> {
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).mistral;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = MistralLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(&api_url);
+ (api_key, api_url)
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err)).boxed();
+ }
};
let future = self.request_limiter.stream(async move {
- let api_key = api_key.context("Missing Mistral API Key")?;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
let request =
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@@ -780,7 +757,7 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self.api_key_editor.read(cx).text(cx);
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
@@ -788,12 +765,10 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -802,11 +777,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -841,7 +816,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -878,7 +853,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
- format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."),
+ format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
@@ -897,7 +872,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.")
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
} else {
"API key configured.".to_string()
})),
@@ -910,7 +885,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable.")))
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
@@ -1,10 +1,8 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use collections::{BTreeMap, HashMap};
-use credentials_provider::CredentialsProvider;
-
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -20,18 +18,21 @@ use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr as _;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
-
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
-use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
+const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
pub api_url: String,
@@ -54,132 +55,48 @@ pub struct OpenAiLanguageModelProvider {
}
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
- last_api_url: String,
- _subscription: Subscription,
+ api_key_state: ApiKeyState,
}
-const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
-
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
- }
-
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .openai
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
- }
-
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .openai
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
+ self.api_key_state.has_key()
}
- fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .openai
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = OpenAiLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- self.get_api_key(cx)
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = OpenAiLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
}
}
impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let initial_api_url = AllLanguageModelSettings::get_global(cx)
- .openai
- .api_url
- .clone();
-
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- last_api_url: initial_api_url.clone(),
- _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
- let current_api_url = AllLanguageModelSettings::get_global(cx)
- .openai
- .api_url
- .clone();
-
- if this.last_api_url != current_api_url {
- this.last_api_url = current_api_url;
- if !this.api_key_from_env {
- this.api_key = None;
- let spawn_task = cx.spawn(async move |handle, cx| {
- if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
- if let Err(_) = task.await {
- handle
- .update(cx, |this, _| {
- this.api_key = None;
- this.api_key_from_env = false;
- })
- .ok();
- }
- }
- });
- spawn_task.detach();
- }
- }
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
});
Self { http_client, state }
@@ -194,6 +111,19 @@ impl OpenAiLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
+
+ fn settings(cx: &App) -> &OpenAiSettings {
+ &AllLanguageModelSettings::get_global(cx).openai
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ open_ai::OPEN_AI_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
}
impl LanguageModelProviderState for OpenAiLanguageModelProvider {
@@ -278,7 +208,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -298,11 +229,17 @@ impl OpenAiLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).openai;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = OpenAiLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(&api_url);
+ (api_key, api_url)
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err)).boxed();
+ }
};
let future = self.request_limiter.stream(async move {
@@ -802,29 +739,18 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self
- .api_key_editor
- .read(cx)
- .editor()
- .read(cx)
- .text(cx)
- .trim()
- .to_string();
-
- // Don't proceed if no API key is provided and we're not authenticated
- if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -836,11 +762,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
@@ -850,7 +776,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -872,10 +798,11 @@ impl Render for ConfigurationView {
)
.child(self.api_key_editor.clone())
.child(
- Label::new(
- format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
- )
- .size(LabelSize::Small).color(Color::Muted),
+ Label::new(format!(
+ "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
+ ))
+ .size(LabelSize::Small)
+ .color(Color::Muted),
)
.child(
Label::new(
@@ -898,7 +825,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
} else {
"API key configured.".to_string()
})),
@@ -911,7 +838,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
@@ -1,9 +1,7 @@
-use anyhow::{Context as _, Result, anyhow};
-use credentials_provider::CredentialsProvider;
-
+use anyhow::Result;
use convert_case::{Case, Casing};
use futures::{FutureExt, StreamExt, future::BoxFuture};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -17,13 +15,13 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
-
use ui::{ElevationIndex, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
+use zed_env_vars::EnvVar;
-use crate::AllLanguageModelSettings;
use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
+use crate::{AllLanguageModelSettings, api_key::ApiKeyState};
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiCompatibleSettings {
@@ -70,82 +68,30 @@ pub struct OpenAiCompatibleLanguageModelProvider {
pub struct State {
id: Arc<str>,
- env_var_name: Arc<str>,
- api_key: Option<String>,
- api_key_from_env: bool,
+ api_key_env_var: EnvVar,
+ api_key_state: ApiKeyState,
settings: OpenAiCompatibleSettings,
- _subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
- }
-
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = self.settings.api_url.clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
- }
-
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = self.settings.api_url.clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
+ self.api_key_state.has_key()
}
- fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let env_var_name = self.env_var_name.clone();
- let api_url = self.settings.api_url.clone();
- cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = SharedString::new(self.settings.api_url.as_str());
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- self.get_api_key(cx)
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = SharedString::new(self.settings.api_url.clone());
+ self.api_key_state.load_if_needed(
+ api_url,
+ &self.api_key_env_var,
+ |this| &mut this.api_key_state,
+ cx,
+ )
}
}
@@ -157,37 +103,32 @@ impl OpenAiCompatibleLanguageModelProvider {
.get(id)
}
- let state = cx.new(|cx| State {
- id: id.clone(),
- env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
- settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into();
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
return;
};
if &this.settings != &settings {
- if settings.api_url != this.settings.api_url && !this.api_key_from_env {
- let spawn_task = cx.spawn(async move |handle, cx| {
- if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
- if let Err(_) = task.await {
- handle
- .update(cx, |this, _| {
- this.api_key = None;
- this.api_key_from_env = false;
- })
- .ok();
- }
- }
- });
- spawn_task.detach();
- }
-
+ let api_url = SharedString::new(settings.api_url.as_str());
+ this.api_key_state.handle_url_change(
+ api_url,
+ &this.api_key_env_var,
+ |this| &mut this.api_key_state,
+ cx,
+ );
this.settings = settings;
cx.notify();
}
- }),
+ })
+ .detach();
+ let settings = resolve_settings(&id, cx).cloned().unwrap_or_default();
+ State {
+ id: id.clone(),
+ api_key_env_var: EnvVar::new(api_key_env_var_name),
+ api_key_state: ApiKeyState::new(SharedString::new(settings.api_url.as_str())),
+ settings,
+ }
});
Self {
@@ -274,7 +215,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -296,10 +238,17 @@ impl OpenAiCompatibleLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
- (state.api_key.clone(), state.settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+
+ let api_key_and_url = self.state.read_with(cx, |state, _cx| {
+ let api_url = &state.settings.api_url;
+ let api_key = state.api_key_state.key(api_url);
+ (api_key, state.settings.api_url.clone())
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err)).boxed();
+ }
};
let provider = self.provider_name.clone();
@@ -469,29 +418,18 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self
- .api_key_editor
- .read(cx)
- .editor()
- .read(cx)
- .text(cx)
- .trim()
- .to_string();
-
- // Don't proceed if no API key is provided and we're not authenticated
- if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -503,22 +441,23 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
- fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
+ fn should_render_editor(&self, cx: &Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
- let env_var_name = self.state.read(cx).env_var_name.clone();
+ let state = self.state.read(cx);
+ let env_var_set = state.api_key_state.is_from_env_var();
+ let env_var_name = &state.api_key_env_var.name;
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -1,10 +1,9 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use collections::HashMap;
-use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use gpui::{
- AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
+ AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
};
use http_client::HttpClient;
use language_model::{
@@ -16,23 +15,26 @@ use language_model::{
};
use open_router::{
Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models,
- stream_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr as _;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
+use zed_env_vars::{EnvVar, env_var};
-use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
+const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenRouterSettings {
pub api_url: String,
@@ -90,93 +92,38 @@ pub struct OpenRouterLanguageModelProvider {
}
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
+ api_key_state: ApiKeyState,
http_client: Arc<dyn HttpClient>,
available_models: Vec<open_router::Model>,
fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
settings: OpenRouterSettings,
- _subscription: Subscription,
}
-const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY";
-
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
- }
-
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .open_router
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
+ self.api_key_state.has_key()
}
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .open_router
- .api_url
- .clone();
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.restart_fetch_models_task(cx);
- cx.notify();
- })
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let api_url = AllLanguageModelSettings::get_global(cx)
- .open_router
- .api_url
- .clone();
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ let task = self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
-
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- this.restart_fetch_models_task(cx);
- cx.notify();
- })?;
-
- Ok(())
+ let result = task.await;
+ this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
+ .ok();
+ result
})
}
@@ -184,10 +131,9 @@ impl State {
&mut self,
cx: &mut Context<Self>,
) -> Task<Result<(), LanguageModelCompletionError>> {
- let settings = &AllLanguageModelSettings::get_global(cx).open_router;
let http_client = self.http_client.clone();
- let api_url = settings.api_url.clone();
- let Some(api_key) = self.api_key.clone() else {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ let Some(api_key) = self.api_key_state.key(&api_url) else {
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
}));
@@ -216,33 +162,45 @@ impl State {
if self.is_authenticated() {
let task = self.fetch_models(cx);
self.fetch_models_task.replace(task);
+ } else {
+ self.available_models = Vec::new();
}
}
}
impl OpenRouterLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- http_client: http_client.clone(),
- available_models: Vec::new(),
- fetch_models_task: None,
- settings: OpenRouterSettings::default(),
- _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
let current_settings = &AllLanguageModelSettings::get_global(cx).open_router;
let settings_changed = current_settings != &this.settings;
if settings_changed {
this.settings = current_settings.clone();
- this.restart_fetch_models_task(cx);
+ this.authenticate(cx).detach();
}
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ http_client: http_client.clone(),
+ available_models: Vec::new(),
+ fetch_models_task: None,
+ settings: OpenRouterSettings::default(),
+ }
});
Self { http_client, state }
}
+ fn settings(cx: &App) -> &OpenRouterSettings {
+ &AllLanguageModelSettings::get_global(cx).open_router
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ SharedString::new(Self::settings(cx).api_url.as_str())
+ }
+
fn create_language_model(&self, model: open_router::Model) -> Arc<dyn LanguageModel> {
Arc::new(OpenRouterLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
@@ -287,10 +245,7 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
let mut models_from_api = self.state.read(cx).available_models.clone();
let mut settings_models = Vec::new();
- for model in &AllLanguageModelSettings::get_global(cx)
- .open_router
- .available_models
- {
+ for model in &Self::settings(cx).available_models {
settings_models.push(open_router::Model {
name: model.name.clone(),
display_name: model.display_name.clone(),
@@ -338,7 +293,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -366,14 +322,17 @@ impl OpenRouterLanguageModel {
>,
> {
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).open_router;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
- "App state dropped"
- ))))
- .boxed();
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = OpenRouterLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(&api_url);
+ (api_key, api_url)
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(LanguageModelCompletionError::Other(err)))
+ .boxed();
+ }
};
async move {
@@ -382,7 +341,8 @@ impl OpenRouterLanguageModel {
provider: PROVIDER_NAME,
});
};
- let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let request =
+ open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.map_err(Into::into)
}
.boxed()
@@ -830,7 +790,7 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self.api_key_editor.read(cx).text(cx);
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
if api_key.is_empty() {
return;
}
@@ -838,12 +798,10 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -852,11 +810,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
@@ -891,7 +849,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
@@ -928,7 +886,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
- format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."),
+ format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
)
.size(LabelSize::Small).color(Color::Muted),
)
@@ -947,7 +905,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.")
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
} else {
"API key configured.".to_string()
})),
@@ -960,7 +918,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable.")))
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
@@ -1,8 +1,7 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::Result;
use collections::BTreeMap;
-use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -10,24 +9,26 @@ use language_model::{
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter, Role,
};
-use menu;
use open_ai::ResponseStreamEvent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
-use vercel::Model;
-
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
+use vercel::Model;
+use zed_env_vars::{EnvVar, env_var};
-use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+use crate::{api_key::ApiKeyState, ui::InstructionListItem};
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
+const API_KEY_ENV_VAR_NAME: &str = "VERCEL_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
#[derive(Default, Clone, Debug, PartialEq)]
pub struct VercelSettings {
pub api_url: String,
@@ -49,103 +50,48 @@ pub struct VercelLanguageModelProvider {
}
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
- _subscription: Subscription,
+ api_key_state: ApiKeyState,
}
-const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY";
-
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ self.api_key_state.has_key()
}
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let settings = &AllLanguageModelSettings::get_global(cx).vercel;
- let api_url = if settings.api_url.is_empty() {
- vercel::VERCEL_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = VercelLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let settings = &AllLanguageModelSettings::get_global(cx).vercel;
- let api_url = if settings.api_url.is_empty() {
- vercel::VERCEL_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
- }
-
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let settings = &AllLanguageModelSettings::get_global(cx).vercel;
- let api_url = if settings.api_url.is_empty() {
- vercel::VERCEL_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = VercelLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
}
}
impl VercelLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
});
Self { http_client, state }
@@ -160,6 +106,19 @@ impl VercelLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
+
+ fn settings(cx: &App) -> &VercelSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).vercel
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ vercel::VERCEL_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
}
impl LanguageModelProviderState for VercelLanguageModelProvider {
@@ -200,10 +159,7 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
}
}
- for model in &AllLanguageModelSettings::get_global(cx)
- .vercel
- .available_models
- {
+ for model in &Self::settings(cx).available_models {
models.insert(
model.name.clone(),
vercel::Model::Custom {
@@ -241,7 +197,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -261,16 +218,17 @@ impl VercelLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).vercel;
- let api_url = if settings.api_url.is_empty() {
- vercel::VERCEL_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- (state.api_key.clone(), api_url)
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = VercelLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(&api_url);
+ (api_key, api_url)
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err)).boxed();
+ }
};
let future = self.request_limiter.stream(async move {
@@ -466,29 +424,18 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self
- .api_key_editor
- .read(cx)
- .editor()
- .read(cx)
- .text(cx)
- .trim()
- .to_string();
-
- // Don't proceed if no API key is provided and we're not authenticated
- if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -500,11 +447,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
@@ -514,7 +461,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -534,7 +481,7 @@ impl Render for ConfigurationView {
.child(self.api_key_editor.clone())
.child(
Label::new(format!(
- "You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed."
+ "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
@@ -559,7 +506,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {VERCEL_API_KEY_VAR} environment variable.")
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
} else {
"API key configured.".to_string()
})),
@@ -572,7 +519,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable.")))
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
@@ -1,8 +1,7 @@
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::Result;
use collections::BTreeMap;
-use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@@ -10,23 +9,25 @@ use language_model::{
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role,
};
-use menu;
use open_ai::ResponseStreamEvent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use strum::IntoEnumIterator;
-use x_ai::Model;
-
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
+use x_ai::{Model, XAI_API_URL};
+use zed_env_vars::{EnvVar, env_var};
+
+use crate::{api_key::ApiKeyState, ui::InstructionListItem};
-use crate::{AllLanguageModelSettings, ui::InstructionListItem};
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
-const PROVIDER_ID: &str = "x_ai";
-const PROVIDER_NAME: &str = "xAI";
+const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct XAiSettings {
@@ -49,103 +50,48 @@ pub struct XAiLanguageModelProvider {
}
pub struct State {
- api_key: Option<String>,
- api_key_from_env: bool,
- _subscription: Subscription,
+ api_key_state: ApiKeyState,
}
-const XAI_API_KEY_VAR: &str = "XAI_API_KEY";
-
impl State {
fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ self.api_key_state.has_key()
}
- fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
- let api_url = if settings.api_url.is_empty() {
- x_ai::XAI_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- cx.spawn(async move |this, cx| {
- credentials_provider
- .delete_credentials(&api_url, cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = None;
- this.api_key_from_env = false;
- cx.notify();
- })
- })
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = XAiLanguageModelProvider::api_url(cx);
+ self.api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx)
}
- fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
- let api_url = if settings.api_url.is_empty() {
- x_ai::XAI_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- cx.spawn(async move |this, cx| {
- credentials_provider
- .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
- .await
- .log_err();
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- cx.notify();
- })
- })
- }
-
- fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
- if self.is_authenticated() {
- return Task::ready(Ok(()));
- }
-
- let credentials_provider = <dyn CredentialsProvider>::global(cx);
- let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
- let api_url = if settings.api_url.is_empty() {
- x_ai::XAI_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- cx.spawn(async move |this, cx| {
- let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) {
- (api_key, true)
- } else {
- let (_, api_key) = credentials_provider
- .read_credentials(&api_url, cx)
- .await?
- .ok_or(AuthenticateError::CredentialsNotFound)?;
- (
- String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
- false,
- )
- };
- this.update(cx, |this, cx| {
- this.api_key = Some(api_key);
- this.api_key_from_env = from_env;
- cx.notify();
- })?;
-
- Ok(())
- })
+ fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = XAiLanguageModelProvider::api_url(cx);
+ self.api_key_state.load_if_needed(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ )
}
}
impl XAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
+ let state = cx.new(|cx| {
+ cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
+ let api_url = Self::api_url(cx);
+ this.api_key_state.handle_url_change(
+ api_url,
+ &API_KEY_ENV_VAR,
+ |this| &mut this.api_key_state,
+ cx,
+ );
cx.notify();
- }),
+ })
+ .detach();
+ State {
+ api_key_state: ApiKeyState::new(Self::api_url(cx)),
+ }
});
Self { http_client, state }
@@ -160,6 +106,19 @@ impl XAiLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
+
+ fn settings(cx: &App) -> &XAiSettings {
+ &crate::AllLanguageModelSettings::get_global(cx).x_ai
+ }
+
+ fn api_url(cx: &App) -> SharedString {
+ let api_url = &Self::settings(cx).api_url;
+ if api_url.is_empty() {
+ XAI_API_URL.into()
+ } else {
+ SharedString::new(api_url.as_str())
+ }
+ }
}
impl LanguageModelProviderState for XAiLanguageModelProvider {
@@ -172,11 +131,11 @@ impl LanguageModelProviderState for XAiLanguageModelProvider {
impl LanguageModelProvider for XAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -200,10 +159,7 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
}
}
- for model in &AllLanguageModelSettings::get_global(cx)
- .x_ai
- .available_models
- {
+ for model in &Self::settings(cx).available_models {
models.insert(
model.name.clone(),
x_ai::Model::Custom {
@@ -241,7 +197,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.reset_api_key(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -261,20 +218,25 @@ impl XAiLanguageModel {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
- let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).x_ai;
- let api_url = if settings.api_url.is_empty() {
- x_ai::XAI_API_URL.to_string()
- } else {
- settings.api_url.clone()
- };
- (state.api_key.clone(), api_url)
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+
+ let api_key_and_url = self.state.read_with(cx, |state, cx| {
+ let api_url = XAiLanguageModelProvider::api_url(cx);
+ let api_key = state.api_key_state.key(&api_url);
+ (api_key, api_url)
+ });
+ let (api_key, api_url) = match api_key_and_url {
+ Ok(api_key_and_url) => api_key_and_url,
+ Err(err) => {
+ return futures::future::ready(Err(err)).boxed();
+ }
};
let future = self.request_limiter.stream(async move {
- let api_key = api_key.context("Missing xAI API Key")?;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@@ -295,11 +257,11 @@ impl LanguageModel for XAiLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -456,29 +418,18 @@ impl ConfigurationView {
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
- let api_key = self
- .api_key_editor
- .read(cx)
- .editor()
- .read(cx)
- .text(cx)
- .trim()
- .to_string();
-
- // Don't proceed if no API key is provided and we're not authenticated
- if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
- .update(cx, |state, cx| state.set_api_key(api_key, cx))?
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
.await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
@@ -490,11 +441,11 @@ impl ConfigurationView {
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
- state.update(cx, |state, cx| state.reset_api_key(cx))?.await
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))?
+ .await
})
.detach_and_log_err(cx);
-
- cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
@@ -504,7 +455,7 @@ impl ConfigurationView {
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let env_var_set = self.state.read(cx).api_key_from_env;
+ let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
let api_key_section = if self.should_render_editor(cx) {
v_flex()
@@ -524,7 +475,7 @@ impl Render for ConfigurationView {
.child(self.api_key_editor.clone())
.child(
Label::new(format!(
- "You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed."
+ "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
@@ -549,7 +500,7 @@ impl Render for ConfigurationView {
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
- format!("API key set in {XAI_API_KEY_VAR} environment variable.")
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
} else {
"API key configured.".to_string()
})),
@@ -562,7 +513,7 @@ impl Render for ConfigurationView {
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
- this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable.")))
+ this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
@@ -16,3 +16,4 @@ default = []
[dependencies]
workspace-hack.workspace = true
+gpui.workspace = true
@@ -1,6 +1,44 @@
+use gpui::SharedString;
use std::sync::LazyLock;
/// Whether Zed is running in stateless mode.
/// When true, Zed will use in-memory databases instead of persistent storage.
-pub static ZED_STATELESS: LazyLock<bool> =
- LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
+pub static ZED_STATELESS: LazyLock<bool> = bool_env_var!("ZED_STATELESS");
+
+pub struct EnvVar {
+ pub name: SharedString,
+ /// Value of the environment variable. Also `None` when set to an empty string.
+ pub value: Option<String>,
+}
+
+impl EnvVar {
+ pub fn new(name: SharedString) -> Self {
+ let value = std::env::var(name.as_str()).ok();
+ if value.as_ref().is_some_and(|v| v.is_empty()) {
+ Self { name, value: None }
+ } else {
+ Self { name, value }
+ }
+ }
+
+ pub fn or(self, other: EnvVar) -> EnvVar {
+ if self.value.is_some() { self } else { other }
+ }
+}
+
+/// Creates a `LazyLock<EnvVar>` expression for use in a `static` declaration.
+#[macro_export]
+macro_rules! env_var {
+ ($name:expr) => {
+ LazyLock::new(|| $crate::EnvVar::new(($name).into()))
+ };
+}
+
+/// Generates a `LazyLock<bool>` expression for use in a `static` declaration. Checks if the
+/// environment variable exists and is non-empty.
+#[macro_export]
+macro_rules! bool_env_var {
+ ($name:expr) => {
+ LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some())
+ };
+}