acp: Handle Gemini Auth Better (#36631)

Conrad Irwin and Danilo Leal created

Release Notes:

- N/A

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>

Change summary

crates/agent_servers/src/gemini.rs            |   7 
crates/agent_ui/src/acp/thread_view.rs        | 154 +++++++++++++++++++-
crates/language_models/src/provider/google.rs |  53 ++++++
3 files changed, 195 insertions(+), 19 deletions(-)

Detailed changes

crates/agent_servers/src/gemini.rs πŸ”—

@@ -5,6 +5,7 @@ use crate::{AgentServer, AgentServerCommand};
 use acp_thread::{AgentConnection, LoadError};
 use anyhow::Result;
 use gpui::{Entity, Task};
+use language_models::provider::google::GoogleLanguageModelProvider;
 use project::Project;
 use settings::SettingsStore;
 use ui::App;
@@ -47,7 +48,7 @@ impl AgentServer for Gemini {
                 settings.get::<AllAgentServersSettings>(None).gemini.clone()
             })?;
 
-            let Some(command) =
+            let Some(mut command) =
                 AgentServerCommand::resolve("gemini", &[ACP_ARG], None, settings, &project, cx).await
             else {
                 return Err(LoadError::NotInstalled {
@@ -57,6 +58,10 @@ impl AgentServer for Gemini {
                 }.into());
             };
 
+            if let Some(api_key)= cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() {
+                command.env.get_or_insert_default().insert("GEMINI_API_KEY".to_owned(), api_key.key);
+            }
+
             let result = crate::acp::connect(server_name, command.clone(), &root_dir, cx).await;
             if result.is_err() {
                 let version_fut = util::command::new_smol_command(&command.path)

crates/agent_ui/src/acp/thread_view.rs πŸ”—

@@ -278,6 +278,7 @@ enum ThreadState {
         connection: Rc<dyn AgentConnection>,
         description: Option<Entity<Markdown>>,
         configuration_view: Option<AnyView>,
+        pending_auth_method: Option<acp::AuthMethodId>,
         _subscription: Option<Subscription>,
     },
 }
@@ -563,6 +564,7 @@ impl AcpThreadView {
 
         this.update(cx, |this, cx| {
             this.thread_state = ThreadState::Unauthenticated {
+                pending_auth_method: None,
                 connection,
                 configuration_view,
                 description: err
@@ -999,12 +1001,74 @@ impl AcpThreadView {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
+        let ThreadState::Unauthenticated {
+            connection,
+            pending_auth_method,
+            configuration_view,
+            ..
+        } = &mut self.thread_state
+        else {
             return;
         };
 
+        if method.0.as_ref() == "gemini-api-key" {
+            let registry = LanguageModelRegistry::global(cx);
+            let provider = registry
+                .read(cx)
+                .provider(&language_model::GOOGLE_PROVIDER_ID)
+                .unwrap();
+            if !provider.is_authenticated(cx) {
+                let this = cx.weak_entity();
+                let agent = self.agent.clone();
+                let connection = connection.clone();
+                window.defer(cx, |window, cx| {
+                    Self::handle_auth_required(
+                        this,
+                        AuthRequired {
+                            description: Some("GEMINI_API_KEY must be set".to_owned()),
+                            provider_id: Some(language_model::GOOGLE_PROVIDER_ID),
+                        },
+                        agent,
+                        connection,
+                        window,
+                        cx,
+                    );
+                });
+                return;
+            }
+        } else if method.0.as_ref() == "vertex-ai"
+            && std::env::var("GOOGLE_API_KEY").is_err()
+            && (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()
+                || (std::env::var("GOOGLE_CLOUD_PROJECT").is_err()))
+        {
+            let this = cx.weak_entity();
+            let agent = self.agent.clone();
+            let connection = connection.clone();
+
+            window.defer(cx, |window, cx| {
+                    Self::handle_auth_required(
+                        this,
+                        AuthRequired {
+                            description: Some(
+                                "GOOGLE_API_KEY must be set in the environment to use Vertex AI authentication for Gemini CLI. Please export it and restart Zed."
+                                    .to_owned(),
+                            ),
+                            provider_id: None,
+                        },
+                        agent,
+                        connection,
+                        window,
+                        cx,
+                    )
+                });
+            return;
+        }
+
         self.thread_error.take();
+        configuration_view.take();
+        pending_auth_method.replace(method.clone());
         let authenticate = connection.authenticate(method, cx);
+        cx.notify();
         self.auth_task = Some(cx.spawn_in(window, {
             let project = self.project.clone();
             let agent = self.agent.clone();
@@ -2425,6 +2489,7 @@ impl AcpThreadView {
         connection: &Rc<dyn AgentConnection>,
         description: Option<&Entity<Markdown>>,
         configuration_view: Option<&AnyView>,
+        pending_auth_method: Option<&acp::AuthMethodId>,
         window: &mut Window,
         cx: &Context<Self>,
     ) -> Div {
@@ -2456,17 +2521,80 @@ impl AcpThreadView {
                     .cloned()
                     .map(|view| div().px_4().w_full().max_w_128().child(view)),
             )
-            .child(h_flex().mt_1p5().justify_center().children(
-                connection.auth_methods().iter().map(|method| {
-                    Button::new(SharedString::from(method.id.0.clone()), method.name.clone())
-                        .on_click({
-                            let method_id = method.id.clone();
-                            cx.listener(move |this, _, window, cx| {
-                                this.authenticate(method_id.clone(), window, cx)
+            .when(
+                configuration_view.is_none()
+                    && description.is_none()
+                    && pending_auth_method.is_none(),
+                |el| {
+                    el.child(
+                        div()
+                            .text_ui(cx)
+                            .text_center()
+                            .px_4()
+                            .w_full()
+                            .max_w_128()
+                            .child(Label::new("Authentication required")),
+                    )
+                },
+            )
+            .when_some(pending_auth_method, |el, _| {
+                let spinner_icon = div()
+                    .px_0p5()
+                    .id("generating")
+                    .tooltip(Tooltip::text("Generating Changes…"))
+                    .child(
+                        Icon::new(IconName::ArrowCircle)
+                            .size(IconSize::Small)
+                            .with_animation(
+                                "arrow-circle",
+                                Animation::new(Duration::from_secs(2)).repeat(),
+                                |icon, delta| {
+                                    icon.transform(Transformation::rotate(percentage(delta)))
+                                },
+                            )
+                            .into_any_element(),
+                    )
+                    .into_any();
+                el.child(
+                    h_flex()
+                        .text_ui(cx)
+                        .text_center()
+                        .justify_center()
+                        .gap_2()
+                        .px_4()
+                        .w_full()
+                        .max_w_128()
+                        .child(Label::new("Authenticating..."))
+                        .child(spinner_icon),
+                )
+            })
+            .child(
+                h_flex()
+                    .mt_1p5()
+                    .gap_1()
+                    .flex_wrap()
+                    .justify_center()
+                    .children(connection.auth_methods().iter().enumerate().rev().map(
+                        |(ix, method)| {
+                            Button::new(
+                                SharedString::from(method.id.0.clone()),
+                                method.name.clone(),
+                            )
+                            .style(ButtonStyle::Outlined)
+                            .when(ix == 0, |el| {
+                                el.style(ButtonStyle::Tinted(ui::TintColor::Accent))
                             })
-                        })
-                }),
-            ))
+                            .size(ButtonSize::Medium)
+                            .label_size(LabelSize::Small)
+                            .on_click({
+                                let method_id = method.id.clone();
+                                cx.listener(move |this, _, window, cx| {
+                                    this.authenticate(method_id.clone(), window, cx)
+                                })
+                            })
+                        },
+                    )),
+            )
     }
 
     fn render_load_error(&self, e: &LoadError, cx: &Context<Self>) -> AnyElement {
@@ -2551,6 +2679,8 @@ impl AcpThreadView {
             let install_command = install_command.clone();
             container = container.child(
                 Button::new("install", install_message)
+                    .style(ButtonStyle::Tinted(ui::TintColor::Accent))
+                    .size(ButtonSize::Medium)
                     .tooltip(Tooltip::text(install_command.clone()))
                     .on_click(cx.listener(move |this, _, window, cx| {
                         let task = this
@@ -4372,11 +4502,13 @@ impl Render for AcpThreadView {
                     connection,
                     description,
                     configuration_view,
+                    pending_auth_method,
                     ..
                 } => self.render_auth_required_state(
                     connection,
                     description.as_ref(),
                     configuration_view.as_ref(),
+                    pending_auth_method.as_ref(),
                     window,
                     cx,
                 ),

crates/language_models/src/provider/google.rs πŸ”—

@@ -12,9 +12,9 @@ use gpui::{
 };
 use http_client::HttpClient;
 use language_model::{
-    AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
-    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LanguageModelToolUse,
-    LanguageModelToolUseId, MessageContent, StopReason,
+    AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
+    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
+    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
 };
 use language_model::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -37,6 +37,8 @@ use util::ResultExt;
 use crate::AllLanguageModelSettings;
 use crate::ui::InstructionListItem;
 
+use super::anthropic::ApiKey;
+
 const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
 const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
 
@@ -198,6 +200,33 @@ impl GoogleLanguageModelProvider {
             request_limiter: RateLimiter::new(4),
         })
     }
+
+    pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
+        let credentials_provider = <dyn CredentialsProvider>::global(cx);
+        let api_url = AllLanguageModelSettings::get_global(cx)
+            .google
+            .api_url
+            .clone();
+
+        if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) {
+            Task::ready(Ok(ApiKey {
+                key,
+                from_env: true,
+            }))
+        } else {
+            cx.spawn(async move |cx| {
+                let (_, api_key) = credentials_provider
+                    .read_credentials(&api_url, cx)
+                    .await?
+                    .ok_or(AuthenticateError::CredentialsNotFound)?;
+
+                Ok(ApiKey {
+                    key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
+                    from_env: false,
+                })
+            })
+        }
+    }
 }
 
 impl LanguageModelProviderState for GoogleLanguageModelProvider {
@@ -279,11 +308,11 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
 
     fn configuration_view(
         &self,
-        _target_agent: language_model::ConfigurationViewTargetAgent,
+        target_agent: language_model::ConfigurationViewTargetAgent,
         window: &mut Window,
         cx: &mut App,
     ) -> AnyView {
-        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
+        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
             .into()
     }
 
@@ -776,11 +805,17 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
 struct ConfigurationView {
     api_key_editor: Entity<Editor>,
     state: gpui::Entity<State>,
+    target_agent: language_model::ConfigurationViewTargetAgent,
     load_credentials_task: Option<Task<()>>,
 }
 
 impl ConfigurationView {
-    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
+    fn new(
+        state: gpui::Entity<State>,
+        target_agent: language_model::ConfigurationViewTargetAgent,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Self {
         cx.observe(&state, |_, _, cx| {
             cx.notify();
         })
@@ -810,6 +845,7 @@ impl ConfigurationView {
                 editor.set_placeholder_text("AIzaSy...", cx);
                 editor
             }),
+            target_agent,
             state,
             load_credentials_task,
         }
@@ -885,7 +921,10 @@ impl Render for ConfigurationView {
             v_flex()
                 .size_full()
                 .on_action(cx.listener(Self::save_api_key))
-                .child(Label::new("To use Zed's agent with Google AI, you need to add an API key. Follow these steps:"))
+                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match self.target_agent {
+                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI",
+                    ConfigurationViewTargetAgent::Other(agent) => agent,
+                })))
                 .child(
                     List::new()
                         .child(InstructionListItem::new(