diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs index 75928a26a8b4499d7b6ced8a8392191ac3ca2f32..0669e0a68ef019975fdbdcbe155fa3dc6aeb0b96 100644 --- a/crates/agent_servers/src/custom.rs +++ b/crates/agent_servers/src/custom.rs @@ -2,6 +2,7 @@ use crate::{AgentServerCommand, AgentServerSettings}; use acp_thread::AgentConnection; use anyhow::Result; use gpui::{App, Entity, SharedString, Task}; +use language_models::provider::anthropic::AnthropicLanguageModelProvider; use project::Project; use std::{path::Path, rc::Rc}; use ui::IconName; @@ -49,10 +50,22 @@ impl crate::AgentServer for CustomAgentServer { cx: &mut App, ) -> Task>> { let server_name = self.name(); - let command = self.command.clone(); + let mut command = self.command.clone(); let root_dir = root_dir.to_path_buf(); + // TODO: Remove this once we have Claude properly cx.spawn(async move |mut cx| { + if let Some(api_key) = cx + .update(AnthropicLanguageModelProvider::api_key)? + .await + .ok() + { + command + .env + .get_or_insert_default() + .insert("ANTHROPIC_API_KEY".to_owned(), api_key.key); + } + crate::acp::connect(server_name, command, &root_dir, &mut cx).await }) } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 57d90734ef00d1160eb017d1e1257f63577cbebc..2b18ebcd1d72e721e57d91469c835cb70e7812f4 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -9,7 +9,7 @@ use agent_client_protocol::{self as acp, PromptCapabilities}; use agent_servers::{AgentServer, ClaudeCode}; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode, NotifyWhenAgentWaiting}; use agent2::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore}; -use anyhow::bail; +use anyhow::{Result, anyhow, bail}; use audio::{Audio, Sound}; use buffer_diff::BufferDiff; use client::zed_urls; @@ -18,6 +18,7 @@ use editor::scroll::Autoscroll; use editor::{Editor, EditorEvent, EditorMode, MultiBuffer, PathKey, SelectionEffects}; use file_icons::FileIcons; use fs::Fs; +use futures::FutureExt as _; use gpui::{ Action, Animation, AnimationExt, AnyView, App, BorderStyle, ClickEvent, ClipboardItem, CursorStyle, EdgesRefinement, ElementId, Empty, Entity, FocusHandle, Focusable, Hsla, Length, @@ -39,6 +40,8 @@ use std::path::Path; use std::sync::Arc; use std::time::Instant; use std::{collections::BTreeMap, rc::Rc, time::Duration}; +use task::SpawnInTerminal; +use terminal_view::terminal_panel::TerminalPanel; use text::Anchor; use theme::ThemeSettings; use ui::{ @@ -93,6 +96,10 @@ impl ThreadError { error.downcast_ref::() { Self::ModelRequestLimitReached(error.plan) + } else if let Some(acp_error) = error.downcast_ref::() + && acp_error.code == acp::ErrorCode::AUTH_REQUIRED.code + { + Self::AuthenticationRequired(acp_error.message.clone().into()) } else { let string = error.to_string(); // TODO: we should have Gemini return better errors here. @@ -898,7 +905,7 @@ impl AcpThreadView { fn send_impl( &mut self, - contents: Task, Vec>)>>, + contents: Task, Vec>)>>, window: &mut Window, cx: &mut Context, ) { @@ -1234,6 +1241,31 @@ impl AcpThreadView { }); return; } + } else if method.0.as_ref() == "anthropic-api-key" { + let registry = LanguageModelRegistry::global(cx); + let provider = registry + .read(cx) + .provider(&language_model::ANTHROPIC_PROVIDER_ID) + .unwrap(); + if !provider.is_authenticated(cx) { + let this = cx.weak_entity(); + let agent = self.agent.clone(); + let connection = connection.clone(); + window.defer(cx, |window, cx| { + Self::handle_auth_required( + this, + AuthRequired { + description: Some("ANTHROPIC_API_KEY must be set".to_owned()), + provider_id: Some(language_model::ANTHROPIC_PROVIDER_ID), + }, + agent, + connection, + window, + cx, + ); + }); + return; + } } else if method.0.as_ref() == "vertex-ai" && std::env::var("GOOGLE_API_KEY").is_err() && (std::env::var("GOOGLE_CLOUD_PROJECT").is_err() @@ -1265,7 +1297,15 @@ impl AcpThreadView { self.thread_error.take(); configuration_view.take(); pending_auth_method.replace(method.clone()); - let authenticate = connection.authenticate(method, cx); + let authenticate = if method.0.as_ref() == "claude-login" { + if let Some(workspace) = self.workspace.upgrade() { + Self::spawn_claude_login(&workspace, window, cx) + } else { + Task::ready(Ok(())) + } + } else { + connection.authenticate(method, cx) + }; cx.notify(); self.auth_task = Some(cx.spawn_in(window, { @@ -1289,6 +1329,13 @@ impl AcpThreadView { this.update_in(cx, |this, window, cx| { if let Err(err) = result { + if let ThreadState::Unauthenticated { + pending_auth_method, + .. + } = &mut this.thread_state + { + pending_auth_method.take(); + } this.handle_thread_error(err, cx); } else { this.thread_state = Self::initial_state( @@ -1307,6 +1354,76 @@ impl AcpThreadView { })); } + fn spawn_claude_login( + workspace: &Entity, + window: &mut Window, + cx: &mut App, + ) -> Task> { + let Some(terminal_panel) = workspace.read(cx).panel::(cx) else { + return Task::ready(Ok(())); + }; + let project = workspace.read(cx).project().read(cx); + let cwd = project.first_project_directory(cx); + let shell = project.terminal_settings(&cwd, cx).shell.clone(); + + let terminal = terminal_panel.update(cx, |terminal_panel, cx| { + terminal_panel.spawn_task( + &SpawnInTerminal { + id: task::TaskId("claude-login".into()), + full_label: "claude /login".to_owned(), + label: "claude /login".to_owned(), + command: Some("claude".to_owned()), + args: vec!["/login".to_owned()], + command_label: "claude /login".to_owned(), + cwd, + use_new_terminal: true, + allow_concurrent_runs: true, + hide: task::HideStrategy::Always, + shell, + ..Default::default() + }, + window, + cx, + ) + }); + cx.spawn(async move |cx| { + let terminal = terminal.await?; + let mut exit_status = terminal + .read_with(cx, |terminal, cx| terminal.wait_for_completed_task(cx))? + .fuse(); + + let logged_in = cx + .spawn({ + let terminal = terminal.clone(); + async move |cx| { + loop { + cx.background_executor().timer(Duration::from_secs(1)).await; + let content = + terminal.update(cx, |terminal, _cx| terminal.get_content())?; + if content.contains("Login successful") { + return anyhow::Ok(()); + } + } + } + }) + .fuse(); + futures::pin_mut!(logged_in); + futures::select_biased! { + result = logged_in => { + if let Err(e) = result { + log::error!("{e}"); + return Err(anyhow!("exited before logging in")); + } + } + _ = exit_status => { + return Err(anyhow!("exited before logging in")); + } + } + terminal.update(cx, |terminal, _| terminal.kill_active_task())?; + Ok(()) + }) + } + fn authorize_tool_call( &mut self, tool_call_id: acp::ToolCallId, @@ -4024,7 +4141,7 @@ impl AcpThreadView { workspace: Entity, window: &mut Window, cx: &mut App, - ) -> Task> { + ) -> Task> { let markdown_language_task = workspace .read(cx) .app_state() diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index 45e36c199048f9699c920939f7c6f8921d25c5e9..848737aeb24ef52a6819e57882ab022edef94e25 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -485,7 +485,7 @@ impl TerminalPanel { .detach_and_log_err(cx); } - fn spawn_task( + pub fn spawn_task( &mut self, task: &SpawnInTerminal, window: &mut Window,