From 98edf1bf0b6e1b6ec62986d370763920dde8b890 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Sun, 14 Sep 2025 21:36:24 -0600 Subject: [PATCH] Reload API keys when URLs configured for LLM providers change (#38163) Three motivations for this: * Changing provider URL could cause credentials for the prior URL to be sent to the new URL. * The UI is in a misleading state after URL change - it shows a configured API key, but on restart it will show no API key. * #34110 will add support for both URL and key configuration for Ollama. This is the first provider to have UI for setting the URL, and this makes these issues show up more directly as odd UI interactions. #37610 implemented something similar for the OpenAI and OpenAI compatible providers. This extracts out some shared code, uses it in all relevant providers, and adds more safety around key use. I haven't tested all providers, but the per-provider changes were pretty mechanical, so hopefully work properly. Release Notes: - Fixed handling of changes to LLM provider URL in settings to also load the associated API key. --- Cargo.lock | 2 + crates/agent_servers/src/gemini.rs | 8 +- crates/language_models/Cargo.toml | 1 + crates/language_models/src/api_key.rs | 295 ++++++++++++++++++ crates/language_models/src/language_models.rs | 1 + .../language_models/src/provider/anthropic.rs | 186 ++++------- .../language_models/src/provider/deepseek.rs | 183 +++++------ crates/language_models/src/provider/google.rs | 219 ++++++------- .../language_models/src/provider/mistral.rs | 181 +++++------ .../language_models/src/provider/open_ai.rs | 225 +++++-------- .../src/provider/open_ai_compatible.rs | 183 ++++------- .../src/provider/open_router.rs | 192 +++++------- crates/language_models/src/provider/vercel.rs | 201 +++++------- crates/language_models/src/provider/x_ai.rs | 219 +++++-------- crates/zed_env_vars/Cargo.toml | 1 + crates/zed_env_vars/src/zed_env_vars.rs | 42 ++- 16 files changed, 1030 insertions(+), 1109 deletions(-) create mode 100644 crates/language_models/src/api_key.rs diff --git a/Cargo.lock b/Cargo.lock index db61ba7b121773c6e6941dd940201ec0a9e82f73..22589ee11a4ffb657238091ab85a3e76d9b6bf32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9212,6 +9212,7 @@ dependencies = [ "vercel", "workspace-hack", "x_ai", + "zed_env_vars", ] [[package]] @@ -20677,6 +20678,7 @@ dependencies = [ name = "zed_env_vars" version = "0.1.0" dependencies = [ + "gpui", "workspace-hack", ] diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 01f15557899e1c7826e91d1555320996eccd0f45..6a668785943239f28b4bc9aafce9d37fdfa386d2 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -44,8 +44,12 @@ impl AgentServer for Gemini { cx.spawn(async move |cx| { let mut extra_env = HashMap::default(); - if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() { - extra_env.insert("GEMINI_API_KEY".into(), api_key.key); + if let Some(api_key) = cx + .update(GoogleLanguageModelProvider::api_key_for_gemini_cli)? + .await + .ok() + { + extra_env.insert("GEMINI_API_KEY".into(), api_key); } let (mut command, root_dir, login) = store .update(cx, |store, cx| { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b5bfb870f643452bd5be248c9910d99f16a8101e..7dc0988d23579c4d1ab1ac2dde1f1413c5e751b8 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -61,6 +61,7 @@ util.workspace = true vercel = { workspace = true, features = ["schemars"] } workspace-hack.workspace = true x_ai = { workspace = true, features = ["schemars"] } +zed_env_vars.workspace = true [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models/src/api_key.rs b/crates/language_models/src/api_key.rs new file mode 100644 index 0000000000000000000000000000000000000000..122234b6ced6d0bf1b7a0d684683c841824ccd2d --- /dev/null +++ b/crates/language_models/src/api_key.rs @@ -0,0 +1,295 @@ +use anyhow::{Result, anyhow}; +use credentials_provider::CredentialsProvider; +use futures::{FutureExt, future}; +use gpui::{AsyncApp, Context, SharedString, Task}; +use language_model::AuthenticateError; +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; +use util::ResultExt as _; +use zed_env_vars::EnvVar; + +/// Manages a single API key for a language model provider. API keys either come from environment +/// variables or the system keychain. +/// +/// Keys from the system keychain are associated with a provider URL, and this ensures that they are +/// only used with that URL. +pub struct ApiKeyState { + url: SharedString, + load_status: LoadStatus, + load_task: Option>>, +} + +#[derive(Debug, Clone)] +pub enum LoadStatus { + NotPresent, + Error(String), + Loaded(ApiKey), +} + +#[derive(Debug, Clone)] +pub struct ApiKey { + source: ApiKeySource, + key: Arc, +} + +impl ApiKeyState { + pub fn new(url: SharedString) -> Self { + Self { + url, + load_status: LoadStatus::NotPresent, + load_task: None, + } + } + + pub fn has_key(&self) -> bool { + matches!(self.load_status, LoadStatus::Loaded { .. }) + } + + pub fn is_from_env_var(&self) -> bool { + match &self.load_status { + LoadStatus::Loaded(ApiKey { + source: ApiKeySource::EnvVar { .. }, + .. + }) => true, + _ => false, + } + } + + /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if + /// there is no key or for URL mismatches, and the mismatch case is logged. + /// + /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been + /// called with this URL. + pub fn key(&self, url: &str) -> Option> { + let api_key = match &self.load_status { + LoadStatus::Loaded(api_key) => api_key, + _ => return None, + }; + if url == self.url.as_str() { + Some(api_key.key.clone()) + } else if let ApiKeySource::EnvVar(var_name) = &api_key.source { + log::warn!( + "{} is now being used with URL {}, when initially it was used with URL {}", + var_name, + url, + self.url + ); + Some(api_key.key.clone()) + } else { + // bug case because load_if_needed should be called whenever the url may have changed + log::error!( + "bug: Attempted to use API key associated with URL {} instead with URL {}", + self.url, + url + ); + None + } + } + + /// Set or delete the API key in the system keychain. + pub fn store( + &mut self, + url: SharedString, + key: Option, + get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + cx: &Context, + ) -> Task> { + if self.is_from_env_var() { + return Task::ready(Err(anyhow!( + "bug: attempted to store API key in system keychain when API key is from env var", + ))); + } + let credentials_provider = ::global(cx); + cx.spawn(async move |ent, cx| { + if let Some(key) = &key { + credentials_provider + .write_credentials(&url, "Bearer", key.as_bytes(), cx) + .await + .log_err(); + } else { + credentials_provider + .delete_credentials(&url, cx) + .await + .log_err(); + } + ent.update(cx, |ent, cx| { + let this = get_this(ent); + this.url = url; + this.load_status = match &key { + Some(key) => LoadStatus::Loaded(ApiKey { + source: ApiKeySource::SystemKeychain, + key: key.as_str().into(), + }), + None => LoadStatus::NotPresent, + }; + cx.notify(); + }) + }) + } + + /// Reloads the API key if the current API key is associated with a different URL. + /// + /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs + /// interchangeably - URL change should correspond to some user initiated change. + pub fn handle_url_change( + &mut self, + url: SharedString, + env_var: &EnvVar, + get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + cx: &mut Context, + ) { + if url != self.url { + if !self.is_from_env_var() { + // loading will continue even though this result task is dropped + let _task = self.load_if_needed(url, env_var, get_this, cx); + } + } + } + + /// If needed, loads the API key associated with the given URL from the system keychain. When a + /// non-empty environment variable is provided, it will be used instead. If called when an API + /// key was already loaded for a different URL, that key will be cleared before loading. + /// + /// Dropping the returned Task does not cancel key loading. + pub fn load_if_needed( + &mut self, + url: SharedString, + env_var: &EnvVar, + get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + cx: &mut Context, + ) -> Task> { + if let LoadStatus::Loaded { .. } = &self.load_status + && self.url == url + { + return Task::ready(Ok(())); + } + + if let Some(key) = &env_var.value + && !key.is_empty() + { + let api_key = ApiKey::from_env(env_var.name.clone(), key); + self.url = url; + self.load_status = LoadStatus::Loaded(api_key); + self.load_task = None; + cx.notify(); + return Task::ready(Ok(())); + } + + let task = if let Some(load_task) = &self.load_task { + load_task.clone() + } else { + let load_task = Self::load(url.clone(), get_this.clone(), cx).shared(); + self.url = url; + self.load_status = LoadStatus::NotPresent; + self.load_task = Some(load_task.clone()); + cx.notify(); + load_task + }; + + cx.spawn(async move |ent, cx| { + task.await; + ent.update(cx, |ent, _cx| { + get_this(ent).load_status.clone().into_authenticate_result() + }) + .ok(); + Ok(()) + }) + } + + fn load( + url: SharedString, + get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + cx: &Context, + ) -> Task<()> { + let credentials_provider = ::global(cx); + cx.spawn({ + async move |ent, cx| { + let load_status = + ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx) + .await; + ent.update(cx, |ent, cx| { + let this = get_this(ent); + this.url = url; + this.load_status = load_status; + this.load_task = None; + cx.notify(); + }) + .ok(); + } + }) + } +} + +impl ApiKey { + pub fn key(&self) -> &str { + &self.key + } + + pub fn from_env(env_var_name: SharedString, key: &str) -> Self { + Self { + source: ApiKeySource::EnvVar(env_var_name), + key: key.into(), + } + } + + pub async fn load_from_system_keychain( + url: &str, + credentials_provider: &dyn CredentialsProvider, + cx: &AsyncApp, + ) -> Result { + Self::load_from_system_keychain_impl(url, credentials_provider, cx) + .await + .into_authenticate_result() + } + + async fn load_from_system_keychain_impl( + url: &str, + credentials_provider: &dyn CredentialsProvider, + cx: &AsyncApp, + ) -> LoadStatus { + if url.is_empty() { + return LoadStatus::NotPresent; + } + let read_result = credentials_provider.read_credentials(&url, cx).await; + let api_key = match read_result { + Ok(Some((_, api_key))) => api_key, + Ok(None) => return LoadStatus::NotPresent, + Err(err) => return LoadStatus::Error(err.to_string()), + }; + let key = match str::from_utf8(&api_key) { + Ok(key) => key, + Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")), + }; + LoadStatus::Loaded(Self { + source: ApiKeySource::SystemKeychain, + key: key.into(), + }) + } +} + +impl LoadStatus { + fn into_authenticate_result(self) -> Result { + match self { + LoadStatus::Loaded(api_key) => Ok(api_key), + LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound), + LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))), + } + } +} + +#[derive(Debug, Clone)] +enum ApiKeySource { + EnvVar(SharedString), + SystemKeychain, +} + +impl Display for ApiKeySource { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var), + ApiKeySource::SystemKeychain => write!(f, "system keychain"), + } + } +} diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 738b72b0c9a6dbb7c9606cc72707b27e66abf09c..61e1a794695310421397469515a43a4d5bf5deb8 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -7,6 +7,7 @@ use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; +mod api_key; pub mod provider; mod settings; pub mod ui; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index ca7763e2c5cda3c07c5cb51389cb3173a55865e2..f9c664182426dfd9523254867ba2f95b8f3dc7c4 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,18 +1,15 @@ -use crate::AllLanguageModelSettings; use crate::ui::InstructionListItem; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState}; use anthropic::{ AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent, ToolResultPart, Usage, }; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, -}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, FontStyle, Task, TextStyle, WhiteSpace}; use http_client::HttpClient; use language_model::{ AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, @@ -27,11 +24,12 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME; @@ -97,91 +95,52 @@ pub struct AnthropicLanguageModelProvider { state: gpui::Entity, } -const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY"; +const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } impl State { - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .ok(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) - } - - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .ok(); - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let key = AnthropicLanguageModelProvider::api_key(cx); - - cx.spawn(async move |this, cx| { - let key = key.await?; - - this.update(cx, |this, cx| { - this.api_key = Some(key.key); - this.api_key_from_env = key.from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } -} -pub struct ApiKey { - pub key: String, - pub from_env: bool, + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) + } } impl AnthropicLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -197,30 +156,16 @@ impl AnthropicLanguageModelProvider { }) } - pub fn api_key(cx: &mut App) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - - if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) { - Task::ready(Ok(ApiKey { - key, - from_env: true, - })) + fn settings(cx: &App) -> &AnthropicSettings { + &AllLanguageModelSettings::get_global(cx).anthropic + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + anthropic::ANTHROPIC_API_URL.into() } else { - cx.spawn(async move |cx| { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - - Ok(ApiKey { - key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - from_env: false, - }) - }) + SharedString::new(api_url.as_str()) } } } @@ -327,7 +272,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -417,11 +363,16 @@ impl AnthropicModel { > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).anthropic; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err.into())).boxed(); + } }; let beta_headers = self.model.beta_headers(); @@ -483,7 +434,10 @@ impl LanguageModel for AnthropicModel { } fn api_key(&self, cx: &App) -> Option { - self.state.read(cx).api_key.clone() + self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + state.api_key_state.key(&api_url).map(|key| key.to_string()) + }) } fn max_token_count(&self) -> u64 { @@ -987,12 +941,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -1001,11 +953,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -1040,7 +992,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -1079,7 +1031,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small) .color(Color::Muted), @@ -1099,7 +1051,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -1112,7 +1064,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 82bf067cd475fe031630767da9e4302afa4d78ec..e00d8bbf4b34560bdc55c982114bf96675e22d99 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -1,12 +1,11 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{ - AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle, - WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -21,16 +20,19 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use theme::ThemeSettings; use ui::{Icon, IconName, List, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); -const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; + +const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default)] struct RawToolCall { @@ -59,95 +61,48 @@ pub struct DeepSeekLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = SharedString::new(DeepSeekLanguageModelProvider::api_url(cx)); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = SharedString::new(DeepSeekLanguageModelProvider::api_url(cx)); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl DeepSeekLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = SharedString::new(Self::api_url(cx)); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(SharedString::new(Self::api_url(cx))), + } }); Self { http_client, state } @@ -160,7 +115,15 @@ impl DeepSeekLanguageModelProvider { state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) + } + + fn settings(cx: &App) -> &DeepSeekSettings { + &AllLanguageModelSettings::get_global(cx).deepseek + } + + fn api_url(cx: &App) -> &str { + &Self::settings(cx).api_url } } @@ -199,11 +162,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { models.insert("deepseek-chat", deepseek::Model::Chat); models.insert("deepseek-reasoner", deepseek::Model::Reasoner); - for available_model in AllLanguageModelSettings::get_global(cx) - .deepseek - .available_models - .iter() - { + for available_model in &Self::settings(cx).available_models { models.insert( &available_model.name, deepseek::Model::Custom { @@ -240,7 +199,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -259,15 +219,25 @@ impl DeepSeekLanguageModel { cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).deepseek; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = DeepSeekLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(api_url); + (api_key, api_url.to_string()) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing DeepSeek API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -610,7 +580,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -618,12 +588,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn(async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -631,10 +599,12 @@ impl ConfigurationView { .update(cx, |editor, cx| editor.set_text("", window, cx)); let state = self.state.clone(); - cx.spawn(async move |_, cx| state.update(cx, |state, cx| state.reset_api_key(cx))?.await) - .detach_and_log_err(cx); - - cx.notify(); + cx.spawn(async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -672,7 +642,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -706,8 +676,7 @@ impl Render for ConfigurationView { ) .child( Label::new(format!( - "Or set the {} environment variable.", - DEEPSEEK_API_KEY_VAR + "Or set the {API_KEY_ENV_VAR_NAME} environment variable." )) .size(LabelSize::Small) .color(Color::Muted), @@ -727,7 +696,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {}", DEEPSEEK_API_KEY_VAR) + format!("API key set in {API_KEY_ENV_VAR_NAME}") } else { "API key configured".to_string() })), diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 939cf0ca60d92d713b90a5d62e8ec7f6dac7ec46..677d5775b3854245ad441acf37b0e1ed02f6bab5 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,4 +1,4 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; @@ -8,7 +8,8 @@ use google_ai::{ ThinkingConfig, UsageMetadata, }; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -26,18 +27,18 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::sync::{ - Arc, + Arc, LazyLock, atomic::{self, AtomicU64}, }; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::EnvVar; -use crate::AllLanguageModelSettings; +use crate::api_key::ApiKey; use crate::ui::InstructionListItem; - -use super::anthropic::ApiKey; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState}; const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; @@ -91,101 +92,56 @@ pub struct GoogleLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY"; -const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY"; +const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY"; +const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY"; + +static API_KEY_ENV_VAR: LazyLock = LazyLock::new(|| { + // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY + EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into())) +}); impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) { - (api_key, true) - } else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl GoogleLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -201,30 +157,28 @@ impl GoogleLanguageModelProvider { }) } - pub fn api_key(cx: &mut App) -> Task> { + pub fn api_key_for_gemini_cli(cx: &mut App) -> Task> { + if let Some(key) = API_KEY_ENV_VAR.value.clone() { + return Task::ready(Ok(key)); + } let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - - if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) { - Task::ready(Ok(ApiKey { - key, - from_env: true, - })) - } else { - cx.spawn(async move |cx| { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) + let api_url = Self::api_url(cx).to_string(); + cx.spawn(async move |cx| { + Ok( + ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx) .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; + .key() + .to_string(), + ) + }) + } - Ok(ApiKey { - key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - from_env: false, - }) - }) + fn api_url(cx: &App) -> SharedString { + let api_url = &AllLanguageModelSettings::get_global(cx).google.api_url; + if api_url.is_empty() { + google_ai::API_URL.into() + } else { + SharedString::new(api_url.as_str()) } } } @@ -317,7 +271,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -340,11 +295,16 @@ impl GoogleLanguageModel { > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).google; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = GoogleLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; async move { @@ -418,13 +378,16 @@ impl LanguageModel for GoogleLanguageModel { let model_id = self.model.request_id().to_string(); let request = into_google(request, model_id, self.model.mode()); let http_client = self.http_client.clone(); - let api_key = self.state.read(cx).api_key.clone(); - - let settings = &AllLanguageModelSettings::get_global(cx).google; - let api_url = settings.api_url.clone(); + let api_url = GoogleLanguageModelProvider::api_url(cx); + let api_key = self.state.read(cx).api_key_state.key(&api_url); async move { - let api_key = api_key.context("Missing Google API key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + } + .into()); + }; let response = google_ai::count_tokens( http_client.as_ref(), &api_url, @@ -852,7 +815,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -860,12 +823,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -874,11 +835,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -913,7 +874,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -950,7 +911,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -969,7 +930,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {GEMINI_API_KEY_VAR} environment variable.") + format!("API key set in {GEMINI_API_KEY_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -982,7 +943,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset."))) + this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index c9824bf89ea7a919f4517f492a5091a2cda7b43b..d2de056489d5fe10bab5eaa41780c0a5f312a181 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -1,10 +1,10 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -21,17 +21,21 @@ use settings::{Settings, SettingsStore}; use std::collections::HashMap; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); +const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct MistralSettings { pub api_url: String, @@ -56,96 +60,48 @@ pub struct MistralLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = MistralLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = MistralLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl MistralLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -160,6 +116,19 @@ impl MistralLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &MistralSettings { + &crate::AllLanguageModelSettings::get_global(cx).mistral + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + mistral::MISTRAL_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for MistralLanguageModelProvider { @@ -202,10 +171,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } // Override with available models from settings - for model in &AllLanguageModelSettings::get_global(cx) - .mistral - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), mistral::Model::Custom { @@ -254,7 +220,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -276,15 +243,25 @@ impl MistralLanguageModel { Result>>, > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).mistral; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = MistralLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing Mistral API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -780,7 +757,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -788,12 +765,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -802,11 +777,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -841,7 +816,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -878,7 +853,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -897,7 +872,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -910,7 +885,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index fca1cf977cb5e3b32dc6f2335fb0d9188979bc9f..a7e8f64d6614d95e4e892aba569d441320108ca4 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,10 +1,8 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; - use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -20,18 +18,21 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; - use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; +const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiSettings { pub api_url: String, @@ -54,132 +55,48 @@ pub struct OpenAiLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - last_api_url: String, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) - } - - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn get_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - self.get_api_key(cx) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl OpenAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let initial_api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - last_api_url: initial_api_url.clone(), - _subscription: cx.observe_global::(|this: &mut State, cx| { - let current_api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - if this.last_api_url != current_api_url { - this.last_api_url = current_api_url; - if !this.api_key_from_env { - this.api_key = None; - let spawn_task = cx.spawn(async move |handle, cx| { - if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { - if let Err(_) = task.await { - handle - .update(cx, |this, _| { - this.api_key = None; - this.api_key_from_env = false; - }) - .ok(); - } - } - }); - spawn_task.detach(); - } - } + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -194,6 +111,19 @@ impl OpenAiLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &OpenAiSettings { + &AllLanguageModelSettings::get_global(cx).openai + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + open_ai::OPEN_AI_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for OpenAiLanguageModelProvider { @@ -278,7 +208,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -298,11 +229,17 @@ impl OpenAiLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).openai; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { @@ -802,29 +739,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -836,11 +762,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -850,7 +776,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -872,10 +798,11 @@ impl Render for ConfigurationView { ) .child(self.api_key_editor.clone()) .child( - Label::new( - format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."), - ) - .size(LabelSize::Small).color(Color::Muted), + Label::new(format!( + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), ) .child( Label::new( @@ -898,7 +825,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {OPENAI_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -911,7 +838,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 4ebb11a07b66ec7054ca65437ec887a415fa3f5c..2b5e372a1d6ccd9b3bb13feb777f1091f0336b58 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -1,9 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; -use credentials_provider::CredentialsProvider; - +use anyhow::Result; use convert_case::{Case, Casing}; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -17,13 +15,13 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::sync::Arc; - use ui::{ElevationIndex, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use zed_env_vars::EnvVar; -use crate::AllLanguageModelSettings; use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState}; #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiCompatibleSettings { @@ -70,82 +68,30 @@ pub struct OpenAiCompatibleLanguageModelProvider { pub struct State { id: Arc, - env_var_name: Arc, - api_key: Option, - api_key_from_env: bool, + api_key_env_var: EnvVar, + api_key_state: ApiKeyState, settings: OpenAiCompatibleSettings, - _subscription: Subscription, } impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) - } - - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn get_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let env_var_name = self.env_var_name.clone(); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = SharedString::new(self.settings.api_url.as_str()); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - self.get_api_key(cx) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = SharedString::new(self.settings.api_url.clone()); + self.api_key_state.load_if_needed( + api_url, + &self.api_key_env_var, + |this| &mut this.api_key_state, + cx, + ) } } @@ -157,37 +103,32 @@ impl OpenAiCompatibleLanguageModelProvider { .get(id) } - let state = cx.new(|cx| State { - id: id.clone(), - env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(), - settings: resolve_settings(&id, cx).cloned().unwrap_or_default(), - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|this: &mut State, cx| { + let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into(); + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { let Some(settings) = resolve_settings(&this.id, cx).cloned() else { return; }; if &this.settings != &settings { - if settings.api_url != this.settings.api_url && !this.api_key_from_env { - let spawn_task = cx.spawn(async move |handle, cx| { - if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { - if let Err(_) = task.await { - handle - .update(cx, |this, _| { - this.api_key = None; - this.api_key_from_env = false; - }) - .ok(); - } - } - }); - spawn_task.detach(); - } - + let api_url = SharedString::new(settings.api_url.as_str()); + this.api_key_state.handle_url_change( + api_url, + &this.api_key_env_var, + |this| &mut this.api_key_state, + cx, + ); this.settings = settings; cx.notify(); } - }), + }) + .detach(); + let settings = resolve_settings(&id, cx).cloned().unwrap_or_default(); + State { + id: id.clone(), + api_key_env_var: EnvVar::new(api_key_env_var_name), + api_key_state: ApiKeyState::new(SharedString::new(settings.api_url.as_str())), + settings, + } }); Self { @@ -274,7 +215,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -296,10 +238,17 @@ impl OpenAiCompatibleLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| { - (state.api_key.clone(), state.settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, _cx| { + let api_url = &state.settings.api_url; + let api_key = state.api_key_state.key(api_url); + (api_key, state.settings.api_url.clone()) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let provider = self.provider_name.clone(); @@ -469,29 +418,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -503,22 +441,23 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } - fn should_render_editor(&self, cx: &mut Context) -> bool { + fn should_render_editor(&self, cx: &Context) -> bool { !self.state.read(cx).is_authenticated() } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; - let env_var_name = self.state.read(cx).env_var_name.clone(); + let state = self.state.read(cx); + let env_var_set = state.api_key_state.is_from_env_var(); + let env_var_name = &state.api_key_env_var.name; let api_key_section = if self.should_render_editor(cx) { v_flex() diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 0ebf379b393791558cba6ff36ab31d278162386e..efa6a1eabe33fe8ecb52da3b5c61122cfc26e036 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -1,10 +1,9 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::HashMap; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, }; use http_client::HttpClient; use language_model::{ @@ -16,23 +15,26 @@ use language_model::{ }; use open_router::{ Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models, - stream_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); +const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenRouterSettings { pub api_url: String, @@ -90,93 +92,38 @@ pub struct OpenRouterLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, + api_key_state: ApiKeyState, http_client: Arc, available_models: Vec, fetch_models_task: Option>>, settings: OpenRouterSettings, - _subscription: Subscription, } -const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.restart_fetch_models_task(cx); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let task = self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - this.restart_fetch_models_task(cx); - cx.notify(); - })?; - - Ok(()) + let result = task.await; + this.update(cx, |this, cx| this.restart_fetch_models_task(cx)) + .ok(); + result }) } @@ -184,10 +131,9 @@ impl State { &mut self, cx: &mut Context, ) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).open_router; let http_client = self.http_client.clone(); - let api_url = settings.api_url.clone(); - let Some(api_key) = self.api_key.clone() else { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let Some(api_key) = self.api_key_state.key(&api_url) else { return Task::ready(Err(LanguageModelCompletionError::NoApiKey { provider: PROVIDER_NAME, })); @@ -216,33 +162,45 @@ impl State { if self.is_authenticated() { let task = self.fetch_models(cx); self.fetch_models_task.replace(task); + } else { + self.available_models = Vec::new(); } } } impl OpenRouterLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - http_client: http_client.clone(), - available_models: Vec::new(), - fetch_models_task: None, - settings: OpenRouterSettings::default(), - _subscription: cx.observe_global::(|this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { let current_settings = &AllLanguageModelSettings::get_global(cx).open_router; let settings_changed = current_settings != &this.settings; if settings_changed { this.settings = current_settings.clone(); - this.restart_fetch_models_task(cx); + this.authenticate(cx).detach(); } cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + http_client: http_client.clone(), + available_models: Vec::new(), + fetch_models_task: None, + settings: OpenRouterSettings::default(), + } }); Self { http_client, state } } + fn settings(cx: &App) -> &OpenRouterSettings { + &AllLanguageModelSettings::get_global(cx).open_router + } + + fn api_url(cx: &App) -> SharedString { + SharedString::new(Self::settings(cx).api_url.as_str()) + } + fn create_language_model(&self, model: open_router::Model) -> Arc { Arc::new(OpenRouterLanguageModel { id: LanguageModelId::from(model.id().to_string()), @@ -287,10 +245,7 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { let mut models_from_api = self.state.read(cx).available_models.clone(); let mut settings_models = Vec::new(); - for model in &AllLanguageModelSettings::get_global(cx) - .open_router - .available_models - { + for model in &Self::settings(cx).available_models { settings_models.push(open_router::Model { name: model.name.clone(), display_name: model.display_name.clone(), @@ -338,7 +293,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -366,14 +322,17 @@ impl OpenRouterLanguageModel { >, > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).open_router; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!( - "App state dropped" - )))) - .boxed(); + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(LanguageModelCompletionError::Other(err))) + .boxed(); + } }; async move { @@ -382,7 +341,8 @@ impl OpenRouterLanguageModel { provider: PROVIDER_NAME, }); }; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = + open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request); request.await.map_err(Into::into) } .boxed() @@ -830,7 +790,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -838,12 +798,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -852,11 +810,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -891,7 +849,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -928,7 +886,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -947,7 +905,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -960,7 +918,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 84f3175d1e5493fd55cafd2ea9c4a0604d2a97b4..ad28946e3bc78078d9d6510fafacbba34c1f98ca 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -1,8 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,24 +9,26 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, }; -use menu; use open_ai::ResponseStreamEvent; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; -use vercel::Model; - use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use vercel::Model; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel"); +const API_KEY_ENV_VAR_NAME: &str = "VERCEL_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct VercelSettings { pub api_url: String, @@ -49,103 +50,48 @@ pub struct VercelLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = VercelLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = VercelLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl VercelLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -160,6 +106,19 @@ impl VercelLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &VercelSettings { + &crate::AllLanguageModelSettings::get_global(cx).vercel + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + vercel::VERCEL_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for VercelLanguageModelProvider { @@ -200,10 +159,7 @@ impl LanguageModelProvider for VercelLanguageModelProvider { } } - for model in &AllLanguageModelSettings::get_global(cx) - .vercel - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), vercel::Model::Custom { @@ -241,7 +197,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -261,16 +218,17 @@ impl VercelLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - (state.api_key.clone(), api_url) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = VercelLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { @@ -466,29 +424,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -500,11 +447,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -514,7 +461,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -534,7 +481,7 @@ impl Render for ConfigurationView { .child(self.api_key_editor.clone()) .child( Label::new(format!( - "You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed." + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." )) .size(LabelSize::Small) .color(Color::Muted), @@ -559,7 +506,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {VERCEL_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -572,7 +519,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index bb17f22c7f3fdbb0296b1e0bb290fbce9a979ddf..d5c5848293da5170b609bfd1526c26429865184c 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -1,8 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,23 +9,25 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role, }; -use menu; use open_ai::ResponseStreamEvent; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; -use x_ai::Model; - use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use x_ai::{Model, XAI_API_URL}; +use zed_env_vars::{EnvVar, env_var}; + +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); -const PROVIDER_ID: &str = "x_ai"; -const PROVIDER_NAME: &str = "xAI"; +const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default, Clone, Debug, PartialEq)] pub struct XAiSettings { @@ -49,103 +50,48 @@ pub struct XAiLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const XAI_API_KEY_VAR: &str = "XAI_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = XAiLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = XAiLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl XAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -160,6 +106,19 @@ impl XAiLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &XAiSettings { + &crate::AllLanguageModelSettings::get_global(cx).x_ai + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + XAI_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for XAiLanguageModelProvider { @@ -172,11 +131,11 @@ impl LanguageModelProviderState for XAiLanguageModelProvider { impl LanguageModelProvider for XAiLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -200,10 +159,7 @@ impl LanguageModelProvider for XAiLanguageModelProvider { } } - for model in &AllLanguageModelSettings::get_global(cx) - .x_ai - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), x_ai::Model::Custom { @@ -241,7 +197,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -261,20 +218,25 @@ impl XAiLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - (state.api_key.clone(), api_url) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = XAiLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing xAI API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -295,11 +257,11 @@ impl LanguageModel for XAiLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -456,29 +418,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -490,11 +441,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -504,7 +455,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -524,7 +475,7 @@ impl Render for ConfigurationView { .child(self.api_key_editor.clone()) .child( Label::new(format!( - "You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed." + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." )) .size(LabelSize::Small) .color(Color::Muted), @@ -549,7 +500,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {XAI_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -562,7 +513,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/zed_env_vars/Cargo.toml b/crates/zed_env_vars/Cargo.toml index 9abfc410e7e74774c4e9e7608e8c1c3824ebc3c1..f56e3dd529cc7a8001d0021e96902f55034f88e2 100644 --- a/crates/zed_env_vars/Cargo.toml +++ b/crates/zed_env_vars/Cargo.toml @@ -16,3 +16,4 @@ default = [] [dependencies] workspace-hack.workspace = true +gpui.workspace = true diff --git a/crates/zed_env_vars/src/zed_env_vars.rs b/crates/zed_env_vars/src/zed_env_vars.rs index d1679a0518f2bae857364b0035b6184350ffca55..53b9c22bb207e81831d1d9ae6087d1a297331d3f 100644 --- a/crates/zed_env_vars/src/zed_env_vars.rs +++ b/crates/zed_env_vars/src/zed_env_vars.rs @@ -1,6 +1,44 @@ +use gpui::SharedString; use std::sync::LazyLock; /// Whether Zed is running in stateless mode. /// When true, Zed will use in-memory databases instead of persistent storage. -pub static ZED_STATELESS: LazyLock = - LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty())); +pub static ZED_STATELESS: LazyLock = bool_env_var!("ZED_STATELESS"); + +pub struct EnvVar { + pub name: SharedString, + /// Value of the environment variable. Also `None` when set to an empty string. + pub value: Option, +} + +impl EnvVar { + pub fn new(name: SharedString) -> Self { + let value = std::env::var(name.as_str()).ok(); + if value.as_ref().is_some_and(|v| v.is_empty()) { + Self { name, value: None } + } else { + Self { name, value } + } + } + + pub fn or(self, other: EnvVar) -> EnvVar { + if self.value.is_some() { self } else { other } + } +} + +/// Creates a `LazyLock` expression for use in a `static` declaration. +#[macro_export] +macro_rules! env_var { + ($name:expr) => { + LazyLock::new(|| $crate::EnvVar::new(($name).into())) + }; +} + +/// Generates a `LazyLock` expression for use in a `static` declaration. Checks if the +/// environment variable exists and is non-empty. +#[macro_export] +macro_rules! bool_env_var { + ($name:expr) => { + LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some()) + }; +}