diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 48310f07ce3cb162111b6c88d7f39f36b39b1f77..9c64ae65b0f6ca822f2e7e4b258742d18ea2f2a1 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -80,12 +80,35 @@ pub trait AgentSessionResume { } #[derive(Debug)] -pub struct AuthRequired; +pub struct AuthRequired { + pub description: Option, + /// A Task that resolves when authentication is updated + pub update_task: Option>, +} + +impl AuthRequired { + pub fn new() -> Self { + Self { + description: None, + update_task: None, + } + } + + pub fn with_description(mut self, description: String) -> Self { + self.description = Some(description); + self + } + + pub fn with_update(mut self, update: Task<()>) -> Self { + self.update_task = Some(update); + self + } +} impl Error for AuthRequired {} impl fmt::Display for AuthRequired { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AuthRequired") + write!(f, "Authentication required") } } diff --git a/crates/agent_servers/src/acp/v0.rs b/crates/agent_servers/src/acp/v0.rs index 74647f73133f23681f18da1d2bddb02675c55a22..551e9fa01a6c79c4a8c4d67b9af6996a0235086f 100644 --- a/crates/agent_servers/src/acp/v0.rs +++ b/crates/agent_servers/src/acp/v0.rs @@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection { let result = acp_old::InitializeParams::response_from_any(result)?; if !result.is_authenticated { - anyhow::bail!(AuthRequired) + anyhow::bail!(AuthRequired::new()) } cx.update(|cx| { diff --git a/crates/agent_servers/src/acp/v1.rs b/crates/agent_servers/src/acp/v1.rs index b77b5ef36d26ebec9bae48cfe5c1a36c003e230b..93a5ae757a3fde11660db24e182b453e0fbb9850 100644 --- a/crates/agent_servers/src/acp/v1.rs +++ b/crates/agent_servers/src/acp/v1.rs @@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection { .await .map_err(|err| { if err.code == acp::ErrorCode::AUTH_REQUIRED.code { - anyhow!(AuthRequired) + let mut error = AuthRequired::new(); + + if err.message != acp::ErrorCode::AUTH_REQUIRED.message { + error = error.with_description(err.message); + } + + anyhow!(error) } else { anyhow!(err) } diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 786fc118fc8ac260ed9a2ad6fa1bf0c00c606356..8fb4b898b14e58e62cd77be9b6cc79704f687491 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -3,6 +3,8 @@ pub mod tools; use collections::HashMap; use context_server::listener::McpServerTool; +use language_model::LanguageModelRegistry; +use language_models::provider::anthropic::AnthropicLanguageModelProvider; use project::Project; use settings::SettingsStore; use smol::process::Child; @@ -11,6 +13,7 @@ use std::cell::RefCell; use std::fmt::Display; use std::path::Path; use std::rc::Rc; +use std::sync::Arc; use uuid::Uuid; use agent_client_protocol as acp; @@ -96,12 +99,49 @@ impl AgentConnection for ClaudeAgentConnection { anyhow::bail!("Failed to find claude binary"); }; + let anthropic: Arc = cx.update(|cx| { + let registry = LanguageModelRegistry::global(cx); + let provider: Arc = registry + .read(cx) + .provider(&language_model::ANTHROPIC_PROVIDER_ID) + .context("Failed to get Anthropic provider")?; + + Arc::downcast::(provider) + .map_err(|_| anyhow!("Failed to downcast provider")) + })??; + let api_key = cx - .update(|cx| language_models::provider::anthropic::ApiKey::get(cx))? + .update(|cx| AnthropicLanguageModelProvider::api_key(cx))? .await .map_err(|err| { if err.is::() { - anyhow!(AuthRequired) + let (update_tx, update_rx) = oneshot::channel(); + let mut update_tx = Some(update_tx); + + let sub = cx + .update(|cx| { + anthropic.observe( + move |_cx| { + if let Some(update_tx) = update_tx.take() { + update_tx.send(()).ok(); + } + }, + cx, + ) + }) + .ok(); + + let update_task = cx.foreground_executor().spawn(async move { + update_rx.await.ok(); + drop(sub) + }); + + anyhow!( + AuthRequired::new() + .with_description( + "To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into()) + .with_update(update_task) + ) } else { anyhow!(err) } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 4760677fa1e9ff270f69cdc1adbd49c0e86f799c..7ad71e1f46fac3196ad34b6b6ed85053b1ef09ac 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -137,6 +137,7 @@ enum ThreadState { LoadError(LoadError), Unauthenticated { connection: Rc, + description: Option>, }, ServerExited { status: ExitStatus, @@ -269,15 +270,40 @@ impl AcpThreadView { let result = match result.await { Err(e) => { let mut cx = cx.clone(); - if e.is::() { - this.update(&mut cx, |this, cx| { - this.thread_state = ThreadState::Unauthenticated { connection }; - cx.notify(); - }) - .ok(); - return; - } else { - Err(e) + match e.downcast::() { + Ok(mut err) => { + if let Some(update_task) = err.update_task.take() { + let this = this.clone(); + let project = project.clone(); + cx.spawn(async move |cx| { + update_task.await; + this.update_in(cx, |this, window, cx| { + this.thread_state = Self::initial_state( + agent, + this.workspace.clone(), + project.clone(), + window, + cx, + ); + cx.notify(); + }) + .ok(); + }) + .detach(); + } + this.update(&mut cx, |this, cx| { + this.thread_state = ThreadState::Unauthenticated { + connection, + description: err.description.clone().map(|desc| { + cx.new(|cx| Markdown::new(desc.into(), None, None, cx)) + }), + }; + cx.notify(); + }) + .ok(); + return; + } + Err(err) => Err(err), } } Ok(thread) => Ok(thread), @@ -369,7 +395,7 @@ impl AcpThreadView { ThreadState::Ready { thread, .. } => thread.read(cx).title(), ThreadState::Loading { .. } => "Loading…".into(), ThreadState::LoadError(_) => "Failed to load".into(), - ThreadState::Unauthenticated { .. } => "Not authenticated".into(), + ThreadState::Unauthenticated { .. } => "Authentication Required".into(), ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(), } } @@ -708,7 +734,7 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let ThreadState::Unauthenticated { ref connection } = self.thread_state else { + let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else { return; }; @@ -1851,7 +1877,7 @@ impl AcpThreadView { .mt_4() .mb_1() .justify_center() - .child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)), + .child(Headline::new("Authentication Required").size(HeadlineSize::Medium)), ) .into_any() } @@ -2778,6 +2804,13 @@ impl AcpThreadView { cx.open_url(url.as_str()); } }) + } else if url == "zed:///agent/settings" { + workspace.update(cx, |workspace, cx| { + if let Some(panel) = workspace.panel::(cx) { + workspace.focus_panel::(window, cx); + panel.update(cx, |panel, cx| panel.open_configuration(window, cx)); + } + }); } else { cx.open_url(&url); } @@ -3347,12 +3380,22 @@ impl Render for AcpThreadView { .on_action(cx.listener(Self::toggle_burn_mode)) .bg(cx.theme().colors().panel_background) .child(match &self.thread_state { - ThreadState::Unauthenticated { connection } => v_flex() + ThreadState::Unauthenticated { + connection, + description, + } => v_flex() .p_2() + .gap_2() .flex_1() .items_center() .justify_center() .child(self.render_pending_auth_state()) + .text_ui(cx) + .text_center() + .text_color(cx.theme().colors().text_muted) + .children(description.clone().map(|desc| { + self.render_markdown(desc, default_markdown_style(false, window, cx)) + })) .child(h_flex().mt_1p5().justify_center().children( connection.auth_methods().into_iter().map(|method| { Button::new( diff --git a/crates/gpui/src/subscription.rs b/crates/gpui/src/subscription.rs index a584f1a45f82094ce9b867bc5f43805c48f93ebe..bd869f8d32cdfc81917fc2287b7dc62fac7d727d 100644 --- a/crates/gpui/src/subscription.rs +++ b/crates/gpui/src/subscription.rs @@ -201,3 +201,9 @@ impl Drop for Subscription { } } } + +impl std::fmt::Debug for Subscription { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Subscription").finish() + } +} diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 1637d2de8a3c14b910ea345c03a4eb5db13df28d..d74e8b7076bfac4e40504bf52f46bed0929ab0f7 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -20,6 +20,7 @@ use icons::IconName; use parking_lot::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use std::any::Any; use std::ops::{Add, Sub}; use std::str::FromStr; use std::sync::Arc; @@ -620,7 +621,7 @@ pub enum AuthenticateError { Other(#[from] anyhow::Error), } -pub trait LanguageModelProvider: 'static { +pub trait LanguageModelProvider: Any + Send + Sync { fn id(&self) -> LanguageModelProviderId; fn name(&self) -> LanguageModelProviderName; fn icon(&self) -> IconName { diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 7cf071808a2c0d95bf9aa5a41eaa260cff533d57..6b4f471b0f6fd5d9687811084da153313a289827 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -108,6 +108,7 @@ pub enum Event { CommitMessageModelChanged, ThreadSummaryModelChanged, ProviderStateChanged, + ProviderAuthUpdated, AddedProvider(LanguageModelProviderId), RemovedProvider(LanguageModelProviderId), } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 76f8a9b71d063d1eca104c5ea43b802657bb962a..3f14841210678bbe90242d3b87dde29d24e12f12 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -153,7 +153,7 @@ impl State { return Task::ready(Ok(())); } - let key = ApiKey::get(cx); + let key = AnthropicLanguageModelProvider::api_key(cx); cx.spawn(async move |this, cx| { let key = key.await?; @@ -174,8 +174,30 @@ pub struct ApiKey { pub from_env: bool, } -impl ApiKey { - pub fn get(cx: &mut App) -> Task> { +impl AnthropicLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut App) -> Self { + let state = cx.new(|cx| State { + api_key: None, + api_key_from_env: false, + _subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } + + fn create_language_model(&self, model: anthropic::Model) -> Arc { + Arc::new(AnthropicModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + request_limiter: RateLimiter::new(4), + }) + } + + pub fn api_key(cx: &mut App) -> Task> { let credentials_provider = ::global(cx); let api_url = AllLanguageModelSettings::get_global(cx) .anthropic @@ -201,29 +223,13 @@ impl ApiKey { }) } } -} - -impl AnthropicLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_, cx| { - cx.notify(); - }), - }); - Self { http_client, state } - } - - fn create_language_model(&self, model: anthropic::Model) -> Arc { - Arc::new(AnthropicModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) + pub fn observe( + &self, + mut on_notify: impl FnMut(&mut App) + 'static, + cx: &mut App, + ) -> Subscription { + cx.observe(&self.state, move |_, cx| on_notify(cx)) } }