Detailed changes
@@ -80,12 +80,35 @@ pub trait AgentSessionResume {
}
#[derive(Debug)]
-pub struct AuthRequired;
+pub struct AuthRequired {
+ pub description: Option<String>,
+ /// A Task that resolves when authentication is updated
+ pub update_task: Option<Task<()>>,
+}
+
+impl AuthRequired {
+ pub fn new() -> Self {
+ Self {
+ description: None,
+ update_task: None,
+ }
+ }
+
+ pub fn with_description(mut self, description: String) -> Self {
+ self.description = Some(description);
+ self
+ }
+
+ pub fn with_update(mut self, update: Task<()>) -> Self {
+ self.update_task = Some(update);
+ self
+ }
+}
impl Error for AuthRequired {}
impl fmt::Display for AuthRequired {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "AuthRequired")
+ write!(f, "Authentication required")
}
}
@@ -437,7 +437,7 @@ impl AgentConnection for AcpConnection {
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
- anyhow::bail!(AuthRequired)
+ anyhow::bail!(AuthRequired::new())
}
cx.update(|cx| {
@@ -140,7 +140,13 @@ impl AgentConnection for AcpConnection {
.await
.map_err(|err| {
if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
- anyhow!(AuthRequired)
+ let mut error = AuthRequired::new();
+
+ if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
+ error = error.with_description(err.message);
+ }
+
+ anyhow!(error)
} else {
anyhow!(err)
}
@@ -3,6 +3,8 @@ pub mod tools;
use collections::HashMap;
use context_server::listener::McpServerTool;
+use language_model::LanguageModelRegistry;
+use language_models::provider::anthropic::AnthropicLanguageModelProvider;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
@@ -11,6 +13,7 @@ use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
use std::rc::Rc;
+use std::sync::Arc;
use uuid::Uuid;
use agent_client_protocol as acp;
@@ -96,12 +99,49 @@ impl AgentConnection for ClaudeAgentConnection {
anyhow::bail!("Failed to find claude binary");
};
+ let anthropic: Arc<AnthropicLanguageModelProvider> = cx.update(|cx| {
+ let registry = LanguageModelRegistry::global(cx);
+ let provider: Arc<dyn Any + Send + Sync> = registry
+ .read(cx)
+ .provider(&language_model::ANTHROPIC_PROVIDER_ID)
+ .context("Failed to get Anthropic provider")?;
+
+ Arc::downcast::<AnthropicLanguageModelProvider>(provider)
+ .map_err(|_| anyhow!("Failed to downcast provider"))
+ })??;
+
let api_key = cx
- .update(|cx| language_models::provider::anthropic::ApiKey::get(cx))?
+ .update(|cx| AnthropicLanguageModelProvider::api_key(cx))?
.await
.map_err(|err| {
if err.is::<language_model::AuthenticateError>() {
- anyhow!(AuthRequired)
+ let (update_tx, update_rx) = oneshot::channel();
+ let mut update_tx = Some(update_tx);
+
+ let sub = cx
+ .update(|cx| {
+ anthropic.observe(
+ move |_cx| {
+ if let Some(update_tx) = update_tx.take() {
+ update_tx.send(()).ok();
+ }
+ },
+ cx,
+ )
+ })
+ .ok();
+
+ let update_task = cx.foreground_executor().spawn(async move {
+ update_rx.await.ok();
+ drop(sub)
+ });
+
+ anyhow!(
+ AuthRequired::new()
+ .with_description(
+ "To use Claude Code in Zed, you need an [Anthropic API key](https://console.anthropic.com/settings/keys)\n\nAdd one in [settings](zed:///agent/settings) or set the `ANTHROPIC_API_KEY` variable".into())
+ .with_update(update_task)
+ )
} else {
anyhow!(err)
}
@@ -137,6 +137,7 @@ enum ThreadState {
LoadError(LoadError),
Unauthenticated {
connection: Rc<dyn AgentConnection>,
+ description: Option<Entity<Markdown>>,
},
ServerExited {
status: ExitStatus,
@@ -269,15 +270,40 @@ impl AcpThreadView {
let result = match result.await {
Err(e) => {
let mut cx = cx.clone();
- if e.is::<acp_thread::AuthRequired>() {
- this.update(&mut cx, |this, cx| {
- this.thread_state = ThreadState::Unauthenticated { connection };
- cx.notify();
- })
- .ok();
- return;
- } else {
- Err(e)
+ match e.downcast::<acp_thread::AuthRequired>() {
+ Ok(mut err) => {
+ if let Some(update_task) = err.update_task.take() {
+ let this = this.clone();
+ let project = project.clone();
+ cx.spawn(async move |cx| {
+ update_task.await;
+ this.update_in(cx, |this, window, cx| {
+ this.thread_state = Self::initial_state(
+ agent,
+ this.workspace.clone(),
+ project.clone(),
+ window,
+ cx,
+ );
+ cx.notify();
+ })
+ .ok();
+ })
+ .detach();
+ }
+ this.update(&mut cx, |this, cx| {
+ this.thread_state = ThreadState::Unauthenticated {
+ connection,
+ description: err.description.clone().map(|desc| {
+ cx.new(|cx| Markdown::new(desc.into(), None, None, cx))
+ }),
+ };
+ cx.notify();
+ })
+ .ok();
+ return;
+ }
+ Err(err) => Err(err),
}
}
Ok(thread) => Ok(thread),
@@ -369,7 +395,7 @@ impl AcpThreadView {
ThreadState::Ready { thread, .. } => thread.read(cx).title(),
ThreadState::Loading { .. } => "Loadingβ¦".into(),
ThreadState::LoadError(_) => "Failed to load".into(),
- ThreadState::Unauthenticated { .. } => "Not authenticated".into(),
+ ThreadState::Unauthenticated { .. } => "Authentication Required".into(),
ThreadState::ServerExited { .. } => "Server exited unexpectedly".into(),
}
}
@@ -708,7 +734,7 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
+ let ThreadState::Unauthenticated { ref connection, .. } = self.thread_state else {
return;
};
@@ -1851,7 +1877,7 @@ impl AcpThreadView {
.mt_4()
.mb_1()
.justify_center()
- .child(Headline::new("Not Authenticated").size(HeadlineSize::Medium)),
+ .child(Headline::new("Authentication Required").size(HeadlineSize::Medium)),
)
.into_any()
}
@@ -2778,6 +2804,13 @@ impl AcpThreadView {
cx.open_url(url.as_str());
}
})
+ } else if url == "zed:///agent/settings" {
+ workspace.update(cx, |workspace, cx| {
+ if let Some(panel) = workspace.panel::<AgentPanel>(cx) {
+ workspace.focus_panel::<AgentPanel>(window, cx);
+ panel.update(cx, |panel, cx| panel.open_configuration(window, cx));
+ }
+ });
} else {
cx.open_url(&url);
}
@@ -3347,12 +3380,22 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::toggle_burn_mode))
.bg(cx.theme().colors().panel_background)
.child(match &self.thread_state {
- ThreadState::Unauthenticated { connection } => v_flex()
+ ThreadState::Unauthenticated {
+ connection,
+ description,
+ } => v_flex()
.p_2()
+ .gap_2()
.flex_1()
.items_center()
.justify_center()
.child(self.render_pending_auth_state())
+ .text_ui(cx)
+ .text_center()
+ .text_color(cx.theme().colors().text_muted)
+ .children(description.clone().map(|desc| {
+ self.render_markdown(desc, default_markdown_style(false, window, cx))
+ }))
.child(h_flex().mt_1p5().justify_center().children(
connection.auth_methods().into_iter().map(|method| {
Button::new(
@@ -201,3 +201,9 @@ impl Drop for Subscription {
}
}
}
+
+impl std::fmt::Debug for Subscription {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Subscription").finish()
+ }
+}
@@ -20,6 +20,7 @@ use icons::IconName;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
+use std::any::Any;
use std::ops::{Add, Sub};
use std::str::FromStr;
use std::sync::Arc;
@@ -620,7 +621,7 @@ pub enum AuthenticateError {
Other(#[from] anyhow::Error),
}
-pub trait LanguageModelProvider: 'static {
+pub trait LanguageModelProvider: Any + Send + Sync {
fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
fn icon(&self) -> IconName {
@@ -108,6 +108,7 @@ pub enum Event {
CommitMessageModelChanged,
ThreadSummaryModelChanged,
ProviderStateChanged,
+ ProviderAuthUpdated,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
}
@@ -153,7 +153,7 @@ impl State {
return Task::ready(Ok(()));
}
- let key = ApiKey::get(cx);
+ let key = AnthropicLanguageModelProvider::api_key(cx);
cx.spawn(async move |this, cx| {
let key = key.await?;
@@ -174,8 +174,30 @@ pub struct ApiKey {
pub from_env: bool,
}
-impl ApiKey {
- pub fn get(cx: &mut App) -> Task<Result<Self>> {
+impl AnthropicLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
+ let state = cx.new(|cx| State {
+ api_key: None,
+ api_key_from_env: false,
+ _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+ cx.notify();
+ }),
+ });
+
+ Self { http_client, state }
+ }
+
+ fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
+ Arc::new(AnthropicModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ request_limiter: RateLimiter::new(4),
+ })
+ }
+
+ pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = AllLanguageModelSettings::get_global(cx)
.anthropic
@@ -201,29 +223,13 @@ impl ApiKey {
})
}
}
-}
-
-impl AnthropicLanguageModelProvider {
- pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
- let state = cx.new(|cx| State {
- api_key: None,
- api_key_from_env: false,
- _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
- cx.notify();
- }),
- });
- Self { http_client, state }
- }
-
- fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
- Arc::new(AnthropicModel {
- id: LanguageModelId::from(model.id().to_string()),
- model,
- state: self.state.clone(),
- http_client: self.http_client.clone(),
- request_limiter: RateLimiter::new(4),
- })
+ pub fn observe(
+ &self,
+ mut on_notify: impl FnMut(&mut App) + 'static,
+ cx: &mut App,
+ ) -> Subscription {
+ cx.observe(&self.state, move |_, cx| on_notify(cx))
}
}