@@ -5,6 +5,7 @@ use crate::{AgentServer, AgentServerCommand};
use acp_thread::{AgentConnection, LoadError};
use anyhow::Result;
use gpui::{Entity, Task};
+use language_models::provider::google::GoogleLanguageModelProvider;
use project::Project;
use settings::SettingsStore;
use ui::App;
@@ -47,7 +48,7 @@ impl AgentServer for Gemini {
settings.get::<AllAgentServersSettings>(None).gemini.clone()
})?;
- let Some(command) =
+ let Some(mut command) =
AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
else {
return Err(LoadError::NotInstalled {
@@ -57,6 +58,10 @@ impl AgentServer for Gemini {
}.into());
};
+ if let Some(api_key)= cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() {
+ command.env.get_or_insert_default().insert("GEMINI_API_KEY".to_owned(), api_key.key);
+ }
+
let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
if result.is_err() {
let version_fut = util::command::new_smol_command(&command.path)
@@ -278,6 +278,7 @@ enum ThreadState {
connection: Rc<dyn AgentConnection>,
description: Option<Entity<Markdown>>,
configuration_view: Option<AnyView>,
+ pending_auth_method: Option<acp::AuthMethodId>,
_subscription: Option<Subscription>,
},
}
@@ -563,6 +564,7 @@ impl AcpThreadView {
this.update(cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated {
+ pending_auth_method: None,
connection,
configuration_view,
description: err
@@ -999,12 +1001,74 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
+ let ThreadState::Unauthenticated {
+ connection,
+ pending_auth_method,
+ configuration_view,
+ ..
+ } = &mut self.thread_state
+ else {
return;
};
+ if method.0.as_ref() == "gemini-api-key" {
+ let registry = LanguageModelRegistry::global(cx);
+ let provider = registry
+ .read(cx)
+ .provider(&language_model::GOOGLE_PROVIDER_ID)
+ .unwrap();
+ if !provider.is_authenticated(cx) {
+ let this = cx.weak_entity();
+ let agent = self.agent.clone();
+ let connection = connection.clone();
+ window.defer(cx, |window, cx| {
+ Self::handle_auth_required(
+ this,
+ AuthRequired {
+ description: Some("GEMINI_API_KEY must be set".to_owned()),
+ provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
+ },
+ agent,
+ connection,
+ window,
+ cx,
+ );
+ });
+ return;
+ }
+ } else if method.0.as_ref() == "vertex-ai"
+ && std::env::var("GOOGLE_API_KEY").is_err()
+ && (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()
+ || (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()))
+ {
+ let this = cx.weak_entity();
+ let agent = self.agent.clone();
+ let connection = connection.clone();
+
+ window.defer(cx, |window, cx| {
+ Self::handle_auth_required(
+ this,
+ AuthRequired {
+ description: Some(
+ "GOOGLE_API_KEY must be set in the environment to use Vertex AI authentication for Gemini CLI. Please export it and restart Zed."
+ .to_owned(),
+ ),
+ provider_id: None,
+ },
+ agent,
+ connection,
+ window,
+ cx,
+ )
+ });
+ return;
+ }
+
self.thread_error.take();
+ configuration_view.take();
+ pending_auth_method.replace(method.clone());
let authenticate = connection.authenticate(method, cx);
+ cx.notify();
self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone();
let agent = self.agent.clone();
@@ -2425,6 +2489,7 @@ impl AcpThreadView {
connection: &Rc<dyn AgentConnection>,
description: Option<&Entity<Markdown>>,
configuration_view: Option<&AnyView>,
+ pending_auth_method: Option<&acp::AuthMethodId>,
window: &mut Window,
cx: &Context<Self>,
) -> Div {
@@ -2456,17 +2521,80 @@ impl AcpThreadView {
.cloned()
.map(|view| div().px_4().w_full().max_w_128().child(view)),
)
- .child(h_flex().mt_1p5().justify_center().children(
- connection.auth_methods().iter().map(|method| {
- Button::new(SharedString::from(method.id.0.clone()), method.name.clone())
- .on_click({
- let method_id = method.id.clone();
- cx.listener(move |this, _, window, cx| {
- this.authenticate(method_id.clone(), window, cx)
+ .when(
+ configuration_view.is_none()
+ && description.is_none()
+ && pending_auth_method.is_none(),
+ |el| {
+ el.child(
+ div()
+ .text_ui(cx)
+ .text_center()
+ .px_4()
+ .w_full()
+ .max_w_128()
+ .child(Label::new("Authentication required")),
+ )
+ },
+ )
+ .when_some(pending_auth_method, |el, _| {
+ let spinner_icon = div()
+ .px_0p5()
+ .id("generating")
+ .tooltip(Tooltip::text("Generating Changesβ¦"))
+ .child(
+ Icon::new(IconName::ArrowCircle)
+ .size(IconSize::Small)
+ .with_animation(
+ "arrow-circle",
+ Animation::new(Duration::from_secs(2)).repeat(),
+ |icon, delta| {
+ icon.transform(Transformation::rotate(percentage(delta)))
+ },
+ )
+ .into_any_element(),
+ )
+ .into_any();
+ el.child(
+ h_flex()
+ .text_ui(cx)
+ .text_center()
+ .justify_center()
+ .gap_2()
+ .px_4()
+ .w_full()
+ .max_w_128()
+ .child(Label::new("Authenticating..."))
+ .child(spinner_icon),
+ )
+ })
+ .child(
+ h_flex()
+ .mt_1p5()
+ .gap_1()
+ .flex_wrap()
+ .justify_center()
+ .children(connection.auth_methods().iter().enumerate().rev().map(
+ |(ix, method)| {
+ Button::new(
+ SharedString::from(method.id.0.clone()),
+ method.name.clone(),
+ )
+ .style(ButtonStyle::Outlined)
+ .when(ix == 0, |el| {
+ el.style(ButtonStyle::Tinted(ui::TintColor::Accent))
})
- })
- }),
- ))
+ .size(ButtonSize::Medium)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let method_id = method.id.clone();
+ cx.listener(move |this, _, window, cx| {
+ this.authenticate(method_id.clone(), window, cx)
+ })
+ })
+ },
+ )),
+ )
}
fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
@@ -2551,6 +2679,8 @@ impl AcpThreadView {
let install_command = install_command.clone();
container = container.child(
Button::new("install", install_message)
+ .style(ButtonStyle::Tinted(ui::TintColor::Accent))
+ .size(ButtonSize::Medium)
.tooltip(Tooltip::text(install_command.clone()))
.on_click(cx.listener(move |this, _, window, cx| {
let task = this
@@ -4372,11 +4502,13 @@ impl Render for AcpThreadView {
connection,
description,
configuration_view,
+ pending_auth_method,
..
} => self.render_auth_required_state(
connection,
description.as_ref(),
configuration_view.as_ref(),
+ pending_auth_method.as_ref(),
window,
cx,
),
@@ -12,9 +12,9 @@ use gpui::{
};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse,
- LanguageModelToolUseId, MessageContent, StopReason,
+ AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
+ LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -37,6 +37,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
+use super::anthropic::ApiKey;
+
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
@@ -198,6 +200,33 @@ impl GoogleLanguageModelProvider {
request_limiter: RateLimiter::new(4),
})
}
+
+ pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .google
+ .api_url
+ .clone();
+
+ if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) {
+ Task::ready(Ok(ApiKey {
+ key,
+ from_env: true,
+ }))
+ } else {
+ cx.spawn(async move |cx| {
+ let (_, api_key) = credentials_provider
+ .read_credentials(&api_url, cx)
+ .await?
+ .ok_or(AuthenticateError::CredentialsNotFound)?;
+
+ Ok(ApiKey {
+ key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
+ from_env: false,
+ })
+ })
+ }
+ }
}
impl LanguageModelProviderState for GoogleLanguageModelProvider {
@@ -279,11 +308,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
fn configuration_view(
&self,
- _target_agent: language_model::ConfigurationViewTargetAgent,
+ target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
- cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
.into()
}
@@ -776,11 +805,17 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,
+ target_agent: language_model::ConfigurationViewTargetAgent,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
- fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
+ fn new(
+ state: gpui::Entity<State>,
+ target_agent: language_model::ConfigurationViewTargetAgent,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
cx.observe(&state, |_, _, cx| {
cx.notify();
})
@@ -810,6 +845,7 @@ impl ConfigurationView {
editor.set_placeholder_text("AIzaSy...", cx);
editor
}),
+ target_agent,
state,
load_credentials_task,
}
@@ -885,7 +921,10 @@ impl Render for ConfigurationView {
v_flex()
.size_full()
.on_action(cx.listener(Self::save_api_key))
- .child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:"))
+ .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent {
+ ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI",
+ ConfigurationViewTargetAgent::Other(agent) => agent,
+ })))
.child(
List::new()
.child(InstructionListItem::new(