Use anthropic provider key

Agus Zubiaga created

Change summary

crates/acp_thread/src/connection.rs              | 27 ++++++
crates/agent_servers/src/acp/v0.rs               |  2 
crates/agent_servers/src/acp/v1.rs               |  8 +
crates/agent_servers/src/claude.rs               | 44 ++++++++++
crates/agent_ui/src/acp/thread_view.rs           | 69 ++++++++++++++---
crates/gpui/src/subscription.rs                  |  6 +
crates/language_model/src/language_model.rs      |  3 
crates/language_model/src/registry.rs            |  1 
crates/language_models/src/provider/anthropic.rs | 56 ++++++++------
9 files changed, 171 insertions(+), 45 deletions(-)

Detailed changes

crates/acp_thread/src/connection.rs πŸ”—

@@ -80,12 +80,35 @@ pub trait AgentSessionResume {
 }
 
 #[derive(Debug)]
-pub struct AuthRequired;
+pub struct AuthRequired {
+    pub description: Option<String>,
+    /// A Task that resolves when authentication is updated
+    pub update_task: Option<Task<()>>,
+}
+
+impl AuthRequired {
+    pub fn new() -> Self {
+        Self {
+            description: None,
+            update_task: None,
+        }
+    }
+
+    pub fn with_description(mut self, description: String) -> Self {
+        self.description = Some(description);
+        self
+    }
+
+    pub fn with_update(mut self, update: Task<()>) -> Self {
+        self.update_task = Some(update);
+        self
+    }
+}
 
 impl Error for AuthRequired {}
 impl fmt::Display for AuthRequired {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "AuthRequired")
+        write!(f, "Authentication required")
     }
 }
 

crates/agent_servers/src/acp/v0.rs πŸ”—

