diff --git a/Cargo.lock b/Cargo.lock index db61ba7b121773c6e6941dd940201ec0a9e82f73..22589ee11a4ffb657238091ab85a3e76d9b6bf32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9212,6 +9212,7 @@ dependencies = [ "vercel", "workspace-hack", "x_ai", + "zed_env_vars", ] [[package]] @@ -20677,6 +20678,7 @@ dependencies = [ name = "zed_env_vars" version = "0.1.0" dependencies = [ + "gpui", "workspace-hack", ] diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 01f15557899e1c7826e91d1555320996eccd0f45..6a668785943239f28b4bc9aafce9d37fdfa386d2 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -44,8 +44,12 @@ impl AgentServer for Gemini { cx.spawn(async move |cx| { let mut extra_env = HashMap::default(); - if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() { - extra_env.insert("GEMINI_API_KEY".into(), api_key.key); + if let Some(api_key) = cx + .update(GoogleLanguageModelProvider::api_key_for_gemini_cli)? + .await + .ok() + { + extra_env.insert("GEMINI_API_KEY".into(), api_key); } let (mut command, root_dir, login) = store .update(cx, |store, cx| { diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b5bfb870f643452bd5be248c9910d99f16a8101e..7dc0988d23579c4d1ab1ac2dde1f1413c5e751b8 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -61,6 +61,7 @@ util.workspace = true vercel = { workspace = true, features = ["schemars"] } workspace-hack.workspace = true x_ai = { workspace = true, features = ["schemars"] } +zed_env_vars.workspace = true [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models/src/api_key.rs b/crates/language_models/src/api_key.rs new file mode 100644 index 0000000000000000000000000000000000000000..122234b6ced6d0bf1b7a0d684683c841824ccd2d --- /dev/null +++ b/crates/language_models/src/api_key.rs @@ -0,0 +1,295 @@ +use anyhow::{Result, anyhow}; +use credentials_provider::CredentialsProvider; +use futures::{FutureExt, future}; +use gpui::{AsyncApp, Context, SharedString, Task}; +use language_model::AuthenticateError; +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; +use util::ResultExt as _; +use zed_env_vars::EnvVar; + +/// Manages a single API key for a language model provider. API keys either come from environment +/// variables or the system keychain. +/// +/// Keys from the system keychain are associated with a provider URL, and this ensures that they are +/// only used with that URL. +pub struct ApiKeyState { + url: SharedString, + load_status: LoadStatus, + load_task: Option>>, +} + +#[derive(Debug, Clone)] +pub enum LoadStatus { + NotPresent, + Error(String), + Loaded(ApiKey), +} + +#[derive(Debug, Clone)] +pub struct ApiKey { + source: ApiKeySource, + key: Arc, +} + +impl ApiKeyState { + pub fn new(url: SharedString) -> Self { + Self { + url, + load_status: LoadStatus::NotPresent, + load_task: None, + } + } + + pub fn has_key(&self) -> bool { + matches!(self.load_status, LoadStatus::Loaded { .. }) + } + + pub fn is_from_env_var(&self) -> bool { + match &self.load_status { + LoadStatus::Loaded(ApiKey { + source: ApiKeySource::EnvVar { .. }, + .. + }) => true, + _ => false, + } + } + + /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if + /// there is no key or for URL mismatches, and the mismatch case is logged. + /// + /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been + /// called with this URL. + pub fn key(&self, url: &str) -> Option> { + let api_key = match &self.load_status { + LoadStatus::Loaded(api_key) => api_key, + _ => return None, + }; + if url == self.url.as_str() { + Some(api_key.key.clone()) + } else if let ApiKeySource::EnvVar(var_name) = &api_key.source { + log::warn!( + "{} is now being used with URL {}, when initially it was used with URL {}", + var_name, + url, + self.url + ); + Some(api_key.key.clone()) + } else { + // bug case because load_if_needed should be called whenever the url may have changed + log::error!( + "bug: Attempted to use API key associated with URL {} instead with URL {}", + self.url, + url + ); + None + } + } + + /// Set or delete the API key in the system keychain. + pub fn store( + &mut self, + url: SharedString, + key: Option, + get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + cx: &Context, + ) -> Task> { + if self.is_from_env_var() { + return Task::ready(Err(anyhow!( + "bug: attempted to store API key in system keychain when API key is from env var", + ))); + } + let credentials_provider = ::global(cx); + cx.spawn(async move |ent, cx| { + if let Some(key) = &key { + credentials_provider + .write_credentials(&url, "Bearer", key.as_bytes(), cx) + .await + .log_err(); + } else { + credentials_provider + .delete_credentials(&url, cx) + .await + .log_err(); + } + ent.update(cx, |ent, cx| { + let this = get_this(ent); + this.url = url; + this.load_status = match &key { + Some(key) => LoadStatus::Loaded(ApiKey { + source: ApiKeySource::SystemKeychain, + key: key.as_str().into(), + }), + None => LoadStatus::NotPresent, + }; + cx.notify(); + }) + }) + } + + /// Reloads the API key if the current API key is associated with a different URL. + /// + /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs + /// interchangeably - URL change should correspond to some user initiated change. + pub fn handle_url_change( + &mut self, + url: SharedString, + env_var: &EnvVar, + get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + cx: &mut Context, + ) { + if url != self.url { + if !self.is_from_env_var() { + // loading will continue even though this result task is dropped + let _task = self.load_if_needed(url, env_var, get_this, cx); + } + } + } + + /// If needed, loads the API key associated with the given URL from the system keychain. When a + /// non-empty environment variable is provided, it will be used instead. If called when an API + /// key was already loaded for a different URL, that key will be cleared before loading. + /// + /// Dropping the returned Task does not cancel key loading. + pub fn load_if_needed( + &mut self, + url: SharedString, + env_var: &EnvVar, + get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + cx: &mut Context, + ) -> Task> { + if let LoadStatus::Loaded { .. } = &self.load_status + && self.url == url + { + return Task::ready(Ok(())); + } + + if let Some(key) = &env_var.value + && !key.is_empty() + { + let api_key = ApiKey::from_env(env_var.name.clone(), key); + self.url = url; + self.load_status = LoadStatus::Loaded(api_key); + self.load_task = None; + cx.notify(); + return Task::ready(Ok(())); + } + + let task = if let Some(load_task) = &self.load_task { + load_task.clone() + } else { + let load_task = Self::load(url.clone(), get_this.clone(), cx).shared(); + self.url = url; + self.load_status = LoadStatus::NotPresent; + self.load_task = Some(load_task.clone()); + cx.notify(); + load_task + }; + + cx.spawn(async move |ent, cx| { + task.await; + ent.update(cx, |ent, _cx| { + get_this(ent).load_status.clone().into_authenticate_result() + }) + .ok(); + Ok(()) + }) + } + + fn load( + url: SharedString, + get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + cx: &Context, + ) -> Task<()> { + let credentials_provider = ::global(cx); + cx.spawn({ + async move |ent, cx| { + let load_status = + ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx) + .await; + ent.update(cx, |ent, cx| { + let this = get_this(ent); + this.url = url; + this.load_status = load_status; + this.load_task = None; + cx.notify(); + }) + .ok(); + } + }) + } +} + +impl ApiKey { + pub fn key(&self) -> &str { + &self.key + } + + pub fn from_env(env_var_name: SharedString, key: &str) -> Self { + Self { + source: ApiKeySource::EnvVar(env_var_name), + key: key.into(), + } + } + + pub async fn load_from_system_keychain( + url: &str, + credentials_provider: &dyn CredentialsProvider, + cx: &AsyncApp, + ) -> Result { + Self::load_from_system_keychain_impl(url, credentials_provider, cx) + .await + .into_authenticate_result() + } + + async fn load_from_system_keychain_impl( + url: &str, + credentials_provider: &dyn CredentialsProvider, + cx: &AsyncApp, + ) -> LoadStatus { + if url.is_empty() { + return LoadStatus::NotPresent; + } + let read_result = credentials_provider.read_credentials(&url, cx).await; + let api_key = match read_result { + Ok(Some((_, api_key))) => api_key, + Ok(None) => return LoadStatus::NotPresent, + Err(err) => return LoadStatus::Error(err.to_string()), + }; + let key = match str::from_utf8(&api_key) { + Ok(key) => key, + Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")), + }; + LoadStatus::Loaded(Self { + source: ApiKeySource::SystemKeychain, + key: key.into(), + }) + } +} + +impl LoadStatus { + fn into_authenticate_result(self) -> Result { + match self { + LoadStatus::Loaded(api_key) => Ok(api_key), + LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound), + LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))), + } + } +} + +#[derive(Debug, Clone)] +enum ApiKeySource { + EnvVar(SharedString), + SystemKeychain, +} + +impl Display for ApiKeySource { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var), + ApiKeySource::SystemKeychain => write!(f, "system keychain"), + } + } +} diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 738b72b0c9a6dbb7c9606cc72707b27e66abf09c..61e1a794695310421397469515a43a4d5bf5deb8 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -7,6 +7,7 @@ use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; +mod api_key; pub mod provider; mod settings; pub mod ui; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index ca7763e2c5cda3c07c5cb51389cb3173a55865e2..f9c664182426dfd9523254867ba2f95b8f3dc7c4 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,18 +1,15 @@ -use crate::AllLanguageModelSettings; use crate::ui::InstructionListItem; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState}; use anthropic::{ AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent, ToolResultPart, Usage, }; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, -}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, FontStyle, Task, TextStyle, WhiteSpace}; use http_client::HttpClient; use language_model::{ AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, @@ -27,11 +24,12 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME; @@ -97,91 +95,52 @@ pub struct AnthropicLanguageModelProvider { state: gpui::Entity, } -const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY"; +const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } impl State { - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .ok(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) - } - - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .ok(); - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let key = AnthropicLanguageModelProvider::api_key(cx); - - cx.spawn(async move |this, cx| { - let key = key.await?; - - this.update(cx, |this, cx| { - this.api_key = Some(key.key); - this.api_key_from_env = key.from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } -} -pub struct ApiKey { - pub key: String, - pub from_env: bool, + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) + } } impl AnthropicLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -197,30 +156,16 @@ impl AnthropicLanguageModelProvider { }) } - pub fn api_key(cx: &mut App) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - - if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) { - Task::ready(Ok(ApiKey { - key, - from_env: true, - })) + fn settings(cx: &App) -> &AnthropicSettings { + &AllLanguageModelSettings::get_global(cx).anthropic + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + anthropic::ANTHROPIC_API_URL.into() } else { - cx.spawn(async move |cx| { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - - Ok(ApiKey { - key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - from_env: false, - }) - }) + SharedString::new(api_url.as_str()) } } } @@ -327,7 +272,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -417,11 +363,16 @@ impl AnthropicModel { > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).anthropic; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err.into())).boxed(); + } }; let beta_headers = self.model.beta_headers(); @@ -483,7 +434,10 @@ impl LanguageModel for AnthropicModel { } fn api_key(&self, cx: &App) -> Option { - self.state.read(cx).api_key.clone() + self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + state.api_key_state.key(&api_url).map(|key| key.to_string()) + }) } fn max_token_count(&self) -> u64 { @@ -987,12 +941,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -1001,11 +953,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -1040,7 +992,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -1079,7 +1031,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small) .color(Color::Muted), @@ -1099,7 +1051,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -1112,7 +1064,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 82bf067cd475fe031630767da9e4302afa4d78ec..e00d8bbf4b34560bdc55c982114bf96675e22d99 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -1,12 +1,11 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{ - AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle, - WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -21,16 +20,19 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use theme::ThemeSettings; use ui::{Icon, IconName, List, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); -const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; + +const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default)] struct RawToolCall { @@ -59,95 +61,48 @@ pub struct DeepSeekLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = SharedString::new(DeepSeekLanguageModelProvider::api_url(cx)); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = SharedString::new(DeepSeekLanguageModelProvider::api_url(cx)); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl DeepSeekLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = SharedString::new(Self::api_url(cx)); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(SharedString::new(Self::api_url(cx))), + } }); Self { http_client, state } @@ -160,7 +115,15 @@ impl DeepSeekLanguageModelProvider { state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) + } + + fn settings(cx: &App) -> &DeepSeekSettings { + &AllLanguageModelSettings::get_global(cx).deepseek + } + + fn api_url(cx: &App) -> &str { + &Self::settings(cx).api_url } } @@ -199,11 +162,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { models.insert("deepseek-chat", deepseek::Model::Chat); models.insert("deepseek-reasoner", deepseek::Model::Reasoner); - for available_model in AllLanguageModelSettings::get_global(cx) - .deepseek - .available_models - .iter() - { + for available_model in &Self::settings(cx).available_models { models.insert( &available_model.name, deepseek::Model::Custom { @@ -240,7 +199,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -259,15 +219,25 @@ impl DeepSeekLanguageModel { cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).deepseek; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = DeepSeekLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(api_url); + (api_key, api_url.to_string()) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing DeepSeek API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -610,7 +580,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -618,12 +588,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn(async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -631,10 +599,12 @@ impl ConfigurationView { .update(cx, |editor, cx| editor.set_text("", window, cx)); let state = self.state.clone(); - cx.spawn(async move |_, cx| state.update(cx, |state, cx| state.reset_api_key(cx))?.await) - .detach_and_log_err(cx); - - cx.notify(); + cx.spawn(async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -672,7 +642,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -706,8 +676,7 @@ impl Render for ConfigurationView { ) .child( Label::new(format!( - "Or set the {} environment variable.", - DEEPSEEK_API_KEY_VAR + "Or set the {API_KEY_ENV_VAR_NAME} environment variable." )) .size(LabelSize::Small) .color(Color::Muted), @@ -727,7 +696,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {}", DEEPSEEK_API_KEY_VAR) + format!("API key set in {API_KEY_ENV_VAR_NAME}") } else { "API key configured".to_string() })), diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 939cf0ca60d92d713b90a5d62e8ec7f6dac7ec46..677d5775b3854245ad441acf37b0e1ed02f6bab5 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -1,4 +1,4 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; @@ -8,7 +8,8 @@ use google_ai::{ ThinkingConfig, UsageMetadata, }; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -26,18 +27,18 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::sync::{ - Arc, + Arc, LazyLock, atomic::{self, AtomicU64}, }; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::EnvVar; -use crate::AllLanguageModelSettings; +use crate::api_key::ApiKey; use crate::ui::InstructionListItem; - -use super::anthropic::ApiKey; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState}; const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; @@ -91,101 +92,56 @@ pub struct GoogleLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY"; -const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY"; +const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY"; +const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY"; + +static API_KEY_ENV_VAR: LazyLock = LazyLock::new(|| { + // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY + EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into())) +}); impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) { - (api_key, true) - } else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl GoogleLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -201,30 +157,28 @@ impl GoogleLanguageModelProvider { }) } - pub fn api_key(cx: &mut App) -> Task> { + pub fn api_key_for_gemini_cli(cx: &mut App) -> Task> { + if let Some(key) = API_KEY_ENV_VAR.value.clone() { + return Task::ready(Ok(key)); + } let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - - if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) { - Task::ready(Ok(ApiKey { - key, - from_env: true, - })) - } else { - cx.spawn(async move |cx| { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) + let api_url = Self::api_url(cx).to_string(); + cx.spawn(async move |cx| { + Ok( + ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx) .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; + .key() + .to_string(), + ) + }) + } - Ok(ApiKey { - key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - from_env: false, - }) - }) + fn api_url(cx: &App) -> SharedString { + let api_url = &AllLanguageModelSettings::get_global(cx).google.api_url; + if api_url.is_empty() { + google_ai::API_URL.into() + } else { + SharedString::new(api_url.as_str()) } } } @@ -317,7 +271,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -340,11 +295,16 @@ impl GoogleLanguageModel { > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).google; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = GoogleLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; async move { @@ -418,13 +378,16 @@ impl LanguageModel for GoogleLanguageModel { let model_id = self.model.request_id().to_string(); let request = into_google(request, model_id, self.model.mode()); let http_client = self.http_client.clone(); - let api_key = self.state.read(cx).api_key.clone(); - - let settings = &AllLanguageModelSettings::get_global(cx).google; - let api_url = settings.api_url.clone(); + let api_url = GoogleLanguageModelProvider::api_url(cx); + let api_key = self.state.read(cx).api_key_state.key(&api_url); async move { - let api_key = api_key.context("Missing Google API key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + } + .into()); + }; let response = google_ai::count_tokens( http_client.as_ref(), &api_url, @@ -852,7 +815,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -860,12 +823,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -874,11 +835,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -913,7 +874,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -950,7 +911,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -969,7 +930,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {GEMINI_API_KEY_VAR} environment variable.") + format!("API key set in {GEMINI_API_KEY_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -982,7 +943,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset."))) + this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index c9824bf89ea7a919f4517f492a5091a2cda7b43b..d2de056489d5fe10bab5eaa41780c0a5f312a181 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -1,10 +1,10 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -21,17 +21,21 @@ use settings::{Settings, SettingsStore}; use std::collections::HashMap; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); +const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct MistralSettings { pub api_url: String, @@ -56,96 +60,48 @@ pub struct MistralLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = MistralLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = MistralLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl MistralLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -160,6 +116,19 @@ impl MistralLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &MistralSettings { + &crate::AllLanguageModelSettings::get_global(cx).mistral + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + mistral::MISTRAL_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for MistralLanguageModelProvider { @@ -202,10 +171,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } // Override with available models from settings - for model in &AllLanguageModelSettings::get_global(cx) - .mistral - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), mistral::Model::Custom { @@ -254,7 +220,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -276,15 +243,25 @@ impl MistralLanguageModel { Result>>, > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).mistral; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = MistralLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing Mistral API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -780,7 +757,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -788,12 +765,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -802,11 +777,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -841,7 +816,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -878,7 +853,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -897,7 +872,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -910,7 +885,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index fca1cf977cb5e3b32dc6f2335fb0d9188979bc9f..a7e8f64d6614d95e4e892aba569d441320108ca4 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,10 +1,8 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; - use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -20,18 +18,21 @@ use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; - use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; +const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiSettings { pub api_url: String, @@ -54,132 +55,48 @@ pub struct OpenAiLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - last_api_url: String, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) - } - - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn get_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - self.get_api_key(cx) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl OpenAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let initial_api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - last_api_url: initial_api_url.clone(), - _subscription: cx.observe_global::(|this: &mut State, cx| { - let current_api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - if this.last_api_url != current_api_url { - this.last_api_url = current_api_url; - if !this.api_key_from_env { - this.api_key = None; - let spawn_task = cx.spawn(async move |handle, cx| { - if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { - if let Err(_) = task.await { - handle - .update(cx, |this, _| { - this.api_key = None; - this.api_key_from_env = false; - }) - .ok(); - } - } - }); - spawn_task.detach(); - } - } + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -194,6 +111,19 @@ impl OpenAiLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &OpenAiSettings { + &AllLanguageModelSettings::get_global(cx).openai + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + open_ai::OPEN_AI_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for OpenAiLanguageModelProvider { @@ -278,7 +208,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -298,11 +229,17 @@ impl OpenAiLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).openai; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { @@ -802,29 +739,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -836,11 +762,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -850,7 +776,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -872,10 +798,11 @@ impl Render for ConfigurationView { ) .child(self.api_key_editor.clone()) .child( - Label::new( - format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."), - ) - .size(LabelSize::Small).color(Color::Muted), + Label::new(format!( + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), ) .child( Label::new( @@ -898,7 +825,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {OPENAI_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -911,7 +838,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 4ebb11a07b66ec7054ca65437ec887a415fa3f5c..2b5e372a1d6ccd9b3bb13feb777f1091f0336b58 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -1,9 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; -use credentials_provider::CredentialsProvider; - +use anyhow::Result; use convert_case::{Case, Casing}; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -17,13 +15,13 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::sync::Arc; - use ui::{ElevationIndex, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use zed_env_vars::EnvVar; -use crate::AllLanguageModelSettings; use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState}; #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiCompatibleSettings { @@ -70,82 +68,30 @@ pub struct OpenAiCompatibleLanguageModelProvider { pub struct State { id: Arc, - env_var_name: Arc, - api_key: Option, - api_key_from_env: bool, + api_key_env_var: EnvVar, + api_key_state: ApiKeyState, settings: OpenAiCompatibleSettings, - _subscription: Subscription, } impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) - } - - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn get_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let env_var_name = self.env_var_name.clone(); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = SharedString::new(self.settings.api_url.as_str()); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - self.get_api_key(cx) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = SharedString::new(self.settings.api_url.clone()); + self.api_key_state.load_if_needed( + api_url, + &self.api_key_env_var, + |this| &mut this.api_key_state, + cx, + ) } } @@ -157,37 +103,32 @@ impl OpenAiCompatibleLanguageModelProvider { .get(id) } - let state = cx.new(|cx| State { - id: id.clone(), - env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(), - settings: resolve_settings(&id, cx).cloned().unwrap_or_default(), - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|this: &mut State, cx| { + let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into(); + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { let Some(settings) = resolve_settings(&this.id, cx).cloned() else { return; }; if &this.settings != &settings { - if settings.api_url != this.settings.api_url && !this.api_key_from_env { - let spawn_task = cx.spawn(async move |handle, cx| { - if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { - if let Err(_) = task.await { - handle - .update(cx, |this, _| { - this.api_key = None; - this.api_key_from_env = false; - }) - .ok(); - } - } - }); - spawn_task.detach(); - } - + let api_url = SharedString::new(settings.api_url.as_str()); + this.api_key_state.handle_url_change( + api_url, + &this.api_key_env_var, + |this| &mut this.api_key_state, + cx, + ); this.settings = settings; cx.notify(); } - }), + }) + .detach(); + let settings = resolve_settings(&id, cx).cloned().unwrap_or_default(); + State { + id: id.clone(), + api_key_env_var: EnvVar::new(api_key_env_var_name), + api_key_state: ApiKeyState::new(SharedString::new(settings.api_url.as_str())), + settings, + } }); Self { @@ -274,7 +215,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -296,10 +238,17 @@ impl OpenAiCompatibleLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| { - (state.api_key.clone(), state.settings.api_url.clone()) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, _cx| { + let api_url = &state.settings.api_url; + let api_key = state.api_key_state.key(api_url); + (api_key, state.settings.api_url.clone()) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let provider = self.provider_name.clone(); @@ -469,29 +418,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -503,22 +441,23 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } - fn should_render_editor(&self, cx: &mut Context) -> bool { + fn should_render_editor(&self, cx: &Context) -> bool { !self.state.read(cx).is_authenticated() } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; - let env_var_name = self.state.read(cx).env_var_name.clone(); + let state = self.state.read(cx); + let env_var_set = state.api_key_state.is_from_env_var(); + let env_var_name = &state.api_key_env_var.name; let api_key_section = if self.should_render_editor(cx) { v_flex() diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 0ebf379b393791558cba6ff36ab31d278162386e..efa6a1eabe33fe8ecb52da3b5c61122cfc26e036 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -1,10 +1,9 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::HashMap; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, }; use http_client::HttpClient; use language_model::{ @@ -16,23 +15,26 @@ use language_model::{ }; use open_router::{ Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models, - stream_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; use util::ResultExt; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); +const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenRouterSettings { pub api_url: String, @@ -90,93 +92,38 @@ pub struct OpenRouterLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, + api_key_state: ApiKeyState, http_client: Arc, available_models: Vec, fetch_models_task: Option>>, settings: OpenRouterSettings, - _subscription: Subscription, } -const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.restart_fetch_models_task(cx); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let task = self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - this.restart_fetch_models_task(cx); - cx.notify(); - })?; - - Ok(()) + let result = task.await; + this.update(cx, |this, cx| this.restart_fetch_models_task(cx)) + .ok(); + result }) } @@ -184,10 +131,9 @@ impl State { &mut self, cx: &mut Context, ) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).open_router; let http_client = self.http_client.clone(); - let api_url = settings.api_url.clone(); - let Some(api_key) = self.api_key.clone() else { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let Some(api_key) = self.api_key_state.key(&api_url) else { return Task::ready(Err(LanguageModelCompletionError::NoApiKey { provider: PROVIDER_NAME, })); @@ -216,33 +162,45 @@ impl State { if self.is_authenticated() { let task = self.fetch_models(cx); self.fetch_models_task.replace(task); + } else { + self.available_models = Vec::new(); } } } impl OpenRouterLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - http_client: http_client.clone(), - available_models: Vec::new(), - fetch_models_task: None, - settings: OpenRouterSettings::default(), - _subscription: cx.observe_global::(|this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { let current_settings = &AllLanguageModelSettings::get_global(cx).open_router; let settings_changed = current_settings != &this.settings; if settings_changed { this.settings = current_settings.clone(); - this.restart_fetch_models_task(cx); + this.authenticate(cx).detach(); } cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + http_client: http_client.clone(), + available_models: Vec::new(), + fetch_models_task: None, + settings: OpenRouterSettings::default(), + } }); Self { http_client, state } } + fn settings(cx: &App) -> &OpenRouterSettings { + &AllLanguageModelSettings::get_global(cx).open_router + } + + fn api_url(cx: &App) -> SharedString { + SharedString::new(Self::settings(cx).api_url.as_str()) + } + fn create_language_model(&self, model: open_router::Model) -> Arc { Arc::new(OpenRouterLanguageModel { id: LanguageModelId::from(model.id().to_string()), @@ -287,10 +245,7 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { let mut models_from_api = self.state.read(cx).available_models.clone(); let mut settings_models = Vec::new(); - for model in &AllLanguageModelSettings::get_global(cx) - .open_router - .available_models - { + for model in &Self::settings(cx).available_models { settings_models.push(open_router::Model { name: model.name.clone(), display_name: model.display_name.clone(), @@ -338,7 +293,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -366,14 +322,17 @@ impl OpenRouterLanguageModel { >, > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).open_router; - (state.api_key.clone(), settings.api_url.clone()) - }) else { - return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!( - "App state dropped" - )))) - .boxed(); + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(LanguageModelCompletionError::Other(err))) + .boxed(); + } }; async move { @@ -382,7 +341,8 @@ impl OpenRouterLanguageModel { provider: PROVIDER_NAME, }); }; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = + open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request); request.await.map_err(Into::into) } .boxed() @@ -830,7 +790,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -838,12 +798,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -852,11 +810,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -891,7 +849,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -928,7 +886,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -947,7 +905,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -960,7 +918,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 84f3175d1e5493fd55cafd2ea9c4a0604d2a97b4..ad28946e3bc78078d9d6510fafacbba34c1f98ca 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -1,8 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,24 +9,26 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, }; -use menu; use open_ai::ResponseStreamEvent; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; -use vercel::Model; - use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use vercel::Model; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel"); +const API_KEY_ENV_VAR_NAME: &str = "VERCEL_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct VercelSettings { pub api_url: String, @@ -49,103 +50,48 @@ pub struct VercelLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = VercelLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = VercelLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl VercelLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -160,6 +106,19 @@ impl VercelLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &VercelSettings { + &crate::AllLanguageModelSettings::get_global(cx).vercel + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + vercel::VERCEL_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for VercelLanguageModelProvider { @@ -200,10 +159,7 @@ impl LanguageModelProvider for VercelLanguageModelProvider { } } - for model in &AllLanguageModelSettings::get_global(cx) - .vercel - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), vercel::Model::Custom { @@ -241,7 +197,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -261,16 +218,17 @@ impl VercelLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - (state.api_key.clone(), api_url) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = VercelLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { @@ -466,29 +424,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -500,11 +447,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -514,7 +461,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -534,7 +481,7 @@ impl Render for ConfigurationView { .child(self.api_key_editor.clone()) .child( Label::new(format!( - "You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed." + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." )) .size(LabelSize::Small) .color(Color::Muted), @@ -559,7 +506,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {VERCEL_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -572,7 +519,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index bb17f22c7f3fdbb0296b1e0bb290fbce9a979ddf..d5c5848293da5170b609bfd1526c26429865184c 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -1,8 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,23 +9,25 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role, }; -use menu; use open_ai::ResponseStreamEvent; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; -use x_ai::Model; - use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; use util::ResultExt; +use x_ai::{Model, XAI_API_URL}; +use zed_env_vars::{EnvVar, env_var}; + +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); -const PROVIDER_ID: &str = "x_ai"; -const PROVIDER_NAME: &str = "xAI"; +const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default, Clone, Debug, PartialEq)] pub struct XAiSettings { @@ -49,103 +50,48 @@ pub struct XAiLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const XAI_API_KEY_VAR: &str = "XAI_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = XAiLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = XAiLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl XAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -160,6 +106,19 @@ impl XAiLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &XAiSettings { + &crate::AllLanguageModelSettings::get_global(cx).x_ai + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + XAI_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for XAiLanguageModelProvider { @@ -172,11 +131,11 @@ impl LanguageModelProviderState for XAiLanguageModelProvider { impl LanguageModelProvider for XAiLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -200,10 +159,7 @@ impl LanguageModelProvider for XAiLanguageModelProvider { } } - for model in &AllLanguageModelSettings::get_global(cx) - .x_ai - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), x_ai::Model::Custom { @@ -241,7 +197,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -261,20 +218,25 @@ impl XAiLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - (state.api_key.clone(), api_url) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + + let api_key_and_url = self.state.read_with(cx, |state, cx| { + let api_url = XAiLanguageModelProvider::api_url(cx); + let api_key = state.api_key_state.key(&api_url); + (api_key, api_url) + }); + let (api_key, api_url) = match api_key_and_url { + Ok(api_key_and_url) => api_key_and_url, + Err(err) => { + return futures::future::ready(Err(err)).boxed(); + } }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing xAI API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -295,11 +257,11 @@ impl LanguageModel for XAiLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -456,29 +418,18 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -490,11 +441,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -504,7 +455,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -524,7 +475,7 @@ impl Render for ConfigurationView { .child(self.api_key_editor.clone()) .child( Label::new(format!( - "You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed." + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." )) .size(LabelSize::Small) .color(Color::Muted), @@ -549,7 +500,7 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {XAI_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") } else { "API key configured.".to_string() })), @@ -562,7 +513,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/zed_env_vars/Cargo.toml b/crates/zed_env_vars/Cargo.toml index 9abfc410e7e74774c4e9e7608e8c1c3824ebc3c1..f56e3dd529cc7a8001d0021e96902f55034f88e2 100644 --- a/crates/zed_env_vars/Cargo.toml +++ b/crates/zed_env_vars/Cargo.toml @@ -16,3 +16,4 @@ default = [] [dependencies] workspace-hack.workspace = true +gpui.workspace = true diff --git a/crates/zed_env_vars/src/zed_env_vars.rs b/crates/zed_env_vars/src/zed_env_vars.rs index d1679a0518f2bae857364b0035b6184350ffca55..53b9c22bb207e81831d1d9ae6087d1a297331d3f 100644 --- a/crates/zed_env_vars/src/zed_env_vars.rs +++ b/crates/zed_env_vars/src/zed_env_vars.rs @@ -1,6 +1,44 @@ +use gpui::SharedString; use std::sync::LazyLock; /// Whether Zed is running in stateless mode. /// When true, Zed will use in-memory databases instead of persistent storage. -pub static ZED_STATELESS: LazyLock = - LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty())); +pub static ZED_STATELESS: LazyLock = bool_env_var!("ZED_STATELESS"); + +pub struct EnvVar { + pub name: SharedString, + /// Value of the environment variable. Also `None` when set to an empty string. + pub value: Option, +} + +impl EnvVar { + pub fn new(name: SharedString) -> Self { + let value = std::env::var(name.as_str()).ok(); + if value.as_ref().is_some_and(|v| v.is_empty()) { + Self { name, value: None } + } else { + Self { name, value } + } + } + + pub fn or(self, other: EnvVar) -> EnvVar { + if self.value.is_some() { self } else { other } + } +} + +/// Creates a `LazyLock` expression for use in a `static` declaration. +#[macro_export] +macro_rules! env_var { + ($name:expr) => { + LazyLock::new(|| $crate::EnvVar::new(($name).into())) + }; +} + +/// Generates a `LazyLock` expression for use in a `static` declaration. Checks if the +/// environment variable exists and is non-empty. +#[macro_export] +macro_rules! bool_env_var { + ($name:expr) => { + LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some()) + }; +}