@@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection {
             let result = acp_old::InitializeParams::response_from_any(result)?;
 
             if !result.is_authenticated {
-                anyhow::bail!(AuthRequired)
+                anyhow::bail!(AuthRequired::new())
             }
 
             cx.update(|cx| {

crates/agent_servers/src/acp/v1.rs πŸ”—

@@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection {
                 .await
                 .map_err(|err| {
                     if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
-                        anyhow!(AuthRequired)
+                        let mut error = AuthRequired::new();
+
+                        if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
+                            error = error.with_description(err.message);
+                        }
+
+                        anyhow!(error)
                     } else {
                         anyhow!(err)
                     }

crates/agent_servers/src/claude.rs πŸ”—

@@ -3,6 +3,8 @@ pub mod tools;
 
 use collections::HashMap;
 use context_server::listener::McpServerTool;
+use language_model::LanguageModelRegistry;
+use language_models::provider::anthropic::AnthropicLanguageModelProvider;
 use project::Project;
 use settings::SettingsStore;
 use smol::process::Child;
@@ -11,6 +13,7 @@ use std::cell::RefCell;
 use std::fmt::Display;
 use std::path::Path;
 use std::rc::Rc;
+use std::sync::Arc;
 use uuid::Uuid;
 
 use agent_client_protocol as acp;
@@ -96,12 +99,49 @@ impl AgentConnection for ClaudeAgentConnection {
                 anyhow::bail!("Failed to find claude binary");
             };
 
+            let anthropic: Arc<AnthropicLanguageModelProvider> = cx.update(|cx| {
+                let registry = LanguageModelRegistry::global(cx);
+                let provider: Arc<dyn Any + Send + Sync> = registry
+                    .read(cx)
+                    .provider(&language_model::ANTHROPIC_PROVIDER_ID)
+                    .context("Failed to get Anthropic provider")?;
+
+                Arc::downcast::<AnthropicLanguageModelProvider>(provider)
+                    .map_err(|_| anyhow!("Failed to downcast provider"))
+            })??;
+
             let api_key = cx
-                .update(|cx| language_models::provider::anthropic::ApiKey::get(cx))?
+                .update(|cx| AnthropicLanguageModelProvider::api_key(cx))?
                 .await
                 .map_err(|err| {
                     if err.is::<language_model::AuthenticateError>() {
-                        anyhow!(AuthRequired)
+                        let (update_tx, update_rx) = oneshot::channel();
+                        let mut update_tx = Some(update_tx);
+
+                        let sub = cx
+                            .update(|cx| {
+                                anthropic.observe(
+                                    move |_cx| {
+                                        if let Some(update_tx) = update_tx.take() {
+                                            update_tx.send(()).ok();
+                                        }
+                                    },
+                                    cx,
+                                )
+                            })
+                            .ok();
+
+                        let update_task = cx.foreground_executor().spawn(async move {
+                            update_rx.await.ok();
+                            drop(sub)
+                        });
+
+                        anyhow!(
+                            AuthRequired::new()
+                                .with_description(
+                                    "To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into())
+                                .with_update(update_task)
+                        )
                     } else {
                         anyhow!(err)
                     }

crates/agent_ui/src/acp/thread_view.rs πŸ”—

@@ -137,6 +137,7 @@ enum ThreadState {
     LoadError(LoadError),
     Unauthenticated {
         connection: Rc<dyn AgentConnection>,
+        description: Option<Entity<Markdown>>,
     },
     ServerExited {
         status: ExitStatus,
@@ -269,15 +270,40 @@ impl AcpThreadView {
             let result = match result.await {
                 Err(e) => {
                     let mut cx = cx.clone();
-                    if e.is::<acp_thread::AuthRequired>() {
-                        this.update(&mut cx, |this, cx| {
-                            this.thread_state = ThreadState::Unauthenticated { connection };
-                            cx.notify();
-                        })
-                        .ok();
-                        return;
-                    } else {
-                        Err(e)
+                    match e.downcast::<acp_thread::AuthRequired>() {
+                        Ok(mut err) => {
+                            if let Some(update_task) = err.update_task.take() {
+                                let this = this.clone();
+                                let project = project.clone();
+                                cx.spawn(async move |cx| {
+                                    update_task.await;
+                                    this.update_in(cx, |this, window, cx| {
+                                        this.thread_state = Self::initial_state(
+                                            agent,
+                                            this.workspace.clone(),
+                                            project.clone(),
+                                            window,
+                                            cx,
+                                        );
+                                        cx.notify();
+                                    })
+                                    .ok();
+                                })
+                                .detach();
+                            }
+                            this.update(&mut cx, |this, cx| {
+                                this.thread_state = ThreadState::Unauthenticated {
+                                    connection,
+                                    description: err.description.clone().map(|desc| {
+                                        cx.new(|cx| Markdown::new(desc.into(), None, None, cx))
+                                    }),
+                                };
+                                cx.notify();
+                            })
+                            .ok();
+                            return;
+                        }
+                        Err(err) => Err(err),
                     }
                 }
                 Ok(thread) => Ok(thread),
@@ -369,7 +395,7 @@ impl AcpThreadView {
             ThreadState::Ready { thread, .. } => thread.read(cx).title(),
             ThreadState::Loading { .. } => "Loading…".into(),
             ThreadState::LoadError(_) => "Failed to load".into(),
-            ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
+            ThreadState::Unauthenticated { .. } => "Authentication Required".into(),
             ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
         }
     }
@@ -708,7 +734,7 @@ impl AcpThreadView {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
+        let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
             return;
         };
 
@@ -1851,7 +1877,7 @@ impl AcpThreadView {
                     .mt_4()
                     .mb_1()
                     .justify_center()
-                    .child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)),
+                    .child(Headline::new("Authentication Required").size(HeadlineSize::Medium)),
             )
             .into_any()
     }
@@ -2778,6 +2804,13 @@ impl AcpThreadView {
                     cx.open_url(url.as_str());
                 }
             })
+        } else if url == "zed:///agent/settings" {
+            workspace.update(cx, |workspace, cx| {
+                if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
+                    workspace.focus_panel::<AgentPanel>(window, cx);
+                    panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
+                }
+            });
         } else {
             cx.open_url(&url);
         }
@@ -3347,12 +3380,22 @@ impl Render for AcpThreadView {
             .on_action(cx.listener(Self::toggle_burn_mode))
             .bg(cx.theme().colors().panel_background)
             .child(match &self.thread_state {
-                ThreadState::Unauthenticated { connection } => v_flex()
+                ThreadState::Unauthenticated {
+                    connection,
+                    description,
+                } => v_flex()
                     .p_2()
+                    .gap_2()
                     .flex_1()
                     .items_center()
                     .justify_center()
                     .child(self.render_pending_auth_state())
+                    .text_ui(cx)
+                    .text_center()
+                    .text_color(cx.theme().colors().text_muted)
+                    .children(description.clone().map(|desc| {
+                        self.render_markdown(desc, default_markdown_style(false, window, cx))
+                    }))
                     .child(h_flex().mt_1p5().justify_center().children(
                         connection.auth_methods().into_iter().map(|method| {
                             Button::new(

crates/gpui/src/subscription.rs πŸ”—

@@ -201,3 +201,9 @@ impl Drop for Subscription {
         }
     }
 }
+
+impl std::fmt::Debug for Subscription {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("Subscription").finish()
+    }
+}

crates/language_model/src/language_model.rs πŸ”—

@@ -20,6 +20,7 @@ use icons::IconName;
 use parking_lot::Mutex;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize, de::DeserializeOwned};
+use std::any::Any;
 use std::ops::{Add, Sub};
 use std::str::FromStr;
 use std::sync::Arc;
@@ -620,7 +621,7 @@ pub enum AuthenticateError {
     Other(#[from] anyhow::Error),
 }
 
-pub trait LanguageModelProvider: 'static {
+pub trait LanguageModelProvider: Any + Send + Sync {
     fn id(&self) -> LanguageModelProviderId;
     fn name(&self) -> LanguageModelProviderName;
     fn icon(&self) -> IconName {

crates/language_model/src/registry.rs πŸ”—

@@ -108,6 +108,7 @@ pub enum Event {
     CommitMessageModelChanged,
     ThreadSummaryModelChanged,
     ProviderStateChanged,
+    ProviderAuthUpdated,
     AddedProvider(LanguageModelProviderId),
     RemovedProvider(LanguageModelProviderId),
 }

crates/language_models/src/provider/anthropic.rs πŸ”—

@@ -153,7 +153,7 @@ impl State {
             return Task::ready(Ok(()));
         }
 
-        let key = ApiKey::get(cx);
+        let key = AnthropicLanguageModelProvider::api_key(cx);
 
         cx.spawn(async move |this, cx| {
             let key = key.await?;
@@ -174,8 +174,30 @@ pub struct ApiKey {
     pub from_env: bool,
 }
 
-impl ApiKey {
-    pub fn get(cx: &mut App) -> Task<Result<Self>> {
+impl AnthropicLanguageModelProvider {
+    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+        let state = cx.new(|cx| State {
+            api_key: None,
+            api_key_from_env: false,
+            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+                cx.notify();
+            }),
+        });
+
+        Self { http_client, state }
+    }
+
+    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
+        Arc::new(AnthropicModel {
+            id: LanguageModelId::from(model.id().to_string()),
+            model,
+            state: self.state.clone(),
+            http_client: self.http_client.clone(),
+            request_limiter: RateLimiter::new(4),
+        })
+    }
+
+    pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
         let credentials_provider = <dyn CredentialsProvider>::global(cx);
         let api_url = AllLanguageModelSettings::get_global(cx)
             .anthropic
@@ -201,29 +223,13 @@ impl ApiKey {
             })
         }
     }
-}
-
-impl AnthropicLanguageModelProvider {
-    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
-        let state = cx.new(|cx| State {
-            api_key: None,
-            api_key_from_env: false,
-            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
-                cx.notify();
-            }),
-        });
 
-        Self { http_client, state }
-    }
-
-    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
-        Arc::new(AnthropicModel {
-            id: LanguageModelId::from(model.id().to_string()),
-            model,
-            state: self.state.clone(),
-            http_client: self.http_client.clone(),
-            request_limiter: RateLimiter::new(4),
-        })
+    pub fn observe(
+        &self,
+        mut on_notify: impl FnMut(&mut App) + 'static,
+        cx: &mut App,
+    ) -> Subscription {
+        cx.observe(&self.state, move |_, cx| on_notify(cx))
     }
 }